164 lines
6.3 KiB
Python
164 lines
6.3 KiB
Python
"""
|
|
进度追踪API路由
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from uuid import UUID
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
from typing import AsyncGenerator
|
|
|
|
from ..core.dependencies import get_database_session
|
|
from ..schemas.progress import ProgressResponse
|
|
from ..services.progress_tracker import ProgressTracker
|
|
from ..core.exceptions import DatabaseError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/{report_id}", response_model=ProgressResponse)
|
|
async def get_report_progress(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告生成进度"""
|
|
|
|
try:
|
|
progress_tracker = ProgressTracker(db)
|
|
progress = await progress_tracker.get_progress(report_id)
|
|
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
|
|
return progress
|
|
|
|
except ValueError as e:
|
|
logger.warning(f"报告不存在或无进度记录: {report_id}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=str(e)
|
|
)
|
|
except DatabaseError as e:
|
|
logger.error(f"获取进度时数据库错误: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"数据库错误: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"获取进度失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取进度失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/{report_id}/stream")
|
|
async def stream_report_progress(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""实时流式获取报告生成进度 (Server-Sent Events)"""
|
|
|
|
async def generate_progress_stream() -> AsyncGenerator[str, None]:
|
|
"""生成进度流"""
|
|
progress_tracker = ProgressTracker(db)
|
|
last_progress = None
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
# 获取当前进度
|
|
current_progress = await progress_tracker.get_progress(report_id)
|
|
|
|
# 只有进度发生变化时才发送更新
|
|
if current_progress != last_progress:
|
|
progress_data = {
|
|
"reportId": str(current_progress.report_id),
|
|
"currentStep": current_progress.current_step,
|
|
"totalSteps": current_progress.total_steps,
|
|
"currentStepName": current_progress.current_step_name,
|
|
"status": current_progress.status,
|
|
"estimatedRemaining": current_progress.estimated_remaining,
|
|
"steps": [
|
|
{
|
|
"id": str(step.step_order),
|
|
"name": step.step_name,
|
|
"status": step.status,
|
|
"startedAt": step.started_at.isoformat() if step.started_at else None,
|
|
"completedAt": step.completed_at.isoformat() if step.completed_at else None,
|
|
"durationMs": step.duration_ms,
|
|
"errorMessage": step.error_message
|
|
}
|
|
for step in current_progress.step_timings
|
|
]
|
|
}
|
|
|
|
# 发送SSE格式的数据
|
|
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
|
last_progress = current_progress
|
|
|
|
# 如果报告已完成或失败,发送完成事件并结束流
|
|
if current_progress.status in ["completed", "failed"]:
|
|
yield f"event: complete\ndata: {json.dumps({'status': current_progress.status})}\n\n"
|
|
break
|
|
|
|
# 等待2秒后再次检查
|
|
await asyncio.sleep(2)
|
|
|
|
except ValueError:
|
|
# 报告不存在,发送错误事件
|
|
yield f"event: error\ndata: {json.dumps({'error': '报告不存在'})}\n\n"
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"流式进度获取错误: {str(e)}")
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"进度流生成错误: {str(e)}")
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_progress_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/{report_id}/reset")
|
|
async def reset_report_progress(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""重置报告生成进度"""
|
|
|
|
try:
|
|
progress_tracker = ProgressTracker(db)
|
|
await progress_tracker.reset_progress(report_id)
|
|
logger.info(f"重置进度成功: {report_id}")
|
|
return {"message": "进度重置成功", "report_id": str(report_id)}
|
|
|
|
except ValueError as e:
|
|
logger.warning(f"报告不存在: {report_id}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=str(e)
|
|
)
|
|
except DatabaseError as e:
|
|
logger.error(f"重置进度时数据库错误: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"数据库错误: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"重置进度失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"重置进度失败: {str(e)}"
|
|
) |