from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import os import sys from pathlib import Path from dotenv import load_dotenv import logging from typing import List, Optional # 配置日志系统 - 在最开始配置,确保所有模块都能使用 # 仅输出到 stdout,容器环境下通过 docker logs 收集日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) # Ensure 'src' directory is in python path for 'storage' module imports # This is required because fetchers use 'from storage.file_io import ...' ROOT_DIR = Path(__file__).resolve().parent.parent.parent SRC_DIR = ROOT_DIR / "src" if str(SRC_DIR) not in sys.path: sys.path.insert(0, str(SRC_DIR)) # 导入新路由 from app.api import data_routes, analysis_routes, chat_routes load_dotenv() @asynccontextmanager async def lifespan(app: FastAPI): # 新架构使用 AsyncSession,不需要显式初始化 yield app = FastAPI( title="FA3 Stock Analysis API", version="2.0.0", description="架构重构后的股票分析 API", lifespan=lifespan ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # For development convenience allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 挂载新的 API 路由 app.include_router(data_routes.router, prefix="/api") app.include_router(analysis_routes.router, prefix="/api") app.include_router(chat_routes.router, prefix="/api") @app.get("/") def read_root(): return { "status": "ok", "message": "FA3 Stock Analysis API v2.0", "architecture": "refactored", "endpoints": { "data": "/api/data/*", "analysis": "/api/analysis/*", "config": "/api/config" }, "docs": "/docs" } @app.get("/health") def health_check(): return { "status": "healthy", "database": "fa3", "version": "2.0.0" } # 配置端点 from fastapi import Request, Depends from app.database import get_db from app.models import Setting from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from pydantic import BaseModel class ConfigUpdateRequest(BaseModel): key: str value: str @app.get("/api/config") async def get_config_compat(db: AsyncSession = Depends(get_db)): """获取配置(从 PostgreSQL)""" try: result = await db.execute(select(Setting)) settings = result.scalars().all() config_map = {} if settings: config_map = {setting.key: setting.value for setting in settings} # Compatibility: Map common UPPERCASE keys to lowercase for frontend if "AI_MODEL" in config_map and "ai_model" not in config_map: config_map["ai_model"] = config_map["AI_MODEL"] if "AVAILABLE_MODELS" in config_map and "available_models" not in config_map: config_map["available_models"] = config_map["AVAILABLE_MODELS"] if "AI_DISCUSSION_ROLES" in config_map and "ai_discussion_roles" not in config_map: config_map["ai_discussion_roles"] = config_map["AI_DISCUSSION_ROLES"] if "AI_QUESTION_LIBRARY" in config_map and "ai_question_library" not in config_map: config_map["ai_question_library"] = config_map["AI_QUESTION_LIBRARY"] # Ensure available_models is present if "available_models" not in config_map: # Default models list as requested import json default_models = [ "gemini-2.0-flash", "gemini-2.5-flash", "gemini-3-flash-preview", "gemini-3-pro-preview" ] config_map["available_models"] = json.dumps(default_models) # Ensure defaults for other keys exist if not in DB defaults = { "ai_model": "gemini-2.5-flash", "data_source_cn": "Tushare", "data_source_hk": "iFinD", "data_source_us": "iFinD", "data_source_jp": "iFinD", "data_source_vn": "iFinD" } for k, v in defaults.items(): if k not in config_map: config_map[k] = v return config_map except Exception as e: # 如果表不存在,返回默认配置 print(f"Config read error (using defaults): {e}") import json return { "ai_model": "gemini-2.5-flash", "available_models": json.dumps([ "gemini-2.0-flash", "gemini-2.5-flash", "gemini-3-flash-preview", "gemini-3-pro-preview" ]), "data_source_cn": "Tushare", "data_source_hk": "iFinD", "data_source_us": "iFinD", "data_source_jp": "iFinD", "data_source_vn": "iFinD" } @app.post("/api/config") async def update_config_compat( request: ConfigUpdateRequest, db: AsyncSession = Depends(get_db) ): """更新配置(到 PostgreSQL)""" try: result = await db.execute( select(Setting).where(Setting.key == request.key) ) setting = result.scalar_one_or_none() if setting: setting.value = request.value else: setting = Setting(key=request.key, value=request.value) db.add(setting) await db.commit() return {"status": "ok", "key": request.key, "value": request.value} except Exception as e: # 如果数据库操作失败,至少返回成功状态(配置会保存在前端) print(f"Config update error: {e}") return {"status": "ok", "key": request.key, "value": request.value, "note": "saved in memory only"} # 搜索端点 from app.services.analysis_service import get_genai_client import json import logging import time logger = logging.getLogger(__name__) class StockSearchRequest(BaseModel): query: str model: str = "gemini-2.0-flash" # 支持前端传入模型参数 class StockSearchResponse(BaseModel): market: str symbol: str company_name: str @app.post("/api/search", response_model=list[StockSearchResponse]) async def search_stock(request: StockSearchRequest): """使用 AI 搜索股票""" logger.info(f"🔍 [搜索] 开始搜索股票: {request.query}") start_time = time.time() try: from google.genai import types client = get_genai_client() prompt = f"""请利用Google搜索查找 "{request.query}" 对应的上市股票信息。 返回最匹配的股票,优先返回准确匹配的公司。 返回格式必须是 JSON 数组,每个元素包含: - market: 市场代码(CH/HK/US/JP/VN),例如:腾讯是 HK,茅台是 CH,英伟达是 US。 - symbol: 股票代码(如果是CH,通常是6位数字;HK是5位;US是字母)。 - company_name: 公司简称。 示例: [{{"market": "HK", "symbol": "00700", "company_name": "腾讯控股"}}] 请直接返回 JSON,不要添加任何其他文字。最多返回5个结果。""" # 启用 Google Search Grounding grounding_tool = types.Tool(google_search=types.GoogleSearch()) # 使用请求中的模型,默认为 gemini-2.5-flash model_name = request.model or "gemini-2.5-flash" logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索") llm_start = time.time() response = client.models.generate_content( model=model_name, contents=prompt, config=types.GenerateContentConfig( tools=[grounding_tool], temperature=0.1 ) ) llm_elapsed = time.time() - llm_start usage = response.usage_metadata prompt_tokens = usage.prompt_token_count if usage else 0 completion_tokens = usage.candidates_token_count if usage else 0 total_tokens = prompt_tokens + completion_tokens logger.info(f"✅ [搜索-LLM] 模型响应完成, 耗时: {llm_elapsed:.2f}秒, Tokens: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}") # 解析响应 text = response.text.strip() # 移除可能的 markdown 代码块标记 if text.startswith("```json"): text = text[7:] if text.startswith("```"): text = text[3:] if text.endswith("```"): text = text[:-3] text = text.strip() results = json.loads(text) total_elapsed = time.time() - start_time logger.info(f"✅ [搜索] 搜索完成, 找到 {len(results)} 个结果, 总耗时: {total_elapsed:.2f}秒") return results except Exception as e: elapsed = time.time() - start_time logger.error(f"❌ [搜索] 搜索失败: {e}, 耗时: {elapsed:.2f}秒") print(f"Search error: {e}") # ---------------------------------------------------------------------- # Chat Endpoint - Configured via app.api.chat_routes # ----------------------------------------------------------------------