Fundamental_Analysis/services/api-gateway/src/api.rs
Lv, Qi 03b53aed71 feat: Refactor Analysis Context Mechanism and Generic Worker
- Implemented Unified Context Mechanism (Task 20251127):
  - Decoupled intent (Module) from resolution (Orchestrator).
  - Added ContextResolver for resolving input bindings (Manual Glob/Auto LLM).
  - Added IOBinder for managing physical paths.
  - Updated GenerateReportCommand to support explicit input bindings and output paths.

- Refactored Report Worker to Generic Execution (Task 20251128):
  - Removed hardcoded financial DTOs and specific formatting logic.
  - Implemented Generic YAML-based context assembly for better LLM readability.
  - Added detailed execution tracing (Sidecar logs).
  - Fixed input data collision bug by using full paths as context keys.

- Updated Tushare Provider to support dynamic output paths.
- Updated Common Contracts with new configuration models.
2025-11-28 20:11:17 +08:00

1139 lines
36 KiB
Rust

use crate::error::Result;
use crate::state::AppState;
use axum::{
Router,
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Json},
routing::{get, post},
};
use common_contracts::config_models::{
AnalysisTemplateSets, DataSourceProvider,
DataSourcesConfig, LlmProvider, LlmProvidersConfig,
};
use common_contracts::messages::GenerateReportCommand;
use common_contracts::observability::{TaskProgress, ObservabilityTaskStatus};
use common_contracts::registry::ProviderMetadata;
use common_contracts::subjects::{NatsSubject, SubjectMessage};
use common_contracts::symbol_utils::{CanonicalSymbol, Market};
use futures_util::future::join_all;
use futures_util::stream::StreamExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::try_join;
use tracing::{error, info, warn};
use uuid::Uuid;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use service_kit::api_dto;
mod registry;
// --- Request/Response Structs ---
#[api_dto]
pub struct DataRequest {
pub symbol: String,
pub market: Option<String>,
pub template_id: String, // Changed to required as it's mandatory for workflow
}
#[api_dto]
pub struct RequestAcceptedResponse {
pub request_id: Uuid,
pub symbol: String,
pub market: String,
}
#[derive(Deserialize)]
pub struct AnalysisRequest {
pub template_id: String,
}
#[derive(Deserialize)]
pub struct AnalysisResultQuery {
pub symbol: String,
}
#[api_dto]
pub struct SymbolResolveRequest {
pub symbol: String,
pub market: Option<String>,
}
#[api_dto]
pub struct SymbolResolveResponse {
pub symbol: String,
pub market: String,
}
// --- Dynamic Schema Structs (Replaced by Dynamic Registry) ---
// Legacy endpoint /configs/data_sources/schema removed.
// Frontend should now use /registry/providers to get metadata.
// --- Router Definition ---
pub fn create_router(app_state: AppState) -> Router {
use crate::openapi::ApiDoc;
let mut router = Router::new()
.route("/health", get(health_check))
.route("/tasks/{request_id}", get(get_task_progress))
// Context Inspector Proxies
.route("/api/context/{req_id}/tree/{commit_hash}", get(proxy_context_tree))
.route("/api/context/{req_id}/blob/{commit_hash}/{*path}", get(proxy_context_blob))
.route("/api/context/{req_id}/diff/{from_commit}/{to_commit}", get(proxy_context_diff))
.nest("/api/v1", create_v1_router())
.with_state(app_state);
// Mount Swagger UI
router = router.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()));
router
}
async fn mock_chat_completion() -> impl IntoResponse {
use axum::http::header;
let body = "data: {\"id\":\"chatcmpl-mock\",\"object\":\"chat.completion.chunk\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"This is a mocked response.\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-mock\",\"object\":\"chat.completion.chunk\",\"created\":1677652288,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n";
(StatusCode::OK, [(header::CONTENT_TYPE, "text/event-stream")], body)
}
async fn mock_models() -> impl IntoResponse {
use axum::http::header;
let body = serde_json::json!({
"data": [
{
"id": "google/gemini-flash-1.5",
"name": "Gemini Flash 1.5",
"pricing": {
"prompt": "0",
"completion": "0"
},
"context_length": 32000,
"architecture": {
"modality": "text+image->text",
"tokenizer": "Gemini",
"instruct_type": null
},
"top_provider": {
"max_completion_tokens": null,
"is_moderated": false
},
"per_request_limits": null
}
]
});
(StatusCode::OK, [(header::CONTENT_TYPE, "application/json")], Json(body))
}
use common_contracts::messages::{StartWorkflowCommand, SyncStateCommand, WorkflowEvent};
fn create_v1_router() -> Router<AppState> {
Router::new()
// Mock LLM for E2E
.route("/mock/chat/completions", post(mock_chat_completion))
.route("/mock/models", get(mock_models))
// New Workflow API
.route("/workflow/start", post(start_workflow))
.route("/workflow/events/{request_id}", get(workflow_events_stream))
.route("/workflow/{request_id}/graph", get(get_workflow_graph_proxy))
// Tools
.route("/tools/resolve-symbol", post(resolve_symbol))
// Legacy routes (marked for removal or compatibility)
.route("/data-requests", post(trigger_data_fetch_legacy))
.route("/session-data/{request_id}", get(proxy_get_session_data))
.route("/analysis-results/stream", get(proxy_analysis_stream))
.route(
"/analysis-requests/{symbol}",
post(trigger_analysis_generation),
)
.route("/analysis-results", get(get_analysis_results_by_symbol))
.route("/companies/{symbol}/profile", get(get_company_profile))
.route(
"/market-data/financial-statements/{symbol}",
get(get_financials_by_symbol),
)
// ... Config routes remain same ...
.route(
"/configs/llm_providers",
get(get_llm_providers_config).put(update_llm_providers_config),
)
.route(
"/configs/analysis_template_sets",
get(get_analysis_template_sets).put(update_analysis_template_sets),
)
.route(
"/configs/data_sources",
get(get_data_sources_config).put(update_data_sources_config),
)
.route("/configs/test", post(test_data_source_config))
.route("/configs/llm/test", post(test_llm_config))
.route("/config", get(get_legacy_system_config))
.route("/discover-models/{provider_id}", get(discover_models))
.route("/discover-models", post(discover_models_preview))
.route("/registry/register", post(registry::register_service))
.route("/registry/heartbeat", post(registry::heartbeat))
.route("/registry/deregister", post(registry::deregister_service))
.route("/registry/providers", get(get_registered_providers))
}
// --- Legacy Config Compatibility ---
#[derive(Serialize, Default)]
struct LegacyDatabaseConfig {
url: Option<String>,
}
#[derive(Serialize, Default)]
struct LegacyNewApiConfig {
provider_id: Option<String>,
provider_name: Option<String>,
api_key: Option<String>,
base_url: Option<String>,
model_count: usize,
}
#[derive(Serialize, Default)]
struct LegacyDataSourceConfig {
provider: String,
api_key: Option<String>,
api_url: Option<String>,
enabled: bool,
}
#[derive(Serialize)]
struct LegacySystemConfigResponse {
database: LegacyDatabaseConfig,
new_api: LegacyNewApiConfig,
data_sources: HashMap<String, LegacyDataSourceConfig>,
llm_providers: LlmProvidersConfig,
analysis_template_sets: AnalysisTemplateSets,
}
async fn get_legacy_system_config(State(state): State<AppState>) -> Result<impl IntoResponse> {
let persistence = state.persistence_client.clone();
let (llm_providers, analysis_template_sets, data_sources) = try_join!(
persistence.get_llm_providers_config(),
persistence.get_analysis_template_sets(),
persistence.get_data_sources_config()
)?;
let new_api = derive_primary_provider(&llm_providers);
let ds_map = project_data_sources(data_sources);
let database_url = std::env::var("DATABASE_URL").ok();
let response = LegacySystemConfigResponse {
database: LegacyDatabaseConfig { url: database_url },
new_api,
data_sources: ds_map,
llm_providers,
analysis_template_sets,
};
Ok(Json(response))
}
fn derive_primary_provider(providers: &LlmProvidersConfig) -> LegacyNewApiConfig {
const PREFERRED_IDS: [&str; 3] = ["new_api", "openrouter", "default"];
let mut selected_id: Option<String> = None;
let mut selected_provider: Option<&LlmProvider> = None;
for preferred in PREFERRED_IDS {
if let Some(provider) = providers.get(preferred) {
selected_id = Some(preferred.to_string());
selected_provider = Some(provider);
break;
}
}
if selected_provider.is_none() {
if let Some((fallback_id, provider)) = providers.iter().next() {
selected_id = Some(fallback_id.clone());
selected_provider = Some(provider);
}
}
if let Some(provider) = selected_provider {
LegacyNewApiConfig {
provider_id: selected_id,
provider_name: Some(provider.name.clone()),
api_key: Some(provider.api_key.clone()),
base_url: Some(provider.api_base_url.clone()),
model_count: provider.models.len(),
}
} else {
LegacyNewApiConfig::default()
}
}
fn project_data_sources(
configs: DataSourcesConfig,
) -> HashMap<String, LegacyDataSourceConfig> {
configs
.0
.into_iter()
.map(|(key, cfg)| {
let provider = provider_id(&cfg.provider).to_string();
let entry = LegacyDataSourceConfig {
provider,
api_key: cfg.api_key.clone(),
api_url: cfg.api_url.clone(),
enabled: cfg.enabled,
};
(key, entry)
})
.collect()
}
fn provider_id(provider: &DataSourceProvider) -> &'static str {
match provider {
DataSourceProvider::Tushare => "tushare",
DataSourceProvider::Finnhub => "finnhub",
DataSourceProvider::Alphavantage => "alphavantage",
DataSourceProvider::Yfinance => "yfinance",
}
}
// --- Helper Functions ---
fn infer_market(symbol: &str) -> String {
if symbol.ends_with(".SS") || symbol.ends_with(".SH") {
"CN".to_string()
} else if symbol.ends_with(".HK") {
"HK".to_string()
} else {
"US".to_string()
}
}
// --- New Workflow Handlers ---
/// [POST /v1/tools/resolve-symbol]
/// Resolves and normalizes a symbol without starting a workflow.
#[utoipa::path(
post,
path = "/api/v1/tools/resolve-symbol",
request_body = SymbolResolveRequest,
responses(
(status = 200, description = "Symbol resolved", body = SymbolResolveResponse)
)
)]
async fn resolve_symbol(Json(payload): Json<SymbolResolveRequest>) -> Result<impl IntoResponse> {
let market = if let Some(m) = payload.market {
if m.is_empty() {
infer_market(&payload.symbol)
} else {
m
}
} else {
infer_market(&payload.symbol)
};
let market_enum = Market::from(market.as_str());
let normalized_symbol = CanonicalSymbol::new(&payload.symbol, &market_enum);
Ok(Json(SymbolResolveResponse {
symbol: normalized_symbol.into(),
market,
}))
}
/// [POST /v1/workflow/start]
/// Initiates a new analysis workflow via the Orchestrator.
#[utoipa::path(
post,
path = "/api/v1/workflow/start",
request_body = DataRequest,
responses(
(status = 202, description = "Workflow started", body = RequestAcceptedResponse)
)
)]
async fn start_workflow(
State(state): State<AppState>,
Json(payload): Json<DataRequest>,
) -> Result<impl IntoResponse> {
let request_id = Uuid::new_v4();
let market = if let Some(m) = payload.market {
if m.is_empty() {
infer_market(&payload.symbol)
} else {
m
}
} else {
infer_market(&payload.symbol)
};
let market_enum = Market::from(market.as_str());
let normalized_symbol = CanonicalSymbol::new(&payload.symbol, &market_enum);
let command = StartWorkflowCommand {
request_id,
symbol: normalized_symbol.clone(),
market: market.clone(),
template_id: payload.template_id,
};
info!(request_id = %request_id, "Publishing StartWorkflowCommand to Orchestrator");
state
.nats_client
.publish(
command.subject().to_string(),
serde_json::to_vec(&command).unwrap().into(),
)
.await?;
Ok((
StatusCode::ACCEPTED,
Json(RequestAcceptedResponse {
request_id,
symbol: normalized_symbol.into(),
market,
}),
))
}
/// [GET /v1/workflow/events/:request_id]
/// SSE endpoint that proxies events from NATS to the frontend.
async fn workflow_events_stream(
State(state): State<AppState>,
Path(request_id): Path<Uuid>,
) -> Result<impl IntoResponse> {
info!("Client connected to event stream for {}", request_id);
// 1. Subscribe to NATS topic FIRST to avoid race condition
// If we sync before subscribing, we might miss the snapshot response if Orchestrator is fast.
let topic = NatsSubject::WorkflowProgress(request_id).to_string();
let mut subscriber = state.nats_client.subscribe(topic).await?;
// 2. Send SyncStateCommand to ask Orchestrator for a snapshot
// This ensures if the client reconnects, they get the latest state immediately.
let sync_cmd = SyncStateCommand { request_id };
if let Err(e) = state
.nats_client
.publish(
sync_cmd.subject().to_string(),
serde_json::to_vec(&sync_cmd).unwrap().into(),
)
.await
{
error!("Failed to send SyncStateCommand: {}", e);
}
// 3. Convert NATS stream to SSE stream
let stream = async_stream::stream! {
while let Some(msg) = subscriber.next().await {
if let Ok(event) = serde_json::from_slice::<WorkflowEvent>(&msg.payload) {
match axum::response::sse::Event::default().json_data(event) {
Ok(sse_event) => yield Ok::<_, anyhow::Error>(sse_event),
Err(e) => error!("Failed to serialize SSE event: {}", e),
}
}
}
};
Ok(axum::response::Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default()))
}
// --- Legacy Handler (Renamed) ---
async fn trigger_data_fetch_legacy(
State(state): State<AppState>,
Json(payload): Json<DataRequest>,
) -> Result<impl IntoResponse> {
// Redirect to new workflow start for compatibility if possible, or keep as is for now?
// Let's just call start_workflow to gradually migrate behavior.
start_workflow(State(state), Json(payload)).await
}
#[utoipa::path(
get,
path = "/health",
responses(
(status = 200, description = "Service healthy")
)
)]
async fn health_check() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
async fn proxy_get_session_data(
State(_state): State<AppState>,
Path(_request_id): Path<Uuid>,
) -> Result<impl IntoResponse> {
Ok((
StatusCode::NOT_IMPLEMENTED,
Json(serde_json::json!({"error": "Not implemented"})),
))
}
async fn proxy_analysis_stream(State(_state): State<AppState>) -> Result<impl IntoResponse> {
Ok((
StatusCode::NOT_IMPLEMENTED,
Json(serde_json::json!({"error": "Not implemented"})),
))
}
/// [POST /v1/analysis-requests/:symbol]
/// Triggers the analysis report generation workflow by publishing a command.
async fn trigger_analysis_generation(
State(state): State<AppState>,
Path(symbol): Path<String>,
Json(payload): Json<AnalysisRequest>,
) -> Result<impl IntoResponse> {
let request_id = Uuid::new_v4();
// Try to infer market to help normalization, defaulting to US if unclear but keeping original behavior safe
let market_str = infer_market(&symbol);
let market_enum = Market::from(market_str.as_str());
let normalized_symbol = CanonicalSymbol::new(&symbol, &market_enum);
if normalized_symbol.as_str() != symbol {
info!(
"Normalized analysis request symbol '{}' to '{}'",
symbol, normalized_symbol
);
}
let command = GenerateReportCommand {
request_id,
symbol: normalized_symbol.clone(),
template_id: payload.template_id,
task_id: None,
module_id: None,
commit_hash: None,
input_bindings: None,
output_path: None,
llm_config: None,
analysis_prompt: None,
};
info!(request_id = %request_id, "Publishing analysis generation command");
state
.nats_client
.publish(
command.subject().to_string(),
serde_json::to_vec(&command).unwrap().into(),
)
.await?;
// Infer market for response consistency
let market = infer_market(normalized_symbol.as_str());
Ok((
StatusCode::ACCEPTED,
Json(RequestAcceptedResponse {
request_id,
symbol: normalized_symbol.into(),
market,
}),
))
}
/// [GET /v1/analysis-results?symbol=...]
async fn get_analysis_results_by_symbol(
State(state): State<AppState>,
Query(query): Query<AnalysisResultQuery>,
) -> Result<impl IntoResponse> {
let results = state
.persistence_client
.get_analysis_results(&query.symbol)
.await?;
Ok(Json(results))
}
/// [GET /v1/companies/:symbol/profile]
/// Queries the persisted company profile from the data-persistence-service.
async fn get_company_profile(
State(state): State<AppState>,
Path(symbol): Path<String>,
) -> Result<impl IntoResponse> {
let profile = state
.persistence_client
.get_company_profile(&symbol)
.await?;
Ok(Json(profile))
}
/// [GET /v1/market-data/financial-statements/:symbol]
async fn get_financials_by_symbol(
State(state): State<AppState>,
Path(symbol): Path<String>,
) -> Result<impl IntoResponse> {
let financials = state.persistence_client.get_financials(&symbol).await?;
Ok(Json(financials))
}
/// [GET /v1/tasks/:request_id]
/// Aggregates task progress from all downstream provider services.
#[utoipa::path(
get,
path = "/tasks/{request_id}",
params(
("request_id" = Uuid, Path, description = "Request ID to query tasks for")
),
responses(
(status = 200, description = "Task progress list", body = Vec<TaskProgress>),
(status = 404, description = "Tasks not found")
)
)]
async fn get_task_progress(
State(state): State<AppState>,
Path(request_id): Path<Uuid>,
) -> Result<impl IntoResponse> {
let client = reqwest::Client::new();
let services = state.get_all_services();
let fetches = services.iter().map(|(service_id, service_url)| {
let client = client.clone();
let url = format!("{}/tasks", service_url);
let service_id_clone = service_id.clone();
async move {
match client.get(&url).send().await {
Ok(resp) => match resp.json::<Vec<TaskProgress>>().await {
Ok(tasks) => Some(tasks),
Err(e) => {
warn!("Failed to decode tasks from {}: {}", url, e);
// Return a synthetic error task for this provider
Some(vec![TaskProgress {
request_id,
task_name: format!("{}:unreachable", service_id_clone),
status: ObservabilityTaskStatus::Failed,
progress_percent: 0,
details: "Invalid response format".to_string(),
started_at: chrono::Utc::now(),
}])
}
},
Err(e) => {
warn!("Failed to fetch tasks from {}: {}", url, e);
// Return a synthetic error task for this provider
Some(vec![TaskProgress {
request_id,
task_name: format!("{}:unreachable", service_id_clone),
status: ObservabilityTaskStatus::Failed,
progress_percent: 0,
details: format!("Connection Error: {}", e),
started_at: chrono::Utc::now(),
}])
}
}
}
});
let results = join_all(fetches).await;
let mut merged: Vec<TaskProgress> = Vec::new();
for maybe_tasks in results {
if let Some(tasks) = maybe_tasks {
merged.extend(tasks);
}
}
let tasks_for_req: Vec<TaskProgress> = merged
.into_iter()
.filter(|t| t.request_id == request_id)
.collect();
if tasks_for_req.is_empty() {
// Instead of returning 404, we should probably return an empty list if we have checked everyone
// But if we really found nothing (even synthetic errors), then 404 is fine.
// With synthetic errors, this should rarely happen unless no providers are registered.
if services.is_empty() {
warn!("No providers registered to query for tasks.");
}
return Ok((
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Task not found"})),
)
.into_response());
}
Ok(Json(tasks_for_req).into_response())
}
// --- New Config Test Handler ---
#[api_dto]
pub struct TestConfigRequest {
pub r#type: String,
pub data: serde_json::Value,
}
#[api_dto]
pub struct TestConnectionResponse {
pub success: bool,
pub message: String,
}
/// [POST /api/v1/configs/test]
/// Forwards a configuration test request to the appropriate downstream service.
#[utoipa::path(
post,
path = "/api/v1/configs/test",
request_body = TestConfigRequest,
responses(
(status = 200, description = "Configuration test result", body = TestConnectionResponse)
)
)]
async fn test_data_source_config(
State(state): State<AppState>,
Json(payload): Json<TestConfigRequest>,
) -> Result<impl IntoResponse> {
info!("test_data_source_config: type={}", payload.r#type);
// Dynamic discovery
let target_service_url = state.get_service_url(&payload.r#type);
if let Some(base_url) = target_service_url {
let client = reqwest::Client::new();
// Remove trailing slash from base_url
let clean_base = base_url.trim_end_matches('/');
// Check if it's a provider service which usually mounts test at /test
let target_url = format!("{}/test", clean_base);
info!(
"Forwarding test request for '{}' to {}",
payload.r#type, target_url
);
let response = client.post(&target_url).json(&payload.data).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
warn!(
"Downstream test for '{}' failed: status={} body={}",
payload.r#type, status, error_text
);
return Ok((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY),
Json(serde_json::json!({
"error": "Downstream service returned an error",
"details": error_text,
})),
)
.into_response());
}
let response_json: serde_json::Value = response.json().await?;
Ok((StatusCode::OK, Json(response_json)).into_response())
} else {
warn!(
"No downstream service registered for config type: {}",
payload.r#type
);
Ok((
StatusCode::NOT_IMPLEMENTED,
Json(serde_json::json!({ "error": "No downstream service registered for this type" })),
)
.into_response())
}
}
#[api_dto]
pub struct TestLlmConfigRequest {
pub api_base_url: String,
pub api_key: String,
pub model_id: String,
}
/// [POST /v1/configs/llm/test]
#[utoipa::path(
post,
path = "/api/v1/configs/llm/test",
request_body = TestLlmConfigRequest,
responses(
(status = 200, description = "LLM config test result (JSON)")
)
)]
async fn test_llm_config(
State(state): State<AppState>,
Json(payload): Json<TestLlmConfigRequest>,
) -> Result<impl IntoResponse> {
let client = reqwest::Client::new();
let target_url = format!(
"{}/test-llm",
state
.config
.report_generator_service_url
.trim_end_matches('/')
);
let response = client.post(&target_url).json(&payload).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Ok((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY),
Json(serde_json::json!({
"error": "LLM test failed",
"details": error_text,
})),
)
.into_response());
}
let response_json: serde_json::Value = response.json().await?;
Ok((StatusCode::OK, Json(response_json)).into_response())
}
// --- Config API Handlers (Proxy to data-persistence-service) ---
/// [GET /api/v1/configs/llm_providers]
#[utoipa::path(
get,
path = "/api/v1/configs/llm_providers",
responses(
(status = 200, description = "LLM providers configuration", body = LlmProvidersConfig)
)
)]
async fn get_llm_providers_config(State(state): State<AppState>) -> Result<impl IntoResponse> {
let config = state.persistence_client.get_llm_providers_config().await?;
Ok(Json(config))
}
/// [PUT /api/v1/configs/llm_providers]
#[utoipa::path(
put,
path = "/api/v1/configs/llm_providers",
request_body = LlmProvidersConfig,
responses(
(status = 200, description = "Updated LLM providers configuration", body = LlmProvidersConfig)
)
)]
async fn update_llm_providers_config(
State(state): State<AppState>,
Json(payload): Json<LlmProvidersConfig>,
) -> Result<impl IntoResponse> {
let updated_config = state
.persistence_client
.update_llm_providers_config(&payload)
.await?;
Ok(Json(updated_config))
}
/// [GET /api/v1/configs/analysis_template_sets]
#[utoipa::path(
get,
path = "/api/v1/configs/analysis_template_sets",
responses(
(status = 200, description = "Analysis template sets configuration", body = AnalysisTemplateSets)
)
)]
async fn get_analysis_template_sets(State(state): State<AppState>) -> Result<impl IntoResponse> {
let config = state
.persistence_client
.get_analysis_template_sets()
.await?;
Ok(Json(config))
}
/// [PUT /api/v1/configs/analysis_template_sets]
#[utoipa::path(
put,
path = "/api/v1/configs/analysis_template_sets",
request_body = AnalysisTemplateSets,
responses(
(status = 200, description = "Updated analysis template sets configuration", body = AnalysisTemplateSets)
)
)]
async fn update_analysis_template_sets(
State(state): State<AppState>,
Json(payload): Json<AnalysisTemplateSets>,
) -> Result<impl IntoResponse> {
let updated_config = state
.persistence_client
.update_analysis_template_sets(&payload)
.await?;
Ok(Json(updated_config))
}
/// [GET /api/v1/configs/data_sources]
#[utoipa::path(
get,
path = "/api/v1/configs/data_sources",
responses(
(status = 200, description = "Data sources configuration", body = DataSourcesConfig)
)
)]
async fn get_data_sources_config(State(state): State<AppState>) -> Result<impl IntoResponse> {
let config = state.persistence_client.get_data_sources_config().await?;
Ok(Json(config))
}
/// [PUT /api/v1/configs/data_sources]
#[utoipa::path(
put,
path = "/api/v1/configs/data_sources",
request_body = DataSourcesConfig,
responses(
(status = 200, description = "Updated data sources configuration", body = DataSourcesConfig)
)
)]
async fn update_data_sources_config(
State(state): State<AppState>,
Json(payload): Json<DataSourcesConfig>,
) -> Result<impl IntoResponse> {
let updated_config = state
.persistence_client
.update_data_sources_config(&payload)
.await?;
Ok(Json(updated_config))
}
/// [GET /api/v1/registry/providers]
/// Returns metadata for all registered data providers.
#[utoipa::path(
get,
path = "/api/v1/registry/providers",
responses(
(status = 200, description = "Registered providers metadata", body = Vec<ProviderMetadata>)
)
)]
async fn get_registered_providers(State(state): State<AppState>) -> Result<impl IntoResponse> {
// let registry = state.registry.read().unwrap(); // OLD
let entries = state.registry.get_entries();
let providers: Vec<ProviderMetadata> = entries
.into_iter()
.filter_map(|entry| {
// Only return DataProvider services that have metadata
if entry.registration.role == common_contracts::registry::ServiceRole::DataProvider {
entry.registration.metadata
} else {
None
}
})
.collect();
Ok(Json(providers))
}
/// [GET /api/v1/discover-models/:provider_id]
#[utoipa::path(
get,
path = "/api/v1/discover-models/{provider_id}",
params(
("provider_id" = String, Path, description = "Provider ID to discover models for")
),
responses(
(status = 200, description = "Discovered models (JSON)"),
(status = 404, description = "Provider not found"),
(status = 502, description = "Provider error")
)
)]
async fn discover_models(
State(state): State<AppState>,
Path(provider_id): Path<String>,
) -> Result<impl IntoResponse> {
info!("discover_models: provider_id={}", provider_id);
let providers = state.persistence_client.get_llm_providers_config().await?;
if let Some(provider) = providers.get(&provider_id) {
let client = reqwest::Client::new();
let url = format!("{}/models", provider.api_base_url.trim_end_matches('/'));
info!(
"discover_models: target_url={} (provider_id={})",
url, provider_id
);
let response = client
.get(&url)
.bearer_auth(&provider.api_key)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
warn!(
"discover_models failed: provider_id={} status={} body={}",
provider_id, status, error_text
);
// Return a structured error to the frontend
return Ok((
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
"error": "Failed to fetch models from provider",
"provider_error": error_text,
})),
)
.into_response());
}
let models_json: serde_json::Value = response.json().await?;
Ok((StatusCode::OK, Json(models_json)).into_response())
} else {
warn!("discover_models: provider not found: {}", provider_id);
Ok((
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "Provider not found" })),
)
.into_response())
}
}
#[api_dto]
pub struct DiscoverPreviewRequest {
pub api_base_url: String,
pub api_key: String,
}
/// [POST /api/v1/discover-models]
/// Preview discovery without persisting provider configuration.
#[utoipa::path(
post,
path = "/api/v1/discover-models",
request_body = DiscoverPreviewRequest,
responses(
(status = 200, description = "Discovered models (JSON)"),
(status = 502, description = "Provider error")
)
)]
async fn discover_models_preview(
Json(payload): Json<DiscoverPreviewRequest>,
) -> Result<impl IntoResponse> {
let redacted_key = if payload.api_key.is_empty() {
"<empty>"
} else {
"<redacted>"
};
info!(
"discover_models_preview: target_url={}/models api_key={}",
payload.api_base_url.trim_end_matches('/'),
redacted_key
);
let client = reqwest::Client::new();
let url = format!("{}/models", payload.api_base_url.trim_end_matches('/'));
let response = client
.get(&url)
.bearer_auth(&payload.api_key)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
warn!(
"discover_models_preview failed: status={} body={}",
status, error_text
);
return Ok((
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
"error": "Failed to fetch models from provider",
"provider_error": error_text,
})),
)
.into_response());
}
let models_json: serde_json::Value = response.json().await?;
Ok((StatusCode::OK, Json(models_json)).into_response())
}
/// [GET /v1/workflow/:request_id/graph]
async fn get_workflow_graph_proxy(
State(state): State<AppState>,
Path(request_id): Path<Uuid>,
) -> Result<impl IntoResponse> {
let url = format!(
"{}/workflows/{}/graph",
state.config.workflow_orchestrator_service_url.trim_end_matches('/'),
request_id
);
let client = reqwest::Client::new();
let resp = client.get(&url).send().await?;
let status = resp.status();
let body = resp.bytes().await?;
Ok((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
axum::body::Body::from(body),
))
}
// --- Context Inspector Proxies ---
async fn proxy_context_tree(
State(state): State<AppState>,
Path((req_id, commit_hash)): Path<(String, String)>,
Query(params): Query<HashMap<String, String>>,
) -> Result<impl IntoResponse> {
let url = format!(
"{}/context/{}/tree/{}",
state.config.workflow_orchestrator_service_url.trim_end_matches('/'),
req_id,
commit_hash
);
let client = reqwest::Client::new();
let resp = client.get(&url).query(&params).send().await?;
let status = resp.status();
let body = resp.bytes().await?;
Ok((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
axum::body::Body::from(body),
))
}
async fn proxy_context_blob(
State(state): State<AppState>,
Path((req_id, commit_hash, path)): Path<(String, String, String)>,
) -> Result<impl IntoResponse> {
let url = format!(
"{}/context/{}/blob/{}/{}",
state.config.workflow_orchestrator_service_url.trim_end_matches('/'),
req_id,
commit_hash,
path
);
let client = reqwest::Client::new();
let resp = client.get(&url).send().await?;
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.bytes().await?;
let mut response_builder = axum::http::Response::builder()
.status(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR));
if let Some(ct) = headers.get(axum::http::header::CONTENT_TYPE) {
response_builder = response_builder.header(axum::http::header::CONTENT_TYPE, ct);
}
Ok(response_builder.body(axum::body::Body::from(body)).unwrap())
}
async fn proxy_context_diff(
State(state): State<AppState>,
Path((req_id, from_commit, to_commit)): Path<(String, String, String)>,
) -> Result<impl IntoResponse> {
let url = format!(
"{}/context/{}/diff/{}/{}",
state.config.workflow_orchestrator_service_url.trim_end_matches('/'),
req_id,
from_commit,
to_commit
);
let client = reqwest::Client::new();
let resp = client.get(&url).send().await?;
let status = resp.status();
let body = resp.bytes().await?;
Ok((
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
axum::body::Body::from(body),
))
}