feat: fix error propagation in report-generator and workflow-orchestrator
This commit is contained in:
parent
a59b994a92
commit
a68a95338b
@ -1,9 +1,10 @@
|
||||
# 1. Build Stage
|
||||
FROM rust:1.90 as builder
|
||||
FROM rust:1.90-bookworm as builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
# Copy necessary crates for compilation
|
||||
COPY ./services/common-contracts /usr/src/app/services/common-contracts
|
||||
COPY ./crates/workflow-context /usr/src/app/crates/workflow-context
|
||||
COPY ./services/workflow-orchestrator-service /usr/src/app/services/workflow-orchestrator-service
|
||||
|
||||
WORKDIR /usr/src/app/services/workflow-orchestrator-service
|
||||
@ -18,7 +19,7 @@ ENV TZ=Asia/Shanghai
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
# Minimal runtime deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates curl libssl3 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy the built binary
|
||||
COPY --from=builder /usr/src/app/services/workflow-orchestrator-service/target/debug/workflow-orchestrator-service /usr/local/bin/
|
||||
|
||||
@ -2,15 +2,17 @@ use axum::{
|
||||
routing::get,
|
||||
Router,
|
||||
Json,
|
||||
extract::State,
|
||||
extract::{State, Path},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use crate::state::AppState;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn create_router(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.route("/workflows/{id}/graph", get(get_workflow_graph))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
@ -21,3 +23,20 @@ async fn health_check(State(_state): State<Arc<AppState>>) -> Json<serde_json::V
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_workflow_graph(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Json<serde_json::Value> {
|
||||
if let Some(dag_arc) = state.workflows.get(&id) {
|
||||
let dag = dag_arc.lock().await;
|
||||
let dto = dag.to_dto();
|
||||
Json(serde_json::to_value(dto).unwrap_or_else(|e| json!({
|
||||
"error": format!("Serialization error: {}", e)
|
||||
})))
|
||||
} else {
|
||||
Json(json!({
|
||||
"error": "Workflow not found"
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -44,6 +44,37 @@ pub struct DagScheduler {
|
||||
pub commit_tracker: CommitTracker,
|
||||
}
|
||||
|
||||
impl DagScheduler {
|
||||
pub fn to_dto(&self) -> common_contracts::messages::WorkflowDag {
|
||||
let edges = self.forward_deps.iter().flat_map(|(from, tos)| {
|
||||
tos.iter().map(move |to| common_contracts::messages::TaskDependency {
|
||||
from: from.clone(),
|
||||
to: to.clone(),
|
||||
})
|
||||
}).collect();
|
||||
|
||||
let nodes = self.nodes.values().map(|n| common_contracts::messages::TaskNode {
|
||||
id: n.id.clone(),
|
||||
name: n.id.clone(), // Use ID as name for now, or add name field to DagNode
|
||||
r#type: n.task_type,
|
||||
initial_status: match n.status {
|
||||
TaskStatus::Pending => common_contracts::messages::TaskStatus::Pending,
|
||||
TaskStatus::Scheduled => common_contracts::messages::TaskStatus::Scheduled,
|
||||
TaskStatus::Running => common_contracts::messages::TaskStatus::Running,
|
||||
TaskStatus::Completed => common_contracts::messages::TaskStatus::Completed,
|
||||
TaskStatus::Failed => common_contracts::messages::TaskStatus::Failed,
|
||||
TaskStatus::Skipped => common_contracts::messages::TaskStatus::Skipped,
|
||||
TaskStatus::Cancelled => common_contracts::messages::TaskStatus::Skipped,
|
||||
},
|
||||
}).collect();
|
||||
|
||||
common_contracts::messages::WorkflowDag {
|
||||
nodes,
|
||||
edges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DagNode {
|
||||
pub id: String,
|
||||
@ -93,15 +124,36 @@ impl DagScheduler {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_status(&self, task_id: &str) -> TaskStatus {
|
||||
self.nodes.get(task_id).map(|n| n.status).unwrap_or(TaskStatus::Pending)
|
||||
}
|
||||
|
||||
pub fn record_result(&mut self, task_id: &str, new_commit: Option<String>) {
|
||||
if let Some(c) = new_commit {
|
||||
self.commit_tracker.record_commit(task_id, c);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if all tasks in the DAG have reached a terminal state.
|
||||
pub fn is_workflow_finished(&self) -> bool {
|
||||
self.nodes.values().all(|n| matches!(n.status,
|
||||
TaskStatus::Completed |
|
||||
TaskStatus::Failed |
|
||||
TaskStatus::Skipped |
|
||||
TaskStatus::Cancelled
|
||||
))
|
||||
}
|
||||
|
||||
/// Check if any task has failed, indicating the workflow is partially or fully failed.
|
||||
/// Note: Depending on requirements, some failures might be tolerant.
|
||||
/// Here we assume any failure means the workflow has failed components.
|
||||
pub fn has_failures(&self) -> bool {
|
||||
self.nodes.values().any(|n| n.status == TaskStatus::Failed)
|
||||
}
|
||||
|
||||
/// Determine which tasks are ready to run given that `completed_task_id` just finished.
|
||||
pub fn get_ready_downstream_tasks(&self, completed_task_id: &str) -> Vec<String> {
|
||||
let mut ready = Vec::new();
|
||||
let mut ready: Vec<String> = Vec::new();
|
||||
if let Some(downstream) = self.forward_deps.get(completed_task_id) {
|
||||
for next_id in downstream {
|
||||
if self.is_ready(next_id) {
|
||||
@ -125,8 +177,8 @@ impl DagScheduler {
|
||||
if let Some(deps) = self.reverse_deps.get(task_id) {
|
||||
for dep_id in deps {
|
||||
match self.nodes.get(dep_id).map(|n| n.status) {
|
||||
Some(TaskStatus::Completed) => continue,
|
||||
_ => return false, // Dependency not completed
|
||||
Some(TaskStatus::Completed) | Some(TaskStatus::Failed) | Some(TaskStatus::Skipped) | Some(TaskStatus::Cancelled) => continue,
|
||||
_ => return false, // Dependency not finished
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use std::sync::Arc;
|
||||
use workflow_orchestrator_service::{config, state, message_consumer, api};
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ use anyhow::Result;
|
||||
use tracing::{info, error};
|
||||
use futures::StreamExt;
|
||||
use crate::state::AppState;
|
||||
use common_contracts::messages::StartWorkflowCommand;
|
||||
use common_contracts::messages::{StartWorkflowCommand, SyncStateCommand};
|
||||
use common_contracts::workflow_types::WorkflowTaskEvent;
|
||||
use common_contracts::subjects::NatsSubject;
|
||||
use crate::workflow::WorkflowEngine;
|
||||
@ -16,6 +16,9 @@ pub async fn run(state: Arc<AppState>, nats: Client) -> Result<()> {
|
||||
// Note: NatsSubject::WorkflowCommandStart string representation is "workflow.commands.start"
|
||||
let mut start_sub = nats.subscribe(NatsSubject::WorkflowCommandStart.to_string()).await?;
|
||||
|
||||
// Topic 1b: Workflow Commands (Sync State)
|
||||
let mut sync_sub = nats.subscribe(NatsSubject::WorkflowCommandSyncState.to_string()).await?;
|
||||
|
||||
// Topic 2: Workflow Task Events (Generic)
|
||||
// Note: NatsSubject::WorkflowEventTaskCompleted string representation is "workflow.evt.task_completed"
|
||||
let mut task_sub = nats.subscribe(NatsSubject::WorkflowEventTaskCompleted.to_string()).await?;
|
||||
@ -37,6 +40,21 @@ pub async fn run(state: Arc<AppState>, nats: Client) -> Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
// --- Task 1b: Sync State ---
|
||||
let engine_sync = engine.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = sync_sub.next().await {
|
||||
if let Ok(cmd) = serde_json::from_slice::<SyncStateCommand>(&msg.payload) {
|
||||
info!("Received SyncStateCommand: request_id={}", cmd.request_id);
|
||||
if let Err(e) = engine_sync.handle_sync_state(cmd).await {
|
||||
error!("Failed to handle SyncStateCommand: {}", e);
|
||||
}
|
||||
} else {
|
||||
error!("Failed to parse SyncStateCommand");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// --- Task 2: Task Completed Events ---
|
||||
let engine2 = engine.clone();
|
||||
tokio::spawn(async move {
|
||||
|
||||
@ -3,7 +3,7 @@ use common_contracts::workflow_types::{
|
||||
WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, StorageConfig
|
||||
};
|
||||
use common_contracts::messages::{
|
||||
StartWorkflowCommand, TaskType
|
||||
StartWorkflowCommand, SyncStateCommand, TaskType, WorkflowEvent, TaskStatus as MsgTaskStatus
|
||||
};
|
||||
use common_contracts::subjects::SubjectMessage;
|
||||
use common_contracts::symbol_utils::CanonicalSymbol;
|
||||
@ -37,14 +37,31 @@ impl WorkflowEngine {
|
||||
// Initial commit is empty for a fresh workflow
|
||||
let mut dag = DagScheduler::new(req_id, String::new());
|
||||
|
||||
// 3. Build DAG (Simplified Hardcoded logic for now, matching old build_dag)
|
||||
// In a real scenario, we fetch Template from DB/Service using cmd.template_id
|
||||
self.build_dag(&mut dag, &cmd.template_id, &cmd.market, &cmd.symbol);
|
||||
// 3. Fetch Template Config
|
||||
let template_sets = self.state.persistence_client.get_analysis_template_sets().await?;
|
||||
let template = template_sets.get(&cmd.template_id).ok_or_else(|| {
|
||||
anyhow::anyhow!("Template {} not found", cmd.template_id)
|
||||
})?;
|
||||
|
||||
// 4. Save State
|
||||
// 4. Build DAG
|
||||
self.build_dag(&mut dag, template, &cmd.template_id, &cmd.market, &cmd.symbol);
|
||||
|
||||
// 5. Save State
|
||||
self.state.workflows.insert(req_id, Arc::new(Mutex::new(dag.clone())));
|
||||
|
||||
// 5. Trigger Initial Tasks
|
||||
// 6. Publish WorkflowStarted Event
|
||||
let event = WorkflowEvent::WorkflowStarted {
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
task_graph: dag.to_dto(),
|
||||
};
|
||||
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
||||
if let Ok(payload) = serde_json::to_vec(&event) {
|
||||
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
||||
error!("Failed to publish WorkflowStarted event: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Trigger Initial Tasks
|
||||
let initial_tasks = dag.get_initial_tasks();
|
||||
|
||||
// Lock the DAG for updates
|
||||
@ -58,6 +75,52 @@ impl WorkflowEngine {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_sync_state(&self, cmd: SyncStateCommand) -> Result<()> {
|
||||
let req_id = cmd.request_id;
|
||||
info!("Handling SyncStateCommand for {}", req_id);
|
||||
|
||||
let dag_arc = match self.state.workflows.get(&req_id) {
|
||||
Some(d) => d.clone(),
|
||||
None => {
|
||||
warn!("Received sync request for unknown workflow {}", req_id);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let dag = dag_arc.lock().await;
|
||||
|
||||
// Map internal status to DTO status
|
||||
let mut tasks_status = std::collections::HashMap::new();
|
||||
for task_id in dag.nodes.keys() {
|
||||
let status = match dag.get_status(task_id) {
|
||||
TaskStatus::Pending => MsgTaskStatus::Pending,
|
||||
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
|
||||
TaskStatus::Running => MsgTaskStatus::Running,
|
||||
TaskStatus::Completed => MsgTaskStatus::Completed,
|
||||
TaskStatus::Failed => MsgTaskStatus::Failed,
|
||||
TaskStatus::Skipped => MsgTaskStatus::Skipped,
|
||||
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
|
||||
};
|
||||
tasks_status.insert(task_id.clone(), status);
|
||||
}
|
||||
|
||||
// Create Snapshot Event
|
||||
let event = WorkflowEvent::WorkflowStateSnapshot {
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
task_graph: dag.to_dto(),
|
||||
tasks_status,
|
||||
tasks_output: std::collections::HashMap::new(), // TODO: Populate output if needed
|
||||
};
|
||||
|
||||
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
||||
if let Ok(payload) = serde_json::to_vec(&event) {
|
||||
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
||||
error!("Failed to publish WorkflowStateSnapshot: {}", e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_task_completed(&self, evt: WorkflowTaskEvent) -> Result<()> {
|
||||
let req_id = evt.request_id;
|
||||
|
||||
@ -74,6 +137,43 @@ impl WorkflowEngine {
|
||||
// 1. Update Status & Record Commit
|
||||
dag.update_status(&evt.task_id, evt.status);
|
||||
|
||||
// Lookup task_type
|
||||
let task_type = dag.nodes.get(&evt.task_id).map(|n| n.task_type).unwrap_or(TaskType::DataFetch);
|
||||
|
||||
// Convert status
|
||||
let msg_status = match evt.status {
|
||||
TaskStatus::Pending => MsgTaskStatus::Pending,
|
||||
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
|
||||
TaskStatus::Running => MsgTaskStatus::Running,
|
||||
TaskStatus::Completed => MsgTaskStatus::Completed,
|
||||
TaskStatus::Failed => MsgTaskStatus::Failed,
|
||||
TaskStatus::Skipped => MsgTaskStatus::Skipped,
|
||||
TaskStatus::Cancelled => MsgTaskStatus::Skipped, // Map Cancelled to Skipped
|
||||
};
|
||||
|
||||
// Extract error message if any
|
||||
let error_message = if let Some(ref result) = evt.result {
|
||||
result.error.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Publish TaskStateChanged event
|
||||
let progress_event = WorkflowEvent::TaskStateChanged {
|
||||
task_id: evt.task_id.clone(),
|
||||
task_type,
|
||||
status: msg_status,
|
||||
message: error_message,
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
progress: None,
|
||||
};
|
||||
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
||||
if let Ok(payload) = serde_json::to_vec(&progress_event) {
|
||||
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
||||
error!("Failed to publish progress event: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(result) = evt.result {
|
||||
if let Some(commit) = result.new_commit {
|
||||
info!("Task {} produced commit {}", evt.task_id, commit);
|
||||
@ -85,20 +185,40 @@ impl WorkflowEngine {
|
||||
}
|
||||
|
||||
// 2. Check for downstream tasks
|
||||
if evt.status == TaskStatus::Completed {
|
||||
if evt.status == TaskStatus::Completed || evt.status == TaskStatus::Failed || evt.status == TaskStatus::Skipped {
|
||||
let ready_tasks = dag.get_ready_downstream_tasks(&evt.task_id);
|
||||
for task_id in ready_tasks {
|
||||
if let Err(e) = self.dispatch_task(&mut dag, &task_id, &self.state.vgcs).await {
|
||||
error!("Failed to dispatch task {}: {}", task_id, e);
|
||||
}
|
||||
}
|
||||
} else if evt.status == TaskStatus::Failed {
|
||||
// Handle failure propagation (skip downstream?)
|
||||
// For now, just log.
|
||||
error!("Task {} failed. Workflow might stall.", evt.task_id);
|
||||
}
|
||||
|
||||
// 3. Check Workflow Completion (TODO)
|
||||
// 3. Check Workflow Completion
|
||||
if dag.is_workflow_finished() {
|
||||
let timestamp = chrono::Utc::now().timestamp_millis();
|
||||
let event = if dag.has_failures() {
|
||||
info!("Workflow {} failed (some tasks failed)", req_id);
|
||||
WorkflowEvent::WorkflowFailed {
|
||||
end_timestamp: timestamp,
|
||||
reason: "Some tasks failed".to_string(),
|
||||
is_fatal: false,
|
||||
}
|
||||
} else {
|
||||
info!("Workflow {} completed successfully", req_id);
|
||||
WorkflowEvent::WorkflowCompleted {
|
||||
end_timestamp: timestamp,
|
||||
result_summary: Some(json!({})),
|
||||
}
|
||||
};
|
||||
|
||||
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
||||
if let Ok(payload) = serde_json::to_vec(&event) {
|
||||
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
||||
error!("Failed to publish completion event: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -124,6 +244,15 @@ impl WorkflowEngine {
|
||||
},
|
||||
};
|
||||
|
||||
// Special handling for Analysis Report Task to inject task_id into the specific command payload
|
||||
// (If the node config is used to build GenerateReportCommand downstream)
|
||||
// Actually, WorkflowTaskCommand is generic. The specific worker (e.g. report-generator)
|
||||
// usually consumes a specific command.
|
||||
// BUT, the current architecture seems to have Orchestrator send `WorkflowTaskCommand`
|
||||
// and the worker receives THAT?
|
||||
|
||||
// Let's check `report-generator-service` consumer.
|
||||
|
||||
// 4. Publish
|
||||
let subject = cmd.subject().to_string(); // This uses the routing_key
|
||||
let payload = serde_json::to_vec(&cmd)?;
|
||||
@ -134,48 +263,67 @@ impl WorkflowEngine {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Helper to build DAG (Migrated from old workflow.rs)
|
||||
fn build_dag(&self, dag: &mut DagScheduler, template_id: &str, market: &str, symbol: &CanonicalSymbol) {
|
||||
// Logic copied/adapted from old WorkflowStateMachine::build_dag
|
||||
|
||||
// Helper to build DAG
|
||||
fn build_dag(&self, dag: &mut DagScheduler, template: &common_contracts::config_models::AnalysisTemplateSet, template_id: &str, market: &str, symbol: &CanonicalSymbol) {
|
||||
let mut providers = Vec::new();
|
||||
match market {
|
||||
"CN" => {
|
||||
providers.push("tushare");
|
||||
// providers.push("yfinance");
|
||||
},
|
||||
"US" => providers.push("yfinance"),
|
||||
"MOCK" => providers.push("mock"),
|
||||
_ => providers.push("yfinance"),
|
||||
}
|
||||
|
||||
// 1. Data Fetch Nodes
|
||||
let mut fetch_tasks = Vec::new();
|
||||
for p in &providers {
|
||||
let task_id = format!("fetch:{}", p);
|
||||
fetch_tasks.push(task_id.clone());
|
||||
dag.add_node(
|
||||
task_id.clone(),
|
||||
TaskType::DataFetch,
|
||||
format!("provider.{}", p), // routing_key: workflow.cmd.provider.tushare
|
||||
format!("provider.{}", p),
|
||||
json!({
|
||||
"symbol": symbol.as_str(), // Simplification
|
||||
"symbol": symbol.as_str(),
|
||||
"market": market
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// 2. Analysis Node (Simplified)
|
||||
let report_task_id = "analysis:report";
|
||||
dag.add_node(
|
||||
report_task_id.to_string(),
|
||||
TaskType::Analysis,
|
||||
"analysis.report".to_string(), // routing_key: workflow.cmd.analysis.report
|
||||
json!({
|
||||
"template_id": template_id
|
||||
})
|
||||
);
|
||||
// 2. Analysis Nodes (Dynamic from Template)
|
||||
for (module_id, module_config) in &template.modules {
|
||||
let task_id = format!("analysis:{}", module_id);
|
||||
|
||||
// 3. Edges
|
||||
for p in &providers {
|
||||
dag.add_dependency(&format!("fetch:{}", p), report_task_id);
|
||||
// Pass module_id and template_id so the worker knows what to do
|
||||
// We pass the FULL module config here if we want the worker to be stateless,
|
||||
// BUT existing worker logic fetches template again.
|
||||
// To support "Single Module Execution", we should probably pass the module_id.
|
||||
dag.add_node(
|
||||
task_id.clone(),
|
||||
TaskType::Analysis,
|
||||
"analysis.report".to_string(), // routing_key matches what report-generator consumes
|
||||
json!({
|
||||
"template_id": template_id,
|
||||
"module_id": module_id,
|
||||
"symbol": symbol.as_str(),
|
||||
"market": market
|
||||
})
|
||||
);
|
||||
|
||||
// Dependencies
|
||||
if module_config.dependencies.is_empty() {
|
||||
// If no analysis dependencies, depend on Data Fetch
|
||||
for fetch_task in &fetch_tasks {
|
||||
dag.add_dependency(fetch_task, &task_id);
|
||||
}
|
||||
} else {
|
||||
// Depend on other analysis modules
|
||||
for dep_module_id in &module_config.dependencies {
|
||||
let dep_task_id = format!("analysis:{}", dep_module_id);
|
||||
dag.add_dependency(&dep_task_id, &task_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
185
services/workflow-orchestrator-service/tests/logic_scenarios.rs
Normal file
185
services/workflow-orchestrator-service/tests/logic_scenarios.rs
Normal file
@ -0,0 +1,185 @@
|
||||
use anyhow::Result;
|
||||
use tempfile::TempDir;
|
||||
use workflow_context::{Vgcs, ContextStore, Transaction};
|
||||
use common_contracts::messages::TaskType;
|
||||
use common_contracts::workflow_types::TaskStatus;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
use workflow_orchestrator_service::dag_scheduler::DagScheduler;
|
||||
|
||||
#[test]
|
||||
fn test_scenario_a_happy_path() -> Result<()> {
|
||||
// Scenario A: Happy Path (A -> B)
|
||||
// 1. Setup
|
||||
let temp_dir = TempDir::new()?;
|
||||
let vgcs = Vgcs::new(temp_dir.path());
|
||||
let req_id = Uuid::new_v4();
|
||||
let req_id_str = req_id.to_string();
|
||||
vgcs.init_repo(&req_id_str)?;
|
||||
|
||||
// Initial Commit
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, "")?;
|
||||
let init_commit = Box::new(tx).commit("Init", "system")?;
|
||||
|
||||
// 2. Build DAG
|
||||
let mut dag = DagScheduler::new(req_id, init_commit.clone());
|
||||
dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({}));
|
||||
dag.add_node("B".to_string(), TaskType::Analysis, "key.b".into(), json!({}));
|
||||
dag.add_dependency("A", "B");
|
||||
|
||||
// 3. Run Task A
|
||||
// Dispatch A (In real engine: Resolve Context -> Send NATS)
|
||||
let ctx_a = dag.resolve_context("A", &vgcs)?;
|
||||
assert_eq!(ctx_a.base_commit.as_ref().unwrap(), &init_commit);
|
||||
|
||||
// Execute A (Worker Logic)
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?;
|
||||
tx.write("data_a.json", b"{\"val\": 1}")?;
|
||||
let commit_a = Box::new(tx).commit("Task A Result", "worker")?;
|
||||
|
||||
// Complete A
|
||||
dag.record_result("A", Some(commit_a.clone()));
|
||||
dag.update_status("A", TaskStatus::Completed);
|
||||
|
||||
// 4. Run Task B
|
||||
// Check Ready
|
||||
let ready = dag.get_ready_downstream_tasks("A");
|
||||
assert_eq!(ready, vec!["B"]);
|
||||
|
||||
// Resolve Context B (Should be Commit A)
|
||||
let ctx_b = dag.resolve_context("B", &vgcs)?;
|
||||
assert_eq!(ctx_b.base_commit.as_ref().unwrap(), &commit_a);
|
||||
|
||||
// Execute B
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, &commit_a)?;
|
||||
tx.write("report.md", b"# Report")?;
|
||||
let commit_b = Box::new(tx).commit("Task B Result", "worker")?;
|
||||
|
||||
// Complete B
|
||||
dag.record_result("B", Some(commit_b.clone()));
|
||||
dag.update_status("B", TaskStatus::Completed);
|
||||
|
||||
// 5. Verify Final State
|
||||
// Orchestrator would snapshot here. We check file existence.
|
||||
let files = vgcs.list_dir(&req_id_str, &commit_b, "")?;
|
||||
let names: Vec<String> = files.iter().map(|f| f.name.clone()).collect();
|
||||
assert!(names.contains(&"data_a.json".to_string()));
|
||||
assert!(names.contains(&"report.md".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scenario_c_partial_failure() -> Result<()> {
|
||||
// Scenario C: Parallel Tasks (A, B) -> C. A fails.
|
||||
// 1. Setup
|
||||
let temp_dir = TempDir::new()?;
|
||||
let vgcs = Vgcs::new(temp_dir.path());
|
||||
let req_id = Uuid::new_v4();
|
||||
let req_id_str = req_id.to_string();
|
||||
vgcs.init_repo(&req_id_str)?;
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, "")?;
|
||||
let init_commit = Box::new(tx).commit("Init", "system")?;
|
||||
|
||||
// 2. DAG: A, B independent. C depends on BOTH.
|
||||
let mut dag = DagScheduler::new(req_id, init_commit.clone());
|
||||
dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({}));
|
||||
dag.add_node("B".to_string(), TaskType::DataFetch, "key.b".into(), json!({}));
|
||||
dag.add_node("C".to_string(), TaskType::Analysis, "key.c".into(), json!({}));
|
||||
dag.add_dependency("A", "C");
|
||||
dag.add_dependency("B", "C");
|
||||
|
||||
// 3. Run A -> Failed
|
||||
dag.update_status("A", TaskStatus::Failed);
|
||||
// A produced no commit.
|
||||
|
||||
// 4. Run B -> Success
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?;
|
||||
tx.write("data_b.json", b"{}")?;
|
||||
let commit_b = Box::new(tx).commit("Task B", "worker")?;
|
||||
dag.record_result("B", Some(commit_b.clone()));
|
||||
dag.update_status("B", TaskStatus::Completed);
|
||||
|
||||
// 5. Check C
|
||||
// C should NOT be ready because A is failed (not Completed).
|
||||
// is_ready checks: reverse_deps.all(|d| status == Completed)
|
||||
// A is Failed.
|
||||
|
||||
// Triggering readiness check from B completion
|
||||
let ready_from_b = dag.get_ready_downstream_tasks("B");
|
||||
// C is downstream of B, but is_ready("C") should be false
|
||||
assert!(ready_from_b.is_empty());
|
||||
|
||||
// Triggering readiness check from A completion (Failed)
|
||||
// Orchestrator logic for failure usually doesn't trigger downstream positive flow.
|
||||
|
||||
assert_eq!(dag.nodes["C"].status, TaskStatus::Pending);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scenario_e_module_logic_check() -> Result<()> {
|
||||
// Scenario E: Parallel Branch Merge
|
||||
// A -> B
|
||||
// A -> C
|
||||
// B, C -> D
|
||||
// Verify 3-way merge logic in D
|
||||
|
||||
let temp_dir = TempDir::new()?;
|
||||
let vgcs = Vgcs::new(temp_dir.path());
|
||||
let req_id = Uuid::new_v4();
|
||||
let req_id_str = req_id.to_string();
|
||||
vgcs.init_repo(&req_id_str)?;
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, "")?;
|
||||
let init_commit = Box::new(tx).commit("Init", "system")?;
|
||||
|
||||
let mut dag = DagScheduler::new(req_id, init_commit.clone());
|
||||
dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({}));
|
||||
dag.add_node("B".to_string(), TaskType::Analysis, "key.b".into(), json!({}));
|
||||
dag.add_node("C".to_string(), TaskType::Analysis, "key.c".into(), json!({}));
|
||||
dag.add_node("D".to_string(), TaskType::Analysis, "key.d".into(), json!({}));
|
||||
|
||||
dag.add_dependency("A", "B");
|
||||
dag.add_dependency("A", "C");
|
||||
dag.add_dependency("B", "D");
|
||||
dag.add_dependency("C", "D");
|
||||
|
||||
// Run A
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?;
|
||||
tx.write("common.json", b"base")?;
|
||||
let commit_a = Box::new(tx).commit("A", "worker")?;
|
||||
dag.record_result("A", Some(commit_a.clone()));
|
||||
dag.update_status("A", TaskStatus::Completed);
|
||||
|
||||
// Run B (Modify common, add b)
|
||||
let ctx_b = dag.resolve_context("B", &vgcs)?;
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, ctx_b.base_commit.as_ref().unwrap())?;
|
||||
tx.write("file_b.txt", b"B")?;
|
||||
let commit_b = Box::new(tx).commit("B", "worker")?;
|
||||
dag.record_result("B", Some(commit_b.clone()));
|
||||
dag.update_status("B", TaskStatus::Completed);
|
||||
|
||||
// Run C (Modify common, add c)
|
||||
let ctx_c = dag.resolve_context("C", &vgcs)?;
|
||||
let mut tx = vgcs.begin_transaction(&req_id_str, ctx_c.base_commit.as_ref().unwrap())?;
|
||||
tx.write("file_c.txt", b"C")?;
|
||||
let commit_c = Box::new(tx).commit("C", "worker")?;
|
||||
dag.record_result("C", Some(commit_c.clone()));
|
||||
dag.update_status("C", TaskStatus::Completed);
|
||||
|
||||
// Run D (Should Merge B and C)
|
||||
let ctx_d = dag.resolve_context("D", &vgcs)?;
|
||||
let merge_commit = ctx_d.base_commit.unwrap();
|
||||
|
||||
// Verify Merge
|
||||
let files = vgcs.list_dir(&req_id_str, &merge_commit, "")?;
|
||||
let names: Vec<String> = files.iter().map(|f| f.name.clone()).collect();
|
||||
|
||||
assert!(names.contains(&"common.json".to_string()));
|
||||
assert!(names.contains(&"file_b.txt".to_string()));
|
||||
assert!(names.contains(&"file_c.txt".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user