Fundamental_Analysis/backend/app/routers/reports.py
2025-10-21 14:30:08 +08:00

414 lines
15 KiB
Python

"""
报告相关API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from typing import Optional, List, AsyncGenerator
from uuid import UUID
import logging
import json
import asyncio
from ..core.dependencies import get_database_session
from ..models.report import Report
from ..schemas.report import ReportResponse, RegenerateRequest
from ..schemas.progress import ProgressResponse
from ..services.report_generator import ReportGenerator
from ..services.config_manager import ConfigManager
from ..services.progress_tracker import ProgressTracker
from ..core.exceptions import ReportGenerationError, DatabaseError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/{symbol}", response_model=ReportResponse)
async def get_or_create_report(
symbol: str,
background_tasks: BackgroundTasks,
market: str = Query(..., description="交易市场"),
db: AsyncSession = Depends(get_database_session)
):
"""获取或创建股票报告"""
try:
# 验证输入参数
symbol = symbol.upper().strip()
market = market.lower().strip()
if not symbol:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="证券代码不能为空"
)
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
# 查询现有报告(包含关联的分析模块)
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(
Report.symbol == symbol,
Report.market == market
)
)
existing_report = result.scalar_one_or_none()
if existing_report:
logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}")
return ReportResponse.from_attributes(existing_report)
# 创建新报告
logger.info(f"开始生成新报告: {symbol}-{market}")
config_manager = ConfigManager(db)
report_generator = ReportGenerator(db, config_manager)
# 在后台任务中生成报告
background_tasks.add_task(
report_generator.generate_report_async,
symbol,
market
)
# 创建初始报告记录
new_report = Report(
symbol=symbol,
market=market,
status="generating"
)
db.add(new_report)
await db.commit()
await db.refresh(new_report)
logger.info(f"创建报告记录: {new_report.id}")
return ReportResponse.from_attributes(new_report)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取或创建报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取或创建报告失败: {str(e)}"
)
@router.post("/{symbol}/regenerate", response_model=ReportResponse)
async def regenerate_report(
symbol: str,
request: RegenerateRequest,
background_tasks: BackgroundTasks,
market: str = Query(..., description="交易市场"),
db: AsyncSession = Depends(get_database_session)
):
"""重新生成报告"""
try:
# 验证输入参数
symbol = symbol.upper().strip()
market = market.lower().strip()
if not symbol:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="证券代码不能为空"
)
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
# 查询现有报告
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(
Report.symbol == symbol,
Report.market == market
)
)
existing_report = result.scalar_one_or_none()
if existing_report and not request.force:
# 如果报告存在且不强制重新生成,返回现有报告
logger.info(f"返回现有报告: {symbol}-{market}")
return ReportResponse.from_attributes(existing_report)
# 删除现有报告(如果存在)
if existing_report:
logger.info(f"删除现有报告: {existing_report.id}")
await db.delete(existing_report)
await db.commit()
# 创建新报告记录
new_report = Report(
symbol=symbol,
market=market,
status="generating"
)
db.add(new_report)
await db.commit()
await db.refresh(new_report)
# 在后台任务中生成报告
config_manager = ConfigManager(db)
report_generator = ReportGenerator(db, config_manager)
background_tasks.add_task(
report_generator.generate_report_async,
symbol,
market
)
logger.info(f"开始重新生成报告: {new_report.id}")
return ReportResponse.from_attributes(new_report)
except HTTPException:
raise
except Exception as e:
logger.error(f"重新生成报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重新生成报告失败: {str(e)}"
)
@router.get("/{report_id}/details", response_model=ReportResponse)
async def get_report_details(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""获取报告详情"""
try:
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(Report.id == report_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="报告不存在"
)
logger.info(f"获取报告详情: {report_id}")
return ReportResponse.from_attributes(report)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取报告详情失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取报告详情失败: {str(e)}"
)
@router.get("/", response_model=List[ReportResponse])
async def list_reports(
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
market: Optional[str] = Query(None, description="按市场筛选"),
status: Optional[str] = Query(None, description="按状态筛选"),
db: AsyncSession = Depends(get_database_session)
):
"""获取报告列表"""
try:
query = select(Report).options(selectinload(Report.analysis_modules))
# 添加筛选条件
if market:
market = market.lower().strip()
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
query = query.where(Report.market == market)
if status:
status_value = status.lower().strip()
if status_value not in ["generating", "completed", "partial", "failed"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的状态值"
)
query = query.where(Report.status == status_value)
# 添加分页和排序
query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
reports = result.scalars().all()
logger.info(f"获取报告列表: {len(reports)} 条记录")
return [ReportResponse.from_attributes(report) for report in reports]
except HTTPException:
raise
except Exception as e:
logger.error(f"获取报告列表失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取报告列表失败: {str(e)}"
)
@router.get("/{report_id}/progress", response_model=ProgressResponse)
async def get_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""获取报告生成进度"""
try:
progress_tracker = ProgressTracker(db)
progress = await progress_tracker.get_progress(report_id)
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
return progress
except ValueError as e:
logger.warning(f"报告不存在或无进度记录: {report_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except DatabaseError as e:
logger.error(f"获取进度时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"获取进度失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取进度失败: {str(e)}"
)
@router.get("/{report_id}/progress/stream")
async def stream_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""实时流式获取报告生成进度 (Server-Sent Events)"""
async def generate_progress_stream() -> AsyncGenerator[str, None]:
"""生成进度流"""
progress_tracker = ProgressTracker(db)
last_progress = None
try:
while True:
try:
# 获取当前进度
current_progress = await progress_tracker.get_progress(report_id)
# 只有进度发生变化时才发送更新
if current_progress != last_progress:
progress_data = {
"reportId": str(current_progress.report_id),
"currentStep": current_progress.current_step,
"totalSteps": current_progress.total_steps,
"currentStepName": current_progress.current_step_name,
"status": current_progress.status,
"estimatedRemaining": current_progress.estimated_remaining,
"steps": [
{
"id": str(step.step_order),
"name": step.step_name,
"status": step.status,
"startedAt": step.started_at.isoformat() if step.started_at else None,
"completedAt": step.completed_at.isoformat() if step.completed_at else None,
"durationMs": step.duration_ms,
"errorMessage": step.error_message
}
for step in current_progress.step_timings
]
}
# 发送SSE格式的数据
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = current_progress
# 如果报告已完成或失败,发送完成事件并结束流
if current_progress.status in ["completed", "failed"]:
yield f"event: complete\ndata: {json.dumps({'status': current_progress.status})}\n\n"
break
# 等待2秒后再次检查
await asyncio.sleep(2)
except ValueError:
# 报告不存在,发送错误事件
yield f"event: error\ndata: {json.dumps({'error': '报告不存在'})}\n\n"
break
except Exception as e:
logger.error(f"流式进度获取错误: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
break
except Exception as e:
logger.error(f"进度流生成错误: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
generate_progress_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@router.delete("/{report_id}")
async def delete_report(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""删除报告"""
try:
result = await db.execute(
select(Report).where(Report.id == report_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="报告不存在"
)
await db.delete(report)
await db.commit()
logger.info(f"删除报告成功: {report_id}")
return {"message": "报告删除成功", "report_id": str(report_id)}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"删除报告失败: {str(e)}"
)