import sys import os import subprocess import asyncio import json from sqlalchemy.ext.asyncio import AsyncSession from app.models import Report, ReportSection, AnalysisStatus from app.services import llm_engine from datetime import datetime import google.genai as genai from google.genai import types async def search_stock(query: str, api_key: str, model: str = "gemini-2.0-flash"): if not api_key: return {"error": "API Key not provided"} client = genai.Client(api_key=api_key) prompt = f""" 你是一个专业的股票代码查询助手。请识别公司 '{query}' 的股票市场和代码。 **重要提示**: 1. 用户输入可能是公司全称、简称、别名或股票代码 2. 请仔细匹配,优先完全匹配,避免返回不相关的公司 3. 中国公司的简称经常省略"股份有限公司"等后缀 4. 例如:"茅台" = "贵州茅台酒股份有限公司" (600519.SH) 请返回一个 JSON 数组,包含所有匹配的公司。每个对象包含以下字段: - 'market': 'CN' (中国), 'US' (美国), 'HK' (香港), 或 'JP' (日本) 之一 - 'symbol': 完整的股票代码 (例如 'AAPL', '600519.SH', '00700.HK', '688778.SH', '2503.T') - 'company_name': 公司的中文简称(如果有的话,优先使用中文;如果只有英文名,则使用英文名) **匹配规则**: - 如果查询词与某公司简称、全称或股票代码完全匹配,返回该公司 - 如果有多个可能的匹配,返回所有相关公司 - 如果公司在中国但用户没有指定市场,默认为 CN(上海/深圳/北京交易所) - 如果完全没找到匹配,返回 {{ "error": "未找到相关公司" }} 示例响应(单个结果): [ {{ "market": "CN", "symbol": "600519.SH", "company_name": "贵州茅台" }} ] 示例响应(多个结果): [ {{ "market": "HK", "symbol": "00700.HK", "company_name": "腾讯控股" }}, {{ "market": "US", "symbol": "TCEHY", "company_name": "Tencent Holdings ADR" }} ] 现在请处理查询: '{query}' """ try: # Enable Google Search for more accurate results grounding_tool = types.Tool(google_search=types.GoogleSearch()) # Note: Cannot use response_mime_type with tools config = types.GenerateContentConfig(tools=[grounding_tool]) response = client.models.generate_content( model=model, contents=prompt, config=config ) response_text = response.text.strip() print(f"Search API raw response: {response_text[:500]}") # Extract JSON from response (may be wrapped in markdown code blocks) if "```json" in response_text: # Extract JSON from code block start = response_text.find("```json") + 7 end = response_text.find("```", start) json_str = response_text[start:end].strip() elif "```" in response_text: # Extract from generic code block start = response_text.find("```") + 3 end = response_text.find("```", start) json_str = response_text[start:end].strip() else: json_str = response_text result = json.loads(json_str) # Ensure result is always an array for consistent handling if not isinstance(result, list): if isinstance(result, dict) and "error" in result: return result # Return error as-is result = [result] # Wrap single object in array return result except json.JSONDecodeError as e: print(f"JSON decode error: {e}, Response text: {response_text}") return {"error": f"无法解析搜索结果: {str(e)}"} except Exception as e: print(f"Search error: {e}") return {"error": f"搜索失败: {str(e)}"} async def run_analysis_task(report_id: int, market: str, symbol: str, api_key: str): """ Background task to run the full analysis pipeline. Creates its own DB session. """ print(f"Starting analysis for report {report_id}: {market} {symbol}") # Create new session from app.database import AsyncSessionLocal async with AsyncSessionLocal() as session: try: report = await session.get(Report, report_id) if not report: print(f"Report {report_id} not found in background task") return report.status = AnalysisStatus.IN_PROGRESS await session.commit() company_name_for_prompt = report.company_name # 2. Run Main Data Fetching Script (main.py) root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) cmd = [sys.executable, "main.py", market, symbol] print(f"Executing data fetch command: {cmd} in {root_dir}") process = await asyncio.create_subprocess_exec( *cmd, cwd=root_dir, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() if process.returncode != 0: error_msg = stderr.decode() print(f"Data fetch failed: {error_msg}") report.status = AnalysisStatus.FAILED await session.commit() return print("Data fetch successful.") # 3. Perform Analysis Logic await llm_engine.process_analysis_steps( report_id=report_id, company_name=company_name_for_prompt, symbol=symbol, market=market, db=session, api_key=api_key ) # 4. Finalize report.status = AnalysisStatus.COMPLETED await session.commit() print(f"Analysis for report {report_id} completed.") except Exception as e: print(f"Analysis task exception: {e}") try: report = await session.get(Report, report_id) if report: report.status = AnalysisStatus.FAILED await session.commit() except: pass