Fundamental_Analysis/services/workflow-orchestrator-service/src/workflow.rs

885 lines
39 KiB
Rust

use std::sync::Arc;
use common_contracts::workflow_types::{
WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, StorageConfig
};
use common_contracts::messages::{
StartWorkflowCommand, SyncStateCommand, TaskType, WorkflowEvent, TaskStatus as MsgTaskStatus,
TaskStateSnapshot
};
use common_contracts::ack::TaskAcknowledgement;
use common_contracts::subjects::SubjectMessage;
use common_contracts::symbol_utils::CanonicalSymbol;
use common_contracts::dtos::{SessionDataDto, NewWorkflowHistory};
use tracing::{info, warn, error};
use anyhow::Result;
use serde_json::json;
use tokio::sync::Mutex;
use std::collections::HashMap;
use uuid::Uuid;
use crate::dag_scheduler::DagScheduler;
use crate::state::AppState;
use crate::context_resolver::ContextResolver;
use crate::io_binder::IOBinder;
use common_contracts::configs::AnalysisModuleConfig;
use workflow_context::{Vgcs, ContextStore};
pub struct WorkflowEngine {
state: Arc<AppState>,
nats: async_nats::Client,
}
impl WorkflowEngine {
pub fn new(state: Arc<AppState>, nats: async_nats::Client) -> Self {
Self { state, nats }
}
pub async fn handle_start_workflow(&self, cmd: StartWorkflowCommand) -> Result<()> {
let req_id = cmd.request_id;
info!(request_id = %req_id, "Starting workflow");
info!(request_id = %req_id, "Workflow started. Initializing...");
// 1. Init VGCS Repo
self.state.vgcs.init_repo(&req_id.to_string())?;
// 2. Create Initial Commit (Context Baseline)
let mut tx = self.state.vgcs.begin_transaction(&req_id.to_string(), "")?;
let request_info = json!({
"request_id": req_id,
"symbol": cmd.symbol.as_str(),
"market": cmd.market,
"template_id": cmd.template_id,
"timestamp": chrono::Utc::now().to_rfc3339()
});
tx.write("request.json", serde_json::to_vec_pretty(&request_info)?.as_slice())?;
let initial_commit = Box::new(tx).commit("Initial Workflow Setup", "System")?;
// 3. Create Scheduler with Initial Commit
let mut dag = DagScheduler::new(req_id, initial_commit);
// 4. Fetch Template Config
info!("Fetching template configuration...");
let template = self.state.persistence_client.get_template_by_id(&cmd.template_id).await.map_err(|e| {
anyhow::anyhow!("Failed to fetch template {}: {}", cmd.template_id, e)
})?;
// 3.1 Fetch Data Sources Config
let data_sources = self.state.persistence_client.get_data_sources_config().await?;
// 4. Build DAG
self.build_dag(&mut dag, &template, &data_sources, &cmd.template_id, &cmd.market, &cmd.symbol);
// 5. Save State
self.state.workflows.insert(req_id, Arc::new(Mutex::new(dag.clone())));
// 6. Publish WorkflowStarted Event
let event = WorkflowEvent::WorkflowStarted {
timestamp: chrono::Utc::now().timestamp_millis(),
task_graph: dag.to_dto(),
};
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
if let Ok(payload) = serde_json::to_vec(&event) {
if let Err(e) = self.nats.publish(subject, payload.into()).await {
error!("Failed to publish WorkflowStarted event: {}", e);
}
}
info!("DAG built and workflow initialized.");
// 6. Trigger Initial Tasks
let initial_tasks = dag.get_initial_tasks();
// Lock the DAG for updates
let dag_arc = self.state.workflows.get(&req_id).unwrap().clone();
let mut dag_guard = dag_arc.lock().await;
for task_id in initial_tasks {
info!("Starting dispatch for initial task: {}", task_id);
self.dispatch_task(&mut dag_guard, &task_id, &self.state.vgcs).await?;
}
Ok(())
}
pub async fn handle_sync_state(&self, cmd: SyncStateCommand) -> Result<()> {
let req_id = cmd.request_id;
// No held span guard across await
info!(request_id = %req_id, "Handling SyncStateCommand");
let dag_arc = match self.state.workflows.get(&req_id) {
Some(d) => d.clone(),
None => {
warn!("Received sync request for unknown workflow {}", req_id);
return Ok(());
}
};
let dag = dag_arc.lock().await;
// Map internal status to DTO status
let mut tasks_status = std::collections::HashMap::new();
for task_id in dag.nodes.keys() {
let status = match dag.get_status(task_id) {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
tasks_status.insert(task_id.clone(), status);
}
// Construct Comprehensive Task States
let mut task_states = HashMap::new();
// 1. Add active buffers
for (task_id, buffer) in &dag.task_execution_states {
// If task is not in nodes for some reason, skip
let status = dag.get_status(task_id);
let input_commit = dag.nodes.get(task_id).and_then(|n| n.input_commit.clone());
let output_commit = dag.commit_tracker.task_commits.get(task_id).cloned();
let metadata = dag.commit_tracker.task_metadata.get(task_id).cloned();
let status_dto = match status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
// Safety: Truncate content if too large to avoid NATS payload limits (1MB default)
// We reserve some space for JSON overhead.
// If content is > 100KB, we truncate and append a warning.
// The full content should be fetched via separate API if needed.
let content_snapshot = if buffer.content_buffer.len() > 100_000 {
let mut s = buffer.content_buffer.chars().take(100_000).collect::<String>();
s.push_str("\n... [Content Truncated in Snapshot] ...");
Some(s)
} else {
Some(buffer.content_buffer.clone())
};
// Safety: Limit logs in snapshot
let logs_snapshot = if buffer.logs.len() > 100 {
buffer.logs.iter().rev().take(100).rev().cloned().collect()
} else {
buffer.logs.clone()
};
task_states.insert(task_id.clone(), TaskStateSnapshot {
task_id: task_id.clone(),
status: status_dto,
logs: logs_snapshot,
content: content_snapshot,
input_commit,
output_commit,
metadata,
});
}
// 2. Add remaining tasks (no buffer)
for task_id in dag.nodes.keys() {
if !task_states.contains_key(task_id) {
let status = dag.get_status(task_id);
let input_commit = dag.nodes.get(task_id).and_then(|n| n.input_commit.clone());
let output_commit = dag.commit_tracker.task_commits.get(task_id).cloned();
let metadata = dag.commit_tracker.task_metadata.get(task_id).cloned();
let status_dto = match status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
task_states.insert(task_id.clone(), TaskStateSnapshot {
task_id: task_id.clone(),
status: status_dto,
logs: vec![],
content: None,
input_commit,
output_commit,
metadata,
});
}
}
// Read buffered logs for replay
let logs = self.state.log_manager.read_current_logs(&req_id.to_string()).unwrap_or_default();
// Create Snapshot Event
let event = WorkflowEvent::WorkflowStateSnapshot {
timestamp: chrono::Utc::now().timestamp_millis(),
task_graph: dag.to_dto(),
tasks_status,
tasks_output: dag.commit_tracker.task_commits.clone().into_iter().map(|(k, v)| (k, Some(v))).collect(),
tasks_metadata: dag.commit_tracker.task_metadata.clone(),
task_states, // NEW
logs,
};
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
if let Ok(payload) = serde_json::to_vec(&event) {
if let Err(e) = self.nats.publish(subject, payload.into()).await {
error!("Failed to publish WorkflowStateSnapshot: {}", e);
}
}
Ok(())
}
// --- New Handler Methods for Stream Capture ---
pub async fn handle_task_stream_update(&self, task_id: String, content: String, req_id: Uuid) -> Result<()> {
if let Some(dag_arc) = self.state.workflows.get(&req_id) {
let mut dag = dag_arc.lock().await;
dag.append_content(&task_id, &content);
// We do NOT re-publish here. The Orchestrator listens to the public event stream
// merely to accumulate state for history/resume. The frontend receives the
// original event directly from the provider.
// Re-publishing would cause an infinite loop if the consumer listens to the same topic.
}
Ok(())
}
pub async fn handle_task_log(&self, task_id: String, log: String, req_id: Uuid) -> Result<()> {
if let Some(dag_arc) = self.state.workflows.get(&req_id) {
let mut dag = dag_arc.lock().await;
dag.append_log(&task_id, log.clone());
// We do NOT re-publish here. See handle_task_stream_update.
}
Ok(())
}
pub async fn handle_task_completed(&self, evt: WorkflowTaskEvent) -> Result<()> {
let req_id = evt.request_id;
// No held span guard
let dag_arc = match self.state.workflows.get(&req_id) {
Some(d) => d.clone(),
None => {
warn!("Received event for unknown workflow {}", req_id);
return Ok(());
}
};
let mut dag = dag_arc.lock().await;
// 1. Update Status & Record Commit
dag.update_status(&evt.task_id, evt.status);
info!("Task {} status changed to {:?}", evt.task_id, evt.status);
// Lookup task_type
let task_type = dag.nodes.get(&evt.task_id).map(|n| n.task_type).unwrap_or(TaskType::DataFetch);
// Convert status
let msg_status = match evt.status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped, // Map Cancelled to Skipped
};
// Extract error message if any
let error_message = if let Some(ref result) = evt.result {
result.error.clone()
} else {
None
};
// Resolve commits
let input_commit = dag.nodes.get(&evt.task_id).and_then(|n| n.input_commit.clone());
let output_commit = evt.result.as_ref().and_then(|r| r.new_commit.clone());
// Publish TaskStateChanged event
let progress_event = WorkflowEvent::TaskStateChanged {
task_id: evt.task_id.clone(),
task_type,
status: msg_status,
message: error_message,
timestamp: chrono::Utc::now().timestamp_millis(),
progress: None,
input_commit: input_commit,
output_commit: output_commit,
};
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
if let Ok(payload) = serde_json::to_vec(&progress_event) {
if let Err(e) = self.nats.publish(subject, payload.into()).await {
error!("Failed to publish progress event: {}", e);
}
}
if let Some(result) = evt.result {
if let Some(commit) = result.new_commit {
info!("Task {} produced commit {}", evt.task_id, commit);
dag.record_result(&evt.task_id, Some(commit));
}
if let Some(summary) = result.summary {
dag.record_metadata(&evt.task_id, summary);
}
if let Some(err) = result.error {
warn!("Task {} failed with error: {}", evt.task_id, err);
}
}
// 2. Check for downstream tasks
if evt.status == TaskStatus::Completed || evt.status == TaskStatus::Failed || evt.status == TaskStatus::Skipped {
let ready_tasks = dag.get_ready_downstream_tasks(&evt.task_id);
for task_id in ready_tasks {
if let Err(e) = self.dispatch_task(&mut dag, &task_id, &self.state.vgcs).await {
error!("Failed to dispatch task {}: {}", task_id, e);
info!("Failed to dispatch task {}: {}", task_id, e);
}
}
}
// 3. Check Workflow Completion
if dag.try_finish_workflow() {
let end_time = chrono::Utc::now();
let timestamp = end_time.timestamp_millis();
// --- Log Persistence (New) ---
let req_id_clone_for_log = req_id;
let vgcs_for_log = self.state.vgcs.clone();
let log_manager = self.state.log_manager.clone();
// We run this blocking operation here or spawn it?
// Spawn is safer to not block the loop, but we want it part of the "completion".
// Let's spawn, but it's fine if it's slightly async.
tokio::spawn(async move {
match log_manager.finalize(&req_id_clone_for_log.to_string()) {
Ok(log_content) => {
if !log_content.is_empty() {
let result = tokio::task::spawn_blocking(move || -> Result<String> {
let mut tx = vgcs_for_log.begin_transaction(&req_id_clone_for_log.to_string(), "")?;
tx.write("workflow.log", log_content.as_bytes())?;
// We use "System" as author
tx.commit("Persist Workflow Logs", "System")
}).await;
match result {
Ok(Ok(commit)) => info!("Persisted workflow logs to VGCS commit: {}", commit),
Ok(Err(e)) => error!("Failed to commit workflow logs: {}", e),
Err(e) => error!("Failed to join log persistence task: {}", e),
}
}
},
Err(e) => error!("Failed to finalize logs: {}", e),
}
});
// --- Snapshot Persistence ---
let tasks_status_map = dag.nodes.iter().map(|(k, n)| {
let status = match n.status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
(k.clone(), status)
}).collect::<std::collections::HashMap<_,_>>();
let tasks_output_map = dag.commit_tracker.task_commits.clone().into_iter().map(|(k, v)| (k, Some(v))).collect::<std::collections::HashMap<_,_>>();
let tasks_metadata_map = dag.commit_tracker.task_metadata.clone();
// Construct Comprehensive Task States for final snapshot
let mut task_states = HashMap::new();
for (task_id, buffer) in &dag.task_execution_states {
let status = dag.get_status(task_id);
let input_commit = dag.nodes.get(task_id).and_then(|n| n.input_commit.clone());
let output_commit = dag.commit_tracker.task_commits.get(task_id).cloned();
let metadata = dag.commit_tracker.task_metadata.get(task_id).cloned();
let status_dto = match status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
task_states.insert(task_id.clone(), TaskStateSnapshot {
task_id: task_id.clone(),
status: status_dto,
logs: buffer.logs.clone(),
content: Some(buffer.content_buffer.clone()),
input_commit,
output_commit,
metadata,
});
}
// Add remaining tasks
for task_id in dag.nodes.keys() {
if !task_states.contains_key(task_id) {
let status = dag.get_status(task_id);
let input_commit = dag.nodes.get(task_id).and_then(|n| n.input_commit.clone());
let output_commit = dag.commit_tracker.task_commits.get(task_id).cloned();
let metadata = dag.commit_tracker.task_metadata.get(task_id).cloned();
let status_dto = match status {
TaskStatus::Pending => MsgTaskStatus::Pending,
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
TaskStatus::Running => MsgTaskStatus::Running,
TaskStatus::Completed => MsgTaskStatus::Completed,
TaskStatus::Failed => MsgTaskStatus::Failed,
TaskStatus::Skipped => MsgTaskStatus::Skipped,
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
};
task_states.insert(task_id.clone(), TaskStateSnapshot {
task_id: task_id.clone(),
status: status_dto,
logs: vec![],
content: None,
input_commit,
output_commit,
metadata,
});
}
}
let snapshot_event = WorkflowEvent::WorkflowStateSnapshot {
timestamp,
task_graph: dag.to_dto(),
tasks_status: tasks_status_map,
tasks_output: tasks_output_map,
tasks_metadata: tasks_metadata_map,
logs: self.state.log_manager.read_current_logs(&req_id.to_string()).unwrap_or_default(),
task_states: std::collections::HashMap::new(),
};
let symbol = dag.nodes.values().next()
.and_then(|n| n.config.get("symbol"))
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let market = dag.nodes.values().next()
.and_then(|n| n.config.get("market"))
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let template_id = dag.nodes.values().next()
.and_then(|n| n.config.get("template_id"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if let Ok(payload) = serde_json::to_value(&snapshot_event) {
// 1. Save Legacy Session Data (Workflow Snapshot)
let session_data = SessionDataDto {
request_id: req_id,
symbol: symbol.clone(),
provider: "orchestrator".to_string(),
data_type: "workflow_snapshot".to_string(),
data_payload: payload.clone(),
created_at: None,
};
// 2. Save New Workflow History
let _start_time = dag.start_time;
let start_time_val = end_time;
let has_failures = dag.has_failures();
let status_str = if has_failures { "Failed" } else { "Completed" }.to_string();
let history = NewWorkflowHistory {
request_id: req_id,
symbol: symbol.clone(),
market: market.clone(),
template_id,
status: status_str,
start_time: start_time_val, // TODO: Add start_time to DAG
end_time: Some(end_time),
snapshot_data: payload,
};
let persistence = self.state.persistence_client.clone();
let req_id_clone = req_id;
tokio::spawn(async move {
// Save session data (Legacy/Raw)
if let Err(e) = persistence.insert_session_data(&session_data).await {
error!("Failed to save workflow snapshot (session_data) for {}: {}", req_id_clone, e);
}
// Save Workflow History (New)
if let Err(e) = persistence.create_workflow_history(&history).await {
error!("Failed to save workflow history for {}: {}", req_id_clone, e);
} else {
info!("Workflow history saved for {}", req_id_clone);
}
});
}
let event = if dag.has_failures() {
info!("Workflow {} failed (some tasks failed)", req_id);
info!("Workflow finished with failures.");
WorkflowEvent::WorkflowFailed {
end_timestamp: timestamp,
reason: "Some tasks failed".to_string(),
is_fatal: false,
}
} else {
info!("Workflow {} completed successfully", req_id);
info!("Workflow completed successfully.");
WorkflowEvent::WorkflowCompleted {
end_timestamp: timestamp,
result_summary: Some(json!({})),
}
};
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
if let Ok(payload) = serde_json::to_vec(&event) {
if let Err(e) = self.nats.publish(subject, payload.into()).await {
error!("Failed to publish completion event: {}", e);
}
}
}
Ok(())
}
async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> {
// 1. Resolve Context (Merge if needed)
let mut context = dag.resolve_context(task_id, vgcs)?;
// Store the input commit in the node for observability
if let Some(base_commit) = &context.base_commit {
dag.set_input_commit(task_id, base_commit.clone());
}
// 2. Update Status
dag.update_status(task_id, TaskStatus::Scheduled);
info!("Task {} scheduled and dispatched.", task_id);
// 3. Construct Command
let (routing_key, task_type, mut config, display_name) = {
let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?;
(node.routing_key.clone(), node.task_type, node.config.clone(), node.display_name.clone())
};
// --- Resolution Phase ---
let symbol = self.get_symbol_from_config(&config);
// 3.1 IO Binding
let io_binder = IOBinder::new();
let output_path = io_binder.allocate_output_path(task_type, &symbol, task_id, display_name.as_deref());
if let Some(obj) = config.as_object_mut() {
obj.insert("output_path".to_string(), serde_json::Value::String(output_path.clone()));
}
// 3.2 Context Resolution (for Analysis)
if task_type == TaskType::Analysis {
// We might update the base_commit if we write a trace file
let mut current_base_commit = context.base_commit.clone().unwrap_or_default();
if let Some(module_config_val) = config.get("_module_config") {
if let Ok(module_config) = serde_json::from_value::<AnalysisModuleConfig>(module_config_val.clone()) {
let mut variables = std::collections::HashMap::new();
variables.insert("symbol".to_string(), symbol.clone());
if let Some(market) = config.get("market").and_then(|v| v.as_str()) {
variables.insert("market".to_string(), market.to_string());
}
let resolver = ContextResolver::new(self.state.vgcs.clone());
// Fetch LLM providers for context resolution (Hybrid/Auto modes)
let llm_providers = match self.state.persistence_client.get_llm_providers_config().await {
Ok(p) => p,
Err(e) => {
warn!("Failed to fetch LLM providers config (non-fatal for Manual mode): {}", e);
common_contracts::config_models::LlmProvidersConfig::default()
}
};
match resolver.resolve_input(&module_config.context_selector, &dag.request_id.to_string(), &current_base_commit, &variables, &llm_providers, &module_config.analysis_prompt).await {
Ok(resolution) => {
// 1. Inject Input Bindings
if let Some(obj) = config.as_object_mut() {
obj.insert("input_bindings".to_string(), serde_json::to_value(&resolution.paths)?);
}
info!("Context resolution successful. Injecting bindings.");
// 2. Write Trace Sidecar to VGCS
let trace_path = io_binder.allocate_trace_path(task_type, &symbol, task_id, display_name.as_deref());
// Use a blocking task for VGCS write/commit to avoid async issues with standard IO
let vgcs = self.state.vgcs.clone();
let req_id_str = dag.request_id.to_string();
let base_commit_for_write = current_base_commit.clone();
let trace_content = resolution.trace.clone();
let task_id_str = task_id.to_string();
let trace_path_clone = trace_path.clone();
let trace_commit_res = tokio::task::spawn_blocking(move || -> Result<String> {
let mut tx = vgcs.begin_transaction(&req_id_str, &base_commit_for_write)?;
tx.write(&trace_path_clone, trace_content.as_bytes())?;
let new_commit = Box::new(tx).commit(&format!("Context Resolution Trace for {}", task_id_str), "Orchestrator")?;
Ok(new_commit)
}).await;
match trace_commit_res {
Ok(Ok(new_commit)) => {
info!("Written context resolution trace to {} (Commit: {})", trace_path, new_commit);
// Update the base commit for the worker, so it sees the trace file (linear history)
current_base_commit = new_commit;
// Update the context passed to the worker
context.base_commit = Some(current_base_commit.clone());
// Also update the DAG node's input commit for observability
// Note: dag is locked in this scope, we can modify it but we need to handle scope issues if we were using dag inside closure.
// We are outside closure here.
dag.set_input_commit(task_id, current_base_commit);
},
Ok(Err(e)) => {
error!("Failed to write trace file: {}", e);
warn!("Failed to write trace file: {}", e);
},
Err(e) => error!("Failed to join trace write task: {}", e),
}
},
Err(e) => {
error!("Context resolution failed for task {}: {}", task_id, e);
warn!("Context resolution failed: {}", e);
// We proceed, but the worker might fail if it relies on inputs
}
}
}
}
}
// Capture for event
let input_commit_for_event = context.base_commit.clone();
let cmd = WorkflowTaskCommand {
request_id: dag.request_id,
task_id: task_id.to_string(),
routing_key: routing_key.clone(),
config, // Use modified config
context,
storage: StorageConfig {
root_path: self.state.config.workflow_data_path.clone(),
},
};
// 4. Publish with Handshake (Request-Reply)
let subject = cmd.subject().to_string();
let payload = serde_json::to_vec(&cmd)?;
info!("Dispatching task {} to subject {} (waiting for ack)", task_id, subject);
let request_timeout = std::time::Duration::from_secs(5);
let request_future = self.nats.request(subject.clone(), payload.into());
match tokio::time::timeout(request_timeout, request_future).await {
Ok(Ok(msg)) => {
// Parse Ack
match serde_json::from_slice::<TaskAcknowledgement>(&msg.payload) {
Ok(TaskAcknowledgement::Accepted) => {
info!("Task {} accepted by provider.", task_id);
// Task proceeds normally
},
Ok(TaskAcknowledgement::Rejected { reason }) => {
let err_msg = format!("Task rejected by provider: {}", reason);
warn!("Task {} rejected: {}", task_id, reason);
// Mark failed immediately
dag.update_status(task_id, TaskStatus::Failed);
error!("{}", err_msg);
// Emit failure event so frontend knows
let failure_event = WorkflowEvent::TaskStateChanged {
task_id: task_id.to_string(),
task_type,
status: MsgTaskStatus::Failed,
message: Some(err_msg.clone()),
timestamp: chrono::Utc::now().timestamp_millis(),
progress: None,
input_commit: input_commit_for_event,
output_commit: None,
};
// ... publish failure event ...
let subject_prog = common_contracts::subjects::NatsSubject::WorkflowProgress(dag.request_id).to_string();
if let Ok(p) = serde_json::to_vec(&failure_event) {
let _ = self.nats.publish(subject_prog, p.into()).await;
}
return Err(anyhow::anyhow!(err_msg));
},
Err(e) => {
// Invalid Ack format, assume failure
let err_msg = format!("Invalid Ack from provider: {}", e);
error!("{}", err_msg);
dag.update_status(task_id, TaskStatus::Failed);
return Err(anyhow::anyhow!(err_msg));
}
}
},
Ok(Err(e)) => {
let err_msg = format!("NATS Request failed: {}", e);
error!("Task {} dispatch error: {}", task_id, e);
dag.update_status(task_id, TaskStatus::Failed);
return Err(anyhow::anyhow!(err_msg));
},
Err(_) => {
let err_msg = "Dispatch timeout (no ack from provider in 5s)";
error!("Task {} {}", task_id, err_msg);
dag.update_status(task_id, TaskStatus::Failed);
return Err(anyhow::anyhow!(err_msg));
}
}
Ok(())
}
fn get_symbol_from_config(&self, config: &serde_json::Value) -> String {
config.get("symbol").and_then(|v| v.as_str()).unwrap_or("unknown").to_string()
}
// Helper to build DAG
fn build_dag(
&self,
dag: &mut DagScheduler,
template: &common_contracts::config_models::AnalysisTemplateSet,
data_sources: &common_contracts::config_models::DataSourcesConfig,
template_id: &str,
market: &str,
symbol: &CanonicalSymbol
) {
// 1. Data Fetch Nodes
let mut fetch_tasks = Vec::new();
// Use all enabled data sources regardless of market
// The provider itself will decide whether to skip or process based on market support.
// We sort keys to ensure deterministic DAG generation.
let mut source_keys: Vec<&String> = data_sources.keys().collect();
source_keys.sort();
for key in source_keys {
let config = &data_sources[key];
if config.enabled {
// Special handling for MOCK market: skip real providers
if market == "MOCK" && key.to_lowercase() != "mock" {
continue;
}
let provider_key = key.to_lowercase();
let task_id = format!("fetch:{}", provider_key);
fetch_tasks.push(task_id.clone());
let display_name = format!("Data Fetch ({:?})", config.provider);
let routing_key = format!("provider.{}", provider_key);
dag.add_node(
task_id.clone(),
Some(display_name),
TaskType::DataFetch,
routing_key,
json!({
"symbol": symbol.as_str(),
"market": market,
"template_id": template_id
})
);
}
}
// Fallback for MOCK if not in config (usually mock is not in data_sources.json but hardcoded for tests)
if market == "MOCK" && fetch_tasks.is_empty() {
let task_id = "fetch:mock".to_string();
fetch_tasks.push(task_id.clone());
let (actual_symbol, sim_mode) = if symbol.as_str().contains('|') {
let parts: Vec<&str> = symbol.as_str().split('|').collect();
(parts[0], Some(parts[1]))
} else {
(symbol.as_str(), None)
};
let mut config = json!({
"symbol": actual_symbol,
"market": market
});
if let Some(mode) = sim_mode {
if let Some(obj) = config.as_object_mut() {
obj.insert("simulation_mode".to_string(), serde_json::Value::String(mode.to_string()));
}
}
dag.add_node(
task_id,
Some("Data Fetch (Mock)".to_string()),
TaskType::DataFetch,
"provider.mock".to_string(),
config
);
}
if fetch_tasks.is_empty() {
warn!("No enabled data providers found in configuration.");
}
// 2. Analysis Nodes (Dynamic from Template)
for (module_id, module_config) in &template.modules {
let task_id = format!("analysis:{}", module_id);
let mut node_config = json!({
"template_id": template_id,
"module_id": module_id,
"symbol": symbol.as_str(),
"market": market
});
// Embed internal module config for Orchestrator use (Context Resolution)
if let Some(obj) = node_config.as_object_mut() {
obj.insert("_module_config".to_string(), serde_json::to_value(module_config).unwrap_or(serde_json::Value::Null));
}
dag.add_node(
task_id.clone(),
Some(module_config.name.clone()),
TaskType::Analysis,
"analysis.report".to_string(), // routing_key matches what report-generator consumes
node_config
);
// Dependencies
if module_config.dependencies.is_empty() {
// If no analysis dependencies, depend on Data Fetch
for fetch_task in &fetch_tasks {
dag.add_dependency(fetch_task, &task_id);
}
} else {
// Depend on other analysis modules
for dep_module_id in &module_config.dependencies {
let dep_task_id = format!("analysis:{}", dep_module_id);
dag.add_dependency(&dep_task_id, &task_id);
}
}
}
}
}