diff --git a/crates/workflow-context/Cargo.toml b/crates/workflow-context/Cargo.toml index dd39572..f2522ff 100644 --- a/crates/workflow-context/Cargo.toml +++ b/crates/workflow-context/Cargo.toml @@ -3,6 +3,10 @@ name = "workflow-context" version = "0.1.0" edition = "2024" +[lib] +name = "workflow_context" +path = "src/lib.rs" + [dependencies] git2 = { version = "0.18", features = ["vendored-openssl"] } sha2 = "0.10" @@ -13,7 +17,7 @@ thiserror = "1.0" hex = "0.4" walkdir = "2.3" regex = "1.10" -globset = "0.4.18" +globset = "0.4" [dev-dependencies] tempfile = "3.8" diff --git a/crates/workflow-context/src/traits.rs b/crates/workflow-context/src/traits.rs index 4be2f25..dcd97bc 100644 --- a/crates/workflow-context/src/traits.rs +++ b/crates/workflow-context/src/traits.rs @@ -17,6 +17,10 @@ pub trait ContextStore { /// Three-way merge (In-Memory), returns new Tree OID fn merge_trees(&self, req_id: &str, base: &str, ours: &str, theirs: &str) -> Result; + + /// Smart merge two commits, automatically finding the best common ancestor. + /// Returns the OID of the new merge commit. + fn merge_commits(&self, req_id: &str, our_commit: &str, their_commit: &str) -> Result; /// Start a write transaction fn begin_transaction(&self, req_id: &str, base_commit: &str) -> Result>; diff --git a/crates/workflow-context/src/vgcs.rs b/crates/workflow-context/src/vgcs.rs index ee9098e..4d19fae 100644 --- a/crates/workflow-context/src/vgcs.rs +++ b/crates/workflow-context/src/vgcs.rs @@ -150,23 +150,72 @@ impl ContextStore for Vgcs { Ok(oid.to_string()) } + fn merge_commits(&self, req_id: &str, our_commit: &str, their_commit: &str) -> Result { + let repo_path = self.get_repo_path(req_id); + let repo = Repository::open(&repo_path).context("Failed to open repo")?; + + let our_oid = Oid::from_str(our_commit).context("Invalid our_commit")?; + let their_oid = Oid::from_str(their_commit).context("Invalid their_commit")?; + + let base_oid = repo.merge_base(our_oid, their_oid).context("Failed to find merge base")?; + + let base_commit = repo.find_commit(base_oid)?; + let our_commit_obj = repo.find_commit(our_oid)?; + let their_commit_obj = repo.find_commit(their_oid)?; + + // If base equals one of the commits, it's a fast-forward + if base_oid == our_oid { + return Ok(their_commit.to_string()); + } + if base_oid == their_oid { + return Ok(our_commit.to_string()); + } + + let base_tree = base_commit.tree()?; + let our_tree = our_commit_obj.tree()?; + let their_tree = their_commit_obj.tree()?; + + let mut index = repo.merge_trees(&base_tree, &our_tree, &their_tree, None)?; + + if index.has_conflicts() { + return Err(anyhow!("Merge conflict detected between {} and {}", our_commit, their_commit)); + } + + let tree_oid = index.write_tree_to(&repo)?; + let tree = repo.find_tree(tree_oid)?; + + let sig = Signature::now("vgcs-merge", "system")?; + + let merge_commit_oid = repo.commit( + None, // Detached + &sig, + &sig, + &format!("Merge commit {} into {}", their_commit, our_commit), + &tree, + &[&our_commit_obj, &their_commit_obj], + )?; + + Ok(merge_commit_oid.to_string()) + } + fn begin_transaction(&self, req_id: &str, base_commit: &str) -> Result> { let repo_path = self.get_repo_path(req_id); let repo = Repository::open(&repo_path).context("Failed to open repo")?; - let base_oid = Oid::from_str(base_commit).context("Invalid base_commit")?; - let mut index = Index::new()?; let mut base_commit_oid = None; - if !base_oid.is_zero() { - // Scope the borrow of repo - { - let commit = repo.find_commit(base_oid).context("Base commit not found")?; - let tree = commit.tree()?; - index.read_tree(&tree)?; + if !base_commit.is_empty() { + let base_oid = Oid::from_str(base_commit).context("Invalid base_commit")?; + if !base_oid.is_zero() { + // Scope the borrow of repo + { + let commit = repo.find_commit(base_oid).context("Base commit not found")?; + let tree = commit.tree()?; + index.read_tree(&tree)?; + } + base_commit_oid = Some(base_oid); } - base_commit_oid = Some(base_oid); } Ok(Box::new(VgcsTransaction { diff --git a/docs/3_project_management/tasks/pending/20251126_design_4_orchestrator.md b/docs/3_project_management/tasks/pending/20251126_design_4_orchestrator.md new file mode 100644 index 0000000..73cabaf --- /dev/null +++ b/docs/3_project_management/tasks/pending/20251126_design_4_orchestrator.md @@ -0,0 +1,117 @@ +# 设计方案 4: Workflow Orchestrator (调度器) [详细设计版] + +## 1. 定位与目标 + +Orchestrator 负责 DAG 调度、RPC 分发和冲突合并。它使用 Rust 实现。 + +**核心原则:** +* **Git as Source of Truth**: 所有的任务产出(数据、报告)均存储在 VGCS (Git) 中。 +* **数据库轻量化**: 数据库仅用于存储系统配置、TimeSeries 缓存以及 Request ID 与 Git Commit Hash 的映射。Workflow 执行过程中不依赖数据库进行状态流转。 +* **Context 隔离**: 每次 Workflow 启动均为全新的 Context(或基于特定 Snapshot),无全局共享 Context。 + +## 2. 调度逻辑规范 + +### 2.1 RPC Subject 命名 +基于 NATS。 +* **Command Topic**: `workflow.cmd.{routing_key}` + * e.g., `workflow.cmd.provider.tushare` + * e.g., `workflow.cmd.analysis.report` +* **Event Topic**: `workflow.evt.task_completed` (统一监听) + +### 2.2 Command Payload Schema + +```json +{ + "request_id": "uuid-v4", + "task_id": "fetch_tushare", + "routing_key": "provider.tushare", + "config": { + "market": "CN", + "years": 5 + // Provider Specific Config + }, + "context": { + "base_commit": "a1b2c3d4...", // Empty for initial task + "mount_path": "/Raw Data/Tushare" + // 指示 Worker 建议把结果挂在哪里 + }, + "storage": { + "root_path": "/mnt/workflow_data" + } +} +``` + +### 2.3 Event Payload Schema + +```json +{ + "request_id": "uuid-v4", + "task_id": "fetch_tushare", + "status": "Completed", + "result": { + "new_commit": "e5f6g7h8...", + "error": null + } +} +``` + +## 3. 合并策略 (Merge Strategy) 实现 + +### 3.1 串行合并 (Fast-Forward) +DAG: A -> B +1. A returns `C1`. +2. Orchestrator dispatch B with `base_commit = C1`. + +### 3.2 并行合并 (Three-Way Merge) +DAG: A -> B, A -> C. (B, C parallel) +1. Dispatch B with `base_commit = C1`. +2. Dispatch C with `base_commit = C1`. +3. B returns `C2`. C returns `C3`. +4. **Merge Action**: + * Wait for BOTH B and C to complete. + * Call `VGCS.merge_trees(base=C1, ours=C2, theirs=C3)`. + * Result: `TreeHash T4`. + * Create Commit `C4` (Parents: C2, C3). +5. Dispatch D (dependent on B, C) with `base_commit = C4`. + +### 3.3 冲突处理 +如果 `VGCS.merge_trees` 返回 Error (Conflict): +1. Orchestrator 捕获错误。 +2. 标记 Workflow 状态为 `Conflict`. +3. (Future) 触发 `ConflictResolver` Agent,传入 C1, C2, C3。Agent 生成 C4。 + +## 4. 状态机重构 +废弃旧的 `WorkflowStateMachine` 中关于 TaskType 的判断。 +引入 `CommitTracker`: +```rust +struct CommitTracker { + // 记录每个任务产出的 Commit + task_commits: HashMap, + // 记录当前主分支的 Commit (Latest Merged) + head_commit: String, +} +``` + +## 5. 执行计划 +1. **Contract**: 定义 Rust Structs for RPC Payloads. +2. **NATS**: 实现 Publisher/Subscriber。 +3. **Engine**: 实现 Merge Loop。 + +## 6. 数据持久化与缓存策略 (Persistence & Caching) + +### 6.1 数据库角色 (Database Role) +数据库不再作为业务数据的“主存储”,其角色转变为: +1. **Configuration**: 存储系统运行所需的配置信息。 +2. **Cache (Hot Data)**: 缓存 Data Provider 抓取的原始数据 (Time-series),避免重复调用外部 API。 +3. **Index**: 存储 `request_id` -> `final_commit_hash` 的映射,作为系统快照的索引。 + +### 6.2 Provider 行为模式 +Provider 在接收到 Workflow Command 时: +1. **Check Cache**: 检查本地 DB/Cache 是否有有效数据。 +2. **Fetch (If miss)**: 如果缓存未命中,调用外部 API 获取数据并更新缓存。 +3. **Inject to Context**: 将数据写入当前的 VGCS Context (via `WorkerRuntime`),生成新的 Commit。 + * *注意*: Provider 不直接将此次 Workflow 的结果“存”回数据库的业务表,数据库仅作 Cache 用。 + +### 6.3 Orchestrator 行为 +Orchestrator 仅负责追踪 Commit Hash 的演变。 +Workflow 结束时,Orchestrator 将最终的 `Head Commit Hash` 关联到 `Request ID` 并持久化(即“Snapshot 落盘”)。 diff --git a/services/common-contracts/src/lib.rs b/services/common-contracts/src/lib.rs index 457335b..3f9ff93 100644 --- a/services/common-contracts/src/lib.rs +++ b/services/common-contracts/src/lib.rs @@ -1,4 +1,5 @@ pub mod dtos; +#[cfg(feature = "persistence")] pub mod models; pub mod observability; pub mod messages; @@ -9,3 +10,6 @@ pub mod registry; pub mod lifecycle; pub mod symbol_utils; pub mod persistence_client; +pub mod abstraction; +pub mod workflow_harness; // Export the harness +pub mod workflow_types; diff --git a/services/common-contracts/src/subjects.rs b/services/common-contracts/src/subjects.rs index a74ae48..0649824 100644 --- a/services/common-contracts/src/subjects.rs +++ b/services/common-contracts/src/subjects.rs @@ -10,6 +10,11 @@ pub trait SubjectMessage: Serialize + DeserializeOwned + Send + Sync { #[derive(Debug, Clone, PartialEq, Eq)] pub enum NatsSubject { + // --- Workflow Generic --- + WorkflowCommand(String), // Dynamic routing key: workflow.cmd.{routing_key} + WorkflowEventTaskCompleted, // workflow.evt.task_completed + WorkflowCommandWildcard, // workflow.cmd.> + // --- Commands --- WorkflowCommandStart, WorkflowCommandSyncState, @@ -37,6 +42,9 @@ pub enum NatsSubject { impl fmt::Display for NatsSubject { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Self::WorkflowCommand(key) => write!(f, "workflow.cmd.{}", key), + Self::WorkflowEventTaskCompleted => write!(f, "workflow.evt.task_completed"), + Self::WorkflowCommandWildcard => write!(f, "workflow.cmd.>"), Self::WorkflowCommandStart => write!(f, "workflow.commands.start"), Self::WorkflowCommandSyncState => write!(f, "workflow.commands.sync_state"), Self::DataFetchCommands => write!(f, "data_fetch_commands"), @@ -58,6 +66,8 @@ impl FromStr for NatsSubject { fn from_str(s: &str) -> Result { match s { + "workflow.evt.task_completed" => Ok(Self::WorkflowEventTaskCompleted), + "workflow.cmd.>" => Ok(Self::WorkflowCommandWildcard), "workflow.commands.start" => Ok(Self::WorkflowCommandStart), "workflow.commands.sync_state" => Ok(Self::WorkflowCommandSyncState), "data_fetch_commands" => Ok(Self::DataFetchCommands), @@ -76,6 +86,10 @@ impl FromStr for NatsSubject { return Ok(Self::WorkflowProgress(uuid)); } } + if s.starts_with("workflow.cmd.") { + let key = s.trim_start_matches("workflow.cmd."); + return Ok(Self::WorkflowCommand(key.to_string())); + } Err(format!("Unknown or invalid subject: {}", s)) } } diff --git a/services/common-contracts/src/workflow_types.rs b/services/common-contracts/src/workflow_types.rs new file mode 100644 index 0000000..1f6be38 --- /dev/null +++ b/services/common-contracts/src/workflow_types.rs @@ -0,0 +1,110 @@ +use uuid::Uuid; +use service_kit::api_dto; +use serde::{Serialize, Deserialize}; +use crate::subjects::{NatsSubject, SubjectMessage}; + +// --- Enums --- + +#[api_dto] +#[derive(Copy, PartialEq, Eq, Hash)] +pub enum WorkflowStatus { + Pending, + Running, + Completed, + Failed, + Cancelled, +} + +#[api_dto] +#[derive(Copy, PartialEq, Eq, Hash)] +pub enum TaskStatus { + Pending, + Scheduled, + Running, + Completed, + Failed, + Skipped, + Cancelled, +} + +// --- Data Structures --- + +/// Context information required for a worker to execute a task. +/// This tells the worker "where" to work (Git Commit). +#[api_dto] +pub struct TaskContext { + /// The Git commit hash that this task should base its work on. + /// If None/Empty, it implies starting from an empty repo (or initial state). + pub base_commit: Option, + + /// Suggested path where the worker should mount/write its output. + /// This is a hint; strict adherence depends on the worker implementation. + pub mount_path: Option, +} + +/// Configuration related to storage access. +#[api_dto] +pub struct StorageConfig { + /// The root path on the host/container where the shared volume is mounted. + /// e.g., "/mnt/workflow_data" + pub root_path: String, +} + +// --- Commands --- + +/// A generic command sent by the Orchestrator to a Worker. +/// The `routing_key` in the subject determines which worker receives it. +#[api_dto] +pub struct WorkflowTaskCommand { + pub request_id: Uuid, + pub task_id: String, + pub routing_key: String, + + /// Dynamic configuration specific to the worker (e.g., {"market": "CN", "years": 5}). + /// The worker is responsible for parsing this. + pub config: serde_json::Value, + + pub context: TaskContext, + pub storage: StorageConfig, +} + +impl SubjectMessage for WorkflowTaskCommand { + fn subject(&self) -> NatsSubject { + // The actual subject is dynamic based on routing_key. + // This return value is primarily for default behavior or matching. + NatsSubject::WorkflowCommand(self.routing_key.clone()) + } +} + +// --- Events --- + +/// A generic event sent by a Worker back to the Orchestrator upon task completion (or failure). +#[api_dto] +pub struct WorkflowTaskEvent { + pub request_id: Uuid, + pub task_id: String, + pub status: TaskStatus, + + /// The result of the task execution. + pub result: Option, +} + +impl SubjectMessage for WorkflowTaskEvent { + fn subject(&self) -> NatsSubject { + NatsSubject::WorkflowEventTaskCompleted + } +} + +#[api_dto] +pub struct TaskResult { + /// The new Git commit hash generated by this task. + /// Should be None if the task failed. + pub new_commit: Option, + + /// Error message if the task failed. + pub error: Option, + + /// Optional metadata or summary of the result (not the full data). + pub summary: Option, +} + diff --git a/services/tushare-provider-service/Cargo.toml b/services/tushare-provider-service/Cargo.toml index 8fa4d7b..1714987 100644 --- a/services/tushare-provider-service/Cargo.toml +++ b/services/tushare-provider-service/Cargo.toml @@ -4,14 +4,16 @@ version = "0.1.0" edition = "2024" [dependencies] -common-contracts = { path = "../common-contracts" } +async-trait = "0.1.89" +secrecy = { version = "0.8", features = ["serde"] } +common-contracts = { path = "../common-contracts", default-features = false } +workflow-context = { path = "../../crates/workflow-context" } anyhow = "1.0" async-nats = "0.45.0" axum = "0.8" config = "0.15.19" dashmap = "6.1.0" -futures = "0.3" futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -23,11 +25,9 @@ uuid = { version = "1.6", features = ["v4", "serde"] } reqwest = { version = "0.12.24", features = ["json"] } url = "2.5.2" thiserror = "2.0.17" -async-trait = "0.1.80" lazy_static = "1.5.0" regex = "1.10.4" chrono = "0.4.38" rust_decimal = "1.35.0" rust_decimal_macros = "1.35.0" itertools = "0.14.0" -secrecy = { version = "0.8", features = ["serde"] } diff --git a/services/tushare-provider-service/src/generic_worker.rs b/services/tushare-provider-service/src/generic_worker.rs new file mode 100644 index 0000000..d518bee --- /dev/null +++ b/services/tushare-provider-service/src/generic_worker.rs @@ -0,0 +1,131 @@ +use anyhow::{Result, anyhow, Context}; +use tracing::{info, error, warn}; +use common_contracts::workflow_types::{WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, TaskResult}; +use common_contracts::subjects::{NatsSubject, SubjectMessage}; +use common_contracts::dtos::{CompanyProfileDto, TimeSeriesFinancialDto}; +use workflow_context::{WorkerContext, OutputFormat}; +use crate::state::AppState; +use crate::tushare::TushareDataProvider; +use serde_json::json; +use std::sync::Arc; + +pub async fn handle_workflow_command(state: AppState, nats: async_nats::Client, cmd: WorkflowTaskCommand) -> Result<()> { + info!("Processing generic workflow command: task_id={}", cmd.task_id); + + // 1. Parse Config + let symbol_code = cmd.config.get("symbol").and_then(|s| s.as_str()).unwrap_or("").to_string(); + let market = cmd.config.get("market").and_then(|s| s.as_str()).unwrap_or("CN").to_string(); + + if symbol_code.is_empty() { + return send_failure(&nats, &cmd, "Missing symbol in config").await; + } + + // 2. Initialize Worker Context + // Note: We use the provided base_commit. If it's empty, it means start from scratch (or empty repo). + // We need to mount the volume. + let root_path = cmd.storage.root_path.clone(); + + let mut ctx = match WorkerContext::init(&cmd.request_id.to_string(), &root_path, cmd.context.base_commit.as_deref()) { + Ok(c) => c, + Err(e) => return send_failure(&nats, &cmd, &format!("Failed to init context: {}", e)).await, + }; + + // 3. Fetch Data (with Cache) + let fetch_result = fetch_and_cache(&state, &symbol_code, &market).await; + + let (profile, financials) = match fetch_result { + Ok(data) => data, + Err(e) => return send_failure(&nats, &cmd, &format!("Fetch failed: {}", e)).await, + }; + + // 4. Write to VGCS + // Organize data in a structured way + let base_dir = format!("raw/tushare/{}", symbol_code); + + if let Err(e) = ctx.write_file(&format!("{}/profile.json", base_dir), &profile, OutputFormat::Json) { + return send_failure(&nats, &cmd, &format!("Failed to write profile: {}", e)).await; + } + + if let Err(e) = ctx.write_file(&format!("{}/financials.json", base_dir), &financials, OutputFormat::Json) { + return send_failure(&nats, &cmd, &format!("Failed to write financials: {}", e)).await; + } + + // 5. Commit + let new_commit = match ctx.commit(&format!("Fetched Tushare data for {}", symbol_code)) { + Ok(c) => c, + Err(e) => return send_failure(&nats, &cmd, &format!("Commit failed: {}", e)).await, + }; + + info!("Task {} completed. New commit: {}", cmd.task_id, new_commit); + + // 6. Send Success Event + let event = WorkflowTaskEvent { + request_id: cmd.request_id, + task_id: cmd.task_id, + status: TaskStatus::Completed, + result: Some(TaskResult { + new_commit: Some(new_commit), + error: None, + summary: Some(json!({ + "symbol": symbol_code, + "records": financials.len() + })), + }), + }; + + publish_event(&nats, event).await +} + +async fn fetch_and_cache(state: &AppState, symbol: &str, _market: &str) -> Result<(CompanyProfileDto, Vec)> { + // 1. Get Provider (which holds the API token) + let provider = state.get_provider().await + .ok_or_else(|| anyhow!("Tushare Provider not initialized (missing API Token?)"))?; + + // 2. Call fetch + let (profile, financials) = provider.fetch_all_data(symbol).await + .context("Failed to fetch data from Tushare")?; + + // 3. Write to DB Cache + // Note: PersistenceClient is not directly in AppState struct definition in `state.rs` I read. + // Let's check `state.rs` again. It implements TaskState which has `get_persistence_url`. + // We should instantiate PersistenceClient on the fly or add it to AppState. + + // For now, let's create a client on the fly to avoid changing AppState struct everywhere. + use common_contracts::persistence_client::PersistenceClient; + use common_contracts::workflow_harness::TaskState; // For get_persistence_url + + let persistence_url = state.get_persistence_url(); + let p_client = PersistenceClient::new(persistence_url); + + if let Err(e) = p_client.save_company_profile(&profile).await { + warn!("Failed to cache company profile: {}", e); + } + + // Batch save financials logic is missing in PersistenceClient (based on context). + // If it existed, we would call it here. + + Ok((profile, financials)) +} + +async fn send_failure(nats: &async_nats::Client, cmd: &WorkflowTaskCommand, error_msg: &str) -> Result<()> { + error!("Task {} failed: {}", cmd.task_id, error_msg); + let event = WorkflowTaskEvent { + request_id: cmd.request_id, + task_id: cmd.task_id.clone(), + status: TaskStatus::Failed, + result: Some(TaskResult { + new_commit: None, + error: Some(error_msg.to_string()), + summary: None, + }), + }; + publish_event(nats, event).await +} + +async fn publish_event(nats: &async_nats::Client, event: WorkflowTaskEvent) -> Result<()> { + let subject = event.subject().to_string(); + let payload = serde_json::to_vec(&event)?; + nats.publish(subject, payload.into()).await?; + Ok(()) +} + diff --git a/services/tushare-provider-service/src/main.rs b/services/tushare-provider-service/src/main.rs index 0355ce6..c02383f 100644 --- a/services/tushare-provider-service/src/main.rs +++ b/services/tushare-provider-service/src/main.rs @@ -15,7 +15,7 @@ use crate::error::{Result, AppError}; use crate::state::AppState; use tracing::{info, warn}; use common_contracts::lifecycle::ServiceRegistrar; -use common_contracts::registry::ServiceRegistration; +use common_contracts::registry::{ServiceRegistration, ProviderMetadata, ConfigFieldSchema, FieldType, ConfigKey}; use std::sync::Arc; #[tokio::main] @@ -52,6 +52,26 @@ async fn main() -> Result<()> { role: common_contracts::registry::ServiceRole::DataProvider, base_url: format!("http://{}:{}", config.service_host, port), health_check_url: format!("http://{}:{}/health", config.service_host, port), + metadata: Some(ProviderMetadata { + id: "tushare".to_string(), + name_en: "Tushare Pro".to_string(), + name_cn: "Tushare Pro (中国股市)".to_string(), + description: "Official Tushare Data Provider".to_string(), + icon_url: None, + config_schema: vec![ + ConfigFieldSchema { + key: ConfigKey::ApiToken, + label: "API Token".to_string(), + field_type: FieldType::Password, + required: true, + placeholder: Some("Enter your token...".to_string()), + default_value: None, + description: Some("Get it from https://tushare.pro".to_string()), + options: None, + }, + ], + supports_test_connection: true, + }), } ); diff --git a/services/tushare-provider-service/src/message_consumer.rs b/services/tushare-provider-service/src/message_consumer.rs index 069304c..96958a9 100644 --- a/services/tushare-provider-service/src/message_consumer.rs +++ b/services/tushare-provider-service/src/message_consumer.rs @@ -1,6 +1,7 @@ use crate::error::Result; use crate::state::{AppState, ServiceOperationalStatus}; use common_contracts::messages::FetchCompanyDataCommand; +use common_contracts::workflow_types::WorkflowTaskCommand; // Import use common_contracts::observability::ObservabilityTaskStatus; use common_contracts::subjects::NatsSubject; use futures_util::StreamExt; @@ -29,9 +30,16 @@ pub async fn run(state: AppState) -> Result<()> { match async_nats::connect(&state.config.nats_addr).await { Ok(client) => { info!("Successfully connected to NATS."); - if let Err(e) = subscribe_and_process(state.clone(), client).await { - error!("NATS subscription error: {}. Reconnecting in 10s...", e); + // Use try_join or spawn multiple subscribers + let s1 = subscribe_legacy(state.clone(), client.clone()); + let s2 = subscribe_workflow(state.clone(), client.clone()); + + tokio::select! { + res = s1 => if let Err(e) = res { error!("Legacy subscriber error: {}", e); }, + res = s2 => if let Err(e) = res { error!("Workflow subscriber error: {}", e); }, } + + // If any subscriber exits, wait and retry } Err(e) => { error!("Failed to connect to NATS: {}. Retrying in 10s...", e); @@ -41,14 +49,39 @@ pub async fn run(state: AppState) -> Result<()> { } } -async fn subscribe_and_process( +async fn subscribe_workflow(state: AppState, client: async_nats::Client) -> Result<()> { + let subject = "workflow.cmd.provider.tushare".to_string(); + let mut subscriber = client.subscribe(subject.clone()).await?; + info!("Workflow Consumer started on '{}'", subject); + + while let Some(message) = subscriber.next().await { + // Check status check (omitted for brevity, assuming handled) + + let state = state.clone(); + let client = client.clone(); + + tokio::spawn(async move { + match serde_json::from_slice::(&message.payload) { + Ok(cmd) => { + if let Err(e) = crate::generic_worker::handle_workflow_command(state, client, cmd).await { + error!("Generic worker handler failed: {}", e); + } + }, + Err(e) => error!("Failed to parse WorkflowTaskCommand: {}", e), + } + }); + } + Ok(()) +} + +async fn subscribe_legacy( state: AppState, client: async_nats::Client, ) -> Result<()> { let subject = NatsSubject::DataFetchCommands.to_string(); let mut subscriber = client.subscribe(subject.clone()).await?; info!( - "Consumer started, waiting for messages on subject '{}'", + "Legacy Consumer started, waiting for messages on subject '{}'", subject ); diff --git a/services/workflow-orchestrator-service/Cargo.toml b/services/workflow-orchestrator-service/Cargo.toml index ca89a38..f759a8d 100644 --- a/services/workflow-orchestrator-service/Cargo.toml +++ b/services/workflow-orchestrator-service/Cargo.toml @@ -20,4 +20,8 @@ dashmap = "6.1.0" axum = "0.8.7" # Internal dependencies -common-contracts = { path = "../common-contracts" } +common-contracts = { path = "../common-contracts", default-features = false } +workflow-context = { path = "../../crates/workflow-context" } + +[dev-dependencies] +tempfile = "3" diff --git a/services/workflow-orchestrator-service/src/config.rs b/services/workflow-orchestrator-service/src/config.rs index eb0359b..a8b926e 100644 --- a/services/workflow-orchestrator-service/src/config.rs +++ b/services/workflow-orchestrator-service/src/config.rs @@ -7,6 +7,7 @@ pub struct AppConfig { pub nats_addr: String, pub server_port: u16, pub data_persistence_service_url: String, + pub workflow_data_path: String, } impl AppConfig { @@ -18,12 +19,14 @@ impl AppConfig { .context("SERVER_PORT must be a number")?; let data_persistence_service_url = env::var("DATA_PERSISTENCE_SERVICE_URL") .unwrap_or_else(|_| "http://data-persistence-service:3000/api/v1".to_string()); + let workflow_data_path = env::var("WORKFLOW_DATA_PATH") + .unwrap_or_else(|_| "/mnt/workflow_data".to_string()); Ok(Self { nats_addr, server_port, data_persistence_service_url, + workflow_data_path, }) } } - diff --git a/services/workflow-orchestrator-service/src/dag_scheduler.rs b/services/workflow-orchestrator-service/src/dag_scheduler.rs new file mode 100644 index 0000000..b7cecb2 --- /dev/null +++ b/services/workflow-orchestrator-service/src/dag_scheduler.rs @@ -0,0 +1,264 @@ +use std::collections::HashMap; +use uuid::Uuid; +use common_contracts::workflow_types::{TaskStatus, TaskContext}; +use common_contracts::messages::TaskType; +use workflow_context::{Vgcs, ContextStore}; +use anyhow::Result; +use tracing::info; +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommitTracker { + /// Maps task_id to the commit hash it produced. + pub task_commits: HashMap, + /// The latest merged commit for the whole workflow (if linear). + /// Or just a reference to the "main" branch tip. + pub head_commit: String, +} + +impl CommitTracker { + pub fn new(initial_commit: String) -> Self { + Self { + task_commits: HashMap::new(), + head_commit: initial_commit, + } + } + + pub fn record_commit(&mut self, task_id: &str, commit: String) { + self.task_commits.insert(task_id.to_string(), commit.clone()); + // Note: head_commit update strategy depends on whether we want to track + // a single "main" branch or just use task_commits for DAG resolution. + // For now, we don't eagerly update head_commit unless it's a final task. + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DagScheduler { + pub request_id: Uuid, + pub nodes: HashMap, + /// TaskID -> List of downstream TaskIDs + pub forward_deps: HashMap>, + /// TaskID -> List of upstream TaskIDs + pub reverse_deps: HashMap>, + + pub commit_tracker: CommitTracker, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DagNode { + pub id: String, + pub task_type: TaskType, // Kept for UI/Observability, not for logic + pub status: TaskStatus, + pub config: serde_json::Value, + pub routing_key: String, +} + +impl DagScheduler { + pub fn new(request_id: Uuid, initial_commit: String) -> Self { + Self { + request_id, + nodes: HashMap::new(), + forward_deps: HashMap::new(), + reverse_deps: HashMap::new(), + commit_tracker: CommitTracker::new(initial_commit), + } + } + + pub fn add_node(&mut self, id: String, task_type: TaskType, routing_key: String, config: serde_json::Value) { + self.nodes.insert(id.clone(), DagNode { + id, + task_type, + status: TaskStatus::Pending, + config, + routing_key, + }); + } + + pub fn add_dependency(&mut self, from: &str, to: &str) { + self.forward_deps.entry(from.to_string()).or_default().push(to.to_string()); + self.reverse_deps.entry(to.to_string()).or_default().push(from.to_string()); + } + + /// Get all tasks that have no dependencies (roots) + pub fn get_initial_tasks(&self) -> Vec { + self.nodes.values() + .filter(|n| n.status == TaskStatus::Pending && self.reverse_deps.get(&n.id).map_or(true, |deps| deps.is_empty())) + .map(|n| n.id.clone()) + .collect() + } + + pub fn update_status(&mut self, task_id: &str, status: TaskStatus) { + if let Some(node) = self.nodes.get_mut(task_id) { + node.status = status; + } + } + + pub fn record_result(&mut self, task_id: &str, new_commit: Option) { + if let Some(c) = new_commit { + self.commit_tracker.record_commit(task_id, c); + } + } + + /// Determine which tasks are ready to run given that `completed_task_id` just finished. + pub fn get_ready_downstream_tasks(&self, completed_task_id: &str) -> Vec { + let mut ready = Vec::new(); + if let Some(downstream) = self.forward_deps.get(completed_task_id) { + for next_id in downstream { + if self.is_ready(next_id) { + ready.push(next_id.clone()); + } + } + } + ready + } + + fn is_ready(&self, task_id: &str) -> bool { + let node = match self.nodes.get(task_id) { + Some(n) => n, + None => return false, + }; + + if node.status != TaskStatus::Pending { + return false; + } + + if let Some(deps) = self.reverse_deps.get(task_id) { + for dep_id in deps { + match self.nodes.get(dep_id).map(|n| n.status) { + Some(TaskStatus::Completed) => continue, + _ => return false, // Dependency not completed + } + } + } + true + } + + /// Resolve the context (Base Commit) for a task. + /// If multiple dependencies, perform Merge or Fast-Forward. + pub fn resolve_context(&self, task_id: &str, vgcs: &Vgcs) -> Result { + let deps = self.reverse_deps.get(task_id).cloned().unwrap_or_default(); + + if deps.is_empty() { + // Root task: Use initial commit (usually empty string or base snapshot) + return Ok(TaskContext { + base_commit: Some(self.commit_tracker.head_commit.clone()), + mount_path: None, + }); + } + + // Collect parent commits + let mut parent_commits = Vec::new(); + for dep_id in &deps { + if let Some(c) = self.commit_tracker.task_commits.get(dep_id) { + if !c.is_empty() { + parent_commits.push(c.clone()); + } + } + } + + if parent_commits.is_empty() { + // All parents produced no commit? Fallback to head or empty. + return Ok(TaskContext { + base_commit: Some(self.commit_tracker.head_commit.clone()), + mount_path: None, + }); + } + + // Merge Strategy + let final_commit = self.merge_commits(vgcs, parent_commits)?; + + Ok(TaskContext { + base_commit: Some(final_commit), + mount_path: None, // Or determine based on config + }) + } + + /// Merge logic: + /// 1 parent -> Return it. + /// 2+ parents -> Iteratively merge using smart merge_commits + fn merge_commits(&self, vgcs: &Vgcs, commits: Vec) -> Result { + if commits.is_empty() { + return Ok(String::new()); + } + if commits.len() == 1 { + return Ok(commits[0].clone()); + } + + let mut current_head = commits[0].clone(); + + for i in 1..commits.len() { + let next_commit = &commits[i]; + if current_head == *next_commit { + continue; + } + + // Use the smart merge_commits which finds the common ancestor automatically + // Note: This handles Fast-Forward implicitly (merge_commits checks for ancestry) + info!("Merging commits: Ours={}, Theirs={}", current_head, next_commit); + current_head = vgcs.merge_commits(&self.request_id.to_string(), ¤t_head, next_commit)?; + } + + Ok(current_head) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use workflow_context::{Vgcs, ContextStore, Transaction}; + use common_contracts::messages::TaskType; + use serde_json::json; + + #[test] + fn test_dag_merge_strategy() -> Result<()> { + let temp_dir = TempDir::new()?; + let vgcs = Vgcs::new(temp_dir.path()); + let req_id = Uuid::new_v4(); + let req_id_str = req_id.to_string(); + + vgcs.init_repo(&req_id_str)?; + + // 0. Create Initial Commit (Common Ancestor) + let mut tx = vgcs.begin_transaction(&req_id_str, "")?; + let init_commit = Box::new(tx).commit("Initial Commit", "system")?; + + // 1. Setup DAG + let mut dag = DagScheduler::new(req_id, init_commit.clone()); + dag.add_node("A".to_string(), TaskType::DataFetch, "key.a".into(), json!({})); + dag.add_node("B".to_string(), TaskType::DataFetch, "key.b".into(), json!({})); + dag.add_node("C".to_string(), TaskType::Analysis, "key.c".into(), json!({})); + + // C depends on A and B + dag.add_dependency("A", "C"); + dag.add_dependency("B", "C"); + + // 2. Simulate Task A Execution -> Commit A (Based on Init) + let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?; + tx.write("file_a.txt", b"Content A")?; + let commit_a = Box::new(tx).commit("Task A", "worker")?; + dag.record_result("A", Some(commit_a.clone())); + dag.update_status("A", TaskStatus::Completed); + + // 3. Simulate Task B Execution -> Commit B (Based on Init) + let mut tx = vgcs.begin_transaction(&req_id_str, &init_commit)?; + tx.write("file_b.txt", b"Content B")?; + let commit_b = Box::new(tx).commit("Task B", "worker")?; + dag.record_result("B", Some(commit_b.clone())); + dag.update_status("B", TaskStatus::Completed); + + // 4. Resolve Context for C + // Should merge A and B + let ctx = dag.resolve_context("C", &vgcs)?; + let merged_commit = ctx.base_commit.expect("Should have base commit"); + + // Verify merged content + let files = vgcs.list_dir(&req_id_str, &merged_commit, "")?; + let file_names: Vec = files.iter().map(|f| f.name.clone()).collect(); + + assert!(file_names.contains(&"file_a.txt".to_string())); + assert!(file_names.contains(&"file_b.txt".to_string())); + + Ok(()) + } +} diff --git a/services/workflow-orchestrator-service/src/lib.rs b/services/workflow-orchestrator-service/src/lib.rs new file mode 100644 index 0000000..7ed184d --- /dev/null +++ b/services/workflow-orchestrator-service/src/lib.rs @@ -0,0 +1,7 @@ +pub mod api; +pub mod config; +pub mod message_consumer; +pub mod persistence; +pub mod state; +pub mod workflow; +pub mod dag_scheduler; diff --git a/services/workflow-orchestrator-service/src/main.rs b/services/workflow-orchestrator-service/src/main.rs index 8c4c574..31c6467 100644 --- a/services/workflow-orchestrator-service/src/main.rs +++ b/services/workflow-orchestrator-service/src/main.rs @@ -2,20 +2,15 @@ use anyhow::Result; use tracing::info; use tracing_subscriber::EnvFilter; use std::sync::Arc; - -mod config; -mod state; -mod message_consumer; -mod workflow; -mod persistence; -mod api; +use workflow_orchestrator_service::{config, state, message_consumer, api}; #[tokio::main] async fn main() -> Result<()> { // Initialize tracing tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .init(); + // .with_env_filter(EnvFilter::from_default_env()) + .with_env_filter("info") + .init(); info!("Starting workflow-orchestrator-service..."); @@ -38,8 +33,7 @@ async fn main() -> Result<()> { } }); - // Start HTTP Server (for health check & potential admin/debug) - // Note: The main trigger API is now NATS commands, but we might expose some debug endpoints + // Start HTTP Server let app = api::create_router(state.clone()); let addr = format!("0.0.0.0:{}", config.server_port); let listener = tokio::net::TcpListener::bind(&addr).await?; diff --git a/services/workflow-orchestrator-service/src/message_consumer.rs b/services/workflow-orchestrator-service/src/message_consumer.rs index e8fd7e1..168f436 100644 --- a/services/workflow-orchestrator-service/src/message_consumer.rs +++ b/services/workflow-orchestrator-service/src/message_consumer.rs @@ -4,114 +4,50 @@ use anyhow::Result; use tracing::{info, error}; use futures::StreamExt; use crate::state::AppState; -use common_contracts::messages::{ - StartWorkflowCommand, SyncStateCommand, - FinancialsPersistedEvent, DataFetchFailedEvent, -}; +use common_contracts::messages::StartWorkflowCommand; +use common_contracts::workflow_types::WorkflowTaskEvent; use common_contracts::subjects::NatsSubject; -use crate::persistence::{DashMapWorkflowRepository, NatsMessageBroker}; use crate::workflow::WorkflowEngine; pub async fn run(state: Arc, nats: Client) -> Result<()> { info!("Message Consumer started. Subscribing to topics..."); - let mut workflow_sub = nats.subscribe(NatsSubject::WorkflowCommandsWildcard.to_string()).await?; - let mut data_sub = nats.subscribe(NatsSubject::DataEventsWildcard.to_string()).await?; - let mut analysis_sub = nats.subscribe(NatsSubject::AnalysisEventsWildcard.to_string()).await?; + // Topic 1: Workflow Commands (Start) + // Note: NatsSubject::WorkflowCommandStart string representation is "workflow.commands.start" + let mut start_sub = nats.subscribe(NatsSubject::WorkflowCommandStart.to_string()).await?; - // Prepare components - let repo = DashMapWorkflowRepository::new(state.workflows.clone()); - let broker = NatsMessageBroker::new(nats.clone()); + // Topic 2: Workflow Task Events (Generic) + // Note: NatsSubject::WorkflowEventTaskCompleted string representation is "workflow.evt.task_completed" + let mut task_sub = nats.subscribe(NatsSubject::WorkflowEventTaskCompleted.to_string()).await?; - // ... (Task 1 & 2 remain same) ... + let engine = Arc::new(WorkflowEngine::new(state.clone(), nats.clone())); - // --- Task 3: Analysis Events (Report Gen updates) --- - let repo3 = repo.clone(); - let broker3 = broker.clone(); + // --- Task 1: Start Workflow --- + let engine1 = engine.clone(); tokio::spawn(async move { - let engine = WorkflowEngine::new(Box::new(repo3), Box::new(broker3)); - while let Some(msg) = analysis_sub.next().await { - match NatsSubject::try_from(msg.subject.as_str()) { - Ok(NatsSubject::AnalysisReportGenerated) => { - if let Ok(evt) = serde_json::from_slice::(&msg.payload) { - info!("Received ReportGenerated: {:?}", evt); - if let Err(e) = engine.handle_report_generated(evt).await { - error!("Failed to handle ReportGenerated: {}", e); - } - } - }, - Ok(NatsSubject::AnalysisReportFailed) => { - if let Ok(evt) = serde_json::from_slice::(&msg.payload) { - info!("Received ReportFailed: {:?}", evt); - if let Err(e) = engine.handle_report_failed(evt).await { - error!("Failed to handle ReportFailed: {}", e); - } - } - }, - _ => { - // Ignore other subjects or log warning + while let Some(msg) = start_sub.next().await { + if let Ok(cmd) = serde_json::from_slice::(&msg.payload) { + info!("Received StartWorkflow: {:?}", cmd); + if let Err(e) = engine1.handle_start_workflow(cmd).await { + error!("Failed to handle StartWorkflow: {}", e); } + } else { + error!("Failed to parse StartWorkflowCommand"); } } }); - // --- Task 1: Workflow Commands (Start, Sync) --- - let repo1 = repo.clone(); - let broker1 = broker.clone(); + // --- Task 2: Task Completed Events --- + let engine2 = engine.clone(); tokio::spawn(async move { - let engine = WorkflowEngine::new(Box::new(repo1), Box::new(broker1)); - while let Some(msg) = workflow_sub.next().await { - match NatsSubject::try_from(msg.subject.as_str()) { - Ok(NatsSubject::WorkflowCommandStart) => { - if let Ok(cmd) = serde_json::from_slice::(&msg.payload) { - info!("Received StartWorkflow: {:?}", cmd); - if let Err(e) = engine.handle_start_workflow(cmd).await { - error!("Failed to handle StartWorkflow: {}", e); - } - } else { - error!("Failed to parse StartWorkflowCommand"); - } - }, - Ok(NatsSubject::WorkflowCommandSyncState) => { - if let Ok(cmd) = serde_json::from_slice::(&msg.payload) { - info!("Received SyncState: {:?}", cmd); - // 关键修复:这里直接调用处理逻辑,如果还不行,我们需要检查 state 是否真的加载到了 - // 或者在这里加更多的 log - match engine.handle_sync_state(cmd).await { - Ok(_) => info!("Successfully processed SyncState"), - Err(e) => error!("Failed to handle SyncState: {}", e), - } - } - }, - _ => {} - } - } - }); - - // --- Task 2: Data Events (Progress updates) --- - let repo2 = repo.clone(); - let broker2 = broker.clone(); - tokio::spawn(async move { - let engine = WorkflowEngine::new(Box::new(repo2), Box::new(broker2)); - while let Some(msg) = data_sub.next().await { - match NatsSubject::try_from(msg.subject.as_str()) { - Ok(NatsSubject::DataFinancialsPersisted) => { - if let Ok(evt) = serde_json::from_slice::(&msg.payload) { - info!("Received DataFetched: {:?}", evt); - if let Err(e) = engine.handle_data_fetched(evt).await { - error!("Failed to handle DataFetched: {}", e); - } - } - }, - Ok(NatsSubject::DataFetchFailed) => { - if let Ok(evt) = serde_json::from_slice::(&msg.payload) { - info!("Received DataFailed: {:?}", evt); - if let Err(e) = engine.handle_data_failed(evt).await { - error!("Failed to handle DataFailed: {}", e); - } - } - }, - _ => {} + while let Some(msg) = task_sub.next().await { + if let Ok(evt) = serde_json::from_slice::(&msg.payload) { + info!("Received TaskCompleted: task_id={}", evt.task_id); + if let Err(e) = engine2.handle_task_completed(evt).await { + error!("Failed to handle TaskCompleted: {}", e); + } + } else { + error!("Failed to parse WorkflowTaskEvent"); } } }); diff --git a/services/workflow-orchestrator-service/src/persistence.rs b/services/workflow-orchestrator-service/src/persistence.rs index e659d1c..e4ac118 100644 --- a/services/workflow-orchestrator-service/src/persistence.rs +++ b/services/workflow-orchestrator-service/src/persistence.rs @@ -1,72 +1,2 @@ -use async_trait::async_trait; -use crate::workflow::{WorkflowRepository, WorkflowStateMachine, MessageBroker}; -use std::sync::Arc; -use tokio::sync::Mutex; -use dashmap::DashMap; -use uuid::Uuid; -use anyhow::Result; -use async_nats::Client; -use common_contracts::messages::WorkflowEvent; - -// --- Real Implementation of Repository (In-Memory for MVP) --- -#[derive(Clone)] -pub struct DashMapWorkflowRepository { - store: Arc>>> -} - -impl DashMapWorkflowRepository { - pub fn new(store: Arc>>>) -> Self { - Self { store } - } -} - -#[async_trait] -impl WorkflowRepository for DashMapWorkflowRepository { - async fn save(&self, workflow: &WorkflowStateMachine) -> Result<()> { - if let Some(entry) = self.store.get(&workflow.request_id) { - let mut guard = entry.lock().await; - *guard = workflow.clone(); - } else { - self.store.insert(workflow.request_id, Arc::new(Mutex::new(workflow.clone()))); - } - Ok(()) - } - - async fn load(&self, id: Uuid) -> Result> { - if let Some(entry) = self.store.get(&id) { - let guard = entry.lock().await; - return Ok(Some(guard.clone())); - } - // Add debug log to see what keys are actually in the store - // tracing::debug!("Keys in store: {:?}", self.store.iter().map(|r| *r.key()).collect::>()); - Ok(None) - } -} - -// --- Real Implementation of Broker --- - -#[derive(Clone)] -pub struct NatsMessageBroker { - client: Client -} - -impl NatsMessageBroker { - pub fn new(client: Client) -> Self { - Self { client } - } -} - -use common_contracts::subjects::NatsSubject; - -#[async_trait] -impl MessageBroker for NatsMessageBroker { - async fn publish_event(&self, request_id: Uuid, event: WorkflowEvent) -> Result<()> { - let topic = NatsSubject::WorkflowProgress(request_id).to_string(); - let payload = serde_json::to_vec(&event)?; - self.client.publish(topic, payload.into()).await.map_err(|e| anyhow::anyhow!(e)) - } - - async fn publish_command(&self, topic: &str, payload: Vec) -> Result<()> { - self.client.publish(topic.to_string(), payload.into()).await.map_err(|e| anyhow::anyhow!(e)) - } -} +// Persistence module currently unused after refactor to DagScheduler. +// Will be reintroduced when we implement DB persistence for the DAG. diff --git a/services/workflow-orchestrator-service/src/state.rs b/services/workflow-orchestrator-service/src/state.rs index f85aed8..b7b9527 100644 --- a/services/workflow-orchestrator-service/src/state.rs +++ b/services/workflow-orchestrator-service/src/state.rs @@ -5,25 +5,33 @@ use common_contracts::persistence_client::PersistenceClient; use dashmap::DashMap; use uuid::Uuid; use tokio::sync::Mutex; -use crate::workflow::WorkflowStateMachine; +// use crate::workflow::WorkflowStateMachine; // Deprecated +use crate::dag_scheduler::DagScheduler; +use workflow_context::Vgcs; pub struct AppState { + #[allow(dead_code)] pub config: AppConfig, + #[allow(dead_code)] pub persistence_client: PersistenceClient, + // Key: request_id // We use Mutex here because state transitions need to be atomic per workflow - pub workflows: Arc>>>, + pub workflows: Arc>>>, + + pub vgcs: Arc, } impl AppState { pub async fn new(config: AppConfig) -> Result { let persistence_client = PersistenceClient::new(config.data_persistence_service_url.clone()); + let vgcs = Arc::new(Vgcs::new(&config.workflow_data_path)); Ok(Self { config, persistence_client, workflows: Arc::new(DashMap::new()), + vgcs, }) } } - diff --git a/services/workflow-orchestrator-service/src/workflow.rs b/services/workflow-orchestrator-service/src/workflow.rs index 270e3f8..ad81613 100644 --- a/services/workflow-orchestrator-service/src/workflow.rs +++ b/services/workflow-orchestrator-service/src/workflow.rs @@ -1,680 +1,181 @@ -use std::collections::HashMap; -use uuid::Uuid; -use common_contracts::messages::{ - WorkflowEvent, WorkflowDag, TaskStatus, TaskType, TaskNode, TaskDependency, - StartWorkflowCommand, FinancialsPersistedEvent, DataFetchFailedEvent, - SyncStateCommand, FetchCompanyDataCommand, GenerateReportCommand +use std::sync::Arc; +use common_contracts::workflow_types::{ + WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, StorageConfig +}; +use common_contracts::messages::{ + StartWorkflowCommand, TaskType }; -use common_contracts::symbol_utils::CanonicalSymbol; -use tracing::{info, warn}; -use async_trait::async_trait; -use anyhow::Result; - -#[derive(Debug, Clone)] -pub struct WorkflowStateMachine { - pub request_id: Uuid, - pub symbol: CanonicalSymbol, - pub market: String, - pub template_id: String, - pub dag: WorkflowDag, - - // Runtime state: TaskID -> Status - pub task_status: HashMap, - // Fast lookup for dependencies: TaskID -> List of dependencies (upstream) - pub reverse_deps: HashMap>, - // Fast lookup for dependents: TaskID -> List of tasks that depend on it (downstream) - pub forward_deps: HashMap>, -} - -impl WorkflowStateMachine { - pub fn new(request_id: Uuid, symbol: CanonicalSymbol, market: String, template_id: String) -> Self { - let (dag, reverse_deps, forward_deps) = Self::build_dag(&template_id); - - let mut task_status = HashMap::new(); - for node in &dag.nodes { - task_status.insert(node.id.clone(), node.initial_status); - } - - Self { - request_id, - symbol, - market, - template_id, - dag, - task_status, - reverse_deps, - forward_deps, - } - } - - // Determines which tasks are ready to run immediately (no dependencies) - pub fn get_initial_tasks(&self) -> Vec { - self.dag.nodes.iter() - .filter(|n| n.initial_status == TaskStatus::Pending && self.reverse_deps.get(&n.id).map_or(true, |deps| deps.is_empty())) - .map(|n| n.id.clone()) - .collect() - } - - // Core logic for state transition - // Returns a list of tasks that should be TRIGGERED (transitioned from Pending -> Scheduled) - pub fn update_task_status(&mut self, task_id: &str, new_status: TaskStatus) -> Vec { - if let Some(current) = self.task_status.get_mut(task_id) { - if *current == new_status { - return vec![]; // No change - } - - info!("Workflow {}: Task {} transition {:?} -> {:?}", self.request_id, task_id, current, new_status); - *current = new_status; - } else { - warn!("Workflow {}: Unknown task id {}", self.request_id, task_id); - return vec![]; - } - - // If task completed, check downstream dependents - if new_status == TaskStatus::Completed { - return self.check_dependents(task_id); - } - // If task failed, we might need to skip downstream tasks (Policy decision) - else if new_status == TaskStatus::Failed { - // For now, simplistic policy: propagate failure as Skipped - self.propagate_skip(task_id); - } - - vec![] - } - - fn check_dependents(&self, completed_task_id: &str) -> Vec { - let mut ready_tasks = Vec::new(); - - if let Some(downstream) = self.forward_deps.get(completed_task_id) { - for dependent_id in downstream { - if self.is_task_ready(dependent_id) { - ready_tasks.push(dependent_id.clone()); - } - } - } - - ready_tasks - } - - fn is_task_ready(&self, task_id: &str) -> bool { - if let Some(status) = self.task_status.get(task_id) { - if *status != TaskStatus::Pending { - return false; - } - } else { - return false; - } - - if let Some(deps) = self.reverse_deps.get(task_id) { - for dep_id in deps { - match self.task_status.get(dep_id) { - Some(TaskStatus::Completed) => continue, // Good - _ => return false, // Any other status means not ready - } - } - } - - true - } - - fn propagate_skip(&mut self, failed_task_id: &str) { - if let Some(downstream) = self.forward_deps.get(failed_task_id) { - for dependent_id in downstream { - if let Some(status) = self.task_status.get_mut(dependent_id) { - if *status == TaskStatus::Pending { - *status = TaskStatus::Skipped; - // For MVP we do 1 level. Real impl needs full graph traversal. - } - } - } - } - } - - fn build_dag(template_id: &str) -> (WorkflowDag, HashMap>, HashMap>) { - let mut nodes = vec![]; - let mut edges = vec![]; - - // Simplistic Hack for E2E Testing Phase 4 - if template_id == "simple_test_analysis" { - // Simple DAG: Fetch(yfinance) -> Analysis - nodes.push(TaskNode { - id: "fetch:yfinance".to_string(), - name: "Fetch yfinance".to_string(), - r#type: TaskType::DataFetch, - initial_status: TaskStatus::Pending, - }); - nodes.push(TaskNode { - id: "analysis:report".to_string(), - name: "Generate Report".to_string(), - r#type: TaskType::Analysis, - initial_status: TaskStatus::Pending, - }); - edges.push(TaskDependency { from: "fetch:yfinance".to_string(), to: "analysis:report".to_string() }); - - } else { - // Default Hardcoded DAG - // 1. Data Fetch Nodes (Roots) - let providers = vec!["alphavantage", "tushare", "finnhub", "yfinance"]; - for p in &providers { - nodes.push(TaskNode { - id: format!("fetch:{}", p), - name: format!("Fetch {}", p), - r#type: TaskType::DataFetch, - initial_status: TaskStatus::Pending, - }); - } - - // 2. Analysis Node - nodes.push(TaskNode { - id: "analysis:report".to_string(), - name: "Generate Report".to_string(), - r#type: TaskType::Analysis, - initial_status: TaskStatus::Pending, - }); - - // Edges: All fetchers -> Analysis - for p in &providers { - edges.push(TaskDependency { from: format!("fetch:{}", p), to: "analysis:report".to_string() }); - } - } - - let mut reverse_deps: HashMap> = HashMap::new(); - let mut forward_deps: HashMap> = HashMap::new(); - - for edge in &edges { - reverse_deps.entry(edge.to.clone()).or_default().push(edge.from.clone()); - forward_deps.entry(edge.from.clone()).or_default().push(edge.to.clone()); - } - - (WorkflowDag { nodes, edges }, reverse_deps, forward_deps) - } - - pub fn get_snapshot_event(&self) -> WorkflowEvent { - WorkflowEvent::WorkflowStateSnapshot { - timestamp: chrono::Utc::now().timestamp_millis(), - task_graph: self.dag.clone(), - tasks_status: self.task_status.clone(), - tasks_output: HashMap::new(), - } - } -} - -// --- Traits & Engine --- - use common_contracts::subjects::SubjectMessage; +use common_contracts::symbol_utils::CanonicalSymbol; +use tracing::{info, warn, error}; +use anyhow::Result; +use serde_json::json; +use tokio::sync::Mutex; -#[async_trait] -pub trait WorkflowRepository: Send + Sync { - async fn save(&self, workflow: &WorkflowStateMachine) -> Result<()>; - async fn load(&self, id: Uuid) -> Result>; -} - -#[async_trait] -pub trait MessageBroker: Send + Sync { - async fn publish_event(&self, request_id: Uuid, event: WorkflowEvent) -> Result<()>; - async fn publish_command(&self, topic: &str, payload: Vec) -> Result<()>; -} +use crate::dag_scheduler::DagScheduler; +use crate::state::AppState; +use workflow_context::{Vgcs, ContextStore}; // Added ContextStore pub struct WorkflowEngine { - repo: Box, - broker: Box, + state: Arc, + nats: async_nats::Client, } impl WorkflowEngine { - pub fn new(repo: Box, broker: Box) -> Self { - Self { repo, broker } + pub fn new(state: Arc, nats: async_nats::Client) -> Self { + Self { state, nats } } pub async fn handle_start_workflow(&self, cmd: StartWorkflowCommand) -> Result<()> { - let mut machine = WorkflowStateMachine::new( - cmd.request_id, - cmd.symbol.clone(), - cmd.market.clone(), - cmd.template_id.clone() + let req_id = cmd.request_id; + info!("Starting workflow {}", req_id); + + // 1. Init VGCS Repo + self.state.vgcs.init_repo(&req_id.to_string())?; + + // 2. Create Scheduler + // Initial commit is empty for a fresh workflow + let mut dag = DagScheduler::new(req_id, String::new()); + + // 3. Build DAG (Simplified Hardcoded logic for now, matching old build_dag) + // In a real scenario, we fetch Template from DB/Service using cmd.template_id + self.build_dag(&mut dag, &cmd.template_id, &cmd.market, &cmd.symbol); + + // 4. Save State + self.state.workflows.insert(req_id, Arc::new(Mutex::new(dag.clone()))); + + // 5. Trigger Initial Tasks + let initial_tasks = dag.get_initial_tasks(); + + // Lock the DAG for updates + let dag_arc = self.state.workflows.get(&req_id).unwrap().clone(); + let mut dag_guard = dag_arc.lock().await; + + for task_id in initial_tasks { + self.dispatch_task(&mut dag_guard, &task_id, &self.state.vgcs).await?; + } + + Ok(()) + } + + pub async fn handle_task_completed(&self, evt: WorkflowTaskEvent) -> Result<()> { + let req_id = evt.request_id; + + let dag_arc = match self.state.workflows.get(&req_id) { + Some(d) => d.clone(), + None => { + warn!("Received event for unknown workflow {}", req_id); + return Ok(()); + } + }; + + let mut dag = dag_arc.lock().await; + + // 1. Update Status & Record Commit + dag.update_status(&evt.task_id, evt.status); + + if let Some(result) = evt.result { + if let Some(commit) = result.new_commit { + info!("Task {} produced commit {}", evt.task_id, commit); + dag.record_result(&evt.task_id, Some(commit)); + } + if let Some(err) = result.error { + warn!("Task {} failed with error: {}", evt.task_id, err); + } + } + + // 2. Check for downstream tasks + if evt.status == TaskStatus::Completed { + let ready_tasks = dag.get_ready_downstream_tasks(&evt.task_id); + for task_id in ready_tasks { + if let Err(e) = self.dispatch_task(&mut dag, &task_id, &self.state.vgcs).await { + error!("Failed to dispatch task {}: {}", task_id, e); + } + } + } else if evt.status == TaskStatus::Failed { + // Handle failure propagation (skip downstream?) + // For now, just log. + error!("Task {} failed. Workflow might stall.", evt.task_id); + } + + // 3. Check Workflow Completion (TODO) + + Ok(()) + } + + async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> { + // 1. Resolve Context (Merge if needed) + let context = dag.resolve_context(task_id, vgcs)?; + + // 2. Update Status + dag.update_status(task_id, TaskStatus::Scheduled); + + // 3. Construct Command + let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?; + + let cmd = WorkflowTaskCommand { + request_id: dag.request_id, + task_id: task_id.to_string(), + routing_key: node.routing_key.clone(), + config: node.config.clone(), + context, + storage: StorageConfig { + root_path: self.state.config.workflow_data_path.clone(), + }, + }; + + // 4. Publish + let subject = cmd.subject().to_string(); // This uses the routing_key + let payload = serde_json::to_vec(&cmd)?; + + info!("Dispatching task {} to subject {}", task_id, subject); + self.nats.publish(subject, payload.into()).await?; + + Ok(()) + } + + // Helper to build DAG (Migrated from old workflow.rs) + fn build_dag(&self, dag: &mut DagScheduler, template_id: &str, market: &str, symbol: &CanonicalSymbol) { + // Logic copied/adapted from old WorkflowStateMachine::build_dag + + let mut providers = Vec::new(); + match market { + "CN" => { + providers.push("tushare"); + // providers.push("yfinance"); + }, + "US" => providers.push("yfinance"), + _ => providers.push("yfinance"), + } + + // 1. Data Fetch Nodes + for p in &providers { + let task_id = format!("fetch:{}", p); + dag.add_node( + task_id.clone(), + TaskType::DataFetch, + format!("provider.{}", p), // routing_key: workflow.cmd.provider.tushare + json!({ + "symbol": symbol.as_str(), // Simplification + "market": market + }) + ); + } + + // 2. Analysis Node (Simplified) + let report_task_id = "analysis:report"; + dag.add_node( + report_task_id.to_string(), + TaskType::Analysis, + "analysis.report".to_string(), // routing_key: workflow.cmd.analysis.report + json!({ + "template_id": template_id + }) ); - self.repo.save(&machine).await?; - - // Broadcast Started - let start_event = WorkflowEvent::WorkflowStarted { - timestamp: chrono::Utc::now().timestamp_millis(), - task_graph: machine.dag.clone() - }; - self.broker.publish_event(cmd.request_id, start_event).await?; - - // Initial Tasks - let initial_tasks = machine.get_initial_tasks(); - for task_id in initial_tasks { - let _ = machine.update_task_status(&task_id, TaskStatus::Scheduled); - - self.broadcast_task_state(cmd.request_id, &task_id, TaskType::DataFetch, TaskStatus::Scheduled, None, None).await?; - - if task_id.starts_with("fetch:") { - let fetch_cmd = FetchCompanyDataCommand { - request_id: cmd.request_id, - symbol: cmd.symbol.clone(), - market: cmd.market.clone(), - template_id: Some(cmd.template_id.clone()), - }; - // Extract provider from task_id (e.g., "fetch:yfinance" -> "yfinance") - // We only publish if the task_id matches the fetch provider we intend to trigger. - // But wait, 'FetchCompanyDataCommand' is generic. - // If we send it, all listening providers will pick it up. - // We rely on the fact that unconfigured providers (Tushare etc) are "paused" or won't act if no key. - // However, if we have "simple_test_analysis" which ONLY has "fetch:yfinance", - // we still broadcast "data_fetch_commands". - // Tushare service WILL receive it. If it doesn't have a key, it warns and ignores or retries. - // Ideally, the command should carry the target provider ID. - // But FetchCompanyDataCommand structure is: - // pub struct FetchCompanyDataCommand { request_id, symbol, market, template_id } - // It lacks provider_id. - - // For now, we just broadcast. The key is that the DAG only waits for "fetch:yfinance". - // So even if Tushare fails or never returns, it doesn't matter because "fetch:tushare" is NOT in the DAG. - - let payload = serde_json::to_vec(&fetch_cmd)?; - self.broker.publish_command(fetch_cmd.subject().to_string().as_str(), payload).await?; - } - } - - self.repo.save(&machine).await?; - Ok(()) - } - - pub async fn handle_data_fetched(&self, evt: FinancialsPersistedEvent) -> Result<()> { - if let Some(mut machine) = self.repo.load(evt.request_id).await? { - let provider_id = evt.provider_id.as_deref().unwrap_or("unknown"); - let task_id = format!("fetch:{}", provider_id); - - let next_tasks = machine.update_task_status(&task_id, TaskStatus::Completed); - - let summary = evt.data_summary.clone().unwrap_or_else(|| "Data fetched".to_string()); - self.broker.publish_event(evt.request_id, WorkflowEvent::TaskStreamUpdate { - task_id: task_id.clone(), - content_delta: summary, - index: 0 - }).await?; - - self.broadcast_task_state(evt.request_id, &task_id, TaskType::DataFetch, TaskStatus::Completed, Some("Data persisted".into()), None).await?; - - self.trigger_next_tasks(&mut machine, next_tasks).await?; - - self.repo.save(&machine).await?; - } - Ok(()) - } - - pub async fn handle_data_failed(&self, evt: DataFetchFailedEvent) -> Result<()> { - if let Some(mut machine) = self.repo.load(evt.request_id).await? { - let provider_id = evt.provider_id.as_deref().unwrap_or("unknown"); - let task_id = format!("fetch:{}", provider_id); - - let _ = machine.update_task_status(&task_id, TaskStatus::Failed); - self.broadcast_task_state(evt.request_id, &task_id, TaskType::DataFetch, TaskStatus::Failed, Some(evt.error), None).await?; - self.repo.save(&machine).await?; - } - Ok(()) - } - - pub async fn handle_sync_state(&self, cmd: SyncStateCommand) -> Result<()> { - info!("Handling SyncState for request_id: {}", cmd.request_id); - if let Some(machine) = self.repo.load(cmd.request_id).await? { - info!("Found workflow machine for {}, publishing snapshot", cmd.request_id); - let snapshot = machine.get_snapshot_event(); - self.broker.publish_event(cmd.request_id, snapshot).await?; - } else { - warn!("No workflow found for request_id: {}", cmd.request_id); - } - Ok(()) - } - - pub async fn handle_report_generated(&self, evt: common_contracts::messages::ReportGeneratedEvent) -> Result<()> { - if let Some(mut machine) = self.repo.load(evt.request_id).await? { - let task_id = format!("analysis:{}", evt.module_id); - // The module_id from ReportGen might be "report" or "step2_analyze" or whatever the template used. - // In "simple_test_analysis", the task ID is "analysis:report" (hardcoded in build_dag) - // But the template module key was "step2_analyze". - - // WAIT: In build_dag: - // nodes.push(TaskNode { id: "analysis:report".to_string(), ... }); - // But in GenerateReportCommand sent to report-gen: - // template_id: machine.template_id - - // Report Generator uses the template modules. - // It executes "step2_analyze". - // It sends ReportGeneratedEvent with module_id = "step2_analyze". - - // Orchestrator's DAG has "analysis:report". - // MISMATCH! - - // The hardcoded build_dag uses "analysis:report". - // The actual template execution uses "step2_analyze". - - // This mismatch is why the E2E test fails even if Report Gen succeeds. - // Report Gen says: "step2_analyze" done. - // Orchestrator waits for: "analysis:report". - - // FIX: - // We need to map the incoming module_id to the task_id in our DAG. - // OR update build_dag to use the correct ID. - - // For simple_test_analysis, we can hack it here or fix build_dag. - // Since I cannot change build_dag dynamically easily without loading template, - // I will try to handle the mapping here. - - let target_task_id = if machine.template_id == "simple_test_analysis" && evt.module_id == "step2_analyze" { - "analysis:report".to_string() - } else { - // Default assumption: task_id = "analysis:{module_id}" - // But our hardcoded DAG uses "analysis:report". - // If module_id is "report", then it matches. - format!("analysis:{}", evt.module_id) - }; - - // Fallback: if target_task_id not in DAG, try "analysis:report" if it exists and is Running/Scheduled - let final_task_id = if machine.task_status.contains_key(&target_task_id) { - target_task_id - } else if machine.task_status.contains_key("analysis:report") { - "analysis:report".to_string() - } else { - target_task_id - }; - - let next_tasks = machine.update_task_status(&final_task_id, TaskStatus::Completed); - - self.broadcast_task_state(evt.request_id, &final_task_id, TaskType::Analysis, TaskStatus::Completed, Some("Report generated".into()), None).await?; - - // If this was the last node, the workflow is complete? - // check_dependents returns next ready tasks. If empty, and no running tasks, we are done. - // We should check if workflow is complete. - - self.check_workflow_completion(&machine).await?; - - self.repo.save(&machine).await?; - } - Ok(()) - } - - pub async fn handle_report_failed(&self, evt: common_contracts::messages::ReportFailedEvent) -> Result<()> { - if let Some(mut machine) = self.repo.load(evt.request_id).await? { - let task_id = if machine.template_id == "simple_test_analysis" && evt.module_id == "step2_analyze" { - "analysis:report".to_string() - } else { - format!("analysis:{}", evt.module_id) - }; - - let final_task_id = if machine.task_status.contains_key(&task_id) { - task_id - } else { - "analysis:report".to_string() - }; - - let _ = machine.update_task_status(&final_task_id, TaskStatus::Failed); - self.broadcast_task_state(evt.request_id, &final_task_id, TaskType::Analysis, TaskStatus::Failed, Some(evt.error.clone()), None).await?; - - // Propagate failure? - self.broker.publish_event(evt.request_id, WorkflowEvent::WorkflowFailed { - reason: format!("Analysis task failed: {}", evt.error), - is_fatal: true, - end_timestamp: chrono::Utc::now().timestamp_millis(), - }).await?; - - self.repo.save(&machine).await?; - } - Ok(()) - } - - async fn check_workflow_completion(&self, machine: &WorkflowStateMachine) -> Result<()> { - // If all tasks are Completed or Skipped, workflow is done. - let all_done = machine.task_status.values().all(|s| *s == TaskStatus::Completed || *s == TaskStatus::Skipped || *s == TaskStatus::Failed); - - if all_done { - self.broker.publish_event(machine.request_id, WorkflowEvent::WorkflowCompleted { - result_summary: serde_json::json!({"status": "success"}), - end_timestamp: chrono::Utc::now().timestamp_millis(), - }).await?; - } - Ok(()) - } - - async fn trigger_next_tasks(&self, machine: &mut WorkflowStateMachine, tasks: Vec) -> Result<()> { - for task_id in tasks { - machine.update_task_status(&task_id, TaskStatus::Scheduled); - self.broadcast_task_state(machine.request_id, &task_id, TaskType::Analysis, TaskStatus::Scheduled, None, None).await?; - - if task_id == "analysis:report" { - let cmd = GenerateReportCommand { - request_id: machine.request_id, - symbol: machine.symbol.clone(), - template_id: machine.template_id.clone(), - }; - let payload = serde_json::to_vec(&cmd)?; - self.broker.publish_command(cmd.subject().to_string().as_str(), payload).await?; - } - } - Ok(()) - } - - async fn broadcast_task_state(&self, request_id: Uuid, task_id: &str, ttype: TaskType, status: TaskStatus, msg: Option, progress: Option) -> Result<()> { - let event = WorkflowEvent::TaskStateChanged { - task_id: task_id.to_string(), - task_type: ttype, - status, - message: msg, - timestamp: chrono::Utc::now().timestamp_millis(), - progress, // Passed through - }; - self.broker.publish_event(request_id, event).await - } -} - -// --- Tests --- - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::RwLock; - use std::sync::Arc; - use common_contracts::symbol_utils::Market; - use common_contracts::subjects::NatsSubject; - - // --- Fakes --- - - struct InMemoryRepo { - data: Arc>> - } - - #[async_trait] - impl WorkflowRepository for InMemoryRepo { - async fn save(&self, workflow: &WorkflowStateMachine) -> Result<()> { - self.data.write().await.insert(workflow.request_id, workflow.clone()); - Ok(()) - } - async fn load(&self, id: Uuid) -> Result> { - Ok(self.data.read().await.get(&id).cloned()) - } - } - - struct FakeBroker { - pub events: Arc>>, - pub commands: Arc)>>>, - } - - #[async_trait] - impl MessageBroker for FakeBroker { - async fn publish_event(&self, request_id: Uuid, event: WorkflowEvent) -> Result<()> { - self.events.write().await.push((request_id, event)); - Ok(()) - } - async fn publish_command(&self, topic: &str, payload: Vec) -> Result<()> { - self.commands.write().await.push((topic.to_string(), payload)); - Ok(()) - } - } - - // --- Test Cases --- - - #[tokio::test] - async fn test_dag_execution_flow() { - // 1. Setup - let repo = Box::new(InMemoryRepo { data: Arc::new(RwLock::new(HashMap::new())) }); - - let events = Arc::new(RwLock::new(Vec::new())); - let commands = Arc::new(RwLock::new(Vec::new())); - let broker = Box::new(FakeBroker { events: events.clone(), commands: commands.clone() }); - - let engine = WorkflowEngine::new(repo, broker); - - let req_id = Uuid::new_v4(); - let symbol = CanonicalSymbol::new("AAPL", &Market::US); - - // 2. Start Workflow - let cmd = StartWorkflowCommand { - request_id: req_id, - symbol: symbol.clone(), - market: "US".to_string(), - template_id: "default".to_string(), - }; - engine.handle_start_workflow(cmd).await.unwrap(); - - // Assert: Started Event & Fetch Commands - { - let evs = events.read().await; - assert!(matches!(evs[0].1, WorkflowEvent::WorkflowStarted { .. })); - - let cmds = commands.read().await; - assert_eq!(cmds.len(), 4); // 4 fetchers - assert_eq!(cmds[0].0, NatsSubject::DataFetchCommands.to_string()); - } - - // 3. Simulate Completion (Alphavantage) - engine.handle_data_fetched(FinancialsPersistedEvent { - request_id: req_id, - symbol: symbol.clone(), - years_updated: vec![], - template_id: Some("default".to_string()), - provider_id: Some("alphavantage".to_string()), - data_summary: None, - }).await.unwrap(); - - // Assert: Task Completed - { - let evs = events.read().await; - // Started(1) + 4 Scheduled(4) + StreamUpdate(1) + Completed(1) = 7 - // Wait, order depends. - // Just check last event - if let WorkflowEvent::TaskStateChanged { task_id, status, .. } = &evs.last().unwrap().1 { - assert_eq!(task_id, "fetch:alphavantage"); - assert_eq!(*status, TaskStatus::Completed); - } else { - panic!("Last event should be state change"); - } - } - - // 4. Simulate Completion of ALL fetchers to trigger Analysis - let providers = vec!["tushare", "finnhub", "yfinance"]; - for p in providers { - engine.handle_data_fetched(FinancialsPersistedEvent { - request_id: req_id, - symbol: symbol.clone(), - years_updated: vec![], - template_id: Some("default".to_string()), - provider_id: Some(p.to_string()), - data_summary: None, - }).await.unwrap(); - } - - // Assert: Analysis Triggered - { - let cmds = commands.read().await; - // 4 fetch commands + 1 analysis command - assert_eq!(cmds.len(), 5); - assert_eq!(cmds.last().unwrap().0, NatsSubject::AnalysisCommandGenerateReport.to_string()); - } - } - - #[tokio::test] - async fn test_simple_analysis_flow() { - // 1. Setup - let repo = Box::new(InMemoryRepo { data: Arc::new(RwLock::new(HashMap::new())) }); - let events = Arc::new(RwLock::new(Vec::new())); - let commands = Arc::new(RwLock::new(Vec::new())); - let broker = Box::new(FakeBroker { events: events.clone(), commands: commands.clone() }); - let engine = WorkflowEngine::new(repo, broker); - - let req_id = Uuid::new_v4(); - let symbol = CanonicalSymbol::new("AAPL", &Market::US); - - // 2. Start Workflow with "simple_test_analysis" - let cmd = StartWorkflowCommand { - request_id: req_id, - symbol: symbol.clone(), - market: "US".to_string(), - template_id: "simple_test_analysis".to_string(), - }; - engine.handle_start_workflow(cmd).await.unwrap(); - - // Assert: DAG Structure (Fetch yfinance -> Analysis) - // Check events for WorkflowStarted - { - let evs = events.read().await; - if let WorkflowEvent::WorkflowStarted { task_graph, .. } = &evs[0].1 { - assert_eq!(task_graph.nodes.len(), 2); - assert!(task_graph.nodes.iter().any(|n| n.id == "fetch:yfinance")); - assert!(task_graph.nodes.iter().any(|n| n.id == "analysis:report")); - } else { - panic!("First event must be WorkflowStarted"); - } - } - - // 3. Simulate Fetch Completion - engine.handle_data_fetched(FinancialsPersistedEvent { - request_id: req_id, - symbol: symbol.clone(), - years_updated: vec![], - template_id: Some("simple_test_analysis".to_string()), - provider_id: Some("yfinance".to_string()), - data_summary: None, - }).await.unwrap(); - - // Assert: Analysis Scheduled - { - let evs = events.read().await; - let last_event = &evs.last().unwrap().1; - if let WorkflowEvent::TaskStateChanged { task_id, status, .. } = last_event { - assert_eq!(task_id, "analysis:report"); - assert_eq!(*status, TaskStatus::Scheduled); - } else { - panic!("Expected TaskStateChanged to Scheduled"); - } - } - - // 4. Simulate Report Generated (Analysis Completion) - // Using the mismatch ID "step2_analyze" to verify our mapping logic - engine.handle_report_generated(common_contracts::messages::ReportGeneratedEvent { - request_id: req_id, - symbol: symbol.clone(), - module_id: "step2_analyze".to_string(), - content_snapshot: None, - model_id: None, - }).await.unwrap(); - - // Assert: Workflow Completed - { - let evs = events.read().await; - let last_event = &evs.last().unwrap().1; - if let WorkflowEvent::WorkflowCompleted { .. } = last_event { - // Success! - } else { - // Check if we have task completion before workflow completion - // The very last event should be WorkflowCompleted - // Let's print all events - for (_, e) in evs.iter() { - println!("{:?}", e); - } - panic!("Workflow did not complete. Last event: {:?}", last_event); - } + // 3. Edges + for p in &providers { + dag.add_dependency(&format!("fetch:{}", p), report_task_id); } } }