414 lines
15 KiB
Python
414 lines
15 KiB
Python
"""
|
|
报告相关API路由
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
from typing import Optional, List, AsyncGenerator
|
|
from uuid import UUID
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
|
|
from ..core.dependencies import get_database_session
|
|
from ..models.report import Report
|
|
from ..schemas.report import ReportResponse, RegenerateRequest
|
|
from ..schemas.progress import ProgressResponse
|
|
from ..services.report_generator import ReportGenerator
|
|
from ..services.config_manager import ConfigManager
|
|
from ..services.progress_tracker import ProgressTracker
|
|
from ..core.exceptions import ReportGenerationError, DatabaseError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/{symbol}", response_model=ReportResponse)
|
|
async def get_or_create_report(
|
|
symbol: str,
|
|
background_tasks: BackgroundTasks,
|
|
market: str = Query(..., description="交易市场"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取或创建股票报告"""
|
|
|
|
try:
|
|
# 验证输入参数
|
|
symbol = symbol.upper().strip()
|
|
market = market.lower().strip()
|
|
|
|
if not symbol:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="证券代码不能为空"
|
|
)
|
|
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
|
|
# 查询现有报告(包含关联的分析模块)
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(
|
|
Report.symbol == symbol,
|
|
Report.market == market
|
|
)
|
|
)
|
|
existing_report = result.scalar_one_or_none()
|
|
|
|
if existing_report:
|
|
logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}")
|
|
return ReportResponse.from_attributes(existing_report)
|
|
|
|
# 创建新报告
|
|
logger.info(f"开始生成新报告: {symbol}-{market}")
|
|
config_manager = ConfigManager(db)
|
|
report_generator = ReportGenerator(db, config_manager)
|
|
|
|
# 在后台任务中生成报告
|
|
background_tasks.add_task(
|
|
report_generator.generate_report_async,
|
|
symbol,
|
|
market
|
|
)
|
|
|
|
# 创建初始报告记录
|
|
new_report = Report(
|
|
symbol=symbol,
|
|
market=market,
|
|
status="generating"
|
|
)
|
|
db.add(new_report)
|
|
await db.commit()
|
|
await db.refresh(new_report)
|
|
|
|
logger.info(f"创建报告记录: {new_report.id}")
|
|
return ReportResponse.from_attributes(new_report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取或创建报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取或创建报告失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post("/{symbol}/regenerate", response_model=ReportResponse)
|
|
async def regenerate_report(
|
|
symbol: str,
|
|
request: RegenerateRequest,
|
|
background_tasks: BackgroundTasks,
|
|
market: str = Query(..., description="交易市场"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""重新生成报告"""
|
|
|
|
try:
|
|
# 验证输入参数
|
|
symbol = symbol.upper().strip()
|
|
market = market.lower().strip()
|
|
|
|
if not symbol:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="证券代码不能为空"
|
|
)
|
|
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
|
|
# 查询现有报告
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(
|
|
Report.symbol == symbol,
|
|
Report.market == market
|
|
)
|
|
)
|
|
existing_report = result.scalar_one_or_none()
|
|
|
|
if existing_report and not request.force:
|
|
# 如果报告存在且不强制重新生成,返回现有报告
|
|
logger.info(f"返回现有报告: {symbol}-{market}")
|
|
return ReportResponse.from_attributes(existing_report)
|
|
|
|
# 删除现有报告(如果存在)
|
|
if existing_report:
|
|
logger.info(f"删除现有报告: {existing_report.id}")
|
|
await db.delete(existing_report)
|
|
await db.commit()
|
|
|
|
# 创建新报告记录
|
|
new_report = Report(
|
|
symbol=symbol,
|
|
market=market,
|
|
status="generating"
|
|
)
|
|
db.add(new_report)
|
|
await db.commit()
|
|
await db.refresh(new_report)
|
|
|
|
# 在后台任务中生成报告
|
|
config_manager = ConfigManager(db)
|
|
report_generator = ReportGenerator(db, config_manager)
|
|
background_tasks.add_task(
|
|
report_generator.generate_report_async,
|
|
symbol,
|
|
market
|
|
)
|
|
|
|
logger.info(f"开始重新生成报告: {new_report.id}")
|
|
return ReportResponse.from_attributes(new_report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"重新生成报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"重新生成报告失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/{report_id}/details", response_model=ReportResponse)
|
|
async def get_report_details(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告详情"""
|
|
|
|
try:
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(Report.id == report_id)
|
|
)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="报告不存在"
|
|
)
|
|
|
|
logger.info(f"获取报告详情: {report_id}")
|
|
return ReportResponse.from_attributes(report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取报告详情失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取报告详情失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/", response_model=List[ReportResponse])
|
|
async def list_reports(
|
|
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
|
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
|
|
market: Optional[str] = Query(None, description="按市场筛选"),
|
|
status: Optional[str] = Query(None, description="按状态筛选"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告列表"""
|
|
|
|
try:
|
|
query = select(Report).options(selectinload(Report.analysis_modules))
|
|
|
|
# 添加筛选条件
|
|
if market:
|
|
market = market.lower().strip()
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
query = query.where(Report.market == market)
|
|
|
|
if status:
|
|
status_value = status.lower().strip()
|
|
if status_value not in ["generating", "completed", "partial", "failed"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的状态值"
|
|
)
|
|
query = query.where(Report.status == status_value)
|
|
|
|
# 添加分页和排序
|
|
query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
reports = result.scalars().all()
|
|
|
|
logger.info(f"获取报告列表: {len(reports)} 条记录")
|
|
return [ReportResponse.from_attributes(report) for report in reports]
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取报告列表失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取报告列表失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/{report_id}/progress", response_model=ProgressResponse)
|
|
async def get_report_progress(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告生成进度"""
|
|
|
|
try:
|
|
progress_tracker = ProgressTracker(db)
|
|
progress = await progress_tracker.get_progress(report_id)
|
|
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
|
|
return progress
|
|
|
|
except ValueError as e:
|
|
logger.warning(f"报告不存在或无进度记录: {report_id}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=str(e)
|
|
)
|
|
except DatabaseError as e:
|
|
logger.error(f"获取进度时数据库错误: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"数据库错误: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"获取进度失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取进度失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/{report_id}/progress/stream")
|
|
async def stream_report_progress(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""实时流式获取报告生成进度 (Server-Sent Events)"""
|
|
|
|
async def generate_progress_stream() -> AsyncGenerator[str, None]:
|
|
"""生成进度流"""
|
|
progress_tracker = ProgressTracker(db)
|
|
last_progress = None
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
# 获取当前进度
|
|
current_progress = await progress_tracker.get_progress(report_id)
|
|
|
|
# 只有进度发生变化时才发送更新
|
|
if current_progress != last_progress:
|
|
progress_data = {
|
|
"reportId": str(current_progress.report_id),
|
|
"currentStep": current_progress.current_step,
|
|
"totalSteps": current_progress.total_steps,
|
|
"currentStepName": current_progress.current_step_name,
|
|
"status": current_progress.status,
|
|
"estimatedRemaining": current_progress.estimated_remaining,
|
|
"steps": [
|
|
{
|
|
"id": str(step.step_order),
|
|
"name": step.step_name,
|
|
"status": step.status,
|
|
"startedAt": step.started_at.isoformat() if step.started_at else None,
|
|
"completedAt": step.completed_at.isoformat() if step.completed_at else None,
|
|
"durationMs": step.duration_ms,
|
|
"errorMessage": step.error_message
|
|
}
|
|
for step in current_progress.step_timings
|
|
]
|
|
}
|
|
|
|
# 发送SSE格式的数据
|
|
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
|
last_progress = current_progress
|
|
|
|
# 如果报告已完成或失败,发送完成事件并结束流
|
|
if current_progress.status in ["completed", "failed"]:
|
|
yield f"event: complete\ndata: {json.dumps({'status': current_progress.status})}\n\n"
|
|
break
|
|
|
|
# 等待2秒后再次检查
|
|
await asyncio.sleep(2)
|
|
|
|
except ValueError:
|
|
# 报告不存在,发送错误事件
|
|
yield f"event: error\ndata: {json.dumps({'error': '报告不存在'})}\n\n"
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"流式进度获取错误: {str(e)}")
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"进度流生成错误: {str(e)}")
|
|
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate_progress_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
}
|
|
)
|
|
|
|
|
|
@router.delete("/{report_id}")
|
|
async def delete_report(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""删除报告"""
|
|
|
|
try:
|
|
result = await db.execute(
|
|
select(Report).where(Report.id == report_id)
|
|
)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="报告不存在"
|
|
)
|
|
|
|
await db.delete(report)
|
|
await db.commit()
|
|
|
|
logger.info(f"删除报告成功: {report_id}")
|
|
return {"message": "报告删除成功", "report_id": str(report_id)}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"删除报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"删除报告失败: {str(e)}"
|
|
) |