feat: Implement grounding citation display in AI discussion and upgrade default LLM to Gemini 3 Flash Preview.

This commit is contained in:
xucheng 2026-01-20 10:27:38 +08:00
parent 653812a480
commit e5e72205e8
5 changed files with 237 additions and 72 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse, FileResponse
import json
from pydantic import BaseModel
@ -31,14 +31,15 @@ class ChatRequest(BaseModel):
stock_code: Optional[str] = None
@router.post("/chat")
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)):
async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None):
"""AI Chat Endpoint with Logging (Streaming)"""
# Note: background_tasks argument is not used directly,
# but we use starlette.background.BackgroundTask in StreamingResponse
try:
client = get_genai_client()
# Prepare History and Config
history = []
# Exclude the last message as it will be sent via send_message
if len(request.messages) > 1:
for msg in request.messages[:-1]:
history.append(types.Content(
@ -47,8 +48,7 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
))
last_message = request.messages[-1].content if request.messages else ""
model_name = request.model or "gemini-2.5-flash"
model_name = request.model or "gemini-3-flash-preview"
# Search Configuration & System Prompt
tools = []
@ -56,7 +56,6 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
if request.use_google_search:
tools.append(types.Tool(google_search=types.GoogleSearch()))
# Inject strong instruction to force search usage
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
config = types.GenerateContentConfig(
@ -67,12 +66,18 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
start_time = time.time()
async def generate():
full_response_text = ""
# Mutable context to capture data during the stream for logging
log_context = {
"full_response_text": "",
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_time": 0
}
# Sync generator to run in threadpool
def generate():
grounding_data = None
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
try:
# Initialize Chat
@ -82,24 +87,22 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
config=config
)
# Streaming Call
# Blocking Streaming Call (will run in threadpool)
response_stream = chat.send_message_stream(last_message)
for chunk in response_stream:
if chunk.text:
full_response_text += chunk.text
log_context["full_response_text"] += chunk.text
data = json.dumps({"type": "content", "content": chunk.text})
yield f"{data}\n"
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
if chunk.usage_metadata:
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
total_tokens = chunk.usage_metadata.total_token_count or 0
log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0
log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0
log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0
if chunk.candidates and chunk.candidates[0].grounding_metadata:
gm = chunk.candidates[0].grounding_metadata
# We only expect one final grounding metadata object
if gm.search_entry_point or gm.grounding_chunks:
grounding_obj = {}
if gm.search_entry_point:
@ -115,14 +118,28 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
"title": g_chunk.web.title
}
})
if gm.grounding_supports:
grounding_obj["groundingSupports"] = []
for support in gm.grounding_supports:
support_obj = {
"segment": {
"startIndex": support.segment.start_index if support.segment else 0,
"endIndex": support.segment.end_index if support.segment else 0,
"text": support.segment.text if support.segment else ""
},
"groundingChunkIndices": list(support.grounding_chunk_indices)
}
grounding_obj["groundingSupports"].append(support_obj)
if gm.web_search_queries:
grounding_obj["webSearchQueries"] = gm.web_search_queries
grounding_data = grounding_obj # Save final metadata
grounding_data = grounding_obj
# End of stream actions
# End of stream
end_time = time.time()
response_time = end_time - start_time
log_context["response_time"] = end_time - start_time
# Send Metadata Chunk
if grounding_data:
@ -130,48 +147,57 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
# Send Usage Chunk
usage_data = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"response_time": response_time
"prompt_tokens": log_context["prompt_tokens"],
"completion_tokens": log_context["completion_tokens"],
"total_tokens": log_context["total_tokens"],
"response_time": log_context["response_time"]
}
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
# Log to Database (Async)
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
# Actually, we can use the injected `db` session but we must be careful.
# Since we are inside an async generator, awaiting db calls is fine.
try:
# Reconstruct prompt for logging
logged_prompt = ""
if request.system_prompt:
logged_prompt += f"System: {request.system_prompt}\n\n"
for msg in request.messages: # Log full history including last message
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=model_name,
prompt=logged_prompt,
response=full_response_text,
response_time=response_time,
used_google_search=request.use_google_search,
session_id=request.session_id,
stock_code=request.stock_code,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens
)
db.add(log_entry)
await db.commit()
except Exception as db_err:
logger.error(f"Failed to log LLM usage: {db_err}")
except Exception as e:
logger.error(f"Stream generation error: {e}")
err_json = json.dumps({"type": "error", "error": str(e)})
yield f"{err_json}\n"
return StreamingResponse(generate(), media_type="application/x-ndjson")
# Background Task for logging
from starlette.background import BackgroundTask
async def log_usage(context: dict, req: ChatRequest):
try:
# Create a NEW session for the background task
from app.database import SessionLocal
async with SessionLocal() as session:
# Reconstruct prompt for logging
logged_prompt = ""
if req.system_prompt:
logged_prompt += f"System: {req.system_prompt}\n\n"
for msg in req.messages:
logged_prompt += f"{msg.role}: {msg.content}\n"
log_entry = LLMUsageLog(
model=req.model or "gemini-3-flash-preview",
prompt=logged_prompt,
response=context["full_response_text"],
response_time=context["response_time"],
used_google_search=req.use_google_search,
session_id=req.session_id,
stock_code=req.stock_code,
prompt_tokens=context["prompt_tokens"],
completion_tokens=context["completion_tokens"],
total_tokens=context["total_tokens"]
)
session.add(log_entry)
await session.commit()
except Exception as e:
logger.error(f"Background logging failed: {e}")
# Note the usage of 'generate()' (sync) instead of 'generate' (async)
# and we pass the background task to StreamingResponse
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
background=BackgroundTask(log_usage, log_context, request)
)
except Exception as e:
logger.error(f"Chat error: {e}")

