feat: Implement grounding citation display in AI discussion and upgrade default LLM to Gemini 3 Flash Preview.

This commit is contained in:
xucheng 2026-01-20 10:27:38 +08:00
parent 653812a480
commit e5e72205e8
5 changed files with 237 additions and 72 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
import json import json
from pydantic import BaseModel from pydantic import BaseModel
@ -31,14 +31,15 @@ class ChatRequest(BaseModel):
stock_code: Optional[str] = None stock_code: Optional[str] = None
@router.post("/chat") @router.post("/chat")
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)): async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None):
"""AI Chat Endpoint with Logging (Streaming)""" """AI Chat Endpoint with Logging (Streaming)"""
# Note: background_tasks argument is not used directly,
# but we use starlette.background.BackgroundTask in StreamingResponse
try: try:
client = get_genai_client() client = get_genai_client()
# Prepare History and Config # Prepare History and Config
history = [] history = []
# Exclude the last message as it will be sent via send_message
if len(request.messages) > 1: if len(request.messages) > 1:
for msg in request.messages[:-1]: for msg in request.messages[:-1]:
history.append(types.Content( history.append(types.Content(
@ -47,8 +48,7 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
)) ))
last_message = request.messages[-1].content if request.messages else "" last_message = request.messages[-1].content if request.messages else ""
model_name = request.model or "gemini-3-flash-preview"
model_name = request.model or "gemini-2.5-flash"
# Search Configuration & System Prompt # Search Configuration & System Prompt
tools = [] tools = []
@ -56,7 +56,6 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
if request.use_google_search: if request.use_google_search:
tools.append(types.Tool(google_search=types.GoogleSearch())) tools.append(types.Tool(google_search=types.GoogleSearch()))
# Inject strong instruction to force search usage
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge." final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
config = types.GenerateContentConfig( config = types.GenerateContentConfig(
@ -67,12 +66,18 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
start_time = time.time() start_time = time.time()
async def generate(): # Mutable context to capture data during the stream for logging
full_response_text = "" log_context = {
"full_response_text": "",
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_time": 0
}
# Sync generator to run in threadpool
def generate():
grounding_data = None grounding_data = None
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
try: try:
# Initialize Chat # Initialize Chat
@ -82,24 +87,22 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
config=config config=config
) )
# Streaming Call # Blocking Streaming Call (will run in threadpool)
response_stream = chat.send_message_stream(last_message) response_stream = chat.send_message_stream(last_message)
for chunk in response_stream: for chunk in response_stream:
if chunk.text: if chunk.text:
full_response_text += chunk.text log_context["full_response_text"] += chunk.text
data = json.dumps({"type": "content", "content": chunk.text}) data = json.dumps({"type": "content", "content": chunk.text})
yield f"{data}\n" yield f"{data}\n"
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
if chunk.usage_metadata: if chunk.usage_metadata:
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0 log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0
completion_tokens = chunk.usage_metadata.candidates_token_count or 0 log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0
total_tokens = chunk.usage_metadata.total_token_count or 0 log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0
if chunk.candidates and chunk.candidates[0].grounding_metadata: if chunk.candidates and chunk.candidates[0].grounding_metadata:
gm = chunk.candidates[0].grounding_metadata gm = chunk.candidates[0].grounding_metadata
# We only expect one final grounding metadata object
if gm.search_entry_point or gm.grounding_chunks: if gm.search_entry_point or gm.grounding_chunks:
grounding_obj = {} grounding_obj = {}
if gm.search_entry_point: if gm.search_entry_point:
@ -115,14 +118,28 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
"title": g_chunk.web.title "title": g_chunk.web.title
} }
}) })
if gm.grounding_supports:
grounding_obj["groundingSupports"] = []
for support in gm.grounding_supports:
support_obj = {
"segment": {
"startIndex": support.segment.start_index if support.segment else 0,
"endIndex": support.segment.end_index if support.segment else 0,
"text": support.segment.text if support.segment else ""
},
"groundingChunkIndices": list(support.grounding_chunk_indices)
}
grounding_obj["groundingSupports"].append(support_obj)
if gm.web_search_queries: if gm.web_search_queries:
grounding_obj["webSearchQueries"] = gm.web_search_queries grounding_obj["webSearchQueries"] = gm.web_search_queries
grounding_data = grounding_obj # Save final metadata grounding_data = grounding_obj
# End of stream actions # End of stream
end_time = time.time() end_time = time.time()
response_time = end_time - start_time log_context["response_time"] = end_time - start_time
# Send Metadata Chunk # Send Metadata Chunk
if grounding_data: if grounding_data:
@ -130,48 +147,57 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
# Send Usage Chunk # Send Usage Chunk
usage_data = { usage_data = {
"prompt_tokens": prompt_tokens, "prompt_tokens": log_context["prompt_tokens"],
"completion_tokens": completion_tokens, "completion_tokens": log_context["completion_tokens"],
"total_tokens": total_tokens, "total_tokens": log_context["total_tokens"],
"response_time": response_time "response_time": log_context["response_time"]
} }
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n" yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
# Log to Database (Async)
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
# Actually, we can use the injected `db` session but we must be careful.
# Since we are inside an async generator, awaiting db calls is fine.
try:
# Reconstruct prompt for logging
logged_prompt = ""
if request.system_prompt:
logged_prompt += f"System: {request.system_prompt}\n\n"
for msg in request.messages: # Log full history including last message
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=model_name,
prompt=logged_prompt,
response=full_response_text,
response_time=response_time,
used_google_search=request.use_google_search,
session_id=request.session_id,
stock_code=request.stock_code,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
db.add(log_entry)
await db.commit()
except Exception as db_err:
logger.error(f"Failed to log LLM usage: {db_err}")
except Exception as e: except Exception as e:
logger.error(f"Stream generation error: {e}") logger.error(f"Stream generation error: {e}")
err_json = json.dumps({"type": "error", "error": str(e)}) err_json = json.dumps({"type": "error", "error": str(e)})
yield f"{err_json}\n" yield f"{err_json}\n"
return StreamingResponse(generate(), media_type="application/x-ndjson") # Background Task for logging
from starlette.background import BackgroundTask
async def log_usage(context: dict, req: ChatRequest):
try:
# Create a NEW session for the background task
from app.database import SessionLocal
async with SessionLocal() as session:
# Reconstruct prompt for logging
logged_prompt = ""
if req.system_prompt:
logged_prompt += f"System: {req.system_prompt}\n\n"
for msg in req.messages:
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=req.model or "gemini-3-flash-preview",
prompt=logged_prompt,
response=context["full_response_text"],
response_time=context["response_time"],
used_google_search=req.use_google_search,
session_id=req.session_id,
stock_code=req.stock_code,
prompt_tokens=context["prompt_tokens"],
completion_tokens=context["completion_tokens"],
total_tokens=context["total_tokens"]
)
session.add(log_entry)
await session.commit()
except Exception as e:
logger.error(f"Background logging failed: {e}")
# Note the usage of 'generate()' (sync) instead of 'generate' (async)
# and we pass the background task to StreamingResponse
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
background=BackgroundTask(log_usage, log_context, request)
)
except Exception as e: except Exception as e:
logger.error(f"Chat error: {e}") logger.error(f"Chat error: {e}")

