FA3-Datafetch/backend/app/main.py
2026-01-11 21:33:47 +08:00

236 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
import logging
# 配置日志系统 - 在最开始配置,确保所有模块都能使用
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler("server.log", encoding='utf-8')
]
)
# Ensure 'src' directory is in python path for 'storage' module imports
# This is required because fetchers use 'from storage.file_io import ...'
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
SRC_DIR = ROOT_DIR / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
# 导入新路由
from app.api import data_routes, analysis_routes
load_dotenv()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 新架构使用 AsyncSession不需要显式初始化
yield
app = FastAPI(
title="FA3 Stock Analysis API",
version="2.0.0",
description="架构重构后的股票分析 API",
lifespan=lifespan
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development convenience
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载新的 API 路由
app.include_router(data_routes.router, prefix="/api")
app.include_router(analysis_routes.router, prefix="/api")
@app.get("/")
def read_root():
return {
"status": "ok",
"message": "FA3 Stock Analysis API v2.0",
"architecture": "refactored",
"endpoints": {
"data": "/api/data/*",
"analysis": "/api/analysis/*",
"config": "/api/config"
},
"docs": "/docs"
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"database": "fa3",
"version": "2.0.0"
}
# 配置端点
from fastapi import Request, Depends
from app.database import get_db
from app.models import Setting
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel
class ConfigUpdateRequest(BaseModel):
key: str
value: str
@app.get("/api/config")
async def get_config_compat(db: AsyncSession = Depends(get_db)):
"""获取配置(从 PostgreSQL"""
try:
result = await db.execute(select(Setting))
settings = result.scalars().all()
if settings:
return {setting.key: setting.value for setting in settings}
# 返回默认配置
return {
"ai_model": "gemini-2.0-flash",
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
"data_source_us": "iFinD",
"data_source_jp": "iFinD",
"data_source_vn": "iFinD"
}
except Exception as e:
# 如果表不存在,返回默认配置
print(f"Config read error (using defaults): {e}")
return {
"ai_model": "gemini-2.0-flash",
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
"data_source_us": "iFinD",
"data_source_jp": "iFinD",
"data_source_vn": "iFinD"
}
@app.post("/api/config")
async def update_config_compat(
request: ConfigUpdateRequest,
db: AsyncSession = Depends(get_db)
):
"""更新配置(到 PostgreSQL"""
try:
result = await db.execute(
select(Setting).where(Setting.key == request.key)
)
setting = result.scalar_one_or_none()
if setting:
setting.value = request.value
else:
setting = Setting(key=request.key, value=request.value)
db.add(setting)
await db.commit()
return {"status": "ok", "key": request.key, "value": request.value}
except Exception as e:
# 如果数据库操作失败,至少返回成功状态(配置会保存在前端)
print(f"Config update error: {e}")
return {"status": "ok", "key": request.key, "value": request.value, "note": "saved in memory only"}
# 搜索端点
from app.services.analysis_service import get_genai_client
import json
import logging
import time
logger = logging.getLogger(__name__)
class StockSearchRequest(BaseModel):
query: str
model: str = "gemini-2.0-flash" # 支持前端传入模型参数
class StockSearchResponse(BaseModel):
market: str
symbol: str
company_name: str
@app.post("/api/search", response_model=list[StockSearchResponse])
async def search_stock(request: StockSearchRequest):
"""使用 AI 搜索股票"""
logger.info(f"🔍 [搜索] 开始搜索股票: {request.query}")
start_time = time.time()
try:
from google.genai import types
client = get_genai_client()
prompt = f"""请利用Google搜索查找 "{request.query}" 对应的上市股票信息。
返回最匹配的股票,优先返回准确匹配的公司。
返回格式必须是 JSON 数组,每个元素包含:
- market: 市场代码CH/HK/US/JP/VN例如腾讯是 HK茅台是 CH英伟达是 US。
- symbol: 股票代码如果是CH通常是6位数字HK是5位US是字母
- company_name: 公司简称。
示例:
[{{"market": "HK", "symbol": "00700", "company_name": "腾讯控股"}}]
请直接返回 JSON不要添加任何其他文字。最多返回5个结果。"""
# 启用 Google Search Grounding
grounding_tool = types.Tool(google_search=types.GoogleSearch())
# 使用请求中的模型,默认为 gemini-2.5-flash
model_name = request.model or "gemini-2.5-flash"
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
llm_start = time.time()
response = client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
tools=[grounding_tool],
temperature=0.1
)
)
llm_elapsed = time.time() - llm_start
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else 0
completion_tokens = usage.candidates_token_count if usage else 0
total_tokens = prompt_tokens + completion_tokens
logger.info(f"✅ [搜索-LLM] 模型响应完成, 耗时: {llm_elapsed:.2f}秒, Tokens: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}")
# 解析响应
text = response.text.strip()
# 移除可能的 markdown 代码块标记
if text.startswith("```json"):
text = text[7:]
if text.startswith("```"):
text = text[3:]
if text.endswith("```"):
text = text[:-3]
text = text.strip()
results = json.loads(text)
total_elapsed = time.time() - start_time
logger.info(f"✅ [搜索] 搜索完成, 找到 {len(results)} 个结果, 总耗时: {total_elapsed:.2f}")
return results
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"❌ [搜索] 搜索失败: {e}, 耗时: {elapsed:.2f}")
print(f"Search error: {e}")
# 返回空结果而不是错误,避免前端崩溃
return []