from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse, FileResponse import json from pydantic import BaseModel from typing import List, Optional import logging import time import tempfile import markdown from weasyprint import HTML from urllib.parse import quote from sqlalchemy.ext.asyncio import AsyncSession from app.services.analysis_service import get_genai_client from google.genai import types from app.database import get_db from app.models import LLMUsageLog router = APIRouter() logger = logging.getLogger(__name__) class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] model: str system_prompt: Optional[str] = None use_google_search: bool = False session_id: Optional[str] = None stock_code: Optional[str] = None @router.post("/chat") async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)): """AI Chat Endpoint with Logging (Streaming)""" try: client = get_genai_client() # Prepare History and Config history = [] # Exclude the last message as it will be sent via send_message if len(request.messages) > 1: for msg in request.messages[:-1]: history.append(types.Content( role="user" if msg.role == "user" else "model", parts=[types.Part(text=msg.content)] )) last_message = request.messages[-1].content if request.messages else "" model_name = request.model or "gemini-2.5-flash" # Search Configuration & System Prompt tools = [] final_system_prompt = request.system_prompt or "" if request.use_google_search: tools.append(types.Tool(google_search=types.GoogleSearch())) # Inject strong instruction to force search usage final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge." config = types.GenerateContentConfig( tools=tools if tools else None, temperature=0.1, system_instruction=final_system_prompt ) start_time = time.time() async def generate(): full_response_text = "" grounding_data = None prompt_tokens = 0 completion_tokens = 0 total_tokens = 0 try: # Initialize Chat chat = client.chats.create( model=model_name, history=history, config=config ) # Streaming Call response_stream = chat.send_message_stream(last_message) for chunk in response_stream: if chunk.text: full_response_text += chunk.text data = json.dumps({"type": "content", "content": chunk.text}) yield f"{data}\n" # Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata) if chunk.usage_metadata: prompt_tokens = chunk.usage_metadata.prompt_token_count or 0 completion_tokens = chunk.usage_metadata.candidates_token_count or 0 total_tokens = chunk.usage_metadata.total_token_count or 0 if chunk.candidates and chunk.candidates[0].grounding_metadata: gm = chunk.candidates[0].grounding_metadata # We only expect one final grounding metadata object if gm.search_entry_point or gm.grounding_chunks: grounding_obj = {} if gm.search_entry_point: grounding_obj["searchEntryPoint"] = {"renderedContent": gm.search_entry_point.rendered_content} if gm.grounding_chunks: grounding_obj["groundingChunks"] = [] for g_chunk in gm.grounding_chunks: if g_chunk.web: grounding_obj["groundingChunks"].append({ "web": { "uri": g_chunk.web.uri, "title": g_chunk.web.title } }) if gm.web_search_queries: grounding_obj["webSearchQueries"] = gm.web_search_queries grounding_data = grounding_obj # Save final metadata # End of stream actions end_time = time.time() response_time = end_time - start_time # Send Metadata Chunk if grounding_data: yield f"{json.dumps({'type': 'metadata', 'groundingMetadata': grounding_data})}\n" # Send Usage Chunk usage_data = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "response_time": response_time } yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n" # Log to Database (Async) # Note: creating a new session here because the outer session might be closed or not thread-safe in generator? # Actually, we can use the injected `db` session but we must be careful. # Since we are inside an async generator, awaiting db calls is fine. try: # Reconstruct prompt for logging logged_prompt = "" if request.system_prompt: logged_prompt += f"System: {request.system_prompt}\n\n" for msg in request.messages: # Log full history including last message logged_prompt += f"{msg.role}: {msg.content}\n" log_entry = LLMUsageLog( model=model_name, prompt=logged_prompt, response=full_response_text, response_time=response_time, used_google_search=request.use_google_search, session_id=request.session_id, stock_code=request.stock_code, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens ) db.add(log_entry) await db.commit() except Exception as db_err: logger.error(f"Failed to log LLM usage: {db_err}") except Exception as e: logger.error(f"Stream generation error: {e}") err_json = json.dumps({"type": "error", "error": str(e)}) yield f"{err_json}\n" return StreamingResponse(generate(), media_type="application/x-ndjson") except Exception as e: logger.error(f"Chat error: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/chat/history") async def get_chat_history(db: AsyncSession = Depends(get_db)): """Fetch chat history""" from sqlalchemy import select, desc # Fetch logs with distinct sessions, ordered by time # We want all logs, user can filter on frontend query = select(LLMUsageLog).order_by(desc(LLMUsageLog.timestamp)).limit(100) # Limit 100 for now result = await db.execute(query) logs = result.scalars().all() return [{ "id": log.id, "timestamp": log.timestamp.isoformat() if log.timestamp else None, "model": log.model, "prompt": log.prompt, "response": log.response, "session_id": log.session_id, "stock_code": log.stock_code, "total_tokens": log.total_tokens, "used_google_search": log.used_google_search } for log in logs] class ExportPDFRequest(BaseModel): session_ids: List[str] @router.post("/chat/export-pdf") async def export_chat_pdf(request: ExportPDFRequest, db: AsyncSession = Depends(get_db)): """Export selected chat sessions to PDF""" from sqlalchemy import select, desc, asc from datetime import datetime if not request.session_ids: raise HTTPException(status_code=400, detail="No sessions selected") # Fetch logs for selected sessions # Sort strictly by timestamp ASC to interleave sessions query = ( select(LLMUsageLog) .where(LLMUsageLog.session_id.in_(request.session_ids)) .order_by(asc(LLMUsageLog.timestamp)) ) result = await db.execute(query) logs = result.scalars().all() if not logs: raise HTTPException(status_code=404, detail="No logs found for selected sessions") # Build HTML content # No grouping by session anymore; just a flat stream of messages messages_html = "" for log in logs: timestamp = log.timestamp.strftime('%Y-%m-%d %H:%M:%S') if log.timestamp else "" response_html = markdown.markdown(log.response or "", extensions=['tables', 'fenced_code']) # Add context info (Stock Code / Session) to the header since they are mixed stock_info = log.stock_code or "Unknown Stock" session_info = log.session_id or "Unknown Session" messages_html += f'''
{stock_info} {timestamp}
Session: {session_info} Model: {log.model}
{response_html}
''' # Complete HTML # Updated font-family stack to include common Linux/Windows CJK fonts complete_html = f''' AI Research Discussion Export

AI 研究讨论导出

导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

{messages_html} ''' # Generate PDF try: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: pdf_path = tmp_file.name HTML(string=complete_html).write_pdf(pdf_path) filename = f"AI研究讨论汇总_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf" filename_encoded = quote(filename) return FileResponse( path=pdf_path, media_type='application/pdf', headers={ 'Content-Disposition': f"attachment; filename*=UTF-8''{filename_encoded}" } ) except Exception as e: logger.error(f"Failed to generate PDF: {e}") raise HTTPException(status_code=500, detail=f"Failed to generate PDF: {str(e)}")