196 lines
6.1 KiB
Python
196 lines
6.1 KiB
Python
"""
|
||
LLM 分析服务 (新架构)
|
||
负责调用 LLM 生成分析报告
|
||
"""
|
||
import os
|
||
from typing import Dict, Optional
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from google import genai
|
||
from google.genai import types
|
||
|
||
from app.models import Company
|
||
from app.services.data_fetcher_service import get_financial_data_from_db
|
||
|
||
# Gemini client(懒加载)
|
||
_client: Optional[genai.Client] = None
|
||
|
||
def get_genai_client() -> genai.Client:
|
||
"""获取 Gemini API 客户端(单例)"""
|
||
global _client
|
||
if _client is None:
|
||
api_key = os.getenv('GEMINI_API_KEY')
|
||
if not api_key:
|
||
raise ValueError("GEMINI_API_KEY environment variable is not set")
|
||
_client = genai.Client(api_key=api_key)
|
||
return _client
|
||
|
||
|
||
async def run_llm_analysis(
|
||
company_id: int,
|
||
data_source: str,
|
||
model: str,
|
||
db: AsyncSession
|
||
) -> Dict:
|
||
"""
|
||
执行 LLM 分析
|
||
|
||
Args:
|
||
company_id: 公司ID
|
||
data_source: 数据源
|
||
model: LLM 模型名称
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
包含所有分析内容和 token 使用情况的字典
|
||
"""
|
||
# 1. 获取公司信息
|
||
result = await db.execute(
|
||
select(Company).where(Company.id == company_id)
|
||
)
|
||
company = result.scalar_one()
|
||
|
||
# 2. 获取财务数据
|
||
financial_data = await get_financial_data_from_db(
|
||
company_id=company_id,
|
||
data_source=data_source,
|
||
db=db
|
||
)
|
||
|
||
# 3. 准备数据上下文
|
||
context = format_financial_data_for_llm(financial_data)
|
||
|
||
# 4. 调用 LLM 生成各个部分
|
||
sections = {}
|
||
total_tokens = 0
|
||
tokens_by_section = {}
|
||
|
||
# 公司简介
|
||
profile_result = await generate_section(
|
||
model=model,
|
||
company=company,
|
||
section_name='company_profile',
|
||
context=context
|
||
)
|
||
sections['company_profile'] = profile_result['content']
|
||
tokens_by_section['company_profile'] = profile_result['tokens']
|
||
total_tokens += profile_result['tokens']
|
||
|
||
# 基本面分析
|
||
fundamental_result = await generate_section(
|
||
model=model,
|
||
company=company,
|
||
section_name='fundamental_analysis',
|
||
context=context
|
||
)
|
||
sections['fundamental_analysis'] = fundamental_result['content']
|
||
tokens_by_section['fundamental_analysis'] = fundamental_result['tokens']
|
||
total_tokens += fundamental_result['tokens']
|
||
|
||
# 内部人士分析
|
||
insider_result = await generate_section(
|
||
model=model,
|
||
company=company,
|
||
section_name='insider_analysis',
|
||
context=context
|
||
)
|
||
sections['insider_analysis'] = insider_result['content']
|
||
tokens_by_section['insider_analysis'] = insider_result['tokens']
|
||
total_tokens += insider_result['tokens']
|
||
|
||
# 看涨分析
|
||
bullish_result = await generate_section(
|
||
model=model,
|
||
company=company,
|
||
section_name='bullish_analysis',
|
||
context=context
|
||
)
|
||
sections['bullish_analysis'] = bullish_result['content']
|
||
tokens_by_section['bullish_analysis'] = bullish_result['tokens']
|
||
total_tokens += bullish_result['tokens']
|
||
|
||
# 看跌分析
|
||
bearish_result = await generate_section(
|
||
model=model,
|
||
company=company,
|
||
section_name='bearish_analysis',
|
||
context=context
|
||
)
|
||
sections['bearish_analysis'] = bearish_result['content']
|
||
tokens_by_section['bearish_analysis'] = bearish_result['tokens']
|
||
total_tokens += bearish_result['tokens']
|
||
|
||
return {
|
||
**sections,
|
||
'total_tokens': total_tokens,
|
||
'tokens_by_section': tokens_by_section
|
||
}
|
||
|
||
|
||
def format_financial_data_for_llm(financial_data: Dict) -> str:
|
||
"""将财务数据格式化为 LLM 上下文"""
|
||
# 简化版本,实际应该更详细地格式化数据
|
||
context = f"""
|
||
公司: {financial_data['company']['company_name']}
|
||
市场: {financial_data['company']['market']}
|
||
代码: {financial_data['company']['symbol']}
|
||
数据源: {financial_data['data_source']}
|
||
|
||
财务数据:
|
||
- 利润表记录数: {len(financial_data.get('income_statement', []))}
|
||
- 资产负债表记录数: {len(financial_data.get('balance_sheet', []))}
|
||
- 现金流量表记录数: {len(financial_data.get('cash_flow', []))}
|
||
- 估值数据记录数: {len(financial_data.get('daily_basic', []))}
|
||
"""
|
||
return context
|
||
|
||
|
||
async def generate_section(
|
||
model: str,
|
||
company: Company,
|
||
section_name: str,
|
||
context: str
|
||
) -> Dict:
|
||
"""
|
||
生成报告的一个部分
|
||
"""
|
||
# 设置 prompt
|
||
prompts = {
|
||
'company_profile': f"请为 {company.company_name} ({company.symbol}) 生成公司简介,包括主营业务、市场地位等。\n\n财务数据:\n{context}",
|
||
'fundamental_analysis': f"请为 {company.company_name} 进行基本面分析,包括财务指标、盈利能力、成长性等。\n\n财务数据:\n{context}",
|
||
'insider_analysis': f"请分析 {company.company_name} 的内部人士交易情况、股权结构等。\n\n财务数据:\n{context}",
|
||
'bullish_analysis': f"请分析 {company.company_name} 的看涨因素和投资亮点。\n\n财务数据:\n{context}",
|
||
'bearish_analysis': f"请分析 {company.company_name} 的风险因素和潜在问题。\n\n财务数据:\n{context}"
|
||
}
|
||
|
||
prompt = prompts.get(section_name, f"请分析 {company.company_name}。\n\n{context}")
|
||
|
||
try:
|
||
# 获取 Gemini client
|
||
client = get_genai_client()
|
||
|
||
# 调用新的 genai API
|
||
response = client.models.generate_content(
|
||
model=model,
|
||
contents=prompt
|
||
)
|
||
|
||
# 提取内容和 token 信息
|
||
content = response.text
|
||
|
||
# 获取 token 使用情况(新 API 的方式)
|
||
tokens = 0
|
||
if hasattr(response, 'usage_metadata'):
|
||
tokens = response.usage_metadata.total_token_count
|
||
|
||
return {
|
||
'content': content,
|
||
'tokens': tokens
|
||
}
|
||
except Exception as e:
|
||
# 如果 LLM 调用失败,返回错误信息
|
||
return {
|
||
'content': f"分析生成失败: {str(e)}",
|
||
'tokens': 0
|
||
}
|