feat: Implement grounding citation display in AI discussion and upgrade default LLM to Gemini 3 Flash Preview.
This commit is contained in:
parent
653812a480
commit
e5e72205e8
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
@ -31,14 +31,15 @@ class ChatRequest(BaseModel):
|
||||
stock_code: Optional[str] = None
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
||||
async def chat_with_ai(request: ChatRequest, background_tasks: BackgroundTasks = None):
|
||||
"""AI Chat Endpoint with Logging (Streaming)"""
|
||||
# Note: background_tasks argument is not used directly,
|
||||
# but we use starlette.background.BackgroundTask in StreamingResponse
|
||||
try:
|
||||
client = get_genai_client()
|
||||
|
||||
# Prepare History and Config
|
||||
history = []
|
||||
# Exclude the last message as it will be sent via send_message
|
||||
if len(request.messages) > 1:
|
||||
for msg in request.messages[:-1]:
|
||||
history.append(types.Content(
|
||||
@ -47,8 +48,7 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
))
|
||||
|
||||
last_message = request.messages[-1].content if request.messages else ""
|
||||
|
||||
model_name = request.model or "gemini-2.5-flash"
|
||||
model_name = request.model or "gemini-3-flash-preview"
|
||||
|
||||
# Search Configuration & System Prompt
|
||||
tools = []
|
||||
@ -56,7 +56,6 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
|
||||
if request.use_google_search:
|
||||
tools.append(types.Tool(google_search=types.GoogleSearch()))
|
||||
# Inject strong instruction to force search usage
|
||||
final_system_prompt += "\n\n[SYSTEM INSTRUCTION] You have access to Google Search. You MUST use it to verify any data or find the latest information before answering. Do not rely solely on your internal knowledge."
|
||||
|
||||
config = types.GenerateContentConfig(
|
||||
@ -67,12 +66,18 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async def generate():
|
||||
full_response_text = ""
|
||||
# Mutable context to capture data during the stream for logging
|
||||
log_context = {
|
||||
"full_response_text": "",
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"response_time": 0
|
||||
}
|
||||
|
||||
# Sync generator to run in threadpool
|
||||
def generate():
|
||||
grounding_data = None
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
# Initialize Chat
|
||||
@ -82,24 +87,22 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
config=config
|
||||
)
|
||||
|
||||
# Streaming Call
|
||||
# Blocking Streaming Call (will run in threadpool)
|
||||
response_stream = chat.send_message_stream(last_message)
|
||||
|
||||
for chunk in response_stream:
|
||||
if chunk.text:
|
||||
full_response_text += chunk.text
|
||||
log_context["full_response_text"] += chunk.text
|
||||
data = json.dumps({"type": "content", "content": chunk.text})
|
||||
yield f"{data}\n"
|
||||
|
||||
# Accumulate Metadata (will be complete in the last chunk usually, or usage_metadata)
|
||||
if chunk.usage_metadata:
|
||||
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
|
||||
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
|
||||
total_tokens = chunk.usage_metadata.total_token_count or 0
|
||||
log_context["prompt_tokens"] = chunk.usage_metadata.prompt_token_count or 0
|
||||
log_context["completion_tokens"] = chunk.usage_metadata.candidates_token_count or 0
|
||||
log_context["total_tokens"] = chunk.usage_metadata.total_token_count or 0
|
||||
|
||||
if chunk.candidates and chunk.candidates[0].grounding_metadata:
|
||||
gm = chunk.candidates[0].grounding_metadata
|
||||
# We only expect one final grounding metadata object
|
||||
if gm.search_entry_point or gm.grounding_chunks:
|
||||
grounding_obj = {}
|
||||
if gm.search_entry_point:
|
||||
@ -115,14 +118,28 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
"title": g_chunk.web.title
|
||||
}
|
||||
})
|
||||
|
||||
if gm.grounding_supports:
|
||||
grounding_obj["groundingSupports"] = []
|
||||
for support in gm.grounding_supports:
|
||||
support_obj = {
|
||||
"segment": {
|
||||
"startIndex": support.segment.start_index if support.segment else 0,
|
||||
"endIndex": support.segment.end_index if support.segment else 0,
|
||||
"text": support.segment.text if support.segment else ""
|
||||
},
|
||||
"groundingChunkIndices": list(support.grounding_chunk_indices)
|
||||
}
|
||||
grounding_obj["groundingSupports"].append(support_obj)
|
||||
|
||||
if gm.web_search_queries:
|
||||
grounding_obj["webSearchQueries"] = gm.web_search_queries
|
||||
|
||||
grounding_data = grounding_obj # Save final metadata
|
||||
grounding_data = grounding_obj
|
||||
|
||||
# End of stream actions
|
||||
# End of stream
|
||||
end_time = time.time()
|
||||
response_time = end_time - start_time
|
||||
log_context["response_time"] = end_time - start_time
|
||||
|
||||
# Send Metadata Chunk
|
||||
if grounding_data:
|
||||
@ -130,48 +147,57 @@ async def chat_with_ai(request: ChatRequest, db: AsyncSession = Depends(get_db))
|
||||
|
||||
# Send Usage Chunk
|
||||
usage_data = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"response_time": response_time
|
||||
"prompt_tokens": log_context["prompt_tokens"],
|
||||
"completion_tokens": log_context["completion_tokens"],
|
||||
"total_tokens": log_context["total_tokens"],
|
||||
"response_time": log_context["response_time"]
|
||||
}
|
||||
yield f"{json.dumps({'type': 'usage', 'usage': usage_data})}\n"
|
||||
|
||||
# Log to Database (Async)
|
||||
# Note: creating a new session here because the outer session might be closed or not thread-safe in generator?
|
||||
# Actually, we can use the injected `db` session but we must be careful.
|
||||
# Since we are inside an async generator, awaiting db calls is fine.
|
||||
try:
|
||||
# Reconstruct prompt for logging
|
||||
logged_prompt = ""
|
||||
if request.system_prompt:
|
||||
logged_prompt += f"System: {request.system_prompt}\n\n"
|
||||
for msg in request.messages: # Log full history including last message
|
||||
logged_prompt += f"{msg.role}: {msg.content}\n"
|
||||
|
||||
log_entry = LLMUsageLog(
|
||||
model=model_name,
|
||||
prompt=logged_prompt,
|
||||
response=full_response_text,
|
||||
response_time=response_time,
|
||||
used_google_search=request.use_google_search,
|
||||
session_id=request.session_id,
|
||||
stock_code=request.stock_code,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
db.add(log_entry)
|
||||
await db.commit()
|
||||
except Exception as db_err:
|
||||
logger.error(f"Failed to log LLM usage: {db_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream generation error: {e}")
|
||||
err_json = json.dumps({"type": "error", "error": str(e)})
|
||||
yield f"{err_json}\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
||||
# Background Task for logging
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
async def log_usage(context: dict, req: ChatRequest):
|
||||
try:
|
||||
# Create a NEW session for the background task
|
||||
from app.database import SessionLocal
|
||||
async with SessionLocal() as session:
|
||||
# Reconstruct prompt for logging
|
||||
logged_prompt = ""
|
||||
if req.system_prompt:
|
||||
logged_prompt += f"System: {req.system_prompt}\n\n"
|
||||
for msg in req.messages:
|
||||
logged_prompt += f"{msg.role}: {msg.content}\n"
|
||||
|
||||
log_entry = LLMUsageLog(
|
||||
model=req.model or "gemini-3-flash-preview",
|
||||
prompt=logged_prompt,
|
||||
response=context["full_response_text"],
|
||||
response_time=context["response_time"],
|
||||
used_google_search=req.use_google_search,
|
||||
session_id=req.session_id,
|
||||
stock_code=req.stock_code,
|
||||
prompt_tokens=context["prompt_tokens"],
|
||||
completion_tokens=context["completion_tokens"],
|
||||
total_tokens=context["total_tokens"]
|
||||
)
|
||||
session.add(log_entry)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Background logging failed: {e}")
|
||||
|
||||
# Note the usage of 'generate()' (sync) instead of 'generate' (async)
|
||||
# and we pass the background task to StreamingResponse
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
background=BackgroundTask(log_usage, log_context, request)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat error: {e}")
|
||||
|
||||
@ -116,16 +116,16 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
|
||||
# Default models list as requested
|
||||
import json
|
||||
default_models = [
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-pro-preview"
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash"
|
||||
]
|
||||
config_map["available_models"] = json.dumps(default_models)
|
||||
|
||||
# Ensure defaults for other keys exist if not in DB
|
||||
defaults = {
|
||||
"ai_model": "gemini-2.5-flash",
|
||||
"ai_model": "gemini-3-flash-preview",
|
||||
"data_source_cn": "Tushare",
|
||||
"data_source_hk": "iFinD",
|
||||
"data_source_us": "iFinD",
|
||||
@ -144,12 +144,12 @@ async def get_config_compat(db: AsyncSession = Depends(get_db)):
|
||||
print(f"Config read error (using defaults): {e}")
|
||||
import json
|
||||
return {
|
||||
"ai_model": "gemini-2.5-flash",
|
||||
"ai_model": "gemini-3-flash-preview",
|
||||
"available_models": json.dumps([
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-pro-preview"
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash"
|
||||
]),
|
||||
"data_source_cn": "Tushare",
|
||||
"data_source_hk": "iFinD",
|
||||
@ -193,7 +193,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class StockSearchRequest(BaseModel):
|
||||
query: str
|
||||
model: str = "gemini-2.0-flash" # 支持前端传入模型参数
|
||||
model: str = "gemini-3-flash-preview" # 支持前端传入模型参数
|
||||
|
||||
class StockSearchResponse(BaseModel):
|
||||
market: str
|
||||
@ -238,8 +238,8 @@ async def search_stock(request: StockSearchRequest):
|
||||
# 启用 Google Search Grounding
|
||||
grounding_tool = types.Tool(google_search=types.GoogleSearch())
|
||||
|
||||
# 使用请求中的模型,默认为 gemini-2.5-flash
|
||||
model_name = request.model or "gemini-2.5-flash"
|
||||
# 使用请求中的模型,默认为 gemini-3-flash-preview
|
||||
model_name = request.model or "gemini-3-flash-preview"
|
||||
|
||||
logger.info(f"🤖 [搜索-LLM] 调用 {model_name} 进行股票搜索")
|
||||
llm_start = time.time()
|
||||
|
||||
@ -131,7 +131,7 @@ async def process_analysis_steps(report_id: int, company_name: str, symbol: str,
|
||||
|
||||
# Get AI model from settings
|
||||
model_setting = await db.get(Setting, "AI_MODEL")
|
||||
model_name = model_setting.value if model_setting else "gemini-2.0-flash"
|
||||
model_name = model_setting.value if model_setting else "gemini-3-flash-preview"
|
||||
|
||||
# Prepare all API calls concurrently
|
||||
async def process_section(key: str, name: str):
|
||||
|
||||
@ -16,12 +16,151 @@ import { Switch } from "@/components/ui/switch"
|
||||
import ReactMarkdown from "react-markdown"
|
||||
import remarkGfm from "remark-gfm"
|
||||
|
||||
// Helper function to insert citations into text
|
||||
function addCitations(text: string, groundingMetadata: any) {
|
||||
if (!groundingMetadata || !groundingMetadata.groundingSupports || !groundingMetadata.groundingChunks) {
|
||||
return text;
|
||||
}
|
||||
|
||||
const supports = groundingMetadata.groundingSupports;
|
||||
const chunks = groundingMetadata.groundingChunks;
|
||||
|
||||
// 1. Collect all citations with their original positions
|
||||
interface Citation {
|
||||
index: number;
|
||||
text: string;
|
||||
}
|
||||
const citations: Citation[] = [];
|
||||
|
||||
supports.forEach((support: any) => {
|
||||
const endIndex = support.segment.endIndex;
|
||||
// Optimization: Ignore supports that cover the entire text (or close to it)
|
||||
// These are usually "general" grounding for the whole answer and cause a massive pile of citations at the end.
|
||||
const startIndex = support.segment.startIndex || 0;
|
||||
const textLen = text.length;
|
||||
|
||||
// Aggressive filter: If segment starts near the beginning and ends near the end, skip it.
|
||||
// This targets the "summary grounding" which lists all sources used in the answer.
|
||||
if (startIndex < 5 && endIndex >= textLen - 5) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunkIndices = support.groundingChunkIndices;
|
||||
if (endIndex === undefined || !chunkIndices || chunkIndices.length === 0) return;
|
||||
|
||||
const label = chunkIndices.map((idx: number) => {
|
||||
const chunk = chunks[idx];
|
||||
if (chunk?.web?.uri) {
|
||||
// Use parenthesis style (1), (2)
|
||||
return `[(${idx + 1})](${chunk.web.uri})`;
|
||||
}
|
||||
return "";
|
||||
}).join("");
|
||||
|
||||
if (label) {
|
||||
citations.push({ index: endIndex, text: label });
|
||||
}
|
||||
});
|
||||
|
||||
// 2. Identify numeric ranges to avoid splitting
|
||||
const forbiddenRanges: { start: number, end: number }[] = [];
|
||||
// Matches numbers like integers, decimals, comma-separated (simple)
|
||||
// Heuristic: digits sticking together or with dot/comma inside
|
||||
const regex = /\d+(?:[.,]\d+)*/g;
|
||||
let match;
|
||||
while ((match = regex.exec(text)) !== null) {
|
||||
forbiddenRanges.push({ start: match.index, end: match.index + match[0].length });
|
||||
}
|
||||
|
||||
// 3. Move citations out of forbidden ranges
|
||||
citations.forEach(cit => {
|
||||
for (const range of forbiddenRanges) {
|
||||
if (cit.index > range.start && cit.index < range.end) {
|
||||
// Move to END of number is safer for reading
|
||||
cit.index = range.end;
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 4. Sort citations by index
|
||||
citations.sort((a, b) => a.index - b.index);
|
||||
|
||||
// 5. Reconstruct string with Deduplication
|
||||
let result = "";
|
||||
let curIndex = 0;
|
||||
|
||||
let i = 0;
|
||||
while (i < citations.length) {
|
||||
const batchIndex = citations[i].index;
|
||||
// Append text up to this index
|
||||
if (batchIndex > curIndex) {
|
||||
result += text.slice(curIndex, batchIndex);
|
||||
curIndex = batchIndex;
|
||||
}
|
||||
|
||||
// Collect ALL citation chunks at this exact position
|
||||
const uniqueChunks = new Set<string>();
|
||||
// Iterate through all citations that fall at this 'batchIndex'
|
||||
while (i < citations.length && citations[i].index === batchIndex) {
|
||||
// citations[i].text might be multiple emojis like 1️⃣2️⃣ if originally from one support
|
||||
// But our 'citations' array is built from supports.
|
||||
// Ideally we should have pushed individual chunk indices to 'citations' array to make dedup easier.
|
||||
// But currently 'citations' items are pre-rendered strings.
|
||||
// Let's refine step 1 to push raw data instead if we want robust dedup,
|
||||
// OR just dedup the string representation if they are identical?
|
||||
// Actually, the issue is that multiple supports (sentence vs paragraph) might point to same chunk ID [1].
|
||||
// So we need to parse the pre-rendered string or change Step 1.
|
||||
|
||||
// Let's rely on the parsing of the text content we built in step 1.
|
||||
// The text is like [1️⃣](uri). We can use that as the unique key.
|
||||
// However, a single support might start with MULTIPLE chunks.
|
||||
|
||||
// Hacky but effective: Just add the text to Set.
|
||||
// If "source 1" is cited twice at this location, it will be deduped.
|
||||
// But if "source 1 and 2" are cited, and then "source 1" is cited again, we need granular dedup.
|
||||
|
||||
// To do this properly, let's look at how we built `citations`.
|
||||
// We just appended strings. Let's fix this block to be smarter.
|
||||
|
||||
// Simple string dedup for now will solve the "exact duplicate" issue (same chunk, same position).
|
||||
// It won't solve "Source 1" appearing in "Source 1,2" group.
|
||||
// But visual flooding is usually exact duplicates.
|
||||
|
||||
// BETTER APPROACH: We already have the raw data in the loop above? No, we are in a new loop.
|
||||
// Let's just output distinct strings at this index.
|
||||
|
||||
// We can extract all `[...](...)` patterns from the text and dedup them.
|
||||
const rawText = citations[i].text;
|
||||
const linkRegex = /\[.*?\]\(.*?\)/g;
|
||||
let linkMatch;
|
||||
while ((linkMatch = linkRegex.exec(rawText)) !== null) {
|
||||
uniqueChunks.add(linkMatch[0]);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
// Render unique chunks
|
||||
result += Array.from(uniqueChunks).join("");
|
||||
}
|
||||
// Append remaining text
|
||||
if (curIndex < text.length) {
|
||||
result += text.slice(curIndex);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
interface Message {
|
||||
role: "user" | "model"
|
||||
content: string
|
||||
groundingMetadata?: {
|
||||
searchEntryPoint?: { renderedContent: string }
|
||||
groundingChunks?: Array<{ web?: { uri: string; title: string } }>
|
||||
groundingSupports?: Array<{
|
||||
segment: { startIndex: number; endIndex: number; text: string }
|
||||
groundingChunkIndices: number[]
|
||||
}>
|
||||
webSearchQueries?: string[]
|
||||
}
|
||||
searchStatus?: {
|
||||
@ -46,12 +185,12 @@ const DEFAULT_ROLES: RoleConfig[] = [
|
||||
|
||||
export function AiDiscussionView({ companyName, symbol, market }: { companyName: string, symbol: string, market: string }) {
|
||||
const [roles, setRoles] = useState<RoleConfig[]>([])
|
||||
const [availableModels, setAvailableModels] = useState<string[]>(["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"])
|
||||
const [availableModels, setAvailableModels] = useState<string[]>(["gemini-3-flash-preview", "gemini-3-pro-preview", "gemini-2.0-flash", "gemini-1.5-pro"])
|
||||
const [questionLibrary, setQuestionLibrary] = useState<any>(null)
|
||||
|
||||
// Left Chat State
|
||||
const [leftRole, setLeftRole] = useState("")
|
||||
const [leftModel, setLeftModel] = useState("gemini-2.0-flash")
|
||||
const [leftModel, setLeftModel] = useState("gemini-3-flash-preview")
|
||||
const [leftMessages, setLeftMessages] = useState<Message[]>([])
|
||||
const [leftInput, setLeftInput] = useState("")
|
||||
const [leftLoading, setLeftLoading] = useState(false)
|
||||
@ -60,7 +199,7 @@ export function AiDiscussionView({ companyName, symbol, market }: { companyName:
|
||||
|
||||
// Right Chat State
|
||||
const [rightRole, setRightRole] = useState("")
|
||||
const [rightModel, setRightModel] = useState("gemini-2.0-flash")
|
||||
const [rightModel, setRightModel] = useState("gemini-3-flash-preview")
|
||||
const [rightMessages, setRightMessages] = useState<Message[]>([])
|
||||
const [rightInput, setRightInput] = useState("")
|
||||
const [rightLoading, setRightLoading] = useState(false)
|
||||
@ -533,7 +672,7 @@ function ChatPane({
|
||||
)
|
||||
}}
|
||||
>
|
||||
{m.content}
|
||||
{addCitations(m.content, m.groundingMetadata)}
|
||||
</ReactMarkdown>
|
||||
)}
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ interface AnalysisTriggerProps {
|
||||
export function AnalysisTrigger({
|
||||
companyId,
|
||||
dataSource,
|
||||
model = "gemini-2.0-flash",
|
||||
model = "gemini-3-flash-preview",
|
||||
onAnalysisComplete
|
||||
}: AnalysisTriggerProps) {
|
||||
const [analysisId, setAnalysisId] = useState<number | null>(null)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user