""" 报告相关API路由 """ from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from typing import Optional, List from uuid import UUID import logging from ..core.dependencies import get_database_session from ..models.report import Report from ..schemas.report import ReportResponse, RegenerateRequest from ..services.report_generator import ReportGenerator from ..services.config_manager import ConfigManager from ..core.exceptions import ReportGenerationError, DatabaseError logger = logging.getLogger(__name__) router = APIRouter() @router.get("/{symbol}", response_model=ReportResponse) async def get_or_create_report( symbol: str, background_tasks: BackgroundTasks, market: str = Query(..., description="交易市场"), db: AsyncSession = Depends(get_database_session) ): """获取或创建股票报告""" try: # 验证输入参数 symbol = symbol.upper().strip() market = market.lower().strip() if not symbol: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="证券代码不能为空" ) if market not in ["china", "hongkong", "usa", "japan"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="不支持的交易市场" ) # 查询现有报告(包含关联的分析模块) result = await db.execute( select(Report) .options(selectinload(Report.analysis_modules)) .where( Report.symbol == symbol, Report.market == market ) ) existing_report = result.scalar_one_or_none() if existing_report: logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}") return ReportResponse.from_attributes(existing_report) # 创建新报告 logger.info(f"开始生成新报告: {symbol}-{market}") config_manager = ConfigManager(db) report_generator = ReportGenerator(db, config_manager) # 在后台任务中生成报告 background_tasks.add_task( report_generator.generate_report_async, symbol, market ) # 创建初始报告记录 new_report = Report( symbol=symbol, market=market, status="generating" ) db.add(new_report) await db.commit() await db.refresh(new_report) logger.info(f"创建报告记录: {new_report.id}") return ReportResponse.from_attributes(new_report) except HTTPException: raise except Exception as e: logger.error(f"获取或创建报告失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取或创建报告失败: {str(e)}" ) @router.post("/{symbol}/regenerate", response_model=ReportResponse) async def regenerate_report( symbol: str, request: RegenerateRequest, background_tasks: BackgroundTasks, market: str = Query(..., description="交易市场"), db: AsyncSession = Depends(get_database_session) ): """重新生成报告""" try: # 验证输入参数 symbol = symbol.upper().strip() market = market.lower().strip() if not symbol: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="证券代码不能为空" ) if market not in ["china", "hongkong", "usa", "japan"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="不支持的交易市场" ) # 查询现有报告 result = await db.execute( select(Report) .options(selectinload(Report.analysis_modules)) .where( Report.symbol == symbol, Report.market == market ) ) existing_report = result.scalar_one_or_none() if existing_report and not request.force: # 如果报告存在且不强制重新生成,返回现有报告 logger.info(f"返回现有报告: {symbol}-{market}") return ReportResponse.from_attributes(existing_report) # 删除现有报告(如果存在) if existing_report: logger.info(f"删除现有报告: {existing_report.id}") await db.delete(existing_report) await db.commit() # 创建新报告记录 new_report = Report( symbol=symbol, market=market, status="generating" ) db.add(new_report) await db.commit() await db.refresh(new_report) # 在后台任务中生成报告 config_manager = ConfigManager(db) report_generator = ReportGenerator(db, config_manager) background_tasks.add_task( report_generator.generate_report_async, symbol, market ) logger.info(f"开始重新生成报告: {new_report.id}") return ReportResponse.from_attributes(new_report) except HTTPException: raise except Exception as e: logger.error(f"重新生成报告失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"重新生成报告失败: {str(e)}" ) @router.get("/{report_id}/details", response_model=ReportResponse) async def get_report_details( report_id: UUID, db: AsyncSession = Depends(get_database_session) ): """获取报告详情""" try: result = await db.execute( select(Report) .options(selectinload(Report.analysis_modules)) .where(Report.id == report_id) ) report = result.scalar_one_or_none() if not report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="报告不存在" ) logger.info(f"获取报告详情: {report_id}") return ReportResponse.from_attributes(report) except HTTPException: raise except Exception as e: logger.error(f"获取报告详情失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取报告详情失败: {str(e)}" ) @router.get("/", response_model=List[ReportResponse]) async def list_reports( skip: int = Query(0, ge=0, description="跳过的记录数"), limit: int = Query(10, ge=1, le=100, description="返回的记录数"), market: Optional[str] = Query(None, description="按市场筛选"), status: Optional[str] = Query(None, description="按状态筛选"), db: AsyncSession = Depends(get_database_session) ): """获取报告列表""" try: query = select(Report).options(selectinload(Report.analysis_modules)) # 添加筛选条件 if market: market = market.lower().strip() if market not in ["china", "hongkong", "usa", "japan"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="不支持的交易市场" ) query = query.where(Report.market == market) if status: status_value = status.lower().strip() if status_value not in ["generating", "completed", "failed"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="不支持的状态值" ) query = query.where(Report.status == status_value) # 添加分页和排序 query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit) result = await db.execute(query) reports = result.scalars().all() logger.info(f"获取报告列表: {len(reports)} 条记录") return [ReportResponse.from_attributes(report) for report in reports] except HTTPException: raise except Exception as e: logger.error(f"获取报告列表失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取报告列表失败: {str(e)}" ) @router.delete("/{report_id}") async def delete_report( report_id: UUID, db: AsyncSession = Depends(get_database_session) ): """删除报告""" try: result = await db.execute( select(Report).where(Report.id == report_id) ) report = result.scalar_one_or_none() if not report: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="报告不存在" ) await db.delete(report) await db.commit() logger.info(f"删除报告成功: {report_id}") return {"message": "报告删除成功", "report_id": str(report_id)} except HTTPException: raise except Exception as e: logger.error(f"删除报告失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除报告失败: {str(e)}" )