163 lines
7.5 KiB
Python
163 lines
7.5 KiB
Python
"""
|
||
Google Gemini API Client for company profile generation
|
||
"""
|
||
import time
|
||
from typing import Dict, List, Optional
|
||
import google.generativeai as genai
|
||
|
||
|
||
class CompanyProfileClient:
|
||
def __init__(self, api_key: str):
|
||
"""Initialize Gemini client with API key"""
|
||
genai.configure(api_key=api_key)
|
||
self.model = genai.GenerativeModel("gemini-2.5-flash")
|
||
|
||
async def generate_profile(
|
||
self,
|
||
company_name: str,
|
||
ts_code: str,
|
||
financial_data: Optional[Dict] = None
|
||
) -> Dict:
|
||
"""
|
||
Generate company profile using Gemini API (non-streaming)
|
||
|
||
Args:
|
||
company_name: Company name
|
||
ts_code: Stock code
|
||
financial_data: Optional financial data for context
|
||
|
||
Returns:
|
||
Dict with profile content and metadata
|
||
"""
|
||
start_time = time.perf_counter_ns()
|
||
|
||
# Build prompt
|
||
prompt = self._build_prompt(company_name, ts_code, financial_data)
|
||
|
||
# Call Gemini API (using sync API in async context)
|
||
try:
|
||
# Run synchronous API call in executor
|
||
import asyncio
|
||
loop = asyncio.get_event_loop()
|
||
response = await loop.run_in_executor(
|
||
None,
|
||
lambda: self.model.generate_content(prompt)
|
||
)
|
||
|
||
# Get token usage
|
||
usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None
|
||
|
||
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||
|
||
return {
|
||
"content": response.text,
|
||
"model": "gemini-2.5-flash",
|
||
"tokens": {
|
||
"prompt_tokens": usage_metadata.prompt_token_count if usage_metadata else 0,
|
||
"completion_tokens": usage_metadata.candidates_token_count if usage_metadata else 0,
|
||
"total_tokens": usage_metadata.total_token_count if usage_metadata else 0,
|
||
} if usage_metadata else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||
"elapsed_ms": elapsed_ms,
|
||
"success": True,
|
||
}
|
||
except Exception as e:
|
||
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||
return {
|
||
"content": "",
|
||
"model": "gemini-2.5-flash",
|
||
"tokens": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||
"elapsed_ms": elapsed_ms,
|
||
"success": False,
|
||
"error": str(e),
|
||
}
|
||
|
||
def generate_profile_stream(
|
||
self,
|
||
company_name: str,
|
||
ts_code: str,
|
||
financial_data: Optional[Dict] = None
|
||
):
|
||
"""
|
||
Generate company profile using Gemini API with streaming
|
||
|
||
Args:
|
||
company_name: Company name
|
||
ts_code: Stock code
|
||
financial_data: Optional financial data for context
|
||
|
||
Yields:
|
||
Chunks of generated content
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[CompanyProfile] Starting stream generation for {company_name} ({ts_code})")
|
||
|
||
# Build prompt
|
||
prompt = self._build_prompt(company_name, ts_code, financial_data)
|
||
logger.info(f"[CompanyProfile] Prompt built, length: {len(prompt)} chars")
|
||
|
||
# Call Gemini API with streaming
|
||
try:
|
||
logger.info("[CompanyProfile] Calling Gemini API with stream=True")
|
||
# Generate streaming response (sync call, but yields chunks)
|
||
response_stream = self.model.generate_content(prompt, stream=True)
|
||
logger.info("[CompanyProfile] Gemini API stream object created")
|
||
|
||
chunk_count = 0
|
||
# Stream chunks
|
||
logger.info("[CompanyProfile] Starting to iterate response stream")
|
||
for chunk in response_stream:
|
||
logger.info(f"[CompanyProfile] Received chunk from Gemini, has text: {hasattr(chunk, 'text')}")
|
||
if hasattr(chunk, 'text') and chunk.text:
|
||
chunk_count += 1
|
||
text_len = len(chunk.text)
|
||
logger.info(f"[CompanyProfile] Chunk {chunk_count}: {text_len} chars")
|
||
yield chunk.text
|
||
else:
|
||
logger.warning(f"[CompanyProfile] Chunk has no text attribute or empty, chunk: {chunk}")
|
||
|
||
logger.info(f"[CompanyProfile] Stream iteration completed. Total chunks: {chunk_count}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[CompanyProfile] Error during streaming: {type(e).__name__}: {str(e)}", exc_info=True)
|
||
yield f"\n\n---\n\n**错误**: {type(e).__name__}: {str(e)}"
|
||
|
||
def _build_prompt(self, company_name: str, ts_code: str, financial_data: Optional[Dict] = None) -> str:
|
||
"""Build prompt for company profile generation"""
|
||
prompt = f"""您是一位专业的证券市场分析师。请为公司 {company_name} (股票代码: {ts_code}) 生成一份详细且专业的公司介绍。开头不要自我介绍,直接开始正文。正文用MarkDown输出,尽量说明信息来源,用斜体显示信息来源。在生成内容时,请严格遵循以下要求并采用清晰、结构化的格式:
|
||
|
||
1. **公司概览**:
|
||
* 简要介绍公司的性质、核心业务领域及其在行业中的定位。
|
||
* 提炼并阐述公司的核心价值理念。
|
||
|
||
2. **主营业务**:
|
||
* 详细描述公司主要的**产品或服务**。
|
||
* **重要提示**:如果能获取到公司最新的官方**年报**或**财务报告**,请从中提取各主要产品/服务线的**收入金额**和其占公司总收入的**百分比**。请**明确标注数据来源**(例如:"数据来源于XX年年度报告")。
|
||
* **严格禁止**编造或估算任何财务数据。若无法找到公开、准确的财务数据,请**不要**在这一点中提及具体金额或比例,仅描述业务内容。
|
||
|
||
3. **发展历程**:
|
||
* 以时间线或关键事件的形式,概述公司自成立以来的主要**里程碑事件**、重大发展阶段、战略转型或重要成就。
|
||
|
||
4. **核心团队**:
|
||
* 介绍公司**主要管理层和核心技术团队成员**。
|
||
* 对于每位核心成员,提供其**职务、主要工作履历、教育背景**。
|
||
* 如果公开可查,可补充其**出生年份**。
|
||
|
||
5. **供应链**:
|
||
* 描述公司的**主要原材料、部件或服务来源**。
|
||
* 如果公开信息中包含,请列出**主要供应商名称**,并**明确其在总采购金额中的大致占比**。若无此数据,则仅描述采购模式。
|
||
|
||
6. **主要客户及销售模式**:
|
||
* 阐明公司的**销售模式**(例如:直销、经销、线上销售、代理等)。
|
||
* 列出公司的**主要客户群体**或**代表性大客户**。
|
||
* 如果公开信息中包含,请标明**主要客户(或前五大客户)的销售额占公司总销售额的比例**。若无此数据,则仅描述客户类型。
|
||
|
||
7. **未来展望**:
|
||
* 基于公司**公开的官方声明、管理层访谈或战略规划**,总结公司未来的发展方向、战略目标、重点项目或市场预期。请确保此部分内容有可靠的信息来源支持。"""
|
||
|
||
if financial_data:
|
||
prompt += f"\n\n参考财务数据:\n{financial_data}"
|
||
|
||
return prompt
|