import os import time import markdown import google.genai as genai from google.genai import types from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.models import Report, ReportSection, Setting import asyncio async def load_prompts(db: AsyncSession, prompt_dir: str): prompts = {} mapping = { "company_profile": "公司简介.md", "fundamental_analysis": "基本面分析.md", "insider_analysis": "内部人与机构动向分析.md", "bullish_analysis": "看涨分析.md", "bearish_analysis": "看跌分析.md" } for key, filename in mapping.items(): # Try DB First setting_key = f"PROMPT_{key.upper()}" try: result = await db.get(Setting, setting_key) if result: prompts[key] = result.value continue except Exception as e: print(f"Error reading prompt setting {setting_key}: {e}") # Fallback to File try: with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f: prompts[key] = f.read() except FileNotFoundError: print(f"Warning: Prompt file {filename} not found.") prompts[key] = f"Error: Prompt {filename} not found." return prompts async def call_llm(api_key: str, model_name: str, system_prompt: str, user_prompt: str, context: str, enable_search: bool = True): full_prompt = f"{system_prompt}\n\n{user_prompt}\n\nExisting Report Data for context:\n{context}" client = genai.Client(api_key=api_key) config_params = {} if enable_search: grounding_tool = types.Tool(google_search=types.GoogleSearch()) config_params['tools'] = [grounding_tool] config = types.GenerateContentConfig(**config_params) try: def run_sync(): return client.models.generate_content( model=model_name, contents=full_prompt, config=config ) response = await asyncio.to_thread(run_sync) return response.text except Exception as e: print(f"API Call Failed: {e}") return f"\n\nError generating section: {e}\n\n" async def process_analysis_steps(report_id: int, company_name: str, symbol: str, market: str, db: AsyncSession, api_key: str): # 1. Load Prompts root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) prompt_dir = os.path.join(root_dir, "Prompt") prompts = await load_prompts(db, prompt_dir) # 2. Read Data Context (report.md generated by main.py) base_dir = os.path.join(root_dir, "data", market) symbol_dir = os.path.join(base_dir, symbol) if not os.path.exists(symbol_dir): candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))] if candidates: symbol_dir = os.path.join(base_dir, candidates[0]) data_path = os.path.join(symbol_dir, "report.md") if not os.path.exists(data_path): # If report.md is missing, maybe main.py failed or output structure changed. # We try to proceed or fail. print(f"Warning: {data_path} not found.") data_context = "No financial data available." else: with open(data_path, 'r', encoding='utf-8') as f: data_context = f.read() # CSV Context csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv") csv_context = "" if os.path.exists(csv_path): with open(csv_path, 'r', encoding='utf-8') as f: csv_content = f.read() csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n" steps = [ ("company_profile", "3. 公司简介 (Company Profile)"), ("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"), ("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"), ("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"), ("bearish_analysis", "7. 看跌分析 (Bearish Analysis)") ] # Get AI model from settings model_setting = await db.get(Setting, "AI_MODEL") model_name = model_setting.value if model_setting else "gemini-2.0-flash" # Prepare all API calls concurrently async def process_section(key: str, name: str): print(f"Processing {name}...") prompt_template = prompts.get(key) if not prompt_template: return None formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol) system_content = formatted_prompt user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。" current_data_context = data_context if key == "bearish_analysis" and csv_context: current_data_context += csv_context content = await call_llm(api_key, model_name, system_content, user_content, current_data_context, enable_search=True) return (key, content) # Run all sections concurrently print(f"Starting concurrent analysis with {len(steps)} sections...") results = await asyncio.gather(*[process_section(key, name) for key, name in steps]) # Save all results to DB for result in results: if result is None: continue key, content = result section = ReportSection( report_id=report_id, section_name=key, content=content ) db.add(section) await db.commit() print(f"All {len(steps)} sections completed and saved!") return True