169 lines
6.3 KiB
Python
169 lines
6.3 KiB
Python
import os
|
|
import time
|
|
import markdown
|
|
import google.genai as genai
|
|
from google.genai import types
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from app.models import Report, ReportSection, Setting
|
|
import asyncio
|
|
|
|
async def load_prompts(db: AsyncSession, prompt_dir: str):
|
|
prompts = {}
|
|
mapping = {
|
|
"company_profile": "公司简介.md",
|
|
"fundamental_analysis": "基本面分析.md",
|
|
"insider_analysis": "内部人与机构动向分析.md",
|
|
"bullish_analysis": "看涨分析.md",
|
|
"bearish_analysis": "看跌分析.md"
|
|
}
|
|
|
|
for key, filename in mapping.items():
|
|
# Try DB First
|
|
setting_key = f"PROMPT_{key.upper()}"
|
|
try:
|
|
result = await db.get(Setting, setting_key)
|
|
if result:
|
|
prompts[key] = result.value
|
|
continue
|
|
except Exception as e:
|
|
print(f"Error reading prompt setting {setting_key}: {e}")
|
|
|
|
# Fallback to File
|
|
try:
|
|
with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f:
|
|
prompts[key] = f.read()
|
|
except FileNotFoundError:
|
|
print(f"Warning: Prompt file {filename} not found.")
|
|
prompts[key] = f"Error: Prompt {filename} not found."
|
|
|
|
return prompts
|
|
|
|
async def call_llm(api_key: str, model_name: str, system_prompt: str, user_prompt: str, context: str, enable_search: bool = True):
|
|
full_prompt = f"{system_prompt}\n\n{user_prompt}\n\nExisting Report Data for context:\n{context}"
|
|
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
config_params = {}
|
|
if enable_search:
|
|
grounding_tool = types.Tool(google_search=types.GoogleSearch())
|
|
config_params['tools'] = [grounding_tool]
|
|
|
|
config = types.GenerateContentConfig(**config_params)
|
|
|
|
try:
|
|
def run_sync():
|
|
return client.models.generate_content(
|
|
model=model_name,
|
|
contents=full_prompt,
|
|
config=config
|
|
)
|
|
|
|
response = await asyncio.to_thread(run_sync)
|
|
usage = response.usage_metadata
|
|
prompt_tokens = usage.prompt_token_count if usage else 0
|
|
completion_tokens = usage.candidates_token_count if usage else 0
|
|
|
|
return {
|
|
"text": response.text,
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens
|
|
}
|
|
except Exception as e:
|
|
print(f"API Call Failed: {e}")
|
|
return {
|
|
"text": f"\n\nError generating section: {e}\n\n",
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0
|
|
}
|
|
|
|
async def process_analysis_steps(report_id: int, company_name: str, symbol: str, market: str, db: AsyncSession, api_key: str):
|
|
# 1. Load Prompts
|
|
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
|
prompt_dir = os.path.join(root_dir, "Prompt")
|
|
prompts = await load_prompts(db, prompt_dir)
|
|
|
|
# 2. Read Data Context (report.md generated by run_fetcher.py)
|
|
base_dir = os.path.join(root_dir, "data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
data_path = os.path.join(symbol_dir, "report.md")
|
|
if not os.path.exists(data_path):
|
|
# If report.md is missing, maybe run_fetcher.py failed or output structure changed.
|
|
# We try to proceed or fail.
|
|
print(f"Warning: {data_path} not found.")
|
|
data_context = "No financial data available."
|
|
else:
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
data_context = f.read()
|
|
|
|
# CSV Context
|
|
csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv")
|
|
csv_context = ""
|
|
if os.path.exists(csv_path):
|
|
with open(csv_path, 'r', encoding='utf-8') as f:
|
|
csv_content = f.read()
|
|
csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n"
|
|
|
|
|
|
steps = [
|
|
("company_profile", "3. 公司简介 (Company Profile)"),
|
|
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
|
|
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
|
|
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
|
|
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
|
|
]
|
|
|
|
# Get AI model from settings
|
|
model_setting = await db.get(Setting, "AI_MODEL")
|
|
model_name = model_setting.value if model_setting else "gemini-2.0-flash"
|
|
|
|
# Prepare all API calls concurrently
|
|
async def process_section(key: str, name: str):
|
|
print(f"Processing {name}...")
|
|
prompt_template = prompts.get(key)
|
|
if not prompt_template:
|
|
return None
|
|
|
|
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
|
|
|
|
system_content = formatted_prompt
|
|
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
|
|
|
|
current_data_context = data_context
|
|
if key == "bearish_analysis" and csv_context:
|
|
current_data_context += csv_context
|
|
|
|
result = await call_llm(api_key, model_name, system_content, user_content, current_data_context, enable_search=True)
|
|
|
|
return (key, result)
|
|
|
|
# Run all sections concurrently
|
|
print(f"Starting concurrent analysis with {len(steps)} sections...")
|
|
results = await asyncio.gather(*[process_section(key, name) for key, name in steps])
|
|
|
|
# Save all results to DB
|
|
for result in results:
|
|
if result is None:
|
|
continue
|
|
key, result_data = result
|
|
section = ReportSection(
|
|
report_id=report_id,
|
|
section_name=key,
|
|
content=result_data["text"],
|
|
prompt_tokens=result_data["prompt_tokens"],
|
|
completion_tokens=result_data["completion_tokens"],
|
|
total_tokens=result_data["prompt_tokens"] + result_data["completion_tokens"]
|
|
)
|
|
db.add(section)
|
|
|
|
await db.commit()
|
|
print(f"All {len(steps)} sections completed and saved!")
|
|
|
|
return True
|