431 lines
16 KiB
Python
431 lines
16 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
|
from fastapi.responses import StreamingResponse, FileResponse
|
|
import json
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
import logging
|
|
import time
|
|
import tempfile
|
|
import markdown
|
|
from weasyprint import HTML
|
|
from urllib.parse import quote
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.services.analysis_service import get_genai_client
|
|
from google.genai import types
|
|
from app.database import get_db
|
|
from app.models import LLMUsageLog
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
messages: List[ChatMessage]
|
|
model: str
|
|
system_prompt: Optional[str] = None
|
|
use_google_search: bool = False
|
|
session_id: Optional[str] = None
|
|
stock_code: Optional[str] = None
|
|
|
|
@router.post("/chat")
|
|
async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None):
|
|
"""AI Chat Endpoint with Logging (Streaming)"""
|
|
# Note: background_tasks argument is not used directly,
|
|
# but we use starlette.background.BackgroundTask in StreamingResponse
|
|
try:
|
|
client = get_genai_client()
|
|
|
|
# Prepare History and Config
|
|
history = []
|
|
if len(request.messages) > 1:
|
|
for msg in request.messages[:-1]:
|
|
history.append(types.Content(
|
|
role="user" if msg.role == "user" else "model",
|
|
parts=[types.Part(text=msg.content)]
|
|
))
|
|
|
|
last_message = request.messages[-1].content if request.messages else ""
|
|
model_name = request.model or "gemini-3-flash-preview"
|
|
|
|
# Search Configuration & System Prompt
|
|
tools = []
|
|
final_system_prompt = request.system_prompt or ""
|
|
|
|
if request.use_google_search:
|
|
tools.append(types.Tool(google_search=types.GoogleSearch()))
|
|
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
|
|
|
|
# Build config - only set system_instruction if it has content
|
|
config_params = {
|
|
"tools": tools if tools else None,
|
|
"temperature": 0.1
|
|
}
|
|
if final_system_prompt and final_system_prompt.strip():
|
|
config_params["system_instruction"] = final_system_prompt
|
|
|
|
config = types.GenerateContentConfig(**config_params)
|
|
|
|
start_time = time.time()
|
|
|
|
# Mutable context to capture data during the stream for logging
|
|
log_context = {
|
|
"full_response_text": "",
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0,
|
|
"response_time": 0
|
|
}
|
|
|
|
# Sync generator to run in threadpool
|
|
def generate():
|
|
grounding_data = None
|
|
|
|
try:
|
|
# Initialize Chat
|
|
chat = client.chats.create(
|
|
model=model_name,
|
|
history=history,
|
|
config=config
|
|
)
|
|
|
|
# Blocking Streaming Call (will run in threadpool)
|
|
response_stream = chat.send_message_stream(last_message)
|
|
|
|
for chunk in response_stream:
|
|
if chunk.text:
|
|
log_context["full_response_text"] += chunk.text
|
|
data = json.dumps({"type": "content", "content": chunk.text})
|
|
yield f"{data}\n"
|
|
|
|
if chunk.usage_metadata:
|
|
log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0
|
|
log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0
|
|
log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0
|
|
|
|
if chunk.candidates and chunk.candidates[0].grounding_metadata:
|
|
gm = chunk.candidates[0].grounding_metadata
|
|
if gm.search_entry_point or gm.grounding_chunks:
|
|
grounding_obj = {}
|
|
if gm.search_entry_point:
|
|
grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content}
|
|
|
|
if gm.grounding_chunks:
|
|
grounding_obj["groundingChunks"] = []
|
|
for g_chunk in gm.grounding_chunks:
|
|
if g_chunk.web:
|
|
grounding_obj["groundingChunks"].append({
|
|
"web": {
|
|
"uri": g_chunk.web.uri,
|
|
"title": g_chunk.web.title
|
|
}
|
|
})
|
|
|
|
if gm.grounding_supports:
|
|
grounding_obj["groundingSupports"] = []
|
|
for support in gm.grounding_supports:
|
|
support_obj = {
|
|
"segment": {
|
|
"startIndex": support.segment.start_index if support.segment else 0,
|
|
"endIndex": support.segment.end_index if support.segment else 0,
|
|
"text": support.segment.text if support.segment else ""
|
|
},
|
|
"groundingChunkIndices": list(support.grounding_chunk_indices)
|
|
}
|
|
grounding_obj["groundingSupports"].append(support_obj)
|
|
|
|
if gm.web_search_queries:
|
|
grounding_obj["webSearchQueries"] = gm.web_search_queries
|
|
|
|
grounding_data = grounding_obj
|
|
|
|
# End of stream
|
|
end_time = time.time()
|
|
log_context["response_time"] = end_time - start_time
|
|
|
|
# Send Metadata Chunk
|
|
if grounding_data:
|
|
yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n"
|
|
|
|
# Send Usage Chunk
|
|
usage_data = {
|
|
"prompt_tokens": log_context["prompt_tokens"],
|
|
"completion_tokens": log_context["completion_tokens"],
|
|
"total_tokens": log_context["total_tokens"],
|
|
"response_time": log_context["response_time"]
|
|
}
|
|
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream generation error: {e}")
|
|
err_json = json.dumps({"type": "error", "error": str(e)})
|
|
yield f"{err_json}\n"
|
|
|
|
# Background Task for logging
|
|
from starlette.background import BackgroundTask
|
|
|
|
async def log_usage(context: dict, req: ChatRequest):
|
|
try:
|
|
# Create a NEW session for the background task
|
|
from app.database import SessionLocal
|
|
async with SessionLocal() as session:
|
|
# Reconstruct prompt for logging
|
|
logged_prompt = ""
|
|
if req.system_prompt:
|
|
logged_prompt += f"System: {req.system_prompt}\n\n"
|
|
for msg in req.messages:
|
|
logged_prompt += f"{msg.role}: {msg.content}\n"
|
|
|
|
log_entry = LLMUsageLog(
|
|
model=req.model or "gemini-3-flash-preview",
|
|
prompt=logged_prompt,
|
|
response=context["full_response_text"],
|
|
response_time=context["response_time"],
|
|
used_google_search=req.use_google_search,
|
|
session_id=req.session_id,
|
|
stock_code=req.stock_code,
|
|
prompt_tokens=context["prompt_tokens"],
|
|
completion_tokens=context["completion_tokens"],
|
|
total_tokens=context["total_tokens"]
|
|
)
|
|
session.add(log_entry)
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Background logging failed: {e}")
|
|
|
|
# Note the usage of 'generate()' (sync) instead of 'generate' (async)
|
|
# and we pass the background task to StreamingResponse
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="application/x-ndjson",
|
|
background=BackgroundTask(log_usage, log_context, request)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/chat/history")
|
|
async def get_chat_history(db: AsyncSession = Depends(get_db)):
|
|
"""Fetch chat history"""
|
|
from sqlalchemy import select, desc
|
|
|
|
# Fetch logs with distinct sessions, ordered by time
|
|
# We want all logs, user can filter on frontend
|
|
query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now
|
|
|
|
result = await db.execute(query)
|
|
logs = result.scalars().all()
|
|
|
|
return [{
|
|
"id": log.id,
|
|
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
|
"model": log.model,
|
|
"prompt": log.prompt,
|
|
"response": log.response,
|
|
"session_id": log.session_id,
|
|
"stock_code": log.stock_code,
|
|
"total_tokens": log.total_tokens,
|
|
"used_google_search": log.used_google_search
|
|
} for log in logs]
|
|
|
|
|
|
class ExportPDFRequest(BaseModel):
|
|
session_ids: List[str]
|
|
|
|
|
|
@router.post("/chat/export-pdf")
|
|
async def export_chat_pdf(request: ExportPDFRequest, db: AsyncSession = Depends(get_db)):
|
|
"""Export selected chat sessions to PDF"""
|
|
from sqlalchemy import select, desc, asc
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
# Timezone definition (Shanghai)
|
|
shanghai_tz = timezone(timedelta(hours=8))
|
|
|
|
if not request.session_ids:
|
|
raise HTTPException(status_code=400, detail="No sessions selected")
|
|
|
|
# Fetch logs for selected sessions
|
|
# Sort strictly by timestamp ASC to interleave sessions
|
|
query = (
|
|
select(LLMUsageLog)
|
|
.where(LLMUsageLog.session_id.in_(request.session_ids))
|
|
.order_by(asc(LLMUsageLog.timestamp))
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
logs = result.scalars().all()
|
|
|
|
if not logs:
|
|
raise HTTPException(status_code=404, detail="No logs found for selected sessions")
|
|
|
|
# Build HTML content
|
|
# No grouping by session anymore; just a flat stream of messages
|
|
|
|
messages_html = ""
|
|
for log in logs:
|
|
# Convert timestamp to Shanghai time
|
|
if log.timestamp:
|
|
ts_shanghai = log.timestamp.astimezone(shanghai_tz) if log.timestamp.tzinfo else log.timestamp.replace(tzinfo=timezone.utc).astimezone(shanghai_tz)
|
|
timestamp = ts_shanghai.strftime('%Y-%m-%d %H:%M:%S')
|
|
else:
|
|
timestamp = ""
|
|
|
|
response_html = markdown.markdown(log.response or "", extensions=['tables', 'fenced_code'])
|
|
|
|
# Add context info (Stock Code / Session) to the header since they are mixed
|
|
stock_info = log.stock_code or "Unknown Stock"
|
|
session_info = log.session_id or "Unknown Session"
|
|
|
|
messages_html += f'''
|
|
<div class="message">
|
|
<div class="message-meta">
|
|
<span class="context-tag">{stock_info}</span>
|
|
<span class="time">{timestamp}</span>
|
|
</div>
|
|
<div class="message-sub-meta">
|
|
<span class="session-id">Session: {session_info}</span>
|
|
</div>
|
|
<div class="message-content">
|
|
{response_html}
|
|
</div>
|
|
</div>
|
|
'''
|
|
|
|
# Complete HTML
|
|
now_shanghai = datetime.now(shanghai_tz)
|
|
|
|
# Updated font-family stack to include common Linux/Windows CJK fonts
|
|
complete_html = f'''
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<title>AI Research Discussion Export</title>
|
|
<style>
|
|
@page {{
|
|
size: A4;
|
|
margin: 2cm 1.5cm;
|
|
}}
|
|
body {{
|
|
font-family: "Noto Sans CJK SC", "Noto Sans SC", "PingFang SC", "Microsoft YaHei", "SimHei", "WenQuanYi Micro Hei", sans-serif;
|
|
line-height: 1.6;
|
|
color: #333;
|
|
font-size: 10pt;
|
|
background-color: white;
|
|
}}
|
|
h1 {{
|
|
text-align: center;
|
|
color: #1a1a1a;
|
|
border-bottom: 2px solid #4a90e2;
|
|
padding-bottom: 10px;
|
|
margin-bottom: 30px;
|
|
}}
|
|
.message {{
|
|
margin: 15px 0;
|
|
padding: 12px;
|
|
background: #f9f9f9;
|
|
border-radius: 6px;
|
|
border-left: 3px solid #4a90e2;
|
|
page-break-inside: avoid;
|
|
}}
|
|
.message-meta {{
|
|
display: flex;
|
|
justify-content: space-between;
|
|
align-items: center;
|
|
margin-bottom: 4px;
|
|
}}
|
|
.context-tag {{
|
|
font-weight: bold;
|
|
color: #2c3e50;
|
|
font-size: 11pt;
|
|
}}
|
|
.time {{
|
|
color: #666;
|
|
font-size: 9pt;
|
|
}}
|
|
.message-sub-meta {{
|
|
display: flex;
|
|
gap: 15px;
|
|
font-size: 8pt;
|
|
color: #888;
|
|
margin-bottom: 8px;
|
|
border-bottom: 1px solid #eee;
|
|
padding-bottom: 4px;
|
|
}}
|
|
.message-content {{
|
|
font-size: 10pt;
|
|
}}
|
|
.message-content p {{
|
|
margin: 5px 0;
|
|
}}
|
|
.message-content ul, .message-content ol {{
|
|
margin: 5px 0;
|
|
padding-left: 20px;
|
|
}}
|
|
table {{
|
|
border-collapse: collapse;
|
|
width: 100%;
|
|
margin: 10px 0;
|
|
font-size: 9pt;
|
|
}}
|
|
th, td {{
|
|
border: 1px solid #ddd;
|
|
padding: 6px;
|
|
text-align: left;
|
|
}}
|
|
th {{
|
|
background-color: #f3f4f6;
|
|
}}
|
|
code {{
|
|
background-color: #f4f4f4;
|
|
padding: 1px 4px;
|
|
border-radius: 3px;
|
|
font-size: 9pt;
|
|
}}
|
|
pre {{
|
|
background-color: #f4f4f4;
|
|
padding: 10px;
|
|
border-radius: 5px;
|
|
overflow-x: auto;
|
|
font-size: 8pt;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<h1>AI 研究讨论导出</h1>
|
|
<p style="text-align: center; color: #666; margin-bottom: 30px;">导出时间: {now_shanghai.strftime('%Y-%m-%d %H:%M:%S')}</p>
|
|
{messages_html}
|
|
</body>
|
|
</html>
|
|
'''
|
|
|
|
# Generate PDF
|
|
try:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
|
pdf_path = tmp_file.name
|
|
HTML(string=complete_html).write_pdf(pdf_path)
|
|
|
|
# Determine filename prefix based on the most common stock code in logs
|
|
# or just the first one found
|
|
company_name = "AI研究讨论"
|
|
if logs and logs[0].stock_code:
|
|
company_name = logs[0].stock_code
|
|
|
|
filename = f"{company_name}_{now_shanghai.strftime('%Y%m%d')}_讨论记录.pdf"
|
|
filename_encoded = quote(filename)
|
|
|
|
return FileResponse(
|
|
path=pdf_path,
|
|
media_type='application/pdf',
|
|
headers={
|
|
'Content-Disposition': f"attachment; filename*=UTF-8''{filename_encoded}"
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate PDF: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to generate PDF: {str(e)}")
|