183 lines
5.3 KiB
Python
183 lines
5.3 KiB
Python
"""
|
|
AI 分析 API 路由
|
|
处理 LLM 分析的启动、状态查询和结果获取
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_
|
|
from typing import List
|
|
|
|
from app.database import get_db
|
|
from app.schemas import (
|
|
AnalysisStartRequest,
|
|
AnalysisStartResponse,
|
|
AnalysisStatusResponse,
|
|
AnalysisResultResponse
|
|
)
|
|
from app.models import AIAnalysis, Company
|
|
|
|
router = APIRouter(prefix="/analysis", tags=["analysis"])
|
|
|
|
|
|
@router.post("/start", response_model=AnalysisStartResponse)
|
|
async def start_analysis(
|
|
request: AnalysisStartRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
启动 AI 分析
|
|
|
|
基于指定公司和数据源的财务数据启动 LLM 分析
|
|
"""
|
|
# 1. 创建分析记录
|
|
analysis = AIAnalysis(
|
|
company_id=request.company_id,
|
|
data_source=request.data_source,
|
|
ai_model=request.model,
|
|
status='pending'
|
|
)
|
|
db.add(analysis)
|
|
await db.commit()
|
|
await db.refresh(analysis)
|
|
|
|
# 2. 启动后台分析任务
|
|
background_tasks.add_task(
|
|
run_analysis_background,
|
|
analysis_id=analysis.id,
|
|
company_id=request.company_id,
|
|
data_source=request.data_source,
|
|
model=request.model
|
|
)
|
|
|
|
return AnalysisStartResponse(
|
|
analysis_id=analysis.id,
|
|
company_id=request.company_id,
|
|
data_source=request.data_source,
|
|
status='pending',
|
|
message='AI 分析已加入队列'
|
|
)
|
|
|
|
|
|
async def run_analysis_background(
|
|
analysis_id: int,
|
|
company_id: int,
|
|
data_source: str,
|
|
model: str
|
|
):
|
|
"""后台 AI 分析任务"""
|
|
from app.database import SessionLocal
|
|
from app.services.analysis_service import run_llm_analysis
|
|
from datetime import datetime
|
|
|
|
async with SessionLocal() as db:
|
|
try:
|
|
# 更新状态为进行中
|
|
result = await db.execute(
|
|
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
|
|
)
|
|
analysis = result.scalar_one()
|
|
analysis.status = 'in_progress'
|
|
await db.commit()
|
|
|
|
# 执行 LLM 分析
|
|
analysis_result = await run_llm_analysis(
|
|
company_id=company_id,
|
|
data_source=data_source,
|
|
model=model,
|
|
db=db
|
|
)
|
|
|
|
# 更新分析结果
|
|
analysis.status = 'completed'
|
|
analysis.company_profile = analysis_result.get('company_profile')
|
|
analysis.fundamental_analysis = analysis_result.get('fundamental_analysis')
|
|
analysis.insider_analysis = analysis_result.get('insider_analysis')
|
|
analysis.bullish_analysis = analysis_result.get('bullish_analysis')
|
|
analysis.bearish_analysis = analysis_result.get('bearish_analysis')
|
|
analysis.total_tokens = analysis_result.get('total_tokens', 0)
|
|
analysis.tokens_by_section = analysis_result.get('tokens_by_section')
|
|
analysis.completed_at = datetime.now()
|
|
|
|
await db.commit()
|
|
|
|
except Exception as e:
|
|
# 更新为失败状态
|
|
result = await db.execute(
|
|
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
|
|
)
|
|
analysis = result.scalar_one()
|
|
analysis.status = 'failed'
|
|
analysis.error_message = str(e)
|
|
await db.commit()
|
|
|
|
|
|
@router.get("/status/{analysis_id}", response_model=AnalysisStatusResponse)
|
|
async def get_analysis_status(
|
|
analysis_id: int,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
查询分析状态
|
|
|
|
通过 analysis_id 查询 AI 分析任务的当前状态
|
|
"""
|
|
result = await db.execute(
|
|
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
|
|
)
|
|
analysis = result.scalar_one_or_none()
|
|
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="分析记录不存在")
|
|
|
|
return analysis
|
|
|
|
|
|
@router.get("/result/{analysis_id}", response_model=AnalysisResultResponse)
|
|
async def get_analysis_result(
|
|
analysis_id: int,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
获取分析结果
|
|
|
|
获取已完成的 AI 分析的完整报告内容
|
|
"""
|
|
result = await db.execute(
|
|
select(AIAnalysis).where(AIAnalysis.id == analysis_id)
|
|
)
|
|
analysis = result.scalar_one_or_none()
|
|
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="分析记录不存在")
|
|
|
|
if analysis.status != 'completed':
|
|
raise HTTPException(status_code=400, detail="分析尚未完成")
|
|
|
|
return analysis
|
|
|
|
|
|
@router.get("/history", response_model=List[AnalysisStatusResponse])
|
|
async def get_analysis_history(
|
|
company_id: int,
|
|
data_source: str = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
历史分析列表
|
|
|
|
获取指定公司的所有历史分析记录
|
|
可选择按数据源筛选
|
|
"""
|
|
query = select(AIAnalysis).where(AIAnalysis.company_id == company_id)
|
|
|
|
if data_source:
|
|
query = query.where(AIAnalysis.data_source == data_source)
|
|
|
|
query = query.order_by(AIAnalysis.created_at.desc())
|
|
|
|
result = await db.execute(query)
|
|
analyses = result.scalars().all()
|
|
|
|
return analyses
|