From e5e72205e85c8ca8bb40a1742fe0397560bdecf3 Mon Sep 17 00:00:00 2001 From: xucheng Date: Tue, 20 Jan 2026 10:27:38 +0800 Subject: [PATCH] feat: Implement grounding citation display in AI discussion and upgrade default LLM to Gemini 3 Flash Preview. --- backend/app/api/chat_routes.py | 136 +++++++++------- backend/app/main.py | 22 +-- backend/app/services/llm_engine.py | 2 +- .../src/components/ai-discussion-view.tsx | 147 +++++++++++++++++- frontend/src/components/analysis-trigger.tsx | 2 +- 5 files changed, 237 insertions(+), 72 deletions(-) diff --git a/backend/app/api/chat_routes.py b/backend/app/api/chat_routes.py index c408a22..a761f43 100644 --- a/backend/app/api/chat_routes.py +++ b/backend/app/api/chat_routes.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks from fastapi.responses import StreamingResponse, FileResponse import json from pydantic import BaseModel @@ -31,14 +31,15 @@ class ChatRequest(BaseModel): stock_code: Optional[str] = None @router.post("/chat") -async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)): +async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None): """AI Chat Endpoint with Logging (Streaming)""" + # Note: background_tasks argument is not used directly, + # but we use starlette.background.BackgroundTask in StreamingResponse try: client = get_genai_client() # Prepare History and Config history = [] - # Exclude the last message as it will be sent via send_message if len(request.messages) > 1: for msg in request.messages[:-1]: history.append(types.Content( @@ -47,8 +48,7 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) )) last_message = request.messages[-1].content if request.messages else "" - - model_name = request.model or "gemini-2.5-flash" + model_name = request.model or "gemini-3-flash-preview" # Search Configuration & System Prompt tools = [] @@ -56,7 +56,6 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) if request.use_google_search: tools.append(types.Tool(google_search=types.GoogleSearch())) - # Inject strong instruction to force search usage final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge." config = types.GenerateContentConfig( @@ -67,12 +66,18 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) start_time = time.time() - async def generate(): - full_response_text = "" + # Mutable context to capture data during the stream for logging + log_context = { + "full_response_text": "", + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "response_time": 0 + } + + # Sync generator to run in threadpool + def generate(): grounding_data = None - prompt_tokens = 0 - completion_tokens = 0 - total_tokens = 0 try: # Initialize Chat @@ -82,24 +87,22 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) config=config ) - # Streaming Call + # Blocking Streaming Call (will run in threadpool) response_stream = chat.send_message_stream(last_message) for chunk in response_stream: if chunk.text: - full_response_text += chunk.text + log_context["full_response_text"] += chunk.text data = json.dumps({"type": "content", "content": chunk.text}) yield f"{data}\n" - # Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata) if chunk.usage_metadata: - prompt_tokens = chunk.usage_metadata.prompt_token_count or 0 - completion_tokens = chunk.usage_metadata.candidates_token_count or 0 - total_tokens = chunk.usage_metadata.total_token_count or 0 + log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0 + log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0 + log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0 if chunk.candidates and chunk.candidates[0].grounding_metadata: gm = chunk.candidates[0].grounding_metadata - # We only expect one final grounding metadata object if gm.search_entry_point or gm.grounding_chunks: grounding_obj = {} if gm.search_entry_point: @@ -115,14 +118,28 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) "title": g_chunk.web.title } }) + + if gm.grounding_supports: + grounding_obj["groundingSupports"] = [] + for support in gm.grounding_supports: + support_obj = { + "segment": { + "startIndex": support.segment.start_index if support.segment else 0, + "endIndex": support.segment.end_index if support.segment else 0, + "text": support.segment.text if support.segment else "" + }, + "groundingChunkIndices": list(support.grounding_chunk_indices) + } + grounding_obj["groundingSupports"].append(support_obj) + if gm.web_search_queries: grounding_obj["webSearchQueries"] = gm.web_search_queries - grounding_data = grounding_obj # Save final metadata + grounding_data = grounding_obj - # End of stream actions + # End of stream end_time = time.time() - response_time = end_time - start_time + log_context["response_time"] = end_time - start_time # Send Metadata Chunk if grounding_data: @@ -130,48 +147,57 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)) # Send Usage Chunk usage_data = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "response_time": response_time + "prompt_tokens": log_context["prompt_tokens"], + "completion_tokens": log_context["completion_tokens"], + "total_tokens": log_context["total_tokens"], + "response_time": log_context["response_time"] } yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n" - # Log to Database (Async) - # Note: creating a new session here because the outer session might be closed or not thread-safe in generator? - # Actually, we can use the injected `db` session but we must be careful. - # Since we are inside an async generator, awaiting db calls is fine. - try: - # Reconstruct prompt for logging - logged_prompt = "" - if request.system_prompt: - logged_prompt += f"System: {request.system_prompt}\n\n" - for msg in request.messages: # Log full history including last message - logged_prompt += f"{msg.role}: {msg.content}\n" - - log_entry = LLMUsageLog( - model=model_name, - prompt=logged_prompt, - response=full_response_text, - response_time=response_time, - used_google_search=request.use_google_search, - session_id=request.session_id, - stock_code=request.stock_code, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens - ) - db.add(log_entry) - await db.commit() - except Exception as db_err: - logger.error(f"Failed to log LLM usage: {db_err}") - except Exception as e: logger.error(f"Stream generation error: {e}") err_json = json.dumps({"type": "error", "error": str(e)}) yield f"{err_json}\n" - return StreamingResponse(generate(), media_type="application/x-ndjson") + # Background Task for logging + from starlette.background import BackgroundTask + + async def log_usage(context: dict, req: ChatRequest): + try: + # Create a NEW session for the background task + from app.database import SessionLocal + async with SessionLocal() as session: + # Reconstruct prompt for logging + logged_prompt = "" + if req.system_prompt: + logged_prompt += f"System: {req.system_prompt}\n\n" + for msg in req.messages: + logged_prompt += f"{msg.role}: {msg.content}\n" + + log_entry = LLMUsageLog( + model=req.model or "gemini-3-flash-preview", + prompt=logged_prompt, + response=context["full_response_text"], + response_time=context["response_time"], + used_google_search=req.use_google_search, + session_id=req.session_id, + stock_code=req.stock_code, + prompt_tokens=context["prompt_tokens"], + completion_tokens=context["completion_tokens"], + total_tokens=context["total_tokens"] + ) + session.add(log_entry) + await session.commit() + except Exception as e: + logger.error(f"Background logging failed: {e}") + + # Note the usage of 'generate()' (sync) instead of 'generate' (async) + # and we pass the background task to StreamingResponse + return StreamingResponse( + generate(), + media_type="application/x-ndjson", + background=BackgroundTask(log_usage, log_context, request) + ) except Exception as e: logger.error(f"Chat error: {e}") diff --git a/backend/app/main.py b/backend/app/main.py index 997468e..0c5884b 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -116,16 +116,16 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)): # Default models list as requested import json default_models = [ - "gemini-2.0-flash", - "gemini-2.5-flash", "gemini-3-flash-preview", - "gemini-3-pro-preview" + "gemini-3-pro-preview", + "gemini-2.0-flash", + "gemini-2.5-flash" ] config_map["available_models"] = json.dumps(default_models) # Ensure defaults for other keys exist if not in DB defaults = { - "ai_model": "gemini-2.5-flash", + "ai_model": "gemini-3-flash-preview", "data_source_cn": "Tushare", "data_source_hk": "iFinD", "data_source_us": "iFinD", @@ -144,12 +144,12 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)): print(f"Config read error (using defaults): {e}") import json return { - "ai_model": "gemini-2.5-flash", + "ai_model": "gemini-3-flash-preview", "available_models": json.dumps([ - "gemini-2.0-flash", - "gemini-2.5-flash", "gemini-3-flash-preview", - "gemini-3-pro-preview" + "gemini-3-pro-preview", + "gemini-2.0-flash", + "gemini-2.5-flash" ]), "data_source_cn": "Tushare", "data_source_hk": "iFinD", @@ -193,7 +193,7 @@ logger = logging.getLogger(__name__) class StockSearchRequest(BaseModel): query: str - model: str = "gemini-2.0-flash" # 支持前端传入模型参数 + model: str = "gemini-3-flash-preview" # 支持前端传入模型参数 class StockSearchResponse(BaseModel): market: str @@ -238,8 +238,8 @@ async def search_stock(request: StockSearchRequest): # 启用 Google Search Grounding grounding_tool = types.Tool(google_search=types.GoogleSearch()) - # 使用请求中的模型,默认为 gemini-2.5-flash - model_name = request.model or "gemini-2.5-flash" + # 使用请求中的模型,默认为 gemini-3-flash-preview + model_name = request.model or "gemini-3-flash-preview" logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索") llm_start = time.time() diff --git a/backend/app/services/llm_engine.py b/backend/app/services/llm_engine.py index 952aef1..1685129 100644 --- a/backend/app/services/llm_engine.py +++ b/backend/app/services/llm_engine.py @@ -131,7 +131,7 @@ async def process_analysis_steps(report_id: int, company_name: str, symbol: str, # Get AI model from settings model_setting = await db.get(Setting, "AI_MODEL") - model_name = model_setting.value if model_setting else "gemini-2.0-flash" + model_name = model_setting.value if model_setting else "gemini-3-flash-preview" # Prepare all API calls concurrently async def process_section(key: str, name: str): diff --git a/frontend/src/components/ai-discussion-view.tsx b/frontend/src/components/ai-discussion-view.tsx index b75c540..09ad42c 100644 --- a/frontend/src/components/ai-discussion-view.tsx +++ b/frontend/src/components/ai-discussion-view.tsx @@ -16,12 +16,151 @@ import { Switch } from "@/components/ui/switch" import ReactMarkdown from "react-markdown" import remarkGfm from "remark-gfm" +// Helper function to insert citations into text +function addCitations(text: string, groundingMetadata: any) { + if (!groundingMetadata || !groundingMetadata.groundingSupports || !groundingMetadata.groundingChunks) { + return text; + } + + const supports = groundingMetadata.groundingSupports; + const chunks = groundingMetadata.groundingChunks; + + // 1. Collect all citations with their original positions + interface Citation { + index: number; + text: string; + } + const citations: Citation[] = []; + + supports.forEach((support: any) => { + const endIndex = support.segment.endIndex; + // Optimization: Ignore supports that cover the entire text (or close to it) + // These are usually "general" grounding for the whole answer and cause a massive pile of citations at the end. + const startIndex = support.segment.startIndex || 0; + const textLen = text.length; + + // Aggressive filter: If segment starts near the beginning and ends near the end, skip it. + // This targets the "summary grounding" which lists all sources used in the answer. + if (startIndex < 5 && endIndex >= textLen - 5) { + return; + } + + const chunkIndices = support.groundingChunkIndices; + if (endIndex === undefined || !chunkIndices || chunkIndices.length === 0) return; + + const label = chunkIndices.map((idx: number) => { + const chunk = chunks[idx]; + if (chunk?.web?.uri) { + // Use parenthesis style (1), (2) + return `[(${idx + 1})](${chunk.web.uri})`; + } + return ""; + }).join(""); + + if (label) { + citations.push({ index: endIndex, text: label }); + } + }); + + // 2. Identify numeric ranges to avoid splitting + const forbiddenRanges: { start: number, end: number }[] = []; + // Matches numbers like integers, decimals, comma-separated (simple) + // Heuristic: digits sticking together or with dot/comma inside + const regex = /\d+(?:[.,]\d+)*/g; + let match; + while ((match = regex.exec(text)) !== null) { + forbiddenRanges.push({ start: match.index, end: match.index + match[0].length }); + } + + // 3. Move citations out of forbidden ranges + citations.forEach(cit => { + for (const range of forbiddenRanges) { + if (cit.index > range.start && cit.index < range.end) { + // Move to END of number is safer for reading + cit.index = range.end; + break; + } + } + }); + + // 4. Sort citations by index + citations.sort((a, b) => a.index - b.index); + + // 5. Reconstruct string with Deduplication + let result = ""; + let curIndex = 0; + + let i = 0; + while (i < citations.length) { + const batchIndex = citations[i].index; + // Append text up to this index + if (batchIndex > curIndex) { + result += text.slice(curIndex, batchIndex); + curIndex = batchIndex; + } + + // Collect ALL citation chunks at this exact position + const uniqueChunks = new Set(); + // Iterate through all citations that fall at this 'batchIndex' + while (i < citations.length && citations[i].index === batchIndex) { + // citations[i].text might be multiple emojis like 1️⃣2️⃣ if originally from one support + // But our 'citations' array is built from supports. + // Ideally we should have pushed individual chunk indices to 'citations' array to make dedup easier. + // But currently 'citations' items are pre-rendered strings. + // Let's refine step 1 to push raw data instead if we want robust dedup, + // OR just dedup the string representation if they are identical? + // Actually, the issue is that multiple supports (sentence vs paragraph) might point to same chunk ID [1]. + // So we need to parse the pre-rendered string or change Step 1. + + // Let's rely on the parsing of the text content we built in step 1. + // The text is like [1️⃣](uri). We can use that as the unique key. + // However, a single support might start with MULTIPLE chunks. + + // Hacky but effective: Just add the text to Set. + // If "source 1" is cited twice at this location, it will be deduped. + // But if "source 1 and 2" are cited, and then "source 1" is cited again, we need granular dedup. + + // To do this properly, let's look at how we built `citations`. + // We just appended strings. Let's fix this block to be smarter. + + // Simple string dedup for now will solve the "exact duplicate" issue (same chunk, same position). + // It won't solve "Source 1" appearing in "Source 1,2" group. + // But visual flooding is usually exact duplicates. + + // BETTER APPROACH: We already have the raw data in the loop above? No, we are in a new loop. + // Let's just output distinct strings at this index. + + // We can extract all `[...](...)` patterns from the text and dedup them. + const rawText = citations[i].text; + const linkRegex = /\[.*?\]\(.*?\)/g; + let linkMatch; + while ((linkMatch = linkRegex.exec(rawText)) !== null) { + uniqueChunks.add(linkMatch[0]); + } + i++; + } + + // Render unique chunks + result += Array.from(uniqueChunks).join(""); + } + // Append remaining text + if (curIndex < text.length) { + result += text.slice(curIndex); + } + + return result; +} + interface Message { role: "user" | "model" content: string groundingMetadata?: { searchEntryPoint?: { renderedContent: string } groundingChunks?: Array<{ web?: { uri: string; title: string } }> + groundingSupports?: Array<{ + segment: { startIndex: number; endIndex: number; text: string } + groundingChunkIndices: number[] + }> webSearchQueries?: string[] } searchStatus?: { @@ -46,12 +185,12 @@ const DEFAULT_ROLES: RoleConfig[] = [ export function AiDiscussionView({ companyName, symbol, market }: { companyName: string, symbol: string, market: string }) { const [roles, setRoles] = useState([]) - const [availableModels, setAvailableModels] = useState(["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"]) + const [availableModels, setAvailableModels] = useState(["gemini-3-flash-preview", "gemini-3-pro-preview", "gemini-2.0-flash", "gemini-1.5-pro"]) const [questionLibrary, setQuestionLibrary] = useState(null) // Left Chat State const [leftRole, setLeftRole] = useState("") - const [leftModel, setLeftModel] = useState("gemini-2.0-flash") + const [leftModel, setLeftModel] = useState("gemini-3-flash-preview") const [leftMessages, setLeftMessages] = useState([]) const [leftInput, setLeftInput] = useState("") const [leftLoading, setLeftLoading] = useState(false) @@ -60,7 +199,7 @@ export function AiDiscussionView({ companyName, symbol, market }: { companyName: // Right Chat State const [rightRole, setRightRole] = useState("") - const [rightModel, setRightModel] = useState("gemini-2.0-flash") + const [rightModel, setRightModel] = useState("gemini-3-flash-preview") const [rightMessages, setRightMessages] = useState([]) const [rightInput, setRightInput] = useState("") const [rightLoading, setRightLoading] = useState(false) @@ -533,7 +672,7 @@ function ChatPane({ ) }} > - {m.content} + {addCitations(m.content, m.groundingMetadata)} )} diff --git a/frontend/src/components/analysis-trigger.tsx b/frontend/src/components/analysis-trigger.tsx index adfd186..1cace31 100644 --- a/frontend/src/components/analysis-trigger.tsx +++ b/frontend/src/components/analysis-trigger.tsx @@ -18,7 +18,7 @@ interface AnalysisTriggerProps { export function AnalysisTrigger({ companyId, dataSource, - model = "gemini-2.0-flash", + model = "gemini-3-flash-preview", onAnalysisComplete }: AnalysisTriggerProps) { const [analysisId, setAnalysisId] = useState(null)