View File

@ -116,16 +116,16 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
# Default models list as requested # Default models list as requested
import json import json
default_models = [ default_models = [
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-3-flash-preview", "gemini-3-flash-preview",
"gemini-3-pro-preview" "gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
] ]
config_map["available_models"] = json.dumps(default_models) config_map["available_models"] = json.dumps(default_models)
# Ensure defaults for other keys exist if not in DB # Ensure defaults for other keys exist if not in DB
defaults = { defaults = {
"ai_model": "gemini-2.5-flash", "ai_model": "gemini-3-flash-preview",
"data_source_cn": "Tushare", "data_source_cn": "Tushare",
"data_source_hk": "iFinD", "data_source_hk": "iFinD",
"data_source_us": "iFinD", "data_source_us": "iFinD",
@ -144,12 +144,12 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
print(f"Config read error (using defaults): {e}") print(f"Config read error (using defaults): {e}")
import json import json
return { return {
"ai_model": "gemini-2.5-flash", "ai_model": "gemini-3-flash-preview",
"available_models": json.dumps([ "available_models": json.dumps([
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-3-flash-preview", "gemini-3-flash-preview",
"gemini-3-pro-preview" "gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
]), ]),
"data_source_cn": "Tushare", "data_source_cn": "Tushare",
"data_source_hk": "iFinD", "data_source_hk": "iFinD",
@ -193,7 +193,7 @@ logger = logging.getLogger(__name__)
class StockSearchRequest(BaseModel): class StockSearchRequest(BaseModel):
query: str query: str
model: str = "gemini-2.0-flash" # 支持前端传入模型参数 model: str = "gemini-3-flash-preview" # 支持前端传入模型参数
class StockSearchResponse(BaseModel): class StockSearchResponse(BaseModel):
market: str market: str
@ -238,8 +238,8 @@ async def search_stock(request: StockSearchRequest):
# 启用 Google Search Grounding # 启用 Google Search Grounding
grounding_tool = types.Tool(google_search=types.GoogleSearch()) grounding_tool = types.Tool(google_search=types.GoogleSearch())
# 使用请求中的模型,默认为 gemini-2.5-flash # 使用请求中的模型,默认为 gemini-3-flash-preview
model_name = request.model or "gemini-2.5-flash" model_name = request.model or "gemini-3-flash-preview"
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索") logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
llm_start = time.time() llm_start = time.time()

View File

@ -131,7 +131,7 @@ async def process_analysis_steps(report_id: int, company_name: str, symbol: str,
# Get AI model from settings # Get AI model from settings
model_setting = await db.get(Setting, "AI_MODEL") model_setting = await db.get(Setting, "AI_MODEL")
model_name = model_setting.value if model_setting else "gemini-2.0-flash" model_name = model_setting.value if model_setting else "gemini-3-flash-preview"
# Prepare all API calls concurrently # Prepare all API calls concurrently
async def process_section(key: str, name: str): async def process_section(key: str, name: str):

View File

@ -16,12 +16,151 @@ import { Switch } from "@/components/ui/switch"
import ReactMarkdown from "react-markdown" import ReactMarkdown from "react-markdown"
import remarkGfm from "remark-gfm" import remarkGfm from "remark-gfm"
// Helper function to insert citations into text
function addCitations(text: string, groundingMetadata: any) {
if (!groundingMetadata || !groundingMetadata.groundingSupports || !groundingMetadata.groundingChunks) {
return text;
}
const supports = groundingMetadata.groundingSupports;
const chunks = groundingMetadata.groundingChunks;
// 1. Collect all citations with their original positions
interface Citation {
index: number;
text: string;
}
const citations: Citation[] = [];
supports.forEach((support: any) => {
const endIndex = support.segment.endIndex;
// Optimization: Ignore supports that cover the entire text (or close to it)
// These are usually "general" grounding for the whole answer and cause a massive pile of citations at the end.
const startIndex = support.segment.startIndex || 0;
const textLen = text.length;
// Aggressive filter: If segment starts near the beginning and ends near the end, skip it.
// This targets the "summary grounding" which lists all sources used in the answer.
if (startIndex < 5 && endIndex >= textLen - 5) {
return;
}
const chunkIndices = support.groundingChunkIndices;
if (endIndex === undefined || !chunkIndices || chunkIndices.length === 0) return;
const label = chunkIndices.map((idx: number) => {
const chunk = chunks[idx];
if (chunk?.web?.uri) {
// Use parenthesis style (1), (2)
return `[(${idx + 1})](${chunk.web.uri})`;
}
return "";
}).join("");
if (label) {
citations.push({ index: endIndex, text: label });
}
});
// 2. Identify numeric ranges to avoid splitting
const forbiddenRanges: { start: number, end: number }[] = [];
// Matches numbers like integers, decimals, comma-separated (simple)
// Heuristic: digits sticking together or with dot/comma inside
const regex = /\d+(?:[.,]\d+)*/g;
let match;
while ((match = regex.exec(text)) !== null) {
forbiddenRanges.push({ start: match.index, end: match.index + match[0].length });
}
// 3. Move citations out of forbidden ranges
citations.forEach(cit => {
for (const range of forbiddenRanges) {
if (cit.index > range.start && cit.index < range.end) {
// Move to END of number is safer for reading
cit.index = range.end;
break;
}
}
});
// 4. Sort citations by index
citations.sort((a, b) => a.index - b.index);
// 5. Reconstruct string with Deduplication
let result = "";
let curIndex = 0;
let i = 0;
while (i < citations.length) {
const batchIndex = citations[i].index;
// Append text up to this index
if (batchIndex > curIndex) {
result += text.slice(curIndex, batchIndex);
curIndex = batchIndex;
}
// Collect ALL citation chunks at this exact position
const uniqueChunks = new Set<string>();
// Iterate through all citations that fall at this 'batchIndex'
while (i < citations.length && citations[i].index === batchIndex) {
// citations[i].text might be multiple emojis like 1⃣2⃣ if originally from one support
// But our 'citations' array is built from supports.
// Ideally we should have pushed individual chunk indices to 'citations' array to make dedup easier.
// But currently 'citations' items are pre-rendered strings.
// Let's refine step 1 to push raw data instead if we want robust dedup,
// OR just dedup the string representation if they are identical?
// Actually, the issue is that multiple supports (sentence vs paragraph) might point to same chunk ID [1].
// So we need to parse the pre-rendered string or change Step 1.
// Let's rely on the parsing of the text content we built in step 1.
// The text is like [1⃣](uri). We can use that as the unique key.
// However, a single support might start with MULTIPLE chunks.
// Hacky but effective: Just add the text to Set.
// If "source 1" is cited twice at this location, it will be deduped.
// But if "source 1 and 2" are cited, and then "source 1" is cited again, we need granular dedup.
// To do this properly, let's look at how we built `citations`.
// We just appended strings. Let's fix this block to be smarter.
// Simple string dedup for now will solve the "exact duplicate" issue (same chunk, same position).
// It won't solve "Source 1" appearing in "Source 1,2" group.
// But visual flooding is usually exact duplicates.
// BETTER APPROACH: We already have the raw data in the loop above? No, we are in a new loop.
// Let's just output distinct strings at this index.
// We can extract all `[...](...)` patterns from the text and dedup them.
const rawText = citations[i].text;
const linkRegex = /\[.*?\]\(.*?\)/g;
let linkMatch;
while ((linkMatch = linkRegex.exec(rawText)) !== null) {
uniqueChunks.add(linkMatch[0]);
}
i++;
}
// Render unique chunks
result += Array.from(uniqueChunks).join("");
}
// Append remaining text
if (curIndex < text.length) {
result += text.slice(curIndex);
}
return result;
}
interface Message { interface Message {
role: "user" | "model" role: "user" | "model"
content: string content: string
groundingMetadata?: { groundingMetadata?: {
searchEntryPoint?: { renderedContent: string } searchEntryPoint?: { renderedContent: string }
groundingChunks?: Array<{ web?: { uri: string; title: string } }> groundingChunks?: Array<{ web?: { uri: string; title: string } }>
groundingSupports?: Array<{
segment: { startIndex: number; endIndex: number; text: string }
groundingChunkIndices: number[]
}>
webSearchQueries?: string[] webSearchQueries?: string[]
} }
searchStatus?: { searchStatus?: {
@ -46,12 +185,12 @@ const DEFAULT_ROLES: RoleConfig[] = [
export function AiDiscussionView({ companyName, symbol, market }: { companyName: string, symbol: string, market: string }) { export function AiDiscussionView({ companyName, symbol, market }: { companyName: string, symbol: string, market: string }) {
const [roles, setRoles] = useState<RoleConfig[]>([]) const [roles, setRoles] = useState<RoleConfig[]>([])
const [availableModels, setAvailableModels] = useState<string[]>(["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"]) const [availableModels, setAvailableModels] = useState<string[]>(["gemini-3-flash-preview", "gemini-3-pro-preview", "gemini-2.0-flash", "gemini-1.5-pro"])
const [questionLibrary, setQuestionLibrary] = useState<any>(null) const [questionLibrary, setQuestionLibrary] = useState<any>(null)
// Left Chat State // Left Chat State
const [leftRole, setLeftRole] = useState("") const [leftRole, setLeftRole] = useState("")
const [leftModel, setLeftModel] = useState("gemini-2.0-flash") const [leftModel, setLeftModel] = useState("gemini-3-flash-preview")
const [leftMessages, setLeftMessages] = useState<Message[]>([]) const [leftMessages, setLeftMessages] = useState<Message[]>([])
const [leftInput, setLeftInput] = useState("") const [leftInput, setLeftInput] = useState("")
const [leftLoading, setLeftLoading] = useState(false) const [leftLoading, setLeftLoading] = useState(false)
@ -60,7 +199,7 @@ export function AiDiscussionView({ companyName, symbol, market }: { companyName:
// Right Chat State // Right Chat State
const [rightRole, setRightRole] = useState("") const [rightRole, setRightRole] = useState("")
const [rightModel, setRightModel] = useState("gemini-2.0-flash") const [rightModel, setRightModel] = useState("gemini-3-flash-preview")
const [rightMessages, setRightMessages] = useState<Message[]>([]) const [rightMessages, setRightMessages] = useState<Message[]>([])
const [rightInput, setRightInput] = useState("") const [rightInput, setRightInput] = useState("")
const [rightLoading, setRightLoading] = useState(false) const [rightLoading, setRightLoading] = useState(false)
@ -533,7 +672,7 @@ function ChatPane({
) )
}} }}
> >
{m.content} {addCitations(m.content, m.groundingMetadata)}
</ReactMarkdown> </ReactMarkdown>
)} )}

View File

@ -18,7 +18,7 @@ interface AnalysisTriggerProps {
export function AnalysisTrigger({ export function AnalysisTrigger({
companyId, companyId,
dataSource, dataSource,
model = "gemini-2.0-flash", model = "gemini-3-flash-preview",
onAnalysisComplete onAnalysisComplete
}: AnalysisTriggerProps) { }: AnalysisTriggerProps) {
const [analysisId, setAnalysisId] = useState<number | null>(null) const [analysisId, setAnalysisId] = useState<number | null>(null)