""" 进度追踪服务 处理报告生成进度的追踪和管理 """ from typing import List, Optional from uuid import UUID from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from ..models.progress_tracking import ProgressTracking from ..schemas.progress import ProgressResponse, StepTiming class ProgressTracker: """进度追踪器""" def __init__(self, db_session: AsyncSession): self.db = db_session async def initialize_progress(self, report_id: UUID): """初始化进度追踪""" # 定义报告生成步骤 steps = [ "初始化报告", "获取财务数据", "生成业务信息", "执行基本面分析", "执行看涨分析", "执行看跌分析", "执行市场分析", "执行新闻分析", "执行交易分析", "执行内部人分析", "生成最终结论", "保存报告" ] # 创建进度记录 for i, step_name in enumerate(steps, 1): progress = ProgressTracking( report_id=report_id, step_name=step_name, step_order=i, status="pending" ) self.db.add(progress) await self.db.flush() async def start_step(self, report_id: UUID, step_name: str): """开始执行步骤""" result = await self.db.execute( select(ProgressTracking).where( ProgressTracking.report_id == report_id, ProgressTracking.step_name == step_name ) ) progress = result.scalar_one_or_none() if progress: progress.status = "running" progress.started_at = datetime.utcnow() await self.db.flush() async def complete_step(self, report_id: UUID, step_name: str, success: bool = True, error_message: Optional[str] = None): """完成步骤""" result = await self.db.execute( select(ProgressTracking).where( ProgressTracking.report_id == report_id, ProgressTracking.step_name == step_name ) ) progress = result.scalar_one_or_none() if progress: progress.status = "completed" if success else "failed" progress.completed_at = datetime.utcnow() progress.error_message = error_message # 计算耗时 if progress.started_at: duration = progress.completed_at - progress.started_at progress.duration_ms = int(duration.total_seconds() * 1000) await self.db.flush() async def get_progress(self, report_id: UUID) -> ProgressResponse: """获取进度信息""" result = await self.db.execute( select(ProgressTracking) .where(ProgressTracking.report_id == report_id) .order_by(ProgressTracking.step_order) ) progress_records = result.scalars().all() if not progress_records: raise ValueError(f"未找到报告 {report_id} 的进度信息") # 计算当前步骤 current_step = 1 current_step_name = "初始化报告" overall_status = "running" completed_count = 0 failed_count = 0 for record in progress_records: if record.status == "completed": completed_count += 1 elif record.status == "failed": failed_count += 1 elif record.status == "running": current_step = record.step_order current_step_name = record.step_name # 确定整体状态 if failed_count > 0: overall_status = "failed" elif completed_count == len(progress_records): overall_status = "completed" # 转换为StepTiming对象 step_timings = [ StepTiming( step_name=record.step_name, step_order=record.step_order, status=record.status, started_at=record.started_at, completed_at=record.completed_at, duration_ms=record.duration_ms, error_message=record.error_message ) for record in progress_records ] return ProgressResponse( report_id=report_id, current_step=current_step, total_steps=len(progress_records), current_step_name=current_step_name, status=overall_status, step_timings=step_timings, estimated_remaining=self._estimate_remaining_time(step_timings) ) def _estimate_remaining_time(self, step_timings: List[StepTiming]) -> Optional[int]: """估算剩余时间""" # 计算已完成步骤的平均耗时 completed_durations = [ timing.duration_ms for timing in step_timings if timing.status == "completed" and timing.duration_ms ] if not completed_durations: return None avg_duration_ms = sum(completed_durations) / len(completed_durations) remaining_steps = len([t for t in step_timings if t.status == "pending"]) return int((avg_duration_ms * remaining_steps) / 1000) # 转换为秒