170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from fastapi.responses import HTMLResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
from app.database import get_db
|
|
from app.schemas import StockSearchRequest, StockSearchResponse, AnalysisRequest, ReportResponse, AnalysisStatus, ConfigUpdateRequest
|
|
from app.models import Report, Setting
|
|
from app.services import analysis_service
|
|
import os
|
|
import markdown
|
|
|
|
router = APIRouter()
|
|
|
|
@router.get("/health")
|
|
def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
@router.post("/search", response_model=list[StockSearchResponse])
|
|
async def search_stock(request: StockSearchRequest, db: AsyncSession = Depends(get_db)):
|
|
setting = await db.get(Setting, "GEMINI_API_KEY")
|
|
api_key = setting.value if setting else os.getenv("GEMINI_API_KEY")
|
|
|
|
if not api_key:
|
|
raise HTTPException(status_code=500, detail="API Key not configured")
|
|
|
|
# Get AI model setting
|
|
model_setting = await db.get(Setting, "AI_MODEL")
|
|
model = model_setting.value if model_setting else "gemini-2.0-flash"
|
|
|
|
result = await analysis_service.search_stock(request.query, api_key, model)
|
|
if isinstance(result, dict) and "error" in result:
|
|
if isinstance(result, str) and "```json" in result:
|
|
pass
|
|
raise HTTPException(status_code=400, detail=str(result))
|
|
return result
|
|
|
|
@router.post("/analyze", response_model=ReportResponse)
|
|
async def start_analysis(request: AnalysisRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db)):
|
|
# Get AI model setting
|
|
model_setting = await db.get(Setting, "AI_MODEL")
|
|
model = model_setting.value if model_setting else "gemini-2.0-flash"
|
|
|
|
new_report = Report(
|
|
market=request.market,
|
|
symbol=request.symbol,
|
|
company_name=request.company_name,
|
|
status=AnalysisStatus.PENDING,
|
|
ai_model=model
|
|
)
|
|
db.add(new_report)
|
|
await db.commit()
|
|
await db.refresh(new_report)
|
|
|
|
setting = await db.get(Setting, "GEMINI_API_KEY")
|
|
api_key = setting.value if setting else os.getenv("GEMINI_API_KEY")
|
|
|
|
if not api_key:
|
|
new_report.status = AnalysisStatus.FAILED
|
|
await db.commit()
|
|
raise HTTPException(status_code=500, detail="API Key not configured")
|
|
|
|
# Trigger background task
|
|
background_tasks.add_task(
|
|
analysis_service.run_analysis_task,
|
|
new_report.id,
|
|
request.market,
|
|
request.symbol,
|
|
api_key
|
|
)
|
|
|
|
# Re-fetch with selectinload to avoid lazy loading issues
|
|
result = await db.execute(select(Report).options(selectinload(Report.sections)).where(Report.id == new_report.id))
|
|
report_with_sections = result.scalar_one()
|
|
|
|
return report_with_sections
|
|
|
|
@router.get("/reports", response_model=list[ReportResponse])
|
|
async def get_reports(db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Report).options(selectinload(Report.sections)).order_by(Report.created_at.desc()))
|
|
return result.scalars().all()
|
|
|
|
@router.get("/reports/{report_id}", response_model=ReportResponse)
|
|
async def get_report(report_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Report).options(selectinload(Report.sections)).where(Report.id == report_id))
|
|
report = result.scalar_one_or_none()
|
|
if not report:
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
return report
|
|
|
|
@router.get("/reports/{report_id}/html", response_class=HTMLResponse)
|
|
async def get_report_html(report_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Report).options(selectinload(Report.sections)).where(Report.id == report_id))
|
|
report = result.scalar_one_or_none()
|
|
if not report:
|
|
raise HTTPException(status_code=404, detail="Report not found")
|
|
|
|
# Get Financial HTML (Charts)
|
|
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
|
base_dir = os.path.join(root_dir, "data", report.market)
|
|
symbol_dir = os.path.join(base_dir, report.symbol)
|
|
|
|
# Fuzzy match logic
|
|
financial_html = ""
|
|
try:
|
|
if not os.path.exists(symbol_dir) and os.path.exists(base_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(report.symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
start_html_path = os.path.join(symbol_dir, "report.html")
|
|
if os.path.exists(start_html_path):
|
|
with open(start_html_path, 'r', encoding='utf-8') as f:
|
|
financial_html = f.read()
|
|
else:
|
|
financial_html = "<p>财务图表尚未生成,数据获取可能仍在进行中。</p>"
|
|
except Exception as e:
|
|
financial_html = f"<p>加载财务图表时出错: {str(e)}</p>"
|
|
|
|
# Only return financial charts, no analysis sections
|
|
final_html = f"""
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<title>{report.company_name} - 财务数据</title>
|
|
<style>
|
|
body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; padding: 20px; line-height: 1.6; max-width: 1200px; margin: 0 auto; }}
|
|
table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
|
|
th, td {{ border: 1px solid #ddd; padding: 12px; }}
|
|
th {{ background-color: #f5f5f5; }}
|
|
img {{ max-width: 100%; }}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
{financial_html}
|
|
</body>
|
|
</html>
|
|
"""
|
|
return final_html
|
|
|
|
@router.get("/config")
|
|
async def get_config(db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Setting))
|
|
settings = result.scalars().all()
|
|
|
|
config_map = {s.key: s.value for s in settings}
|
|
|
|
if "GEMINI_API_KEY" in config_map:
|
|
config_map["GEMINI_API_KEY"] = "********" + config_map["GEMINI_API_KEY"][-4:]
|
|
elif os.getenv("GEMINI_API_KEY"):
|
|
val = os.getenv("GEMINI_API_KEY")
|
|
config_map["GEMINI_API_KEY"] = "********" + val[-4:]
|
|
else:
|
|
config_map["GEMINI_API_KEY"] = ""
|
|
|
|
return config_map
|
|
|
|
@router.post("/config")
|
|
async def update_config(request: ConfigUpdateRequest, db: AsyncSession = Depends(get_db)):
|
|
setting = await db.get(Setting, request.key)
|
|
if not setting:
|
|
setting = Setting(key=request.key, value=request.value)
|
|
db.add(setting)
|
|
else:
|
|
setting.value = request.value
|
|
|
|
await db.commit()
|
|
return {"status": "updated", "key": request.key}
|