""" 进度追踪API路由 """ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from uuid import UUID import logging import json import asyncio from typing import AsyncGenerator from ..core.dependencies import get_database_session from ..schemas.progress import ProgressResponse from ..services.progress_tracker import ProgressTracker from ..core.exceptions import DatabaseError logger = logging.getLogger(__name__) router = APIRouter() @router.get("/{report_id}", response_model=ProgressResponse) async def get_report_progress( report_id: UUID, db: AsyncSession = Depends(get_database_session) ): """获取报告生成进度""" try: progress_tracker = ProgressTracker(db) progress = await progress_tracker.get_progress(report_id) logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}") return progress except ValueError as e: logger.warning(f"报告不存在或无进度记录: {report_id}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(e) ) except DatabaseError as e: logger.error(f"获取进度时数据库错误: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"数据库错误: {str(e)}" ) except Exception as e: logger.error(f"获取进度失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取进度失败: {str(e)}" ) @router.get("/{report_id}/stream") async def stream_report_progress( report_id: UUID, db: AsyncSession = Depends(get_database_session) ): """实时流式获取报告生成进度 (Server-Sent Events)""" async def generate_progress_stream() -> AsyncGenerator[str, None]: """生成进度流""" progress_tracker = ProgressTracker(db) last_progress = None try: while True: try: # 获取当前进度 current_progress = await progress_tracker.get_progress(report_id) # 只有进度发生变化时才发送更新 if current_progress != last_progress: progress_data = { "reportId": str(current_progress.report_id), "currentStep": current_progress.current_step, "totalSteps": current_progress.total_steps, "currentStepName": current_progress.current_step_name, "status": current_progress.status, "estimatedRemaining": current_progress.estimated_remaining, "steps": [ { "id": str(step.step_order), "name": step.step_name, "status": step.status, "startedAt": step.started_at.isoformat() if step.started_at else None, "completedAt": step.completed_at.isoformat() if step.completed_at else None, "durationMs": step.duration_ms, "errorMessage": step.error_message } for step in current_progress.step_timings ] } # 发送SSE格式的数据 yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n" last_progress = current_progress # 如果报告已完成或失败,发送完成事件并结束流 if current_progress.status in ["completed", "failed"]: yield f"event: complete\ndata: {json.dumps({'status': current_progress.status})}\n\n" break # 等待2秒后再次检查 await asyncio.sleep(2) except ValueError: # 报告不存在,发送错误事件 yield f"event: error\ndata: {json.dumps({'error': '报告不存在'})}\n\n" break except Exception as e: logger.error(f"流式进度获取错误: {str(e)}") yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" break except Exception as e: logger.error(f"进度流生成错误: {str(e)}") yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( generate_progress_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control" } ) @router.post("/{report_id}/reset") async def reset_report_progress( report_id: UUID, db: AsyncSession = Depends(get_database_session) ): """重置报告生成进度""" try: progress_tracker = ProgressTracker(db) await progress_tracker.reset_progress(report_id) logger.info(f"重置进度成功: {report_id}") return {"message": "进度重置成功", "report_id": str(report_id)} except ValueError as e: logger.warning(f"报告不存在: {report_id}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=str(e) ) except DatabaseError as e: logger.error(f"重置进度时数据库错误: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"数据库错误: {str(e)}" ) except Exception as e: logger.error(f"重置进度失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"重置进度失败: {str(e)}" )