FA3-Datafetch/backend/app/api/chat_routes.py
2026-01-29 11:49:48 +00:00

431 lines
16 KiB
Python

from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse, FileResponse
import json
from pydantic import BaseModel
from typing import List, Optional
import logging
import time
import tempfile
import markdown
from weasyprint import HTML
from urllib.parse import quote
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.analysis_service import get_genai_client
from google.genai import types
from app.database import get_db
from app.models import LLMUsageLog
router = APIRouter()
logger = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: str
system_prompt: Optional[str] = None
use_google_search: bool = False
session_id: Optional[str] = None
stock_code: Optional[str] = None
@router.post("/chat")
async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None):
"""AI Chat Endpoint with Logging (Streaming)"""
# Note: background_tasks argument is not used directly,
# but we use starlette.background.BackgroundTask in StreamingResponse
try:
client = get_genai_client()
# Prepare History and Config
history = []
if len(request.messages) > 1:
for msg in request.messages[:-1]:
history.append(types.Content(
role="user" if msg.role == "user" else "model",
parts=[types.Part(text=msg.content)]
))
last_message = request.messages[-1].content if request.messages else ""
model_name = request.model or "gemini-3-flash-preview"
# Search Configuration & System Prompt
tools = []
final_system_prompt = request.system_prompt or ""
if request.use_google_search:
tools.append(types.Tool(google_search=types.GoogleSearch()))
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
# Build config - only set system_instruction if it has content
config_params = {
"tools": tools if tools else None,
"temperature": 0.1
}
if final_system_prompt and final_system_prompt.strip():
config_params["system_instruction"] = final_system_prompt
config = types.GenerateContentConfig(**config_params)
start_time = time.time()
# Mutable context to capture data during the stream for logging
log_context = {
"full_response_text": "",
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_time": 0
}
# Sync generator to run in threadpool
def generate():
grounding_data = None
try:
# Initialize Chat
chat = client.chats.create(
model=model_name,
history=history,
config=config
)
# Blocking Streaming Call (will run in threadpool)
response_stream = chat.send_message_stream(last_message)
for chunk in response_stream:
if chunk.text:
log_context["full_response_text"] += chunk.text
data = json.dumps({"type": "content", "content": chunk.text})
yield f"{data}\n"
if chunk.usage_metadata:
log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0
log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0
log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0
if chunk.candidates and chunk.candidates[0].grounding_metadata:
gm = chunk.candidates[0].grounding_metadata
if gm.search_entry_point or gm.grounding_chunks:
grounding_obj = {}
if gm.search_entry_point:
grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content}
if gm.grounding_chunks:
grounding_obj["groundingChunks"] = []
for g_chunk in gm.grounding_chunks:
if g_chunk.web:
grounding_obj["groundingChunks"].append({
"web": {
"uri": g_chunk.web.uri,
"title": g_chunk.web.title
}
})
if gm.grounding_supports:
grounding_obj["groundingSupports"] = []
for support in gm.grounding_supports:
support_obj = {
"segment": {
"startIndex": support.segment.start_index if support.segment else 0,
"endIndex": support.segment.end_index if support.segment else 0,
"text": support.segment.text if support.segment else ""
},
"groundingChunkIndices": list(support.grounding_chunk_indices)
}
grounding_obj["groundingSupports"].append(support_obj)
if gm.web_search_queries:
grounding_obj["webSearchQueries"] = gm.web_search_queries
grounding_data = grounding_obj
# End of stream
end_time = time.time()
log_context["response_time"] = end_time - start_time
# Send Metadata Chunk
if grounding_data:
yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n"
# Send Usage Chunk
usage_data = {
"prompt_tokens": log_context["prompt_tokens"],
"completion_tokens": log_context["completion_tokens"],
"total_tokens": log_context["total_tokens"],
"response_time": log_context["response_time"]
}
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
except Exception as e:
logger.error(f"Stream generation error: {e}")
err_json = json.dumps({"type": "error", "error": str(e)})
yield f"{err_json}\n"
# Background Task for logging
from starlette.background import BackgroundTask
async def log_usage(context: dict, req: ChatRequest):
try:
# Create a NEW session for the background task
from app.database import SessionLocal
async with SessionLocal() as session:
# Reconstruct prompt for logging
logged_prompt = ""
if req.system_prompt:
logged_prompt += f"System: {req.system_prompt}\n\n"
for msg in req.messages:
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=req.model or "gemini-3-flash-preview",
prompt=logged_prompt,
response=context["full_response_text"],
response_time=context["response_time"],
used_google_search=req.use_google_search,
session_id=req.session_id,
stock_code=req.stock_code,
prompt_tokens=context["prompt_tokens"],
completion_tokens=context["completion_tokens"],
total_tokens=context["total_tokens"]
)
session.add(log_entry)
await session.commit()
except Exception as e:
logger.error(f"Background logging failed: {e}")
# Note the usage of 'generate()' (sync) instead of 'generate' (async)
# and we pass the background task to StreamingResponse
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
background=BackgroundTask(log_usage, log_context, request)
)
except Exception as e:
logger.error(f"Chat error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat/history")
async def get_chat_history(db: AsyncSession = Depends(get_db)):
"""Fetch chat history"""
from sqlalchemy import select, desc
# Fetch logs with distinct sessions, ordered by time
# We want all logs, user can filter on frontend
query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now
result = await db.execute(query)
logs = result.scalars().all()
return [{
"id": log.id,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"model": log.model,
"prompt": log.prompt,
"response": log.response,
"session_id": log.session_id,
"stock_code": log.stock_code,
"total_tokens": log.total_tokens,
"used_google_search": log.used_google_search
} for log in logs]
class ExportPDFRequest(BaseModel):
session_ids: List[str]
@router.post("/chat/export-pdf")
async def export_chat_pdf(request: ExportPDFRequest, db: AsyncSession = Depends(get_db)):
"""Export selected chat sessions to PDF"""
from sqlalchemy import select, desc, asc
from datetime import datetime, timezone, timedelta
# Timezone definition (Shanghai)
shanghai_tz = timezone(timedelta(hours=8))
if not request.session_ids:
raise HTTPException(status_code=400, detail="No sessions selected")
# Fetch logs for selected sessions
# Sort strictly by timestamp ASC to interleave sessions
query = (
select(LLMUsageLog)
.where(LLMUsageLog.session_id.in_(request.session_ids))
.order_by(asc(LLMUsageLog.timestamp))
)
result = await db.execute(query)
logs = result.scalars().all()
if not logs:
raise HTTPException(status_code=404, detail="No logs found for selected sessions")
# Build HTML content
# No grouping by session anymore; just a flat stream of messages
messages_html = ""
for log in logs:
# Convert timestamp to Shanghai time
if log.timestamp:
ts_shanghai = log.timestamp.astimezone(shanghai_tz) if log.timestamp.tzinfo else log.timestamp.replace(tzinfo=timezone.utc).astimezone(shanghai_tz)
timestamp = ts_shanghai.strftime('%Y-%m-%d %H:%M:%S')
else:
timestamp = ""
response_html = markdown.markdown(log.response or "", extensions=['tables', 'fenced_code'])
# Add context info (Stock Code / Session) to the header since they are mixed
stock_info = log.stock_code or "Unknown Stock"
session_info = log.session_id or "Unknown Session"
messages_html += f'''
<div class="message">
<div class="message-meta">
<span class="context-tag">{stock_info}</span>
<span class="time">{timestamp}</span>
</div>
<div class="message-sub-meta">
<span class="session-id">Session: {session_info}</span>
</div>
<div class="message-content">
{response_html}
</div>
</div>
'''
# Complete HTML
now_shanghai = datetime.now(shanghai_tz)
# Updated font-family stack to include common Linux/Windows CJK fonts
complete_html = f'''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>AI Research Discussion Export</title>
<style>
@page {{
size: A4;
margin: 2cm 1.5cm;
}}
body {{
font-family: "Noto Sans CJK SC", "Noto Sans SC", "PingFang SC", "Microsoft YaHei", "SimHei", "WenQuanYi Micro Hei", sans-serif;
line-height: 1.6;
color: #333;
font-size: 10pt;
background-color: white;
}}
h1 {{
text-align: center;
color: #1a1a1a;
border-bottom: 2px solid #4a90e2;
padding-bottom: 10px;
margin-bottom: 30px;
}}
.message {{
margin: 15px 0;
padding: 12px;
background: #f9f9f9;
border-radius: 6px;
border-left: 3px solid #4a90e2;
page-break-inside: avoid;
}}
.message-meta {{
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 4px;
}}
.context-tag {{
font-weight: bold;
color: #2c3e50;
font-size: 11pt;
}}
.time {{
color: #666;
font-size: 9pt;
}}
.message-sub-meta {{
display: flex;
gap: 15px;
font-size: 8pt;
color: #888;
margin-bottom: 8px;
border-bottom: 1px solid #eee;
padding-bottom: 4px;
}}
.message-content {{
font-size: 10pt;
}}
.message-content p {{
margin: 5px 0;
}}
.message-content ul, .message-content ol {{
margin: 5px 0;
padding-left: 20px;
}}
table {{
border-collapse: collapse;
width: 100%;
margin: 10px 0;
font-size: 9pt;
}}
th, td {{
border: 1px solid #ddd;
padding: 6px;
text-align: left;
}}
th {{
background-color: #f3f4f6;
}}
code {{
background-color: #f4f4f4;
padding: 1px 4px;
border-radius: 3px;
font-size: 9pt;
}}
pre {{
background-color: #f4f4f4;
padding: 10px;
border-radius: 5px;
overflow-x: auto;
font-size: 8pt;
}}
</style>
</head>
<body>
<h1>AI 研究讨论导出</h1>
<p style="text-align: center; color: #666; margin-bottom: 30px;">导出时间: {now_shanghai.strftime('%Y-%m-%d %H:%M:%S')}</p>
{messages_html}
</body>
</html>
'''
# Generate PDF
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
pdf_path = tmp_file.name
HTML(string=complete_html).write_pdf(pdf_path)
# Determine filename prefix based on the most common stock code in logs
# or just the first one found
company_name = "AI研究讨论"
if logs and logs[0].stock_code:
company_name = logs[0].stock_code
filename = f"{company_name}_{now_shanghai.strftime('%Y%m%d')}_讨论记录.pdf"
filename_encoded = quote(filename)
return FileResponse(
path=pdf_path,
media_type='application/pdf',
headers={
'Content-Disposition': f"attachment; filename*=UTF-8''{filename_encoded}"
}
)
except Exception as e:
logger.error(f"Failed to generate PDF: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate PDF: {str(e)}")