400 lines
14 KiB
Python
400 lines
14 KiB
Python
from fastapi import APIRouter, HTTPException, Depends
|
|
from fastapi.responses import StreamingResponse, FileResponse
|
|
import json
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
import logging
|
|
import time
|
|
import tempfile
|
|
import markdown
|
|
from weasyprint import HTML
|
|
from urllib.parse import quote
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.services.analysis_service import get_genai_client
|
|
from google.genai import types
|
|
from app.database import get_db
|
|
from app.models import LLMUsageLog
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
messages: List[ChatMessage]
|
|
model: str
|
|
system_prompt: Optional[str] = None
|
|
use_google_search: bool = False
|
|
session_id: Optional[str] = None
|
|
stock_code: Optional[str] = None
|
|
|
|
@router.post("/chat")
|
|
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
"""AI Chat Endpoint with Logging (Streaming)"""
|
|
try:
|
|
client = get_genai_client()
|
|
|
|
# Prepare History and Config
|
|
history = []
|
|
# Exclude the last message as it will be sent via send_message
|
|
if len(request.messages) > 1:
|
|
for msg in request.messages[:-1]:
|
|
history.append(types.Content(
|
|
role="user" if msg.role == "user" else "model",
|
|
parts=[types.Part(text=msg.content)]
|
|
))
|
|
|
|
last_message = request.messages[-1].content if request.messages else ""
|
|
|
|
model_name = request.model or "gemini-2.5-flash"
|
|
|
|
# Search Configuration & System Prompt
|
|
tools = []
|
|
final_system_prompt = request.system_prompt or ""
|
|
|
|
if request.use_google_search:
|
|
tools.append(types.Tool(google_search=types.GoogleSearch()))
|
|
# Inject strong instruction to force search usage
|
|
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
|
|
|
|
config = types.GenerateContentConfig(
|
|
tools=tools if tools else None,
|
|
temperature=0.1,
|
|
system_instruction=final_system_prompt
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
async def generate():
|
|
full_response_text = ""
|
|
grounding_data = None
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
total_tokens = 0
|
|
|
|
try:
|
|
# Initialize Chat
|
|
chat = client.chats.create(
|
|
model=model_name,
|
|
history=history,
|
|
config=config
|
|
)
|
|
|
|
# Streaming Call
|
|
response_stream = chat.send_message_stream(last_message)
|
|
|
|
for chunk in response_stream:
|
|
if chunk.text:
|
|
full_response_text += chunk.text
|
|
data = json.dumps({"type": "content", "content": chunk.text})
|
|
yield f"{data}\n"
|
|
|
|
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
|
|
if chunk.usage_metadata:
|
|
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
|
|
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
|
|
total_tokens = chunk.usage_metadata.total_token_count or 0
|
|
|
|
if chunk.candidates and chunk.candidates[0].grounding_metadata:
|
|
gm = chunk.candidates[0].grounding_metadata
|
|
# We only expect one final grounding metadata object
|
|
if gm.search_entry_point or gm.grounding_chunks:
|
|
grounding_obj = {}
|
|
if gm.search_entry_point:
|
|
grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content}
|
|
|
|
if gm.grounding_chunks:
|
|
grounding_obj["groundingChunks"] = []
|
|
for g_chunk in gm.grounding_chunks:
|
|
if g_chunk.web:
|
|
grounding_obj["groundingChunks"].append({
|
|
"web": {
|
|
"uri": g_chunk.web.uri,
|
|
"title": g_chunk.web.title
|
|
}
|
|
})
|
|
if gm.web_search_queries:
|
|
grounding_obj["webSearchQueries"] = gm.web_search_queries
|
|
|
|
grounding_data = grounding_obj # Save final metadata
|
|
|
|
# End of stream actions
|
|
end_time = time.time()
|
|
response_time = end_time - start_time
|
|
|
|
# Send Metadata Chunk
|
|
if grounding_data:
|
|
yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n"
|
|
|
|
# Send Usage Chunk
|
|
usage_data = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": total_tokens,
|
|
"response_time": response_time
|
|
}
|
|
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
|
|
|
|
# Log to Database (Async)
|
|
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
|
|
# Actually, we can use the injected `db` session but we must be careful.
|
|
# Since we are inside an async generator, awaiting db calls is fine.
|
|
try:
|
|
# Reconstruct prompt for logging
|
|
logged_prompt = ""
|
|
if request.system_prompt:
|
|
logged_prompt += f"System: {request.system_prompt}\n\n"
|
|
for msg in request.messages: # Log full history including last message
|
|
logged_prompt += f"{msg.role}: {msg.content}\n"
|
|
|
|
log_entry = LLMUsageLog(
|
|
model=model_name,
|
|
prompt=logged_prompt,
|
|
response=full_response_text,
|
|
response_time=response_time,
|
|
used_google_search=request.use_google_search,
|
|
session_id=request.session_id,
|
|
stock_code=request.stock_code,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens
|
|
)
|
|
db.add(log_entry)
|
|
await db.commit()
|
|
except Exception as db_err:
|
|
logger.error(f"Failed to log LLM usage: {db_err}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream generation error: {e}")
|
|
err_json = json.dumps({"type": "error", "error": str(e)})
|
|
yield f"{err_json}\n"
|
|
|
|
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/chat/history")
|
|
async def get_chat_history(db: AsyncSession = Depends(get_db)):
|
|
"""Fetch chat history"""
|
|
from sqlalchemy import select, desc
|
|
|
|
# Fetch logs with distinct sessions, ordered by time
|
|
# We want all logs, user can filter on frontend
|
|
query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now
|
|
|
|
result = await db.execute(query)
|
|
logs = result.scalars().all()
|
|
|
|
return [{
|
|
"id": log.id,
|
|
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
|
"model": log.model,
|
|
"prompt": log.prompt,
|
|
"response": log.response,
|
|
"session_id": log.session_id,
|
|
"stock_code": log.stock_code,
|
|
"total_tokens": log.total_tokens,
|
|
"used_google_search": log.used_google_search
|
|
} for log in logs]
|
|
|
|
|
|
class ExportPDFRequest(BaseModel):
|
|
session_ids: List[str]
|
|
|
|
|
|
@router.post("/chat/export-pdf")
|
|
async def export_chat_pdf(request: ExportPDFRequest, db: AsyncSession = Depends(get_db)):
|
|
"""Export selected chat sessions to PDF"""
|
|
from sqlalchemy import select, desc, asc
|
|
from datetime import datetime
|
|
|
|
if not request.session_ids:
|
|
raise HTTPException(status_code=400, detail="No sessions selected")
|
|
|
|
# Fetch logs for selected sessions
|
|
query = (
|
|
select(LLMUsageLog)
|
|
.where(LLMUsageLog.session_id.in_(request.session_ids))
|
|
.order_by(asc(LLMUsageLog.timestamp))
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
logs = result.scalars().all()
|
|
|
|
if not logs:
|
|
raise HTTPException(status_code=404, detail="No logs found for selected sessions")
|
|
|
|
# Group logs by session
|
|
sessions = {}
|
|
for log in logs:
|
|
sid = log.session_id or "unknown"
|
|
if sid not in sessions:
|
|
sessions[sid] = {
|
|
"stock_code": log.stock_code or "Unknown",
|
|
"logs": []
|
|
}
|
|
sessions[sid]["logs"].append(log)
|
|
|
|
# Build HTML content
|
|
sections_html = ""
|
|
for session_id in request.session_ids:
|
|
if session_id not in sessions:
|
|
continue
|
|
session_data = sessions[session_id]
|
|
stock_code = session_data["stock_code"]
|
|
|
|
sections_html += f'''
|
|
<div class="session-section">
|
|
<h2>{stock_code}</h2>
|
|
<p class="session-id">Session ID: {session_id}</p>
|
|
'''
|
|
|
|
for log in session_data["logs"]:
|
|
timestamp = log.timestamp.strftime('%Y-%m-%d %H:%M:%S') if log.timestamp else ""
|
|
response_html = markdown.markdown(log.response or "", extensions=['tables', 'fenced_code'])
|
|
|
|
sections_html += f'''
|
|
<div class="message">
|
|
<div class="message-meta">
|
|
<span class="model">Model: {log.model}</span>
|
|
<span class="time">{timestamp}</span>
|
|
</div>
|
|
<div class="message-content">
|
|
{response_html}
|
|
</div>
|
|
</div>
|
|
'''
|
|
|
|
sections_html += '</div>'
|
|
|
|
# Complete HTML
|
|
complete_html = f'''
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<title>AI Research Discussion Export</title>
|
|
<style>
|
|
@page {{
|
|
size: A4;
|
|
margin: 2cm 1.5cm;
|
|
}}
|
|
body {{
|
|
font-family: "PingFang SC", "Microsoft YaHei", "SimHei", sans-serif;
|
|
line-height: 1.6;
|
|
color: #333;
|
|
font-size: 10pt;
|
|
background-color: white;
|
|
}}
|
|
h1 {{
|
|
text-align: center;
|
|
color: #1a1a1a;
|
|
border-bottom: 2px solid #4a90e2;
|
|
padding-bottom: 10px;
|
|
margin-bottom: 30px;
|
|
}}
|
|
h2 {{
|
|
color: #2c3e50;
|
|
font-size: 14pt;
|
|
margin-top: 20px;
|
|
margin-bottom: 5px;
|
|
border-left: 4px solid #4a90e2;
|
|
padding-left: 10px;
|
|
}}
|
|
.session-id {{
|
|
color: #888;
|
|
font-size: 9pt;
|
|
margin-bottom: 15px;
|
|
}}
|
|
.session-section {{
|
|
margin-bottom: 30px;
|
|
page-break-inside: avoid;
|
|
}}
|
|
.message {{
|
|
margin: 15px 0;
|
|
padding: 12px;
|
|
background: #f9f9f9;
|
|
border-radius: 6px;
|
|
border-left: 3px solid #4a90e2;
|
|
}}
|
|
.message-meta {{
|
|
display: flex;
|
|
justify-content: space-between;
|
|
font-size: 9pt;
|
|
color: #666;
|
|
margin-bottom: 8px;
|
|
}}
|
|
.model {{
|
|
font-weight: bold;
|
|
}}
|
|
.message-content {{
|
|
font-size: 10pt;
|
|
}}
|
|
.message-content p {{
|
|
margin: 5px 0;
|
|
}}
|
|
.message-content ul, .message-content ol {{
|
|
margin: 5px 0;
|
|
padding-left: 20px;
|
|
}}
|
|
table {{
|
|
border-collapse: collapse;
|
|
width: 100%;
|
|
margin: 10px 0;
|
|
font-size: 9pt;
|
|
}}
|
|
th, td {{
|
|
border: 1px solid #ddd;
|
|
padding: 6px;
|
|
text-align: left;
|
|
}}
|
|
th {{
|
|
background-color: #f3f4f6;
|
|
}}
|
|
code {{
|
|
background-color: #f4f4f4;
|
|
padding: 1px 4px;
|
|
border-radius: 3px;
|
|
font-size: 9pt;
|
|
}}
|
|
pre {{
|
|
background-color: #f4f4f4;
|
|
padding: 10px;
|
|
border-radius: 5px;
|
|
overflow-x: auto;
|
|
font-size: 8pt;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<h1>AI 研究讨论导出</h1>
|
|
<p style="text-align: center; color: #666; margin-bottom: 30px;">导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
|
{sections_html}
|
|
</body>
|
|
</html>
|
|
'''
|
|
|
|
# Generate PDF
|
|
try:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
|
pdf_path = tmp_file.name
|
|
HTML(string=complete_html).write_pdf(pdf_path)
|
|
|
|
filename = f"AI研究讨论_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"
|
|
filename_encoded = quote(filename)
|
|
|
|
return FileResponse(
|
|
path=pdf_path,
|
|
media_type='application/pdf',
|
|
headers={
|
|
'Content-Disposition': f"attachment; filename*=UTF-8''{filename_encoded}"
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate PDF: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to generate PDF: {str(e)}")
|