use std::collections::HashMap; use uuid::Uuid; use common_contracts::workflow_types::{TaskStatus, TaskContext}; use common_contracts::messages::TaskType; use workflow_context::{Vgcs, ContextStore}; use anyhow::Result; use tracing::info; use serde::{Serialize, Deserialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CommitTracker { /// Maps task_id to the commit hash it produced. pub task_commits: HashMap, /// The latest merged commit for the whole workflow (if linear). /// Or just a reference to the "main" branch tip. pub head_commit: String, } impl CommitTracker { pub fn new(initial_commit: String) -> Self { Self { task_commits: HashMap::new(), head_commit: initial_commit, } } pub fn record_commit(&mut self, task_id: &str, commit: String) { self.task_commits.insert(task_id.to_string(), commit.clone()); // Note: head_commit update strategy depends on whether we want to track // a single "main" branch or just use task_commits for DAG resolution. // For now, we don't eagerly update head_commit unless it's a final task. } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DagScheduler { pub request_id: Uuid, pub nodes: HashMap, /// TaskID -> List of downstream TaskIDs pub forward_deps: HashMap>, /// TaskID -> List of upstream TaskIDs pub reverse_deps: HashMap>, pub commit_tracker: CommitTracker, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DagNode { pub id: String, pub task_type: TaskType, // Kept for UI/Observability, not for logic pub status: TaskStatus, pub config: serde_json::Value, pub routing_key: String, } impl DagScheduler { pub fn new(request_id: Uuid, initial_commit: String) -> Self { Self { request_id, nodes: HashMap::new(), forward_deps: HashMap::new(), reverse_deps: HashMap::new(), commit_tracker: CommitTracker::new(initial_commit), } } pub fn add_node(&mut self, id: String, task_type: TaskType, routing_key: String, config: serde_json::Value) { self.nodes.insert(id.clone(), DagNode { id, task_type, status: TaskStatus::Pending, config, routing_key, }); } pub fn add_dependency(&mut self, from: &str, to: &str) { self.forward_deps.entry(from.to_string()).or_default().push(to.to_string()); self.reverse_deps.entry(to.to_string()).or_default().push(from.to_string()); } /// Get all tasks that have no dependencies (roots) pub fn get_initial_tasks(&self) -> Vec { self.nodes.values() .filter(|n| n.status == TaskStatus::Pending && self.reverse_deps.get(&n.id).map_or(true, |deps| deps.is_empty())) .map(|n| n.id.clone()) .collect() } pub fn update_status(&mut self, task_id: &str, status: TaskStatus) { if let Some(node) = self.nodes.get_mut(task_id) { node.status = status; } } pub fn record_result(&mut self, task_id: &str, new_commit: Option) { if let Some(c) = new_commit { self.commit_tracker.record_commit(task_id, c); } } /// Determine which tasks are ready to run given that `completed_task_id` just finished. pub fn get_ready_downstream_tasks(&self, completed_task_id: &str) -> Vec { let mut ready = Vec::new(); if let Some(downstream) = self.forward_deps.get(completed_task_id) { for next_id in downstream { if self.is_ready(next_id) { ready.push(next_id.clone()); } } } ready } fn is_ready(&self, task_id: &str) -> bool { let node = match self.nodes.get(task_id) { Some(n) => n, None => return false, }; if node.status != TaskStatus::Pending { return false; } if let Some(deps) = self.reverse_deps.get(task_id) { for dep_id in deps { match self.nodes.get(dep_id).map(|n| n.status) { Some(TaskStatus::Completed) => continue, _ => return false, // Dependency not completed } } } true } /// Resolve the context (Base Commit) for a task. /// If multiple dependencies, perform Merge or Fast-Forward. pub fn resolve_context(&self, task_id: &str, vgcs: &Vgcs) -> Result { let deps = self.reverse_deps.get(task_id).cloned().unwrap_or_default(); if deps.is_empty() { // Root task: Use initial commit (usually empty string or base snapshot) return Ok(TaskContext { base_commit: Some(self.commit_tracker.head_commit.clone()), mount_path: None, }); } // Collect parent commits let mut parent_commits = Vec::new(); for dep_id in &deps { if let Some(c) = self.commit_tracker.task_commits.get(dep_id) { if !c.is_empty() { parent_commits.push(c.clone()); } } } if parent_commits.is_empty() { // All parents produced no commit? Fallback to head or empty. return Ok(TaskContext { base_commit: Some(self.commit_tracker.head_commit.clone()), mount_path: None, }); } // Merge Strategy let final_commit = self.merge_commits(vgcs, parent_commits)?; Ok(TaskContext { base_commit: Some(final_commit), mount_path: None, // Or determine based on config }) } /// Merge logic: /// 1 parent -> Return it. /// 2+ parents -> Iteratively merge using smart merge_commits fn merge_commits(&self, vgcs: &Vgcs, commits: Vec) -> Result { if commits.is_empty() { return Ok(String::new()); } if commits.len() == 1 { return Ok(commits[0].clone()); } let mut current_head = commits[0].clone(); for i in 1..commits.len() { let next_commit = &commits[i]; if current_head == *next_commit { continue; } // Use the smart merge_commits which finds the common ancestor automatically // Note: This handles Fast-Forward implicitly (merge_commits checks for ancestry) info!("Merging commits: Ours={}, Theirs={}", current_head, next_commit); current_head = vgcs.merge_commits(&self.request_id.to_string(), ¤t_head, next_commit)?; } Ok(current_head) } } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; use workflow_context::{Vgcs, ContextStore, Transaction}; use common_contracts::messages::TaskType; use serde_json::json; #[test] fn test_dag_merge_strategy() -> Result<()> { let temp_dir = TempDir::new()?; let vgcs = Vgcs::new(temp_dir.path()); let req_id = Uuid::new_v4(); let req_id_str = req_id.to_string(); vgcs.init_repo(&req_id_str)?; // 0. Create Initial Commit (Common Ancestor) let mut tx = vgcs.begin_transaction(&req_id_str, "")?; let init_commit = Box::new(tx).commit("Initial Commit", "system")?; // 1. Setup DAG let mut dag = DagScheduler::new(req_id, init_commit.clone()); dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({})); dag.add_node("B".to_string(), TaskType::DataFetch, "key.b".into(), json!({})); dag.add_node("C".to_string(), TaskType::Analysis, "key.c".into(), json!({})); // C depends on A and B dag.add_dependency("A", "C"); dag.add_dependency("B", "C"); // 2. Simulate Task A Execution -> Commit A (Based on Init) let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?; tx.write("file_a.txt", b"Content A")?; let commit_a = Box::new(tx).commit("Task A", "worker")?; dag.record_result("A", Some(commit_a.clone())); dag.update_status("A", TaskStatus::Completed); // 3. Simulate Task B Execution -> Commit B (Based on Init) let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?; tx.write("file_b.txt", b"Content B")?; let commit_b = Box::new(tx).commit("Task B", "worker")?; dag.record_result("B", Some(commit_b.clone())); dag.update_status("B", TaskStatus::Completed); // 4. Resolve Context for C // Should merge A and B let ctx = dag.resolve_context("C", &vgcs)?; let merged_commit = ctx.base_commit.expect("Should have base commit"); // Verify merged content let files = vgcs.list_dir(&req_id_str, &merged_commit, "")?; let file_names: Vec = files.iter().map(|f| f.name.clone()).collect(); assert!(file_names.contains(&"file_a.txt".to_string())); assert!(file_names.contains(&"file_b.txt".to_string())); Ok(()) } }