View File

@ -116,16 +116,16 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
# Default models list as requested
import json
default_models = [
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-3-flash-preview",
"gemini-3-pro-preview"
"gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
]
config_map["available_models"] = json.dumps(default_models)
# Ensure defaults for other keys exist if not in DB
defaults = {
"ai_model": "gemini-2.5-flash",
"ai_model": "gemini-3-flash-preview",
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
"data_source_us": "iFinD",
@ -144,12 +144,12 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
print(f"Config read error (using defaults): {e}")
import json
return {
"ai_model": "gemini-2.5-flash",
"ai_model": "gemini-3-flash-preview",
"available_models": json.dumps([
"gemini-2.0-flash",
"gemini-2.5-flash",
"gemini-3-flash-preview",
"gemini-3-pro-preview"
"gemini-3-pro-preview",
"gemini-2.0-flash",
"gemini-2.5-flash"
]),
"data_source_cn": "Tushare",
"data_source_hk": "iFinD",
@ -193,7 +193,7 @@ logger = logging.getLogger(__name__)
class StockSearchRequest(BaseModel):
query: str
model: str = "gemini-2.0-flash" # 支持前端传入模型参数
model: str = "gemini-3-flash-preview" # 支持前端传入模型参数
class StockSearchResponse(BaseModel):
market: str
@ -238,8 +238,8 @@ async def search_stock(request: StockSearchRequest):
# 启用 Google Search Grounding
grounding_tool = types.Tool(google_search=types.GoogleSearch())
# 使用请求中的模型,默认为 gemini-2.5-flash
model_name = request.model or "gemini-2.5-flash"
# 使用请求中的模型,默认为 gemini-3-flash-preview
model_name = request.model or "gemini-3-flash-preview"
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
llm_start = time.time()

View File

@ -131,7 +131,7 @@ async def process_analysis_steps(report_id: int, company_name: str, symbol: str,
# Get AI model from settings
model_setting = await db.get(Setting, "AI_MODEL")
model_name = model_setting.value if model_setting else "gemini-2.0-flash"
model_name = model_setting.value if model_setting else "gemini-3-flash-preview"
# Prepare all API calls concurrently
async def process_section(key: str, name: str):

View File

@ -16,12 +16,151 @@ import { Switch } from "@/components/ui/switch"
import ReactMarkdown from "react-markdown"
import remarkGfm from "remark-gfm"
// Helper function to insert citations into text
function addCitations(text: string, groundingMetadata: any) {
if (!groundingMetadata || !groundingMetadata.groundingSupports || !groundingMetadata.groundingChunks) {
return text;
}
const supports = groundingMetadata.groundingSupports;
const chunks = groundingMetadata.groundingChunks;
// 1. Collect all citations with their original positions
interface Citation {
index: number;
text: string;
}
const citations: Citation[] = [];
supports.forEach((support: any) => {
const endIndex = support.segment.endIndex;
// Optimization: Ignore supports that cover the entire text (or close to it)
// These are usually "general" grounding for the whole answer and cause a massive pile of citations at the end.
const startIndex = support.segment.startIndex || 0;
const textLen = text.length;
// Aggressive filter: If segment starts near the beginning and ends near the end, skip it.
// This targets the "summary grounding" which lists all sources used in the answer.
if (startIndex < 5 && endIndex >= textLen - 5) {
return;
}
const chunkIndices = support.groundingChunkIndices;
if (endIndex === undefined || !chunkIndices || chunkIndices.length === 0) return;
const label = chunkIndices.map((idx: number) => {
const chunk = chunks[idx];
if (chunk?.web?.uri) {
// Use parenthesis style (1), (2)
return `[(${idx + 1})](${chunk.web.uri})`;
}
return "";
}).join("");
if (label) {
citations.push({ index: endIndex, text: label });
}
});
// 2. Identify numeric ranges to avoid splitting
const forbiddenRanges: { start: number, end: number }[] = [];
// Matches numbers like integers, decimals, comma-separated (simple)
// Heuristic: digits sticking together or with dot/comma inside
const regex = /\d+(?:[.,]\d+)*/g;
let match;
while ((match = regex.exec(text)) !== null) {
forbiddenRanges.push({ start: match.index, end: match.index + match[0].length });
}
// 3. Move citations out of forbidden ranges
citations.forEach(cit => {
for (const range of forbiddenRanges) {
if (cit.index > range.start && cit.index < range.end) {
// Move to END of number is safer for reading
cit.index = range.end;
break;
}
}
});
// 4. Sort citations by index
citations.sort((a, b) => a.index - b.index);
// 5. Reconstruct string with Deduplication
let result = "";
let curIndex = 0;
let i = 0;
while (i < citations.length) {
const batchIndex = citations[i].index;
// Append text up to this index
if (batchIndex > curIndex) {
result += text.slice(curIndex, batchIndex);
curIndex = batchIndex;
}
// Collect ALL citation chunks at this exact position
const uniqueChunks = new Set<string>();
// Iterate through all citations that fall at this 'batchIndex'
while (i < citations.length && citations[i].index === batchIndex) {
// citations[i].text might be multiple emojis like 1⃣2⃣ if originally from one support
// But our 'citations' array is built from supports.
// Ideally we should have pushed individual chunk indices to 'citations' array to make dedup easier.
// But currently 'citations' items are pre-rendered strings.
// Let's refine step 1 to push raw data instead if we want robust dedup,
// OR just dedup the string representation if they are identical?
// Actually, the issue is that multiple supports (sentence vs paragraph) might point to same chunk ID [1].
// So we need to parse the pre-rendered string or change Step 1.
// Let's rely on the parsing of the text content we built in step 1.
// The text is like [1⃣](uri). We can use that as the unique key.
// However, a single support might start with MULTIPLE chunks.
// Hacky but effective: Just add the text to Set.
// If "source 1" is cited twice at this location, it will be deduped.
// But if "source 1 and 2" are cited, and then "source 1" is cited again, we need granular dedup.
// To do this properly, let's look at how we built `citations`.
// We just appended strings. Let's fix this block to be smarter.
// Simple string dedup for now will solve the "exact duplicate" issue (same chunk, same position).
// It won't solve "Source 1" appearing in "Source 1,2" group.
// But visual flooding is usually exact duplicates.
// BETTER APPROACH: We already have the raw data in the loop above? No, we are in a new loop.
// Let's just output distinct strings at this index.
// We can extract all `[...](...)` patterns from the text and dedup them.
const rawText = citations[i].text;
const linkRegex = /\[.*?\]\(.*?\)/g;
let linkMatch;
while ((linkMatch = linkRegex.exec(rawText)) !== null) {
uniqueChunks.add(linkMatch[0]);
}
i++;
}
// Render unique chunks
result += Array.from(uniqueChunks).join("");
}
// Append remaining text
if (curIndex < text.length) {
result += text.slice(curIndex);
}
return result;
}
interface Message {
role: "user" | "model"
content: string
groundingMetadata?: {
searchEntryPoint?: { renderedContent: string }
groundingChunks?: Array<{ web?: { uri: string; title: string } }>
groundingSupports?: Array<{
segment: { startIndex: number; endIndex: number; text: string }
groundingChunkIndices: number[]
}>
webSearchQueries?: string[]
}
searchStatus?: {
@ -46,12 +185,12 @@ const DEFAULT_ROLES: RoleConfig[] = [
export function AiDiscussionView({ companyName, symbol, market }: { companyName: string, symbol: string, market: string }) {
const [roles, setRoles] = useState<RoleConfig[]>([])
const [availableModels, setAvailableModels] = useState<string[]>(["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"])
const [availableModels, setAvailableModels] = useState<string[]>(["gemini-3-flash-preview", "gemini-3-pro-preview", "gemini-2.0-flash", "gemini-1.5-pro"])
const [questionLibrary, setQuestionLibrary] = useState<any>(null)
// Left Chat State
const [leftRole, setLeftRole] = useState("")
const [leftModel, setLeftModel] = useState("gemini-2.0-flash")
const [leftModel, setLeftModel] = useState("gemini-3-flash-preview")
const [leftMessages, setLeftMessages] = useState<Message[]>([])
const [leftInput, setLeftInput] = useState("")
const [leftLoading, setLeftLoading] = useState(false)
@ -60,7 +199,7 @@ export function AiDiscussionView({ companyName, symbol, market }: { companyName:
// Right Chat State
const [rightRole, setRightRole] = useState("")
const [rightModel, setRightModel] = useState("gemini-2.0-flash")
const [rightModel, setRightModel] = useState("gemini-3-flash-preview")
const [rightMessages, setRightMessages] = useState<Message[]>([])
const [rightInput, setRightInput] = useState("")
const [rightLoading, setRightLoading] = useState(false)
@ -533,7 +672,7 @@ function ChatPane({
)
}}
>
{m.content}
{addCitations(m.content, m.groundingMetadata)}
</ReactMarkdown>
)}

View File

@ -18,7 +18,7 @@ interface AnalysisTriggerProps {
export function AnalysisTrigger({
companyId,
dataSource,
model = "gemini-2.0-flash",
model = "gemini-3-flash-preview",
onAnalysisComplete
}: AnalysisTriggerProps) {
const [analysisId, setAnalysisId] = useState<number | null>(null)