Fundamental_Analysis/backend/app/routers/progress.py
2025-10-21 14:30:08 +08:00

164 lines
6.3 KiB
Python

"""
进度追踪API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import UUID
import logging
import json
import asyncio
from typing import AsyncGenerator
from ..core.dependencies import get_database_session
from ..schemas.progress import ProgressResponse
from ..services.progress_tracker import ProgressTracker
from ..core.exceptions import DatabaseError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/{report_id}", response_model=ProgressResponse)
async def get_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""获取报告生成进度"""
try:
progress_tracker = ProgressTracker(db)
progress = await progress_tracker.get_progress(report_id)
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
return progress
except ValueError as e:
logger.warning(f"报告不存在或无进度记录: {report_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except DatabaseError as e:
logger.error(f"获取进度时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"获取进度失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取进度失败: {str(e)}"
)
@router.get("/{report_id}/stream")
async def stream_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""实时流式获取报告生成进度 (Server-Sent Events)"""
async def generate_progress_stream() -> AsyncGenerator[str, None]:
"""生成进度流"""
progress_tracker = ProgressTracker(db)
last_progress = None
try:
while True:
try:
# 获取当前进度
current_progress = await progress_tracker.get_progress(report_id)
# 只有进度发生变化时才发送更新
if current_progress != last_progress:
progress_data = {
"reportId": str(current_progress.report_id),
"currentStep": current_progress.current_step,
"totalSteps": current_progress.total_steps,
"currentStepName": current_progress.current_step_name,
"status": current_progress.status,
"estimatedRemaining": current_progress.estimated_remaining,
"steps": [
{
"id": str(step.step_order),
"name": step.step_name,
"status": step.status,
"startedAt": step.started_at.isoformat() if step.started_at else None,
"completedAt": step.completed_at.isoformat() if step.completed_at else None,
"durationMs": step.duration_ms,
"errorMessage": step.error_message
}
for step in current_progress.step_timings
]
}
# 发送SSE格式的数据
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = current_progress
# 如果报告已完成或失败,发送完成事件并结束流
if current_progress.status in ["completed", "failed"]:
yield f"event: complete\ndata: {json.dumps({'status': current_progress.status})}\n\n"
break
# 等待2秒后再次检查
await asyncio.sleep(2)
except ValueError:
# 报告不存在,发送错误事件
yield f"event: error\ndata: {json.dumps({'error': '报告不存在'})}\n\n"
break
except Exception as e:
logger.error(f"流式进度获取错误: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
break
except Exception as e:
logger.error(f"进度流生成错误: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
generate_progress_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@router.post("/{report_id}/reset")
async def reset_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""重置报告生成进度"""
try:
progress_tracker = ProgressTracker(db)
await progress_tracker.reset_progress(report_id)
logger.info(f"重置进度成功: {report_id}")
return {"message": "进度重置成功", "report_id": str(report_id)}
except ValueError as e:
logger.warning(f"报告不存在: {report_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except DatabaseError as e:
logger.error(f"重置进度时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"重置进度失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重置进度失败: {str(e)}"
)