- 添加 WeasyPrint 依赖以支持 PDF 导出功能 - 新增 docker-entrypoint.sh 统一管理容器启动流程 - 添加容器健康检查机制(/health 端点) - 配置容器自动重启策略(unless-stopped) - 优化日志输出,仅使用 stdout 适配容器环境 - 改进 update-and-run.sh 添加健康状态检查 - 统一脚本中的 sudo 使用规范
280 lines
9.2 KiB
Python
280 lines
9.2 KiB
Python
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from contextlib import asynccontextmanager
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
import logging
|
||
from typing import List, Optional
|
||
|
||
# 配置日志系统 - 在最开始配置,确保所有模块都能使用
|
||
# 仅输出到 stdout,容器环境下通过 docker logs 收集日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(sys.stdout)
|
||
]
|
||
)
|
||
|
||
# Ensure 'src' directory is in python path for 'storage' module imports
|
||
# This is required because fetchers use 'from storage.file_io import ...'
|
||
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
||
SRC_DIR = ROOT_DIR / "src"
|
||
if str(SRC_DIR) not in sys.path:
|
||
sys.path.insert(0, str(SRC_DIR))
|
||
|
||
# 导入新路由
|
||
from app.api import data_routes, analysis_routes, chat_routes
|
||
|
||
load_dotenv()
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
# 新架构使用 AsyncSession,不需要显式初始化
|
||
yield
|
||
|
||
app = FastAPI(
|
||
title="FA3 Stock Analysis API",
|
||
version="2.0.0",
|
||
description="架构重构后的股票分析 API",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# Configure CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # For development convenience
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 挂载新的 API 路由
|
||
app.include_router(data_routes.router, prefix="/api")
|
||
app.include_router(analysis_routes.router, prefix="/api")
|
||
app.include_router(chat_routes.router, prefix="/api")
|
||
|
||
@app.get("/")
|
||
def read_root():
|
||
return {
|
||
"status": "ok",
|
||
"message": "FA3 Stock Analysis API v2.0",
|
||
"architecture": "refactored",
|
||
"endpoints": {
|
||
"data": "/api/data/*",
|
||
"analysis": "/api/analysis/*",
|
||
"config": "/api/config"
|
||
},
|
||
"docs": "/docs"
|
||
}
|
||
|
||
@app.get("/health")
|
||
def health_check():
|
||
return {
|
||
"status": "healthy",
|
||
"database": "fa3",
|
||
"version": "2.0.0"
|
||
}
|
||
|
||
# 配置端点
|
||
from fastapi import Request, Depends
|
||
from app.database import get_db
|
||
from app.models import Setting
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select
|
||
from pydantic import BaseModel
|
||
|
||
class ConfigUpdateRequest(BaseModel):
|
||
key: str
|
||
value: str
|
||
|
||
@app.get("/api/config")
|
||
async def get_config_compat(db: AsyncSession = Depends(get_db)):
|
||
"""获取配置(从 PostgreSQL)"""
|
||
try:
|
||
result = await db.execute(select(Setting))
|
||
settings = result.scalars().all()
|
||
config_map = {}
|
||
|
||
if settings:
|
||
config_map = {setting.key: setting.value for setting in settings}
|
||
|
||
# Compatibility: Map common UPPERCASE keys to lowercase for frontend
|
||
if "AI_MODEL" in config_map and "ai_model" not in config_map:
|
||
config_map["ai_model"] = config_map["AI_MODEL"]
|
||
if "AVAILABLE_MODELS" in config_map and "available_models" not in config_map:
|
||
config_map["available_models"] = config_map["AVAILABLE_MODELS"]
|
||
if "AI_DISCUSSION_ROLES" in config_map and "ai_discussion_roles" not in config_map:
|
||
config_map["ai_discussion_roles"] = config_map["AI_DISCUSSION_ROLES"]
|
||
if "AI_QUESTION_LIBRARY" in config_map and "ai_question_library" not in config_map:
|
||
config_map["ai_question_library"] = config_map["AI_QUESTION_LIBRARY"]
|
||
|
||
# Ensure available_models is present
|
||
if "available_models" not in config_map:
|
||
# Default models list as requested
|
||
import json
|
||
default_models = [
|
||
"gemini-2.0-flash",
|
||
"gemini-2.5-flash",
|
||
"gemini-3-flash-preview",
|
||
"gemini-3-pro-preview"
|
||
]
|
||
config_map["available_models"] = json.dumps(default_models)
|
||
|
||
# Ensure defaults for other keys exist if not in DB
|
||
defaults = {
|
||
"ai_model": "gemini-2.5-flash",
|
||
"data_source_cn": "Tushare",
|
||
"data_source_hk": "iFinD",
|
||
"data_source_us": "iFinD",
|
||
"data_source_jp": "iFinD",
|
||
"data_source_vn": "iFinD"
|
||
}
|
||
|
||
for k, v in defaults.items():
|
||
if k not in config_map:
|
||
config_map[k] = v
|
||
|
||
return config_map
|
||
|
||
except Exception as e:
|
||
# 如果表不存在,返回默认配置
|
||
print(f"Config read error (using defaults): {e}")
|
||
import json
|
||
return {
|
||
"ai_model": "gemini-2.5-flash",
|
||
"available_models": json.dumps([
|
||
"gemini-2.0-flash",
|
||
"gemini-2.5-flash",
|
||
"gemini-3-flash-preview",
|
||
"gemini-3-pro-preview"
|
||
]),
|
||
"data_source_cn": "Tushare",
|
||
"data_source_hk": "iFinD",
|
||
"data_source_us": "iFinD",
|
||
"data_source_jp": "iFinD",
|
||
"data_source_vn": "iFinD"
|
||
}
|
||
|
||
@app.post("/api/config")
|
||
async def update_config_compat(
|
||
request: ConfigUpdateRequest,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""更新配置(到 PostgreSQL)"""
|
||
try:
|
||
result = await db.execute(
|
||
select(Setting).where(Setting.key == request.key)
|
||
)
|
||
setting = result.scalar_one_or_none()
|
||
|
||
if setting:
|
||
setting.value = request.value
|
||
else:
|
||
setting = Setting(key=request.key, value=request.value)
|
||
db.add(setting)
|
||
|
||
await db.commit()
|
||
return {"status": "ok", "key": request.key, "value": request.value}
|
||
except Exception as e:
|
||
# 如果数据库操作失败,至少返回成功状态(配置会保存在前端)
|
||
print(f"Config update error: {e}")
|
||
return {"status": "ok", "key": request.key, "value": request.value, "note": "saved in memory only"}
|
||
|
||
# 搜索端点
|
||
from app.services.analysis_service import get_genai_client
|
||
import json
|
||
import logging
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class StockSearchRequest(BaseModel):
|
||
query: str
|
||
model: str = "gemini-2.0-flash" # 支持前端传入模型参数
|
||
|
||
class StockSearchResponse(BaseModel):
|
||
market: str
|
||
symbol: str
|
||
company_name: str
|
||
|
||
@app.post("/api/search", response_model=list[StockSearchResponse])
|
||
async def search_stock(request: StockSearchRequest):
|
||
"""使用 AI 搜索股票"""
|
||
logger.info(f"🔍 [搜索] 开始搜索股票: {request.query}")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
from google.genai import types
|
||
client = get_genai_client()
|
||
|
||
prompt = f"""请利用Google搜索查找 "{request.query}" 对应的上市股票信息。
|
||
|
||
返回最匹配的股票,优先返回准确匹配的公司。
|
||
|
||
返回格式必须是 JSON 数组,每个元素包含:
|
||
- market: 市场代码(CH/HK/US/JP/VN),例如:腾讯是 HK,茅台是 CH,英伟达是 US。
|
||
- symbol: 股票代码(如果是CH,通常是6位数字;HK是5位;US是字母)。
|
||
- company_name: 公司简称。
|
||
|
||
示例:
|
||
[{{"market": "HK", "symbol": "00700", "company_name": "腾讯控股"}}]
|
||
|
||
请直接返回 JSON,不要添加任何其他文字。最多返回5个结果。"""
|
||
|
||
# 启用 Google Search Grounding
|
||
grounding_tool = types.Tool(google_search=types.GoogleSearch())
|
||
|
||
# 使用请求中的模型,默认为 gemini-2.5-flash
|
||
model_name = request.model or "gemini-2.5-flash"
|
||
|
||
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
|
||
llm_start = time.time()
|
||
|
||
response = client.models.generate_content(
|
||
model=model_name,
|
||
contents=prompt,
|
||
config=types.GenerateContentConfig(
|
||
tools=[grounding_tool],
|
||
temperature=0.1
|
||
)
|
||
)
|
||
|
||
llm_elapsed = time.time() - llm_start
|
||
usage = response.usage_metadata
|
||
prompt_tokens = usage.prompt_token_count if usage else 0
|
||
completion_tokens = usage.candidates_token_count if usage else 0
|
||
total_tokens = prompt_tokens + completion_tokens
|
||
|
||
logger.info(f"✅ [搜索-LLM] 模型响应完成, 耗时: {llm_elapsed:.2f}秒, Tokens: prompt={prompt_tokens}, completion={completion_tokens}, total={total_tokens}")
|
||
|
||
# 解析响应
|
||
text = response.text.strip()
|
||
# 移除可能的 markdown 代码块标记
|
||
if text.startswith("```json"):
|
||
text = text[7:]
|
||
if text.startswith("```"):
|
||
text = text[3:]
|
||
if text.endswith("```"):
|
||
text = text[:-3]
|
||
text = text.strip()
|
||
|
||
results = json.loads(text)
|
||
|
||
total_elapsed = time.time() - start_time
|
||
logger.info(f"✅ [搜索] 搜索完成, 找到 {len(results)} 个结果, 总耗时: {total_elapsed:.2f}秒")
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
elapsed = time.time() - start_time
|
||
logger.error(f"❌ [搜索] 搜索失败: {e}, 耗时: {elapsed:.2f}秒")
|
||
print(f"Search error: {e}")
|
||
|
||
# ----------------------------------------------------------------------
|
||
# Chat Endpoint - Configured via app.api.chat_routes
|
||
# ----------------------------------------------------------------------
|
||
|