Fundamental_Analysis/services/report-generator-service/src/llm_client.rs
Lv, Qi a59b994a92 WIP: Commit all pending changes across services, frontend, and docs
- Sync updates for provider services (AlphaVantage, Finnhub, YFinance, Tushare)
- Update Frontend components and pages for recent config changes
- Update API Gateway and Registry
- Include design docs and tasks status
2025-11-27 02:45:56 +08:00

181 lines
6.9 KiB
Rust

use crate::error::ProviderError;
use tracing::{debug, error, info, warn};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde_json::{json, Value};
use futures_util::{Stream, StreamExt};
use std::pin::Pin;
use std::time::Duration;
pub struct LlmClient {
http_client: reqwest::Client,
api_base_url: String,
api_key: String,
model: String,
timeout: Duration,
}
impl LlmClient {
pub fn new(api_url: String, api_key: String, model: String, timeout_secs: Option<u64>) -> Self {
let api_url = api_url.trim();
// Normalize base URL
let base_url = if api_url.ends_with("/chat/completions") {
api_url.trim_end_matches("/chat/completions").trim_end_matches('/').to_string()
} else if api_url.ends_with("/completions") {
api_url.trim_end_matches("/completions").trim_end_matches('/').to_string()
} else {
api_url.trim_end_matches('/').to_string()
};
debug!("Initializing LlmClient with base_url: {}", base_url);
// Create a reusable http client
let http_client = reqwest::Client::builder()
.build()
.unwrap_or_default();
Self {
http_client,
api_base_url: base_url,
api_key,
model,
timeout: Duration::from_secs(timeout_secs.unwrap_or(300)), // Default 5 min timeout for full generation
}
}
/// Unified Streaming Interface
/// Always uses streaming under the hood.
pub async fn stream_text(
&self,
prompt: String
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
let url = format!("{}/chat/completions", self.api_base_url);
info!("Sending streaming request to LLM model: {} at URL: {}", self.model, url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let api_key_val = &self.api_key;
let auth_value = format!("Bearer {}", api_key_val);
if let Ok(val) = HeaderValue::from_str(&auth_value) {
headers.insert(AUTHORIZATION, val);
}
let body = json!({
"model": self.model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.7,
"stream": true // FORCE STREAMING
});
let request_builder = self.http_client.post(&url)
.headers(headers)
.json(&body)
.timeout(self.timeout); // Apply hard timeout to request
info!("Dispatching HTTP request to LLM provider...");
let start_time = std::time::Instant::now();
let response = request_builder.send()
.await
.map_err(|e| ProviderError::LlmApi(format!("Network request failed (timeout {}s?): {}", self.timeout.as_secs(), e)))?;
let duration = start_time.elapsed();
info!("Received response headers from LLM provider. Status: {}. Time taken: {:?}", response.status(), duration);
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
error!("LLM API returned error status: {}. Body: {}", status, error_text);
return Err(ProviderError::LlmApi(format!("LLM API error ({}): {}", status, error_text)));
}
// Create a stream that parses SSE events
let stream = response.bytes_stream()
.map(|result| {
match result {
Ok(bytes) => {
// DEBUG log for bytes received - be careful with volume
// tracing::debug!("Received bytes chunk: {} bytes", bytes.len());
let chunk_str = String::from_utf8_lossy(&bytes);
Ok(chunk_str.to_string())
},
Err(e) => {
error!("Stream read error: {}", e);
Err(ProviderError::LlmApi(format!("Stream read error: {}", e)))
}
}
});
// We need to process the raw stream into clean content chunks
// This is a simplified SSE parser. In production, use a proper SSE crate or more robust parsing logic.
// For now, we assume standard OpenAI format: data: {...} \n\n
let processed_stream = async_stream::try_stream! {
let mut buffer = String::new();
let mut chunk_count = 0;
for await chunk_res in stream {
let chunk = chunk_res?;
chunk_count += 1;
if chunk_count % 10 == 0 {
// Log every 10th raw chunk receipt to prove aliveness without spamming
info!("Stream processing alive. Received raw chunk #{}", chunk_count);
}
buffer.push_str(&chunk);
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.starts_with("data: ") {
let data_str = &line["data: ".len()..];
if data_str == "[DONE]" {
info!("Received [DONE] signal from LLM stream.");
break;
}
if let Ok(json) = serde_json::from_str::<Value>(data_str) {
if let Some(content) = json["choices"][0]["delta"]["content"].as_str() {
yield content.to_string();
}
} else {
// Only log parse errors if it's not empty/keepalive
if !data_str.is_empty() {
warn!("Failed to parse JSON from data line: {}", data_str);
}
}
}
}
}
info!("Stream finished. Total raw chunks: {}", chunk_count);
};
Ok(Box::pin(processed_stream))
}
/// Wrapper for non-streaming usage (One-shot)
/// Consumes the stream and returns the full string.
pub async fn generate_text(&self, prompt: String) -> Result<String, ProviderError> {
let mut stream = self.stream_text(prompt).await?;
let mut full_response = String::new();
while let Some(chunk_res) = stream.next().await {
match chunk_res {
Ok(chunk) => full_response.push_str(&chunk),
Err(e) => return Err(e),
}
}
if full_response.is_empty() {
return Err(ProviderError::LlmApi("LLM returned empty response".to_string()));
}
Ok(full_response)
}
}