Fundamental_Analysis/services/workflow-orchestrator-service/src/dag_scheduler.rs
Lv, Qi ca1eddd244 Refactor Orchestrator to DAG-based scheduler with Git Context merging
- Implemented `DagScheduler` in `workflow-orchestrator` to manage task dependencies and commit history.
- Added `vgcs.merge_commits` in `workflow-context` for smart 3-way merge of parallel task branches.
- Introduced generic `WorkflowTaskCommand` and `WorkflowTaskEvent` in `common-contracts`.
- Adapted `tushare-provider-service` with `generic_worker` to support Git-based context read/write.
- Updated NATS subjects to support wildcard routing for generic workflow commands.
2025-11-27 02:42:25 +08:00

265 lines
9.2 KiB
Rust

use std::collections::HashMap;
use uuid::Uuid;
use common_contracts::workflow_types::{TaskStatus, TaskContext};
use common_contracts::messages::TaskType;
use workflow_context::{Vgcs, ContextStore};
use anyhow::Result;
use tracing::info;
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitTracker {
/// Maps task_id to the commit hash it produced.
pub task_commits: HashMap<String, String>,
/// The latest merged commit for the whole workflow (if linear).
/// Or just a reference to the "main" branch tip.
pub head_commit: String,
}
impl CommitTracker {
pub fn new(initial_commit: String) -> Self {
Self {
task_commits: HashMap::new(),
head_commit: initial_commit,
}
}
pub fn record_commit(&mut self, task_id: &str, commit: String) {
self.task_commits.insert(task_id.to_string(), commit.clone());
// Note: head_commit update strategy depends on whether we want to track
// a single "main" branch or just use task_commits for DAG resolution.
// For now, we don't eagerly update head_commit unless it's a final task.
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DagScheduler {
pub request_id: Uuid,
pub nodes: HashMap<String, DagNode>,
/// TaskID -> List of downstream TaskIDs
pub forward_deps: HashMap<String, Vec<String>>,
/// TaskID -> List of upstream TaskIDs
pub reverse_deps: HashMap<String, Vec<String>>,
pub commit_tracker: CommitTracker,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DagNode {
pub id: String,
pub task_type: TaskType, // Kept for UI/Observability, not for logic
pub status: TaskStatus,
pub config: serde_json::Value,
pub routing_key: String,
}
impl DagScheduler {
pub fn new(request_id: Uuid, initial_commit: String) -> Self {
Self {
request_id,
nodes: HashMap::new(),
forward_deps: HashMap::new(),
reverse_deps: HashMap::new(),
commit_tracker: CommitTracker::new(initial_commit),
}
}
pub fn add_node(&mut self, id: String, task_type: TaskType, routing_key: String, config: serde_json::Value) {
self.nodes.insert(id.clone(), DagNode {
id,
task_type,
status: TaskStatus::Pending,
config,
routing_key,
});
}
pub fn add_dependency(&mut self, from: &str, to: &str) {
self.forward_deps.entry(from.to_string()).or_default().push(to.to_string());
self.reverse_deps.entry(to.to_string()).or_default().push(from.to_string());
}
/// Get all tasks that have no dependencies (roots)
pub fn get_initial_tasks(&self) -> Vec<String> {
self.nodes.values()
.filter(|n| n.status == TaskStatus::Pending && self.reverse_deps.get(&n.id).map_or(true, |deps| deps.is_empty()))
.map(|n| n.id.clone())
.collect()
}
pub fn update_status(&mut self, task_id: &str, status: TaskStatus) {
if let Some(node) = self.nodes.get_mut(task_id) {
node.status = status;
}
}
pub fn record_result(&mut self, task_id: &str, new_commit: Option<String>) {
if let Some(c) = new_commit {
self.commit_tracker.record_commit(task_id, c);
}
}
/// Determine which tasks are ready to run given that `completed_task_id` just finished.
pub fn get_ready_downstream_tasks(&self, completed_task_id: &str) -> Vec<String> {
let mut ready = Vec::new();
if let Some(downstream) = self.forward_deps.get(completed_task_id) {
for next_id in downstream {
if self.is_ready(next_id) {
ready.push(next_id.clone());
}
}
}
ready
}
fn is_ready(&self, task_id: &str) -> bool {
let node = match self.nodes.get(task_id) {
Some(n) => n,
None => return false,
};
if node.status != TaskStatus::Pending {
return false;
}
if let Some(deps) = self.reverse_deps.get(task_id) {
for dep_id in deps {
match self.nodes.get(dep_id).map(|n| n.status) {
Some(TaskStatus::Completed) => continue,
_ => return false, // Dependency not completed
}
}
}
true
}
/// Resolve the context (Base Commit) for a task.
/// If multiple dependencies, perform Merge or Fast-Forward.
pub fn resolve_context(&self, task_id: &str, vgcs: &Vgcs) -> Result<TaskContext> {
let deps = self.reverse_deps.get(task_id).cloned().unwrap_or_default();
if deps.is_empty() {
// Root task: Use initial commit (usually empty string or base snapshot)
return Ok(TaskContext {
base_commit: Some(self.commit_tracker.head_commit.clone()),
mount_path: None,
});
}
// Collect parent commits
let mut parent_commits = Vec::new();
for dep_id in &deps {
if let Some(c) = self.commit_tracker.task_commits.get(dep_id) {
if !c.is_empty() {
parent_commits.push(c.clone());
}
}
}
if parent_commits.is_empty() {
// All parents produced no commit? Fallback to head or empty.
return Ok(TaskContext {
base_commit: Some(self.commit_tracker.head_commit.clone()),
mount_path: None,
});
}
// Merge Strategy
let final_commit = self.merge_commits(vgcs, parent_commits)?;
Ok(TaskContext {
base_commit: Some(final_commit),
mount_path: None, // Or determine based on config
})
}
/// Merge logic:
/// 1 parent -> Return it.
/// 2+ parents -> Iteratively merge using smart merge_commits
fn merge_commits(&self, vgcs: &Vgcs, commits: Vec<String>) -> Result<String> {
if commits.is_empty() {
return Ok(String::new());
}
if commits.len() == 1 {
return Ok(commits[0].clone());
}
let mut current_head = commits[0].clone();
for i in 1..commits.len() {
let next_commit = &commits[i];
if current_head == *next_commit {
continue;
}
// Use the smart merge_commits which finds the common ancestor automatically
// Note: This handles Fast-Forward implicitly (merge_commits checks for ancestry)
info!("Merging commits: Ours={}, Theirs={}", current_head, next_commit);
current_head = vgcs.merge_commits(&self.request_id.to_string(), &current_head, next_commit)?;
}
Ok(current_head)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
use workflow_context::{Vgcs, ContextStore, Transaction};
use common_contracts::messages::TaskType;
use serde_json::json;
#[test]
fn test_dag_merge_strategy() -> Result<()> {
let temp_dir = TempDir::new()?;
let vgcs = Vgcs::new(temp_dir.path());
let req_id = Uuid::new_v4();
let req_id_str = req_id.to_string();
vgcs.init_repo(&req_id_str)?;
// 0. Create Initial Commit (Common Ancestor)
let mut tx = vgcs.begin_transaction(&req_id_str, "")?;
let init_commit = Box::new(tx).commit("Initial Commit", "system")?;
// 1. Setup DAG
let mut dag = DagScheduler::new(req_id, init_commit.clone());
dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({}));
dag.add_node("B".to_string(), TaskType::DataFetch, "key.b".into(), json!({}));
dag.add_node("C".to_string(), TaskType::Analysis, "key.c".into(), json!({}));
// C depends on A and B
dag.add_dependency("A", "C");
dag.add_dependency("B", "C");
// 2. Simulate Task A Execution -> Commit A (Based on Init)
let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?;
tx.write("file_a.txt", b"Content A")?;
let commit_a = Box::new(tx).commit("Task A", "worker")?;
dag.record_result("A", Some(commit_a.clone()));
dag.update_status("A", TaskStatus::Completed);
// 3. Simulate Task B Execution -> Commit B (Based on Init)
let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?;
tx.write("file_b.txt", b"Content B")?;
let commit_b = Box::new(tx).commit("Task B", "worker")?;
dag.record_result("B", Some(commit_b.clone()));
dag.update_status("B", TaskStatus::Completed);
// 4. Resolve Context for C
// Should merge A and B
let ctx = dag.resolve_context("C", &vgcs)?;
let merged_commit = ctx.base_commit.expect("Should have base commit");
// Verify merged content
let files = vgcs.list_dir(&req_id_str, &merged_commit, "")?;
let file_names: Vec<String> = files.iter().map(|f| f.name.clone()).collect();
assert!(file_names.contains(&"file_a.txt".to_string()));
assert!(file_names.contains(&"file_b.txt".to_string()));
Ok(())
}
}