- Sync updates for provider services (AlphaVantage, Finnhub, YFinance, Tushare) - Update Frontend components and pages for recent config changes - Update API Gateway and Registry - Include design docs and tasks status
181 lines
6.9 KiB
Rust
181 lines
6.9 KiB
Rust
use crate::error::ProviderError;
|
|
use tracing::{debug, error, info, warn};
|
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
|
|
use serde_json::{json, Value};
|
|
use futures_util::{Stream, StreamExt};
|
|
use std::pin::Pin;
|
|
use std::time::Duration;
|
|
|
|
pub struct LlmClient {
|
|
http_client: reqwest::Client,
|
|
api_base_url: String,
|
|
api_key: String,
|
|
model: String,
|
|
timeout: Duration,
|
|
}
|
|
|
|
impl LlmClient {
|
|
pub fn new(api_url: String, api_key: String, model: String, timeout_secs: Option<u64>) -> Self {
|
|
let api_url = api_url.trim();
|
|
|
|
// Normalize base URL
|
|
let base_url = if api_url.ends_with("/chat/completions") {
|
|
api_url.trim_end_matches("/chat/completions").trim_end_matches('/').to_string()
|
|
} else if api_url.ends_with("/completions") {
|
|
api_url.trim_end_matches("/completions").trim_end_matches('/').to_string()
|
|
} else {
|
|
api_url.trim_end_matches('/').to_string()
|
|
};
|
|
|
|
debug!("Initializing LlmClient with base_url: {}", base_url);
|
|
|
|
// Create a reusable http client
|
|
let http_client = reqwest::Client::builder()
|
|
.build()
|
|
.unwrap_or_default();
|
|
|
|
Self {
|
|
http_client,
|
|
api_base_url: base_url,
|
|
api_key,
|
|
model,
|
|
timeout: Duration::from_secs(timeout_secs.unwrap_or(300)), // Default 5 min timeout for full generation
|
|
}
|
|
}
|
|
|
|
/// Unified Streaming Interface
|
|
/// Always uses streaming under the hood.
|
|
pub async fn stream_text(
|
|
&self,
|
|
prompt: String
|
|
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
|
|
let url = format!("{}/chat/completions", self.api_base_url);
|
|
info!("Sending streaming request to LLM model: {} at URL: {}", self.model, url);
|
|
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
|
|
|
let api_key_val = &self.api_key;
|
|
let auth_value = format!("Bearer {}", api_key_val);
|
|
if let Ok(val) = HeaderValue::from_str(&auth_value) {
|
|
headers.insert(AUTHORIZATION, val);
|
|
}
|
|
|
|
let body = json!({
|
|
"model": self.model,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
],
|
|
"temperature": 0.7,
|
|
"stream": true // FORCE STREAMING
|
|
});
|
|
|
|
let request_builder = self.http_client.post(&url)
|
|
.headers(headers)
|
|
.json(&body)
|
|
.timeout(self.timeout); // Apply hard timeout to request
|
|
|
|
info!("Dispatching HTTP request to LLM provider...");
|
|
let start_time = std::time::Instant::now();
|
|
let response = request_builder.send()
|
|
.await
|
|
.map_err(|e| ProviderError::LlmApi(format!("Network request failed (timeout {}s?): {}", self.timeout.as_secs(), e)))?;
|
|
|
|
let duration = start_time.elapsed();
|
|
info!("Received response headers from LLM provider. Status: {}. Time taken: {:?}", response.status(), duration);
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
error!("LLM API returned error status: {}. Body: {}", status, error_text);
|
|
return Err(ProviderError::LlmApi(format!("LLM API error ({}): {}", status, error_text)));
|
|
}
|
|
|
|
// Create a stream that parses SSE events
|
|
let stream = response.bytes_stream()
|
|
.map(|result| {
|
|
match result {
|
|
Ok(bytes) => {
|
|
// DEBUG log for bytes received - be careful with volume
|
|
// tracing::debug!("Received bytes chunk: {} bytes", bytes.len());
|
|
let chunk_str = String::from_utf8_lossy(&bytes);
|
|
Ok(chunk_str.to_string())
|
|
},
|
|
Err(e) => {
|
|
error!("Stream read error: {}", e);
|
|
Err(ProviderError::LlmApi(format!("Stream read error: {}", e)))
|
|
}
|
|
}
|
|
});
|
|
|
|
// We need to process the raw stream into clean content chunks
|
|
// This is a simplified SSE parser. In production, use a proper SSE crate or more robust parsing logic.
|
|
// For now, we assume standard OpenAI format: data: {...} \n\n
|
|
let processed_stream = async_stream::try_stream! {
|
|
let mut buffer = String::new();
|
|
let mut chunk_count = 0;
|
|
|
|
for await chunk_res in stream {
|
|
let chunk = chunk_res?;
|
|
chunk_count += 1;
|
|
if chunk_count % 10 == 0 {
|
|
// Log every 10th raw chunk receipt to prove aliveness without spamming
|
|
info!("Stream processing alive. Received raw chunk #{}", chunk_count);
|
|
}
|
|
|
|
buffer.push_str(&chunk);
|
|
|
|
while let Some(line_end) = buffer.find('\n') {
|
|
let line = buffer[..line_end].trim().to_string();
|
|
buffer = buffer[line_end + 1..].to_string();
|
|
|
|
if line.starts_with("data: ") {
|
|
let data_str = &line["data: ".len()..];
|
|
if data_str == "[DONE]" {
|
|
info!("Received [DONE] signal from LLM stream.");
|
|
break;
|
|
}
|
|
|
|
if let Ok(json) = serde_json::from_str::<Value>(data_str) {
|
|
if let Some(content) = json["choices"][0]["delta"]["content"].as_str() {
|
|
yield content.to_string();
|
|
}
|
|
} else {
|
|
// Only log parse errors if it's not empty/keepalive
|
|
if !data_str.is_empty() {
|
|
warn!("Failed to parse JSON from data line: {}", data_str);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
info!("Stream finished. Total raw chunks: {}", chunk_count);
|
|
};
|
|
|
|
Ok(Box::pin(processed_stream))
|
|
}
|
|
|
|
/// Wrapper for non-streaming usage (One-shot)
|
|
/// Consumes the stream and returns the full string.
|
|
pub async fn generate_text(&self, prompt: String) -> Result<String, ProviderError> {
|
|
let mut stream = self.stream_text(prompt).await?;
|
|
let mut full_response = String::new();
|
|
|
|
while let Some(chunk_res) = stream.next().await {
|
|
match chunk_res {
|
|
Ok(chunk) => full_response.push_str(&chunk),
|
|
Err(e) => return Err(e),
|
|
}
|
|
}
|
|
|
|
if full_response.is_empty() {
|
|
return Err(ProviderError::LlmApi("LLM returned empty response".to_string()));
|
|
}
|
|
|
|
Ok(full_response)
|
|
}
|
|
}
|