163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
"""
|
|
进度追踪服务
|
|
处理报告生成进度的追踪和管理
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
from datetime import datetime
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from ..models.progress_tracking import ProgressTracking
|
|
from ..schemas.progress import ProgressResponse, StepTiming
|
|
|
|
|
|
class ProgressTracker:
|
|
"""进度追踪器"""
|
|
|
|
def __init__(self, db_session: AsyncSession):
|
|
self.db = db_session
|
|
|
|
async def initialize_progress(self, report_id: UUID):
|
|
"""初始化进度追踪"""
|
|
|
|
# 定义报告生成步骤
|
|
steps = [
|
|
"初始化报告",
|
|
"获取财务数据",
|
|
"生成业务信息",
|
|
"执行基本面分析",
|
|
"执行看涨分析",
|
|
"执行看跌分析",
|
|
"执行市场分析",
|
|
"执行新闻分析",
|
|
"执行交易分析",
|
|
"执行内部人分析",
|
|
"生成最终结论",
|
|
"保存报告"
|
|
]
|
|
|
|
# 创建进度记录
|
|
for i, step_name in enumerate(steps, 1):
|
|
progress = ProgressTracking(
|
|
report_id=report_id,
|
|
step_name=step_name,
|
|
step_order=i,
|
|
status="pending"
|
|
)
|
|
self.db.add(progress)
|
|
|
|
await self.db.flush()
|
|
|
|
async def start_step(self, report_id: UUID, step_name: str):
|
|
"""开始执行步骤"""
|
|
result = await self.db.execute(
|
|
select(ProgressTracking).where(
|
|
ProgressTracking.report_id == report_id,
|
|
ProgressTracking.step_name == step_name
|
|
)
|
|
)
|
|
progress = result.scalar_one_or_none()
|
|
|
|
if progress:
|
|
progress.status = "running"
|
|
progress.started_at = datetime.utcnow()
|
|
await self.db.flush()
|
|
|
|
async def complete_step(self, report_id: UUID, step_name: str, success: bool = True, error_message: Optional[str] = None):
|
|
"""完成步骤"""
|
|
result = await self.db.execute(
|
|
select(ProgressTracking).where(
|
|
ProgressTracking.report_id == report_id,
|
|
ProgressTracking.step_name == step_name
|
|
)
|
|
)
|
|
progress = result.scalar_one_or_none()
|
|
|
|
if progress:
|
|
progress.status = "completed" if success else "failed"
|
|
progress.completed_at = datetime.utcnow()
|
|
progress.error_message = error_message
|
|
|
|
# 计算耗时
|
|
if progress.started_at:
|
|
duration = progress.completed_at - progress.started_at
|
|
progress.duration_ms = int(duration.total_seconds() * 1000)
|
|
|
|
await self.db.flush()
|
|
|
|
async def get_progress(self, report_id: UUID) -> ProgressResponse:
|
|
"""获取进度信息"""
|
|
result = await self.db.execute(
|
|
select(ProgressTracking)
|
|
.where(ProgressTracking.report_id == report_id)
|
|
.order_by(ProgressTracking.step_order)
|
|
)
|
|
progress_records = result.scalars().all()
|
|
|
|
if not progress_records:
|
|
raise ValueError(f"未找到报告 {report_id} 的进度信息")
|
|
|
|
# 计算当前步骤
|
|
current_step = 1
|
|
current_step_name = "初始化报告"
|
|
overall_status = "running"
|
|
|
|
completed_count = 0
|
|
failed_count = 0
|
|
|
|
for record in progress_records:
|
|
if record.status == "completed":
|
|
completed_count += 1
|
|
elif record.status == "failed":
|
|
failed_count += 1
|
|
elif record.status == "running":
|
|
current_step = record.step_order
|
|
current_step_name = record.step_name
|
|
|
|
# 确定整体状态
|
|
if failed_count > 0:
|
|
overall_status = "failed"
|
|
elif completed_count == len(progress_records):
|
|
overall_status = "completed"
|
|
|
|
# 转换为StepTiming对象
|
|
step_timings = [
|
|
StepTiming(
|
|
step_name=record.step_name,
|
|
step_order=record.step_order,
|
|
status=record.status,
|
|
started_at=record.started_at,
|
|
completed_at=record.completed_at,
|
|
duration_ms=record.duration_ms,
|
|
error_message=record.error_message
|
|
)
|
|
for record in progress_records
|
|
]
|
|
|
|
return ProgressResponse(
|
|
report_id=report_id,
|
|
current_step=current_step,
|
|
total_steps=len(progress_records),
|
|
current_step_name=current_step_name,
|
|
status=overall_status,
|
|
step_timings=step_timings,
|
|
estimated_remaining=self._estimate_remaining_time(step_timings)
|
|
)
|
|
|
|
def _estimate_remaining_time(self, step_timings: List[StepTiming]) -> Optional[int]:
|
|
"""估算剩余时间"""
|
|
# 计算已完成步骤的平均耗时
|
|
completed_durations = [
|
|
timing.duration_ms for timing in step_timings
|
|
if timing.status == "completed" and timing.duration_ms
|
|
]
|
|
|
|
if not completed_durations:
|
|
return None
|
|
|
|
avg_duration_ms = sum(completed_durations) / len(completed_durations)
|
|
remaining_steps = len([t for t in step_timings if t.status == "pending"])
|
|
|
|
return int((avg_duration_ms * remaining_steps) / 1000) # 转换为秒 |