FA3-Datafetch/backend/app/services/analysis_service.py
2026-01-29 11:49:48 +00:00

209 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 分析服务 (新架构)
负责调用 LLM 生成分析报告
"""
import os
from typing import Dict, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from google import genai
from google.genai import types
from app.models import Company
from app.services.data_fetcher_service import get_financial_data_from_db
# Gemini client懒加载
_client: Optional[genai.Client] = None
def get_genai_client() -> genai.Client:
"""获取 Gemini API 客户端(单例)"""
global _client
if _client is None:
api_key = os.getenv('GEMINI_API_KEY')
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is not set")
# Configure httpx proxy globally by monkey-patching
import httpx
proxy_url = os.getenv('https_proxy') or os.getenv('HTTPS_PROXY')
if proxy_url:
# Monkey-patch httpx to use proxy with SSL verification disabled
original_init = httpx.Client.__init__
def patched_init(self, **kwargs):
kwargs.setdefault('proxy', proxy_url)
kwargs.setdefault('verify', False)
return original_init(self, **kwargs)
httpx.Client.__init__ = patched_init
_client = genai.Client(api_key=api_key)
return _client
async def run_llm_analysis(
company_id: int,
data_source: str,
model: str,
db: AsyncSession
) -> Dict:
"""
执行 LLM 分析
Args:
company_id: 公司ID
data_source: 数据源
model: LLM 模型名称
db: 数据库会话
Returns:
包含所有分析内容和 token 使用情况的字典
"""
# 1. 获取公司信息
result = await db.execute(
select(Company).where(Company.id == company_id)
)
company = result.scalar_one()
# 2. 获取财务数据
financial_data = await get_financial_data_from_db(
company_id=company_id,
data_source=data_source,
db=db
)
# 3. 准备数据上下文
context = format_financial_data_for_llm(financial_data)
# 4. 调用 LLM 生成各个部分
sections = {}
total_tokens = 0
tokens_by_section = {}
# 公司简介
profile_result = await generate_section(
model=model,
company=company,
section_name='company_profile',
context=context
)
sections['company_profile'] = profile_result['content']
tokens_by_section['company_profile'] = profile_result['tokens']
total_tokens += profile_result['tokens']
# 基本面分析
fundamental_result = await generate_section(
model=model,
company=company,
section_name='fundamental_analysis',
context=context
)
sections['fundamental_analysis'] = fundamental_result['content']
tokens_by_section['fundamental_analysis'] = fundamental_result['tokens']
total_tokens += fundamental_result['tokens']
# 内部人士分析
insider_result = await generate_section(
model=model,
company=company,
section_name='insider_analysis',
context=context
)
sections['insider_analysis'] = insider_result['content']
tokens_by_section['insider_analysis'] = insider_result['tokens']
total_tokens += insider_result['tokens']
# 看涨分析
bullish_result = await generate_section(
model=model,
company=company,
section_name='bullish_analysis',
context=context
)
sections['bullish_analysis'] = bullish_result['content']
tokens_by_section['bullish_analysis'] = bullish_result['tokens']
total_tokens += bullish_result['tokens']
# 看跌分析
bearish_result = await generate_section(
model=model,
company=company,
section_name='bearish_analysis',
context=context
)
sections['bearish_analysis'] = bearish_result['content']
tokens_by_section['bearish_analysis'] = bearish_result['tokens']
total_tokens += bearish_result['tokens']
return {
**sections,
'total_tokens': total_tokens,
'tokens_by_section': tokens_by_section
}
def format_financial_data_for_llm(financial_data: Dict) -> str:
"""将财务数据格式化为 LLM 上下文"""
# 简化版本,实际应该更详细地格式化数据
context = f"""
公司: {financial_data['company']['company_name']}
市场: {financial_data['company']['market']}
代码: {financial_data['company']['symbol']}
数据源: {financial_data['data_source']}
财务数据:
- 利润表记录数: {len(financial_data.get('income_statement', []))}
- 资产负债表记录数: {len(financial_data.get('balance_sheet', []))}
- 现金流量表记录数: {len(financial_data.get('cash_flow', []))}
- 估值数据记录数: {len(financial_data.get('daily_basic', []))}
"""
return context
async def generate_section(
model: str,
company: Company,
section_name: str,
context: str
) -> Dict:
"""
生成报告的一个部分
"""
# 设置 prompt
prompts = {
'company_profile': f"请为 {company.company_name} ({company.symbol}) 生成公司简介,包括主营业务、市场地位等。\n\n财务数据:\n{context}",
'fundamental_analysis': f"请为 {company.company_name} 进行基本面分析,包括财务指标、盈利能力、成长性等。\n\n财务数据:\n{context}",
'insider_analysis': f"请分析 {company.company_name} 的内部人士交易情况、股权结构等。\n\n财务数据:\n{context}",
'bullish_analysis': f"请分析 {company.company_name} 的看涨因素和投资亮点。\n\n财务数据:\n{context}",
'bearish_analysis': f"请分析 {company.company_name} 的风险因素和潜在问题。\n\n财务数据:\n{context}"
}
prompt = prompts.get(section_name, f"请分析 {company.company_name}\n\n{context}")
try:
# 获取 Gemini client
client = get_genai_client()
# 调用新的 genai API
response = client.models.generate_content(
model=model,
contents=prompt
)
# 提取内容和 token 信息
content = response.text
# 获取 token 使用情况(新 API 的方式)
tokens = 0
if hasattr(response, 'usage_metadata'):
tokens = response.usage_metadata.total_token_count
return {
'content': content,
'tokens': tokens
}
except Exception as e:
# 如果 LLM 调用失败,返回错误信息
return {
'content': f"分析生成失败: {str(e)}",
'tokens': 0
}