199 lines
8.3 KiB
Python
199 lines
8.3 KiB
Python
from fastapi import APIRouter, HTTPException, Depends
|
|
from fastapi.responses import StreamingResponse
|
|
import json
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
import logging
|
|
import time
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.services.analysis_service import get_genai_client
|
|
from google.genai import types
|
|
from app.database import get_db
|
|
from app.models import LLMUsageLog
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
messages: List[ChatMessage]
|
|
model: str
|
|
system_prompt: Optional[str] = None
|
|
use_google_search: bool = False
|
|
session_id: Optional[str] = None
|
|
stock_code: Optional[str] = None
|
|
|
|
@router.post("/chat")
|
|
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
"""AI Chat Endpoint with Logging (Streaming)"""
|
|
try:
|
|
client = get_genai_client()
|
|
|
|
# Prepare History and Config
|
|
history = []
|
|
# Exclude the last message as it will be sent via send_message
|
|
if len(request.messages) > 1:
|
|
for msg in request.messages[:-1]:
|
|
history.append(types.Content(
|
|
role="user" if msg.role == "user" else "model",
|
|
parts=[types.Part(text=msg.content)]
|
|
))
|
|
|
|
last_message = request.messages[-1].content if request.messages else ""
|
|
|
|
model_name = request.model or "gemini-2.5-flash"
|
|
|
|
# Search Configuration & System Prompt
|
|
tools = []
|
|
final_system_prompt = request.system_prompt or ""
|
|
|
|
if request.use_google_search:
|
|
tools.append(types.Tool(google_search=types.GoogleSearch()))
|
|
# Inject strong instruction to force search usage
|
|
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
|
|
|
|
config = types.GenerateContentConfig(
|
|
tools=tools if tools else None,
|
|
temperature=0.1,
|
|
system_instruction=final_system_prompt
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
async def generate():
|
|
full_response_text = ""
|
|
grounding_data = None
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
total_tokens = 0
|
|
|
|
try:
|
|
# Initialize Chat
|
|
chat = client.chats.create(
|
|
model=model_name,
|
|
history=history,
|
|
config=config
|
|
)
|
|
|
|
# Streaming Call
|
|
response_stream = chat.send_message_stream(last_message)
|
|
|
|
for chunk in response_stream:
|
|
if chunk.text:
|
|
full_response_text += chunk.text
|
|
data = json.dumps({"type": "content", "content": chunk.text})
|
|
yield f"{data}\n"
|
|
|
|
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
|
|
if chunk.usage_metadata:
|
|
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
|
|
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
|
|
total_tokens = chunk.usage_metadata.total_token_count or 0
|
|
|
|
if chunk.candidates and chunk.candidates[0].grounding_metadata:
|
|
gm = chunk.candidates[0].grounding_metadata
|
|
# We only expect one final grounding metadata object
|
|
if gm.search_entry_point or gm.grounding_chunks:
|
|
grounding_obj = {}
|
|
if gm.search_entry_point:
|
|
grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content}
|
|
|
|
if gm.grounding_chunks:
|
|
grounding_obj["groundingChunks"] = []
|
|
for g_chunk in gm.grounding_chunks:
|
|
if g_chunk.web:
|
|
grounding_obj["groundingChunks"].append({
|
|
"web": {
|
|
"uri": g_chunk.web.uri,
|
|
"title": g_chunk.web.title
|
|
}
|
|
})
|
|
if gm.web_search_queries:
|
|
grounding_obj["webSearchQueries"] = gm.web_search_queries
|
|
|
|
grounding_data = grounding_obj # Save final metadata
|
|
|
|
# End of stream actions
|
|
end_time = time.time()
|
|
response_time = end_time - start_time
|
|
|
|
# Send Metadata Chunk
|
|
if grounding_data:
|
|
yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n"
|
|
|
|
# Send Usage Chunk
|
|
usage_data = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": total_tokens,
|
|
"response_time": response_time
|
|
}
|
|
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
|
|
|
|
# Log to Database (Async)
|
|
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
|
|
# Actually, we can use the injected `db` session but we must be careful.
|
|
# Since we are inside an async generator, awaiting db calls is fine.
|
|
try:
|
|
# Reconstruct prompt for logging
|
|
logged_prompt = ""
|
|
if request.system_prompt:
|
|
logged_prompt += f"System: {request.system_prompt}\n\n"
|
|
for msg in request.messages: # Log full history including last message
|
|
logged_prompt += f"{msg.role}: {msg.content}\n"
|
|
|
|
log_entry = LLMUsageLog(
|
|
model=model_name,
|
|
prompt=logged_prompt,
|
|
response=full_response_text,
|
|
response_time=response_time,
|
|
used_google_search=request.use_google_search,
|
|
session_id=request.session_id,
|
|
stock_code=request.stock_code,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens
|
|
)
|
|
db.add(log_entry)
|
|
await db.commit()
|
|
except Exception as db_err:
|
|
logger.error(f"Failed to log LLM usage: {db_err}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream generation error: {e}")
|
|
err_json = json.dumps({"type": "error", "error": str(e)})
|
|
yield f"{err_json}\n"
|
|
|
|
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.get("/chat/history")
|
|
async def get_chat_history(db: AsyncSession = Depends(get_db)):
|
|
"""Fetch chat history"""
|
|
from sqlalchemy import select, desc
|
|
|
|
# Fetch logs with distinct sessions, ordered by time
|
|
# We want all logs, user can filter on frontend
|
|
query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now
|
|
|
|
result = await db.execute(query)
|
|
logs = result.scalars().all()
|
|
|
|
return [{
|
|
"id": log.id,
|
|
"timestamp": log.timestamp.isoformat() if log.timestamp else None,
|
|
"model": log.model,
|
|
"prompt": log.prompt,
|
|
"response": log.response,
|
|
"session_id": log.session_id,
|
|
"stock_code": log.stock_code,
|
|
"total_tokens": log.total_tokens,
|
|
"used_google_search": log.used_google_search
|
|
} for log in logs]
|