FA3-Datafetch/backend/app/api/chat_routes.py

199 lines
8.3 KiB
Python

from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
import json
from pydantic import BaseModel
from typing import List, Optional
import logging
import time
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.analysis_service import get_genai_client
from google.genai import types
from app.database import get_db
from app.models import LLMUsageLog
router = APIRouter()
logger = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: str
system_prompt: Optional[str] = None
use_google_search: bool = False
session_id: Optional[str] = None
stock_code: Optional[str] = None
@router.post("/chat")
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""AI Chat Endpoint with Logging (Streaming)"""
try:
client = get_genai_client()
# Prepare History and Config
history = []
# Exclude the last message as it will be sent via send_message
if len(request.messages) > 1:
for msg in request.messages[:-1]:
history.append(types.Content(
role="user" if msg.role == "user" else "model",
parts=[types.Part(text=msg.content)]
))
last_message = request.messages[-1].content if request.messages else ""
model_name = request.model or "gemini-2.5-flash"
# Search Configuration & System Prompt
tools = []
final_system_prompt = request.system_prompt or ""
if request.use_google_search:
tools.append(types.Tool(google_search=types.GoogleSearch()))
# Inject strong instruction to force search usage
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
config = types.GenerateContentConfig(
tools=tools if tools else None,
temperature=0.1,
system_instruction=final_system_prompt
)
start_time = time.time()
async def generate():
full_response_text = ""
grounding_data = None
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
try:
# Initialize Chat
chat = client.chats.create(
model=model_name,
history=history,
config=config
)
# Streaming Call
response_stream = chat.send_message_stream(last_message)
for chunk in response_stream:
if chunk.text:
full_response_text += chunk.text
data = json.dumps({"type": "content", "content": chunk.text})
yield f"{data}\n"
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
if chunk.usage_metadata:
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
total_tokens = chunk.usage_metadata.total_token_count or 0
if chunk.candidates and chunk.candidates[0].grounding_metadata:
gm = chunk.candidates[0].grounding_metadata
# We only expect one final grounding metadata object
if gm.search_entry_point or gm.grounding_chunks:
grounding_obj = {}
if gm.search_entry_point:
grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content}
if gm.grounding_chunks:
grounding_obj["groundingChunks"] = []
for g_chunk in gm.grounding_chunks:
if g_chunk.web:
grounding_obj["groundingChunks"].append({
"web": {
"uri": g_chunk.web.uri,
"title": g_chunk.web.title
}
})
if gm.web_search_queries:
grounding_obj["webSearchQueries"] = gm.web_search_queries
grounding_data = grounding_obj # Save final metadata
# End of stream actions
end_time = time.time()
response_time = end_time - start_time
# Send Metadata Chunk
if grounding_data:
yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n"
# Send Usage Chunk
usage_data = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"response_time": response_time
}
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
# Log to Database (Async)
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
# Actually, we can use the injected `db` session but we must be careful.
# Since we are inside an async generator, awaiting db calls is fine.
try:
# Reconstruct prompt for logging
logged_prompt = ""
if request.system_prompt:
logged_prompt += f"System: {request.system_prompt}\n\n"
for msg in request.messages: # Log full history including last message
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=model_name,
prompt=logged_prompt,
response=full_response_text,
response_time=response_time,
used_google_search=request.use_google_search,
session_id=request.session_id,
stock_code=request.stock_code,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
db.add(log_entry)
await db.commit()
except Exception as db_err:
logger.error(f"Failed to log LLM usage: {db_err}")
except Exception as e:
logger.error(f"Stream generation error: {e}")
err_json = json.dumps({"type": "error", "error": str(e)})
yield f"{err_json}\n"
return StreamingResponse(generate(), media_type="application/x-ndjson")
except Exception as e:
logger.error(f"Chat error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chat/history")
async def get_chat_history(db: AsyncSession = Depends(get_db)):
"""Fetch chat history"""
from sqlalchemy import select, desc
# Fetch logs with distinct sessions, ordered by time
# We want all logs, user can filter on frontend
query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now
result = await db.execute(query)
logs = result.scalars().all()
return [{
"id": log.id,
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
"model": log.model,
"prompt": log.prompt,
"response": log.response,
"session_id": log.session_id,
"stock_code": log.stock_code,
"total_tokens": log.total_tokens,
"used_google_search": log.used_google_search
} for log in logs]