use std::sync::Arc; use common_contracts::workflow_types::{ WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, StorageConfig }; use common_contracts::messages::{ StartWorkflowCommand, SyncStateCommand, TaskType, WorkflowEvent, TaskStatus as MsgTaskStatus }; use common_contracts::subjects::SubjectMessage; use common_contracts::symbol_utils::CanonicalSymbol; use tracing::{info, warn, error}; use anyhow::Result; use serde_json::json; use tokio::sync::Mutex; use crate::dag_scheduler::DagScheduler; use crate::state::AppState; use crate::context_resolver::ContextResolver; use crate::io_binder::IOBinder; use common_contracts::configs::AnalysisModuleConfig; use workflow_context::{Vgcs, ContextStore}; pub struct WorkflowEngine { state: Arc, nats: async_nats::Client, } impl WorkflowEngine { pub fn new(state: Arc, nats: async_nats::Client) -> Self { Self { state, nats } } pub async fn handle_start_workflow(&self, cmd: StartWorkflowCommand) -> Result<()> { let req_id = cmd.request_id; info!("Starting workflow {}", req_id); self.publish_log(req_id, "workflow", "INFO", "Workflow started. Initializing...").await; // 1. Init VGCS Repo self.state.vgcs.init_repo(&req_id.to_string())?; // 2. Create Initial Commit (Context Baseline) let mut tx = self.state.vgcs.begin_transaction(&req_id.to_string(), "")?; let request_info = json!({ "request_id": req_id, "symbol": cmd.symbol.as_str(), "market": cmd.market, "template_id": cmd.template_id, "timestamp": chrono::Utc::now().to_rfc3339() }); tx.write("request.json", serde_json::to_vec_pretty(&request_info)?.as_slice())?; let initial_commit = Box::new(tx).commit("Initial Workflow Setup", "System")?; // 3. Create Scheduler with Initial Commit let mut dag = DagScheduler::new(req_id, initial_commit); // 4. Fetch Template Config self.publish_log(req_id, "workflow", "INFO", "Fetching template configuration...").await; let template_sets = self.state.persistence_client.get_analysis_template_sets().await?; let template = template_sets.get(&cmd.template_id).ok_or_else(|| { anyhow::anyhow!("Template {} not found", cmd.template_id) })?; // 3.1 Fetch Data Sources Config let data_sources = self.state.persistence_client.get_data_sources_config().await?; // 4. Build DAG self.build_dag(&mut dag, template, &data_sources, &cmd.template_id, &cmd.market, &cmd.symbol); // 5. Save State self.state.workflows.insert(req_id, Arc::new(Mutex::new(dag.clone()))); // 6. Publish WorkflowStarted Event let event = WorkflowEvent::WorkflowStarted { timestamp: chrono::Utc::now().timestamp_millis(), task_graph: dag.to_dto(), }; let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string(); if let Ok(payload) = serde_json::to_vec(&event) { if let Err(e) = self.nats.publish(subject, payload.into()).await { error!("Failed to publish WorkflowStarted event: {}", e); } } self.publish_log(req_id, "workflow", "INFO", "DAG built and workflow initialized.").await; // 6. Trigger Initial Tasks let initial_tasks = dag.get_initial_tasks(); // Lock the DAG for updates let dag_arc = self.state.workflows.get(&req_id).unwrap().clone(); let mut dag_guard = dag_arc.lock().await; for task_id in initial_tasks { self.dispatch_task(&mut dag_guard, &task_id, &self.state.vgcs).await?; } Ok(()) } pub async fn handle_sync_state(&self, cmd: SyncStateCommand) -> Result<()> { let req_id = cmd.request_id; info!("Handling SyncStateCommand for {}", req_id); let dag_arc = match self.state.workflows.get(&req_id) { Some(d) => d.clone(), None => { warn!("Received sync request for unknown workflow {}", req_id); return Ok(()); } }; let dag = dag_arc.lock().await; // Map internal status to DTO status let mut tasks_status = std::collections::HashMap::new(); for task_id in dag.nodes.keys() { let status = match dag.get_status(task_id) { TaskStatus::Pending => MsgTaskStatus::Pending, TaskStatus::Scheduled => MsgTaskStatus::Scheduled, TaskStatus::Running => MsgTaskStatus::Running, TaskStatus::Completed => MsgTaskStatus::Completed, TaskStatus::Failed => MsgTaskStatus::Failed, TaskStatus::Skipped => MsgTaskStatus::Skipped, TaskStatus::Cancelled => MsgTaskStatus::Skipped, }; tasks_status.insert(task_id.clone(), status); } // Create Snapshot Event let event = WorkflowEvent::WorkflowStateSnapshot { timestamp: chrono::Utc::now().timestamp_millis(), task_graph: dag.to_dto(), tasks_status, tasks_output: dag.commit_tracker.task_commits.clone().into_iter().map(|(k, v)| (k, Some(v))).collect(), }; let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string(); if let Ok(payload) = serde_json::to_vec(&event) { if let Err(e) = self.nats.publish(subject, payload.into()).await { error!("Failed to publish WorkflowStateSnapshot: {}", e); } } Ok(()) } pub async fn handle_task_completed(&self, evt: WorkflowTaskEvent) -> Result<()> { let req_id = evt.request_id; let dag_arc = match self.state.workflows.get(&req_id) { Some(d) => d.clone(), None => { warn!("Received event for unknown workflow {}", req_id); return Ok(()); } }; let mut dag = dag_arc.lock().await; // 1. Update Status & Record Commit dag.update_status(&evt.task_id, evt.status); self.publish_log(req_id, &evt.task_id, "INFO", &format!("Task status changed to {:?}", evt.status)).await; // Lookup task_type let task_type = dag.nodes.get(&evt.task_id).map(|n| n.task_type).unwrap_or(TaskType::DataFetch); // Convert status let msg_status = match evt.status { TaskStatus::Pending => MsgTaskStatus::Pending, TaskStatus::Scheduled => MsgTaskStatus::Scheduled, TaskStatus::Running => MsgTaskStatus::Running, TaskStatus::Completed => MsgTaskStatus::Completed, TaskStatus::Failed => MsgTaskStatus::Failed, TaskStatus::Skipped => MsgTaskStatus::Skipped, TaskStatus::Cancelled => MsgTaskStatus::Skipped, // Map Cancelled to Skipped }; // Extract error message if any let error_message = if let Some(ref result) = evt.result { result.error.clone() } else { None }; // Resolve commits let input_commit = dag.nodes.get(&evt.task_id).and_then(|n| n.input_commit.clone()); let output_commit = evt.result.as_ref().and_then(|r| r.new_commit.clone()); // Publish TaskStateChanged event let progress_event = WorkflowEvent::TaskStateChanged { task_id: evt.task_id.clone(), task_type, status: msg_status, message: error_message, timestamp: chrono::Utc::now().timestamp_millis(), progress: None, input_commit: input_commit, output_commit: output_commit, }; let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string(); if let Ok(payload) = serde_json::to_vec(&progress_event) { if let Err(e) = self.nats.publish(subject, payload.into()).await { error!("Failed to publish progress event: {}", e); } } if let Some(result) = evt.result { if let Some(commit) = result.new_commit { info!("Task {} produced commit {}", evt.task_id, commit); dag.record_result(&evt.task_id, Some(commit)); } if let Some(err) = result.error { warn!("Task {} failed with error: {}", evt.task_id, err); } } // 2. Check for downstream tasks if evt.status == TaskStatus::Completed || evt.status == TaskStatus::Failed || evt.status == TaskStatus::Skipped { let ready_tasks = dag.get_ready_downstream_tasks(&evt.task_id); for task_id in ready_tasks { if let Err(e) = self.dispatch_task(&mut dag, &task_id, &self.state.vgcs).await { error!("Failed to dispatch task {}: {}", task_id, e); self.publish_log(req_id, &task_id, "ERROR", &format!("Failed to dispatch task: {}", e)).await; } } } // 3. Check Workflow Completion if dag.is_workflow_finished() { let timestamp = chrono::Utc::now().timestamp_millis(); let event = if dag.has_failures() { info!("Workflow {} failed (some tasks failed)", req_id); self.publish_log(req_id, "workflow", "ERROR", "Workflow finished with failures.").await; WorkflowEvent::WorkflowFailed { end_timestamp: timestamp, reason: "Some tasks failed".to_string(), is_fatal: false, } } else { info!("Workflow {} completed successfully", req_id); self.publish_log(req_id, "workflow", "INFO", "Workflow completed successfully.").await; WorkflowEvent::WorkflowCompleted { end_timestamp: timestamp, result_summary: Some(json!({})), } }; let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string(); if let Ok(payload) = serde_json::to_vec(&event) { if let Err(e) = self.nats.publish(subject, payload.into()).await { error!("Failed to publish completion event: {}", e); } } } Ok(()) } async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> { // 1. Resolve Context (Merge if needed) let mut context = dag.resolve_context(task_id, vgcs)?; // Store the input commit in the node for observability if let Some(base_commit) = &context.base_commit { dag.set_input_commit(task_id, base_commit.clone()); } // 2. Update Status dag.update_status(task_id, TaskStatus::Scheduled); self.publish_log(dag.request_id, task_id, "INFO", "Task scheduled and dispatched.").await; // 3. Construct Command let (routing_key, task_type, mut config) = { let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?; (node.routing_key.clone(), node.task_type, node.config.clone()) }; // --- Resolution Phase --- let symbol = self.get_symbol_from_config(&config); // 3.1 IO Binding let io_binder = IOBinder::new(); let output_path = io_binder.allocate_output_path(task_type, &symbol, task_id); if let Some(obj) = config.as_object_mut() { obj.insert("output_path".to_string(), serde_json::Value::String(output_path.clone())); } // 3.2 Context Resolution (for Analysis) if task_type == TaskType::Analysis { // We might update the base_commit if we write a trace file let mut current_base_commit = context.base_commit.clone().unwrap_or_default(); if let Some(module_config_val) = config.get("_module_config") { if let Ok(module_config) = serde_json::from_value::(module_config_val.clone()) { let mut variables = std::collections::HashMap::new(); variables.insert("symbol".to_string(), symbol.clone()); if let Some(market) = config.get("market").and_then(|v| v.as_str()) { variables.insert("market".to_string(), market.to_string()); } let resolver = ContextResolver::new(self.state.vgcs.clone()); // Fetch LLM providers for context resolution (Hybrid/Auto modes) let llm_providers = match self.state.persistence_client.get_llm_providers_config().await { Ok(p) => p, Err(e) => { warn!("Failed to fetch LLM providers config (non-fatal for Manual mode): {}", e); common_contracts::config_models::LlmProvidersConfig::default() } }; match resolver.resolve_input(&module_config.context_selector, &dag.request_id.to_string(), ¤t_base_commit, &variables, &llm_providers, &module_config.analysis_prompt).await { Ok(resolution) => { // 1. Inject Input Bindings if let Some(obj) = config.as_object_mut() { obj.insert("input_bindings".to_string(), serde_json::to_value(&resolution.paths)?); } // 2. Write Trace Sidecar to VGCS let trace_path = io_binder.allocate_trace_path(task_type, &symbol, task_id); // Use a blocking task for VGCS write/commit to avoid async issues with standard IO let vgcs = self.state.vgcs.clone(); let req_id_str = dag.request_id.to_string(); let base_commit_for_write = current_base_commit.clone(); let trace_content = resolution.trace.clone(); let task_id_str = task_id.to_string(); let trace_path_clone = trace_path.clone(); let trace_commit_res = tokio::task::spawn_blocking(move || -> Result { let mut tx = vgcs.begin_transaction(&req_id_str, &base_commit_for_write)?; tx.write(&trace_path_clone, trace_content.as_bytes())?; let new_commit = Box::new(tx).commit(&format!("Context Resolution Trace for {}", task_id_str), "Orchestrator")?; Ok(new_commit) }).await; match trace_commit_res { Ok(Ok(new_commit)) => { info!("Written context resolution trace to {} (Commit: {})", trace_path, new_commit); // Update the base commit for the worker, so it sees the trace file (linear history) current_base_commit = new_commit; // Update the context passed to the worker context.base_commit = Some(current_base_commit.clone()); // Also update the DAG node's input commit for observability // Note: dag is locked in this scope, we can modify it but we need to handle scope issues if we were using dag inside closure. // We are outside closure here. dag.set_input_commit(task_id, current_base_commit); }, Ok(Err(e)) => error!("Failed to write trace file: {}", e), Err(e) => error!("Failed to join trace write task: {}", e), } }, Err(e) => { error!("Context resolution failed for task {}: {}", task_id, e); // We proceed, but the worker might fail if it relies on inputs } } } } } let cmd = WorkflowTaskCommand { request_id: dag.request_id, task_id: task_id.to_string(), routing_key: routing_key.clone(), config, // Use modified config context, storage: StorageConfig { root_path: self.state.config.workflow_data_path.clone(), }, }; // Special handling for Analysis Report Task to inject task_id into the specific command payload // (If the node config is used to build GenerateReportCommand downstream) // Actually, WorkflowTaskCommand is generic. The specific worker (e.g. report-generator) // usually consumes a specific command. // BUT, the current architecture seems to have Orchestrator send `WorkflowTaskCommand` // and the worker receives THAT? // Let's check `report-generator-service` consumer. // 4. Publish let subject = cmd.subject().to_string(); // This uses the routing_key let payload = serde_json::to_vec(&cmd)?; info!("Dispatching task {} to subject {}", task_id, subject); self.nats.publish(subject, payload.into()).await?; Ok(()) } fn get_symbol_from_config(&self, config: &serde_json::Value) -> String { config.get("symbol").and_then(|v| v.as_str()).unwrap_or("unknown").to_string() } // Helper to build DAG fn build_dag( &self, dag: &mut DagScheduler, template: &common_contracts::config_models::AnalysisTemplateSet, data_sources: &common_contracts::config_models::DataSourcesConfig, template_id: &str, market: &str, symbol: &CanonicalSymbol ) { // 1. Data Fetch Nodes let mut fetch_tasks = Vec::new(); // Use all enabled data sources regardless of market // The provider itself will decide whether to skip or process based on market support. // We sort keys to ensure deterministic DAG generation. let mut source_keys: Vec<&String> = data_sources.keys().collect(); source_keys.sort(); for key in source_keys { let config = &data_sources[key]; if config.enabled { // Special handling for MOCK market: skip real providers if market == "MOCK" && key.to_lowercase() != "mock" { continue; } let provider_key = key.to_lowercase(); let task_id = format!("fetch:{}", provider_key); fetch_tasks.push(task_id.clone()); let display_name = format!("Data Fetch ({:?})", config.provider); let routing_key = format!("provider.{}", provider_key); dag.add_node( task_id.clone(), Some(display_name), TaskType::DataFetch, routing_key, json!({ "symbol": symbol.as_str(), "market": market }) ); } } // Fallback for MOCK if not in config (usually mock is not in data_sources.json but hardcoded for tests) if market == "MOCK" && fetch_tasks.is_empty() { let task_id = "fetch:mock".to_string(); fetch_tasks.push(task_id.clone()); dag.add_node( task_id, Some("Data Fetch (Mock)".to_string()), TaskType::DataFetch, "provider.mock".to_string(), json!({ "symbol": symbol.as_str(), "market": market }) ); } if fetch_tasks.is_empty() { warn!("No enabled data providers found in configuration."); } // 2. Analysis Nodes (Dynamic from Template) for (module_id, module_config) in &template.modules { let task_id = format!("analysis:{}", module_id); // Pass module_id and template_id so the worker knows what to do // We pass the FULL module config here if we want the worker to be stateless, // BUT existing worker logic fetches template again. // To support "Single Module Execution", we should probably pass the module_id. let mut node_config = json!({ "template_id": template_id, "module_id": module_id, "symbol": symbol.as_str(), "market": market }); // Embed internal module config for Orchestrator use (Context Resolution) if let Some(obj) = node_config.as_object_mut() { obj.insert("_module_config".to_string(), serde_json::to_value(module_config).unwrap_or(serde_json::Value::Null)); } dag.add_node( task_id.clone(), Some(module_config.name.clone()), TaskType::Analysis, "analysis.report".to_string(), // routing_key matches what report-generator consumes node_config ); // Dependencies if module_config.dependencies.is_empty() { // If no analysis dependencies, depend on Data Fetch for fetch_task in &fetch_tasks { dag.add_dependency(fetch_task, &task_id); } } else { // Depend on other analysis modules for dep_module_id in &module_config.dependencies { let dep_task_id = format!("analysis:{}", dep_module_id); dag.add_dependency(&dep_task_id, &task_id); } } } } async fn publish_log(&self, req_id: uuid::Uuid, task_id: &str, level: &str, message: &str) { let event = WorkflowEvent::TaskLog { task_id: task_id.to_string(), level: level.to_string(), message: message.to_string(), timestamp: chrono::Utc::now().timestamp_millis(), }; let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string(); if let Ok(payload) = serde_json::to_vec(&event) { // Fire and forget let nats = self.nats.clone(); tokio::spawn(async move { if let Err(e) = nats.publish(subject, payload.into()).await { error!("Failed to publish TaskLog: {}", e); } }); } } }