FA3-Datafetch/backend/app/services/llm_engine.py
2026-01-11 21:33:47 +08:00

184 lines
7.0 KiB
Python

import os
import time
import markdown
import google.genai as genai
from google.genai import types
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.models import Report, ReportSection, Setting
import asyncio
import logging
logger = logging.getLogger(__name__)
async def load_prompts(db: AsyncSession, prompt_dir: str):
prompts = {}
mapping = {
"company_profile": "公司简介.md",
"fundamental_analysis": "基本面分析.md",
"insider_analysis": "内部人与机构动向分析.md",
"bullish_analysis": "看涨分析.md",
"bearish_analysis": "看跌分析.md"
}
for key, filename in mapping.items():
# Try DB First
setting_key = f"PROMPT_{key.upper()}"
try:
result = await db.get(Setting, setting_key)
if result:
prompts[key] = result.value
continue
except Exception as e:
print(f"Error reading prompt setting {setting_key}: {e}")
# Fallback to File
try:
with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f:
prompts[key] = f.read()
except FileNotFoundError:
print(f"Warning: Prompt file {filename} not found.")
prompts[key] = f"Error: Prompt {filename} not found."
return prompts
async def call_llm(api_key: str, model_name: str, system_prompt: str, user_prompt: str, context: str, enable_search: bool = True):
full_prompt = f"{system_prompt}\n\n{user_prompt}\n\nExisting Report Data for context:\n{context}"
logger.info(f"🤖 [LLM] 开始调用模型: {model_name}, 启用搜索: {enable_search}")
start_time = time.time()
client = genai.Client(api_key=api_key)
config_params = {}
if enable_search:
grounding_tool = types.Tool(google_search=types.GoogleSearch())
config_params['tools'] = [grounding_tool]
config = types.GenerateContentConfig(**config_params)
try:
def run_sync():
return client.models.generate_content(
model=model_name,
contents=full_prompt,
config=config
)
response = await asyncio.to_thread(run_sync)
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else 0
completion_tokens = usage.candidates_token_count if usage else 0
total_tokens = prompt_tokens + completion_tokens
elapsed = time.time() - start_time
logger.info(f"✅ [LLM] 模型响应完成, 耗时: {elapsed:.2f}秒, Tokens: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}")
return {
"text": response.text,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens
}
except Exception as e:
print(f"API Call Failed: {e}")
return {
"text": f"\n\nError generating section: {e}\n\n",
"prompt_tokens": 0,
"completion_tokens": 0
}
async def process_analysis_steps(report_id: int, company_name: str, symbol: str, market: str, db: AsyncSession, api_key: str):
# 1. Load Prompts
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
prompt_dir = os.path.join(root_dir, "Prompt")
prompts = await load_prompts(db, prompt_dir)
# 2. Read Data Context (report.md generated by run_fetcher.py)
base_dir = os.path.join(root_dir, "data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
data_path = os.path.join(symbol_dir, "report.md")
if not os.path.exists(data_path):
# If report.md is missing, maybe run_fetcher.py failed or output structure changed.
# We try to proceed or fail.
print(f"Warning: {data_path} not found.")
data_context = "No financial data available."
else:
with open(data_path, 'r', encoding='utf-8') as f:
data_context = f.read()
# CSV Context
csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv")
csv_context = ""
if os.path.exists(csv_path):
with open(csv_path, 'r', encoding='utf-8') as f:
csv_content = f.read()
csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n"
steps = [
("company_profile", "3. 公司简介 (Company Profile)"),
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
]
# Get AI model from settings
model_setting = await db.get(Setting, "AI_MODEL")
model_name = model_setting.value if model_setting else "gemini-2.0-flash"
# Prepare all API calls concurrently
async def process_section(key: str, name: str):
logger.info(f"📝 [LLM] 开始处理章节: {name}")
section_start = time.time()
prompt_template = prompts.get(key)
if not prompt_template:
return None
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
system_content = formatted_prompt
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
current_data_context = data_context
if key == "bearish_analysis" and csv_context:
current_data_context += csv_context
result = await call_llm(api_key, model_name, system_content, user_content, current_data_context, enable_search=True)
section_elapsed = time.time() - section_start
logger.info(f"✅ [LLM] 章节 {name} 处理完成, 总耗时: {section_elapsed:.2f}")
return (key, result)
# Run all sections concurrently
print(f"Starting concurrent analysis with {len(steps)} sections...")
results = await asyncio.gather(*[process_section(key, name) for key, name in steps])
# Save all results to DB
for result in results:
if result is None:
continue
key, result_data = result
section = ReportSection(
report_id=report_id,
section_name=key,
content=result_data["text"],
prompt_tokens=result_data["prompt_tokens"],
completion_tokens=result_data["completion_tokens"],
total_tokens=result_data["prompt_tokens"] + result_data["completion_tokens"]
)
db.add(section)
await db.commit()
print(f"All {len(steps)} sections completed and saved!")
return True