Fundamental_Analysis/backend/app/services/progress_tracker.py

163 lines
5.5 KiB
Python

"""
进度追踪服务
处理报告生成进度的追踪和管理
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from ..models.progress_tracking import ProgressTracking
from ..schemas.progress import ProgressResponse, StepTiming
class ProgressTracker:
"""进度追踪器"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def initialize_progress(self, report_id: UUID):
"""初始化进度追踪"""
# 定义报告生成步骤
steps = [
"初始化报告",
"获取财务数据",
"生成业务信息",
"执行基本面分析",
"执行看涨分析",
"执行看跌分析",
"执行市场分析",
"执行新闻分析",
"执行交易分析",
"执行内部人分析",
"生成最终结论",
"保存报告"
]
# 创建进度记录
for i, step_name in enumerate(steps, 1):
progress = ProgressTracking(
report_id=report_id,
step_name=step_name,
step_order=i,
status="pending"
)
self.db.add(progress)
await self.db.flush()
async def start_step(self, report_id: UUID, step_name: str):
"""开始执行步骤"""
result = await self.db.execute(
select(ProgressTracking).where(
ProgressTracking.report_id == report_id,
ProgressTracking.step_name == step_name
)
)
progress = result.scalar_one_or_none()
if progress:
progress.status = "running"
progress.started_at = datetime.utcnow()
await self.db.flush()
async def complete_step(self, report_id: UUID, step_name: str, success: bool = True, error_message: Optional[str] = None):
"""完成步骤"""
result = await self.db.execute(
select(ProgressTracking).where(
ProgressTracking.report_id == report_id,
ProgressTracking.step_name == step_name
)
)
progress = result.scalar_one_or_none()
if progress:
progress.status = "completed" if success else "failed"
progress.completed_at = datetime.utcnow()
progress.error_message = error_message
# 计算耗时
if progress.started_at:
duration = progress.completed_at - progress.started_at
progress.duration_ms = int(duration.total_seconds() * 1000)
await self.db.flush()
async def get_progress(self, report_id: UUID) -> ProgressResponse:
"""获取进度信息"""
result = await self.db.execute(
select(ProgressTracking)
.where(ProgressTracking.report_id == report_id)
.order_by(ProgressTracking.step_order)
)
progress_records = result.scalars().all()
if not progress_records:
raise ValueError(f"未找到报告 {report_id} 的进度信息")
# 计算当前步骤
current_step = 1
current_step_name = "初始化报告"
overall_status = "running"
completed_count = 0
failed_count = 0
for record in progress_records:
if record.status == "completed":
completed_count += 1
elif record.status == "failed":
failed_count += 1
elif record.status == "running":
current_step = record.step_order
current_step_name = record.step_name
# 确定整体状态
if failed_count > 0:
overall_status = "failed"
elif completed_count == len(progress_records):
overall_status = "completed"
# 转换为StepTiming对象
step_timings = [
StepTiming(
step_name=record.step_name,
step_order=record.step_order,
status=record.status,
started_at=record.started_at,
completed_at=record.completed_at,
duration_ms=record.duration_ms,
error_message=record.error_message
)
for record in progress_records
]
return ProgressResponse(
report_id=report_id,
current_step=current_step,
total_steps=len(progress_records),
current_step_name=current_step_name,
status=overall_status,
step_timings=step_timings,
estimated_remaining=self._estimate_remaining_time(step_timings)
)
def _estimate_remaining_time(self, step_timings: List[StepTiming]) -> Optional[int]:
"""估算剩余时间"""
# 计算已完成步骤的平均耗时
completed_durations = [
timing.duration_ms for timing in step_timings
if timing.status == "completed" and timing.duration_ms
]
if not completed_durations:
return None
avg_duration_ms = sum(completed_durations) / len(completed_durations)
remaining_steps = len([t for t in step_timings if t.status == "pending"])
return int((avg_duration_ms * remaining_steps) / 1000) # 转换为秒