FA3-Datafetch/backend/app/main.py

291 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
import logging
from typing import List, Optional
# 配置日志系统 - 在最开始配置,确保所有模块都能使用
# 仅输出到 stdout容器环境下通过 docker logs 收集日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
# Ensure 'src' directory is in python path for 'storage' module imports
# This is required because fetchers use 'from storage.file_io import ...'
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
SRC_DIR = ROOT_DIR / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
# 导入新路由
from app.api import data_routes, analysis_routes, chat_routes
load_dotenv()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 新架构使用 AsyncSession不需要显式初始化
yield
app = FastAPI(
title="FA3 Stock Analysis API",
version="2.0.0",
description="架构重构后的股票分析 API",
lifespan=lifespan
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development convenience
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载新的 API 路由
app.include_router(data_routes.router, prefix="/api")
app.include_router(analysis_routes.router, prefix="/api")
app.include_router(chat_routes.router, prefix="/api")
@app.get("/")
def read_root():
return {
"status": "ok",
"message": "FA3 Stock Analysis API v2.0",
"architecture": "refactored",
"endpoints": {
"data": "/api/data/*",
"analysis": "/api/analysis/*",
"config": "/api/config"
},
"docs": "/docs"
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"database": "fa3",
"version": "2.0.0"
}
# 配置端点
from fastapi import Request, Depends
from app.database import get_db
from app.models import Setting
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel
class ConfigUpdateRequest(BaseModel):
key: str
value: str
@app.get("/api/config")
async def get_config_compat(db: AsyncSession = Depends(get_db)):
"""获取配置(从 PostgreSQL"""
try:
result = await db.execute(select(Setting))
settings = result.scalars().all()
config_map = {}
if settings:
config_map = {setting.key: setting.value for setting in settings}
# Compatibility: Map common UPPERCASE keys to lowercase for frontend
if "AI_MODEL" in config_map and "ai_model" not in config_map:
config_map["ai_model"] = config_map["AI_MODEL"]
if "AVAILABLE_MODELS" in config_map and "available_models" not in config_map:
config_map["available_models"] = config_map["AVAILABLE_MODELS"]
if "AI_DISCUSSION_ROLES" in config_map and "ai_discussion_roles" not in config_map:
config_map["ai_discussion_roles"] = config_map["AI_DISCUSSION_ROLES"]
if "AI_QUESTION_LIBRARY" in config_map and "ai_question_library" not in config_map:
config_map["ai_question_library"] = config_map["AI_QUESTION_LIBRARY"]
# Ensure available_models is present
if "available_models" not in config_map:
# Default models list as requested
import json
default_models = [
"gemini-3-flash-preview",
"gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
]
config_map["available_models"] = json.dumps(default_models)
# Ensure defaults for other keys exist if not in DB
defaults = {
"ai_model": "gemini-3-flash-preview",
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
"data_source_us": "iFinD",
"data_source_jp": "iFinD",
"data_source_vn": "iFinD"
}
for k, v in defaults.items():
if k not in config_map:
config_map[k] = v
return config_map
except Exception as e:
# 如果表不存在,返回默认配置
print(f"Config read error (using defaults): {e}")
import json
return {
"ai_model": "gemini-3-flash-preview",
"available_models": json.dumps([
"gemini-3-flash-preview",
"gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
]),
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
"data_source_us": "iFinD",
"data_source_jp": "iFinD",
"data_source_vn": "iFinD"
}
@app.post("/api/config")
async def update_config_compat(
request: ConfigUpdateRequest,
db: AsyncSession = Depends(get_db)
):
"""更新配置(到 PostgreSQL"""
try:
result = await db.execute(
select(Setting).where(Setting.key == request.key)
)
setting = result.scalar_one_or_none()
if setting:
setting.value = request.value
else:
setting = Setting(key=request.key, value=request.value)
db.add(setting)
await db.commit()
return {"status": "ok", "key": request.key, "value": request.value}
except Exception as e:
# 如果数据库操作失败,至少返回成功状态(配置会保存在前端)
print(f"Config update error: {e}")
return {"status": "ok", "key": request.key, "value": request.value, "note": "saved in memory only"}
# 搜索端点
from app.services.analysis_service import get_genai_client
import json
import logging
import time
logger = logging.getLogger(__name__)
class StockSearchRequest(BaseModel):
query: str
model: str = "gemini-3-flash-preview" # 支持前端传入模型参数
class StockSearchResponse(BaseModel):
market: str
symbol: str
company_name: str
@app.post("/api/search", response_model=list[StockSearchResponse])
async def search_stock(request: StockSearchRequest):
"""使用 AI 搜索股票"""
logger.info(f"🔍 [搜索] 开始搜索股票: {request.query}")
start_time = time.time()
try:
from google.genai import types
client = get_genai_client()
prompt = f"""请利用Google搜索查找 "{request.query}" 对应的上市股票信息。
返回最匹配的股票,优先返回准确匹配的公司。
**关键要求Symbol 必须遵循 Bloomberg 终端专用代码格式 (Ticker Only)。**
请注意区分不同市场的代码后缀(用于确定market字段),以及某些市场的特殊代码规则(例如印度股票通常以 IN 结尾,代码可能是缩写也可能是全称,请仔细核对 Bloomberg 标准)。
返回格式必须是 JSON 数组,每个元素包含:
- market: 市场代码CH/HK/US/JP/VN/IN/EU等
- symbol: **Bloomberg标准代码** (仅Ticker不要包含后缀)。
- 香港: "700" (不要写00700)
- 中国: "600519"
- 美国: "NVDA"
- 日本: "7203"
- 印度: "UBBL" (例如 United Breweries不要只写 UBL)
- 欧洲: "NESN"
- company_name: 公司简称。
示例:
[{{"market": "HK", "symbol": "700", "company_name": "腾讯控股"}},
{{"market": "US", "symbol": "NVDA", "company_name": "英伟达"}},
{{"market": "IN", "symbol": "UBBL", "company_name": "United Breweries"}}]
请直接返回 JSON不要添加任何其他文字。最多返回5个结果。"""
# 启用 Google Search Grounding
grounding_tool = types.Tool(google_search=types.GoogleSearch())
# 使用请求中的模型,默认为 gemini-3-flash-preview
model_name = request.model or "gemini-3-flash-preview"
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
llm_start = time.time()
response = client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
tools=[grounding_tool],
temperature=0.1
)
)
llm_elapsed = time.time() - llm_start
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else 0
completion_tokens = usage.candidates_token_count if usage else 0
total_tokens = prompt_tokens + completion_tokens
logger.info(f"✅ [搜索-LLM] 模型响应完成, 耗时: {llm_elapsed:.2f}秒, Tokens: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}")
# 解析响应
text = response.text.strip()
# 移除可能的 markdown 代码块标记
if text.startswith("```json"):
text = text[7:]
if text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
results = json.loads(text)
total_elapsed = time.time() - start_time
logger.info(f"✅ [搜索] 搜索完成, 找到 {len(results)} 个结果, 总耗时: {total_elapsed:.2f}")
return results
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"❌ [搜索] 搜索失败: {e}, 耗时: {elapsed:.2f}")
print(f"Search error: {e}")
# ----------------------------------------------------------------------
# Chat Endpoint - Configured via app.api.chat_routes
# ----------------------------------------------------------------------