298 lines
9.6 KiB
Python
298 lines
9.6 KiB
Python
"""
|
|
报告相关API路由
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
from typing import Optional, List
|
|
from uuid import UUID
|
|
import logging
|
|
|
|
from ..core.dependencies import get_database_session
|
|
from ..models.report import Report
|
|
from ..schemas.report import ReportResponse, RegenerateRequest
|
|
from ..services.report_generator import ReportGenerator
|
|
from ..services.config_manager import ConfigManager
|
|
from ..core.exceptions import ReportGenerationError, DatabaseError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/{symbol}", response_model=ReportResponse)
|
|
async def get_or_create_report(
|
|
symbol: str,
|
|
background_tasks: BackgroundTasks,
|
|
market: str = Query(..., description="交易市场"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取或创建股票报告"""
|
|
|
|
try:
|
|
# 验证输入参数
|
|
symbol = symbol.upper().strip()
|
|
market = market.lower().strip()
|
|
|
|
if not symbol:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="证券代码不能为空"
|
|
)
|
|
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
|
|
# 查询现有报告(包含关联的分析模块)
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(
|
|
Report.symbol == symbol,
|
|
Report.market == market
|
|
)
|
|
)
|
|
existing_report = result.scalar_one_or_none()
|
|
|
|
if existing_report:
|
|
logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}")
|
|
return ReportResponse.from_attributes(existing_report)
|
|
|
|
# 创建新报告
|
|
logger.info(f"开始生成新报告: {symbol}-{market}")
|
|
config_manager = ConfigManager(db)
|
|
report_generator = ReportGenerator(db, config_manager)
|
|
|
|
# 在后台任务中生成报告
|
|
background_tasks.add_task(
|
|
report_generator.generate_report_async,
|
|
symbol,
|
|
market
|
|
)
|
|
|
|
# 创建初始报告记录
|
|
new_report = Report(
|
|
symbol=symbol,
|
|
market=market,
|
|
status="generating"
|
|
)
|
|
db.add(new_report)
|
|
await db.commit()
|
|
await db.refresh(new_report)
|
|
|
|
logger.info(f"创建报告记录: {new_report.id}")
|
|
return ReportResponse.from_attributes(new_report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取或创建报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取或创建报告失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post("/{symbol}/regenerate", response_model=ReportResponse)
|
|
async def regenerate_report(
|
|
symbol: str,
|
|
request: RegenerateRequest,
|
|
background_tasks: BackgroundTasks,
|
|
market: str = Query(..., description="交易市场"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""重新生成报告"""
|
|
|
|
try:
|
|
# 验证输入参数
|
|
symbol = symbol.upper().strip()
|
|
market = market.lower().strip()
|
|
|
|
if not symbol:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="证券代码不能为空"
|
|
)
|
|
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
|
|
# 查询现有报告
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(
|
|
Report.symbol == symbol,
|
|
Report.market == market
|
|
)
|
|
)
|
|
existing_report = result.scalar_one_or_none()
|
|
|
|
if existing_report and not request.force:
|
|
# 如果报告存在且不强制重新生成,返回现有报告
|
|
logger.info(f"返回现有报告: {symbol}-{market}")
|
|
return ReportResponse.from_attributes(existing_report)
|
|
|
|
# 删除现有报告(如果存在)
|
|
if existing_report:
|
|
logger.info(f"删除现有报告: {existing_report.id}")
|
|
await db.delete(existing_report)
|
|
await db.commit()
|
|
|
|
# 创建新报告记录
|
|
new_report = Report(
|
|
symbol=symbol,
|
|
market=market,
|
|
status="generating"
|
|
)
|
|
db.add(new_report)
|
|
await db.commit()
|
|
await db.refresh(new_report)
|
|
|
|
# 在后台任务中生成报告
|
|
config_manager = ConfigManager(db)
|
|
report_generator = ReportGenerator(db, config_manager)
|
|
background_tasks.add_task(
|
|
report_generator.generate_report_async,
|
|
symbol,
|
|
market
|
|
)
|
|
|
|
logger.info(f"开始重新生成报告: {new_report.id}")
|
|
return ReportResponse.from_attributes(new_report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"重新生成报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"重新生成报告失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/{report_id}/details", response_model=ReportResponse)
|
|
async def get_report_details(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告详情"""
|
|
|
|
try:
|
|
result = await db.execute(
|
|
select(Report)
|
|
.options(selectinload(Report.analysis_modules))
|
|
.where(Report.id == report_id)
|
|
)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="报告不存在"
|
|
)
|
|
|
|
logger.info(f"获取报告详情: {report_id}")
|
|
return ReportResponse.from_attributes(report)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取报告详情失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取报告详情失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/", response_model=List[ReportResponse])
|
|
async def list_reports(
|
|
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
|
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
|
|
market: Optional[str] = Query(None, description="按市场筛选"),
|
|
status: Optional[str] = Query(None, description="按状态筛选"),
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""获取报告列表"""
|
|
|
|
try:
|
|
query = select(Report).options(selectinload(Report.analysis_modules))
|
|
|
|
# 添加筛选条件
|
|
if market:
|
|
market = market.lower().strip()
|
|
if market not in ["china", "hongkong", "usa", "japan"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的交易市场"
|
|
)
|
|
query = query.where(Report.market == market)
|
|
|
|
if status:
|
|
status_value = status.lower().strip()
|
|
if status_value not in ["generating", "completed", "failed"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="不支持的状态值"
|
|
)
|
|
query = query.where(Report.status == status_value)
|
|
|
|
# 添加分页和排序
|
|
query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
reports = result.scalars().all()
|
|
|
|
logger.info(f"获取报告列表: {len(reports)} 条记录")
|
|
return [ReportResponse.from_attributes(report) for report in reports]
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取报告列表失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取报告列表失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.delete("/{report_id}")
|
|
async def delete_report(
|
|
report_id: UUID,
|
|
db: AsyncSession = Depends(get_database_session)
|
|
):
|
|
"""删除报告"""
|
|
|
|
try:
|
|
result = await db.execute(
|
|
select(Report).where(Report.id == report_id)
|
|
)
|
|
report = result.scalar_one_or_none()
|
|
|
|
if not report:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="报告不存在"
|
|
)
|
|
|
|
await db.delete(report)
|
|
await db.commit()
|
|
|
|
logger.info(f"删除报告成功: {report_id}")
|
|
return {"message": "报告删除成功", "report_id": str(report_id)}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"删除报告失败: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"删除报告失败: {str(e)}"
|
|
) |