FA3-Datafetch/backend/app/api/analysis_routes.py
2026-01-11 21:33:47 +08:00

183 lines
5.3 KiB
Python

"""
AI 分析 API 路由
处理 LLM 分析的启动、状态查询和结果获取
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from typing import List
from app.database import get_db
from app.schemas import (
AnalysisStartRequest,
AnalysisStartResponse,
AnalysisStatusResponse,
AnalysisResultResponse
)
from app.models import AIAnalysis, Company
router = APIRouter(prefix="/analysis", tags=["analysis"])
@router.post("/start", response_model=AnalysisStartResponse)
async def start_analysis(
request: AnalysisStartRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
启动 AI 分析
基于指定公司和数据源的财务数据启动 LLM 分析
"""
# 1. 创建分析记录
analysis = AIAnalysis(
company_id=request.company_id,
data_source=request.data_source,
ai_model=request.model,
status='pending'
)
db.add(analysis)
await db.commit()
await db.refresh(analysis)
# 2. 启动后台分析任务
background_tasks.add_task(
run_analysis_background,
analysis_id=analysis.id,
company_id=request.company_id,
data_source=request.data_source,
model=request.model
)
return AnalysisStartResponse(
analysis_id=analysis.id,
company_id=request.company_id,
data_source=request.data_source,
status='pending',
message='AI 分析已加入队列'
)
async def run_analysis_background(
analysis_id: int,
company_id: int,
data_source: str,
model: str
):
"""后台 AI 分析任务"""
from app.database import SessionLocal
from app.services.analysis_service import run_llm_analysis
from datetime import datetime
async with SessionLocal() as db:
try:
# 更新状态为进行中
result = await db.execute(
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
)
analysis = result.scalar_one()
analysis.status = 'in_progress'
await db.commit()
# 执行 LLM 分析
analysis_result = await run_llm_analysis(
company_id=company_id,
data_source=data_source,
model=model,
db=db
)
# 更新分析结果
analysis.status = 'completed'
analysis.company_profile = analysis_result.get('company_profile')
analysis.fundamental_analysis = analysis_result.get('fundamental_analysis')
analysis.insider_analysis = analysis_result.get('insider_analysis')
analysis.bullish_analysis = analysis_result.get('bullish_analysis')
analysis.bearish_analysis = analysis_result.get('bearish_analysis')
analysis.total_tokens = analysis_result.get('total_tokens', 0)
analysis.tokens_by_section = analysis_result.get('tokens_by_section')
analysis.completed_at = datetime.now()
await db.commit()
except Exception as e:
# 更新为失败状态
result = await db.execute(
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
)
analysis = result.scalar_one()
analysis.status = 'failed'
analysis.error_message = str(e)
await db.commit()
@router.get("/status/{analysis_id}", response_model=AnalysisStatusResponse)
async def get_analysis_status(
analysis_id: int,
db: AsyncSession = Depends(get_db)
):
"""
查询分析状态
通过 analysis_id 查询 AI 分析任务的当前状态
"""
result = await db.execute(
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="分析记录不存在")
return analysis
@router.get("/result/{analysis_id}", response_model=AnalysisResultResponse)
async def get_analysis_result(
analysis_id: int,
db: AsyncSession = Depends(get_db)
):
"""
获取分析结果
获取已完成的 AI 分析的完整报告内容
"""
result = await db.execute(
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="分析记录不存在")
if analysis.status != 'completed':
raise HTTPException(status_code=400, detail="分析尚未完成")
return analysis
@router.get("/history", response_model=List[AnalysisStatusResponse])
async def get_analysis_history(
company_id: int,
data_source: str = None,
db: AsyncSession = Depends(get_db)
):
"""
历史分析列表
获取指定公司的所有历史分析记录
可选择按数据源筛选
"""
query = select(AIAnalysis).where(AIAnalysis.company_id == company_id)
if data_source:
query = query.where(AIAnalysis.data_source == data_source)
query = query.order_by(AIAnalysis.created_at.desc())
result = await db.execute(query)
analyses = result.scalars().all()
return analyses