- Implemented Unified Context Mechanism (Task 20251127): - Decoupled intent (Module) from resolution (Orchestrator). - Added ContextResolver for resolving input bindings (Manual Glob/Auto LLM). - Added IOBinder for managing physical paths. - Updated GenerateReportCommand to support explicit input bindings and output paths. - Refactored Report Worker to Generic Execution (Task 20251128): - Removed hardcoded financial DTOs and specific formatting logic. - Implemented Generic YAML-based context assembly for better LLM readability. - Added detailed execution tracing (Sidecar logs). - Fixed input data collision bug by using full paths as context keys. - Updated Tushare Provider to support dynamic output paths. - Updated Common Contracts with new configuration models.
529 lines
23 KiB
Rust
529 lines
23 KiB
Rust
use std::sync::Arc;
|
|
use common_contracts::workflow_types::{
|
|
WorkflowTaskCommand, WorkflowTaskEvent, TaskStatus, StorageConfig
|
|
};
|
|
use common_contracts::messages::{
|
|
StartWorkflowCommand, SyncStateCommand, TaskType, WorkflowEvent, TaskStatus as MsgTaskStatus
|
|
};
|
|
use common_contracts::subjects::SubjectMessage;
|
|
use common_contracts::symbol_utils::CanonicalSymbol;
|
|
use tracing::{info, warn, error};
|
|
use anyhow::Result;
|
|
use serde_json::json;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::dag_scheduler::DagScheduler;
|
|
use crate::state::AppState;
|
|
use crate::context_resolver::ContextResolver;
|
|
use crate::io_binder::IOBinder;
|
|
use common_contracts::configs::AnalysisModuleConfig;
|
|
use workflow_context::{Vgcs, ContextStore};
|
|
|
|
pub struct WorkflowEngine {
|
|
state: Arc<AppState>,
|
|
nats: async_nats::Client,
|
|
}
|
|
|
|
impl WorkflowEngine {
|
|
pub fn new(state: Arc<AppState>, nats: async_nats::Client) -> Self {
|
|
Self { state, nats }
|
|
}
|
|
|
|
pub async fn handle_start_workflow(&self, cmd: StartWorkflowCommand) -> Result<()> {
|
|
let req_id = cmd.request_id;
|
|
info!("Starting workflow {}", req_id);
|
|
|
|
self.publish_log(req_id, "workflow", "INFO", "Workflow started. Initializing...").await;
|
|
|
|
// 1. Init VGCS Repo
|
|
self.state.vgcs.init_repo(&req_id.to_string())?;
|
|
|
|
// 2. Create Initial Commit (Context Baseline)
|
|
let mut tx = self.state.vgcs.begin_transaction(&req_id.to_string(), "")?;
|
|
let request_info = json!({
|
|
"request_id": req_id,
|
|
"symbol": cmd.symbol.as_str(),
|
|
"market": cmd.market,
|
|
"template_id": cmd.template_id,
|
|
"timestamp": chrono::Utc::now().to_rfc3339()
|
|
});
|
|
tx.write("request.json", serde_json::to_vec_pretty(&request_info)?.as_slice())?;
|
|
let initial_commit = Box::new(tx).commit("Initial Workflow Setup", "System")?;
|
|
|
|
// 3. Create Scheduler with Initial Commit
|
|
let mut dag = DagScheduler::new(req_id, initial_commit);
|
|
|
|
// 4. Fetch Template Config
|
|
self.publish_log(req_id, "workflow", "INFO", "Fetching template configuration...").await;
|
|
let template_sets = self.state.persistence_client.get_analysis_template_sets().await?;
|
|
let template = template_sets.get(&cmd.template_id).ok_or_else(|| {
|
|
anyhow::anyhow!("Template {} not found", cmd.template_id)
|
|
})?;
|
|
|
|
// 3.1 Fetch Data Sources Config
|
|
let data_sources = self.state.persistence_client.get_data_sources_config().await?;
|
|
|
|
// 4. Build DAG
|
|
self.build_dag(&mut dag, template, &data_sources, &cmd.template_id, &cmd.market, &cmd.symbol);
|
|
|
|
// 5. Save State
|
|
self.state.workflows.insert(req_id, Arc::new(Mutex::new(dag.clone())));
|
|
|
|
// 6. Publish WorkflowStarted Event
|
|
let event = WorkflowEvent::WorkflowStarted {
|
|
timestamp: chrono::Utc::now().timestamp_millis(),
|
|
task_graph: dag.to_dto(),
|
|
};
|
|
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
|
if let Ok(payload) = serde_json::to_vec(&event) {
|
|
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
|
error!("Failed to publish WorkflowStarted event: {}", e);
|
|
}
|
|
}
|
|
|
|
self.publish_log(req_id, "workflow", "INFO", "DAG built and workflow initialized.").await;
|
|
|
|
// 6. Trigger Initial Tasks
|
|
let initial_tasks = dag.get_initial_tasks();
|
|
|
|
// Lock the DAG for updates
|
|
let dag_arc = self.state.workflows.get(&req_id).unwrap().clone();
|
|
let mut dag_guard = dag_arc.lock().await;
|
|
|
|
for task_id in initial_tasks {
|
|
self.dispatch_task(&mut dag_guard, &task_id, &self.state.vgcs).await?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn handle_sync_state(&self, cmd: SyncStateCommand) -> Result<()> {
|
|
let req_id = cmd.request_id;
|
|
info!("Handling SyncStateCommand for {}", req_id);
|
|
|
|
let dag_arc = match self.state.workflows.get(&req_id) {
|
|
Some(d) => d.clone(),
|
|
None => {
|
|
warn!("Received sync request for unknown workflow {}", req_id);
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
let dag = dag_arc.lock().await;
|
|
|
|
// Map internal status to DTO status
|
|
let mut tasks_status = std::collections::HashMap::new();
|
|
for task_id in dag.nodes.keys() {
|
|
let status = match dag.get_status(task_id) {
|
|
TaskStatus::Pending => MsgTaskStatus::Pending,
|
|
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
|
|
TaskStatus::Running => MsgTaskStatus::Running,
|
|
TaskStatus::Completed => MsgTaskStatus::Completed,
|
|
TaskStatus::Failed => MsgTaskStatus::Failed,
|
|
TaskStatus::Skipped => MsgTaskStatus::Skipped,
|
|
TaskStatus::Cancelled => MsgTaskStatus::Skipped,
|
|
};
|
|
tasks_status.insert(task_id.clone(), status);
|
|
}
|
|
|
|
// Create Snapshot Event
|
|
let event = WorkflowEvent::WorkflowStateSnapshot {
|
|
timestamp: chrono::Utc::now().timestamp_millis(),
|
|
task_graph: dag.to_dto(),
|
|
tasks_status,
|
|
tasks_output: dag.commit_tracker.task_commits.clone().into_iter().map(|(k, v)| (k, Some(v))).collect(),
|
|
};
|
|
|
|
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
|
if let Ok(payload) = serde_json::to_vec(&event) {
|
|
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
|
error!("Failed to publish WorkflowStateSnapshot: {}", e);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn handle_task_completed(&self, evt: WorkflowTaskEvent) -> Result<()> {
|
|
let req_id = evt.request_id;
|
|
|
|
let dag_arc = match self.state.workflows.get(&req_id) {
|
|
Some(d) => d.clone(),
|
|
None => {
|
|
warn!("Received event for unknown workflow {}", req_id);
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
let mut dag = dag_arc.lock().await;
|
|
|
|
// 1. Update Status & Record Commit
|
|
dag.update_status(&evt.task_id, evt.status);
|
|
|
|
self.publish_log(req_id, &evt.task_id, "INFO", &format!("Task status changed to {:?}", evt.status)).await;
|
|
|
|
// Lookup task_type
|
|
let task_type = dag.nodes.get(&evt.task_id).map(|n| n.task_type).unwrap_or(TaskType::DataFetch);
|
|
|
|
// Convert status
|
|
let msg_status = match evt.status {
|
|
TaskStatus::Pending => MsgTaskStatus::Pending,
|
|
TaskStatus::Scheduled => MsgTaskStatus::Scheduled,
|
|
TaskStatus::Running => MsgTaskStatus::Running,
|
|
TaskStatus::Completed => MsgTaskStatus::Completed,
|
|
TaskStatus::Failed => MsgTaskStatus::Failed,
|
|
TaskStatus::Skipped => MsgTaskStatus::Skipped,
|
|
TaskStatus::Cancelled => MsgTaskStatus::Skipped, // Map Cancelled to Skipped
|
|
};
|
|
|
|
// Extract error message if any
|
|
let error_message = if let Some(ref result) = evt.result {
|
|
result.error.clone()
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Resolve commits
|
|
let input_commit = dag.nodes.get(&evt.task_id).and_then(|n| n.input_commit.clone());
|
|
let output_commit = evt.result.as_ref().and_then(|r| r.new_commit.clone());
|
|
|
|
// Publish TaskStateChanged event
|
|
let progress_event = WorkflowEvent::TaskStateChanged {
|
|
task_id: evt.task_id.clone(),
|
|
task_type,
|
|
status: msg_status,
|
|
message: error_message,
|
|
timestamp: chrono::Utc::now().timestamp_millis(),
|
|
progress: None,
|
|
input_commit: input_commit,
|
|
output_commit: output_commit,
|
|
};
|
|
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
|
if let Ok(payload) = serde_json::to_vec(&progress_event) {
|
|
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
|
error!("Failed to publish progress event: {}", e);
|
|
}
|
|
}
|
|
|
|
if let Some(result) = evt.result {
|
|
if let Some(commit) = result.new_commit {
|
|
info!("Task {} produced commit {}", evt.task_id, commit);
|
|
dag.record_result(&evt.task_id, Some(commit));
|
|
}
|
|
if let Some(err) = result.error {
|
|
warn!("Task {} failed with error: {}", evt.task_id, err);
|
|
}
|
|
}
|
|
|
|
// 2. Check for downstream tasks
|
|
if evt.status == TaskStatus::Completed || evt.status == TaskStatus::Failed || evt.status == TaskStatus::Skipped {
|
|
let ready_tasks = dag.get_ready_downstream_tasks(&evt.task_id);
|
|
for task_id in ready_tasks {
|
|
if let Err(e) = self.dispatch_task(&mut dag, &task_id, &self.state.vgcs).await {
|
|
error!("Failed to dispatch task {}: {}", task_id, e);
|
|
self.publish_log(req_id, &task_id, "ERROR", &format!("Failed to dispatch task: {}", e)).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Check Workflow Completion
|
|
if dag.is_workflow_finished() {
|
|
let timestamp = chrono::Utc::now().timestamp_millis();
|
|
let event = if dag.has_failures() {
|
|
info!("Workflow {} failed (some tasks failed)", req_id);
|
|
self.publish_log(req_id, "workflow", "ERROR", "Workflow finished with failures.").await;
|
|
WorkflowEvent::WorkflowFailed {
|
|
end_timestamp: timestamp,
|
|
reason: "Some tasks failed".to_string(),
|
|
is_fatal: false,
|
|
}
|
|
} else {
|
|
info!("Workflow {} completed successfully", req_id);
|
|
self.publish_log(req_id, "workflow", "INFO", "Workflow completed successfully.").await;
|
|
WorkflowEvent::WorkflowCompleted {
|
|
end_timestamp: timestamp,
|
|
result_summary: Some(json!({})),
|
|
}
|
|
};
|
|
|
|
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
|
if let Ok(payload) = serde_json::to_vec(&event) {
|
|
if let Err(e) = self.nats.publish(subject, payload.into()).await {
|
|
error!("Failed to publish completion event: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> {
|
|
// 1. Resolve Context (Merge if needed)
|
|
let mut context = dag.resolve_context(task_id, vgcs)?;
|
|
|
|
// Store the input commit in the node for observability
|
|
if let Some(base_commit) = &context.base_commit {
|
|
dag.set_input_commit(task_id, base_commit.clone());
|
|
}
|
|
|
|
// 2. Update Status
|
|
dag.update_status(task_id, TaskStatus::Scheduled);
|
|
self.publish_log(dag.request_id, task_id, "INFO", "Task scheduled and dispatched.").await;
|
|
|
|
// 3. Construct Command
|
|
let (routing_key, task_type, mut config) = {
|
|
let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?;
|
|
(node.routing_key.clone(), node.task_type, node.config.clone())
|
|
};
|
|
|
|
// --- Resolution Phase ---
|
|
let symbol = self.get_symbol_from_config(&config);
|
|
|
|
// 3.1 IO Binding
|
|
let io_binder = IOBinder::new();
|
|
let output_path = io_binder.allocate_output_path(task_type, &symbol, task_id);
|
|
|
|
if let Some(obj) = config.as_object_mut() {
|
|
obj.insert("output_path".to_string(), serde_json::Value::String(output_path.clone()));
|
|
}
|
|
|
|
// 3.2 Context Resolution (for Analysis)
|
|
if task_type == TaskType::Analysis {
|
|
// We might update the base_commit if we write a trace file
|
|
let mut current_base_commit = context.base_commit.clone().unwrap_or_default();
|
|
|
|
if let Some(module_config_val) = config.get("_module_config") {
|
|
if let Ok(module_config) = serde_json::from_value::<AnalysisModuleConfig>(module_config_val.clone()) {
|
|
let mut variables = std::collections::HashMap::new();
|
|
variables.insert("symbol".to_string(), symbol.clone());
|
|
if let Some(market) = config.get("market").and_then(|v| v.as_str()) {
|
|
variables.insert("market".to_string(), market.to_string());
|
|
}
|
|
|
|
let resolver = ContextResolver::new(self.state.vgcs.clone());
|
|
|
|
// Fetch LLM providers for context resolution (Hybrid/Auto modes)
|
|
let llm_providers = match self.state.persistence_client.get_llm_providers_config().await {
|
|
Ok(p) => p,
|
|
Err(e) => {
|
|
warn!("Failed to fetch LLM providers config (non-fatal for Manual mode): {}", e);
|
|
common_contracts::config_models::LlmProvidersConfig::default()
|
|
}
|
|
};
|
|
|
|
match resolver.resolve_input(&module_config.context_selector, &dag.request_id.to_string(), ¤t_base_commit, &variables, &llm_providers, &module_config.analysis_prompt).await {
|
|
Ok(resolution) => {
|
|
// 1. Inject Input Bindings
|
|
if let Some(obj) = config.as_object_mut() {
|
|
obj.insert("input_bindings".to_string(), serde_json::to_value(&resolution.paths)?);
|
|
}
|
|
|
|
// 2. Write Trace Sidecar to VGCS
|
|
let trace_path = io_binder.allocate_trace_path(task_type, &symbol, task_id);
|
|
|
|
// Use a blocking task for VGCS write/commit to avoid async issues with standard IO
|
|
let vgcs = self.state.vgcs.clone();
|
|
let req_id_str = dag.request_id.to_string();
|
|
let base_commit_for_write = current_base_commit.clone();
|
|
let trace_content = resolution.trace.clone();
|
|
let task_id_str = task_id.to_string();
|
|
let trace_path_clone = trace_path.clone();
|
|
|
|
let trace_commit_res = tokio::task::spawn_blocking(move || -> Result<String> {
|
|
let mut tx = vgcs.begin_transaction(&req_id_str, &base_commit_for_write)?;
|
|
tx.write(&trace_path_clone, trace_content.as_bytes())?;
|
|
let new_commit = Box::new(tx).commit(&format!("Context Resolution Trace for {}", task_id_str), "Orchestrator")?;
|
|
Ok(new_commit)
|
|
}).await;
|
|
|
|
match trace_commit_res {
|
|
Ok(Ok(new_commit)) => {
|
|
info!("Written context resolution trace to {} (Commit: {})", trace_path, new_commit);
|
|
// Update the base commit for the worker, so it sees the trace file (linear history)
|
|
current_base_commit = new_commit;
|
|
// Update the context passed to the worker
|
|
context.base_commit = Some(current_base_commit.clone());
|
|
|
|
// Also update the DAG node's input commit for observability
|
|
// Note: dag is locked in this scope, we can modify it but we need to handle scope issues if we were using dag inside closure.
|
|
// We are outside closure here.
|
|
dag.set_input_commit(task_id, current_base_commit);
|
|
},
|
|
Ok(Err(e)) => error!("Failed to write trace file: {}", e),
|
|
Err(e) => error!("Failed to join trace write task: {}", e),
|
|
}
|
|
},
|
|
Err(e) => {
|
|
error!("Context resolution failed for task {}: {}", task_id, e);
|
|
// We proceed, but the worker might fail if it relies on inputs
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let cmd = WorkflowTaskCommand {
|
|
request_id: dag.request_id,
|
|
task_id: task_id.to_string(),
|
|
routing_key: routing_key.clone(),
|
|
config, // Use modified config
|
|
context,
|
|
storage: StorageConfig {
|
|
root_path: self.state.config.workflow_data_path.clone(),
|
|
},
|
|
};
|
|
|
|
// Special handling for Analysis Report Task to inject task_id into the specific command payload
|
|
// (If the node config is used to build GenerateReportCommand downstream)
|
|
// Actually, WorkflowTaskCommand is generic. The specific worker (e.g. report-generator)
|
|
// usually consumes a specific command.
|
|
// BUT, the current architecture seems to have Orchestrator send `WorkflowTaskCommand`
|
|
// and the worker receives THAT?
|
|
|
|
// Let's check `report-generator-service` consumer.
|
|
|
|
// 4. Publish
|
|
let subject = cmd.subject().to_string(); // This uses the routing_key
|
|
let payload = serde_json::to_vec(&cmd)?;
|
|
|
|
info!("Dispatching task {} to subject {}", task_id, subject);
|
|
self.nats.publish(subject, payload.into()).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn get_symbol_from_config(&self, config: &serde_json::Value) -> String {
|
|
config.get("symbol").and_then(|v| v.as_str()).unwrap_or("unknown").to_string()
|
|
}
|
|
|
|
// Helper to build DAG
|
|
fn build_dag(
|
|
&self,
|
|
dag: &mut DagScheduler,
|
|
template: &common_contracts::config_models::AnalysisTemplateSet,
|
|
data_sources: &common_contracts::config_models::DataSourcesConfig,
|
|
template_id: &str,
|
|
market: &str,
|
|
symbol: &CanonicalSymbol
|
|
) {
|
|
// 1. Data Fetch Nodes
|
|
let mut fetch_tasks = Vec::new();
|
|
|
|
// Use all enabled data sources regardless of market
|
|
// The provider itself will decide whether to skip or process based on market support.
|
|
// We sort keys to ensure deterministic DAG generation.
|
|
let mut source_keys: Vec<&String> = data_sources.keys().collect();
|
|
source_keys.sort();
|
|
|
|
for key in source_keys {
|
|
let config = &data_sources[key];
|
|
if config.enabled {
|
|
// Special handling for MOCK market: skip real providers
|
|
if market == "MOCK" && key.to_lowercase() != "mock" {
|
|
continue;
|
|
}
|
|
|
|
let provider_key = key.to_lowercase();
|
|
let task_id = format!("fetch:{}", provider_key);
|
|
fetch_tasks.push(task_id.clone());
|
|
|
|
let display_name = format!("Data Fetch ({:?})", config.provider);
|
|
let routing_key = format!("provider.{}", provider_key);
|
|
|
|
dag.add_node(
|
|
task_id.clone(),
|
|
Some(display_name),
|
|
TaskType::DataFetch,
|
|
routing_key,
|
|
json!({
|
|
"symbol": symbol.as_str(),
|
|
"market": market
|
|
})
|
|
);
|
|
}
|
|
}
|
|
|
|
// Fallback for MOCK if not in config (usually mock is not in data_sources.json but hardcoded for tests)
|
|
if market == "MOCK" && fetch_tasks.is_empty() {
|
|
let task_id = "fetch:mock".to_string();
|
|
fetch_tasks.push(task_id.clone());
|
|
dag.add_node(
|
|
task_id,
|
|
Some("Data Fetch (Mock)".to_string()),
|
|
TaskType::DataFetch,
|
|
"provider.mock".to_string(),
|
|
json!({
|
|
"symbol": symbol.as_str(),
|
|
"market": market
|
|
})
|
|
);
|
|
}
|
|
|
|
if fetch_tasks.is_empty() {
|
|
warn!("No enabled data providers found in configuration.");
|
|
}
|
|
|
|
// 2. Analysis Nodes (Dynamic from Template)
|
|
for (module_id, module_config) in &template.modules {
|
|
let task_id = format!("analysis:{}", module_id);
|
|
|
|
// Pass module_id and template_id so the worker knows what to do
|
|
// We pass the FULL module config here if we want the worker to be stateless,
|
|
// BUT existing worker logic fetches template again.
|
|
// To support "Single Module Execution", we should probably pass the module_id.
|
|
|
|
let mut node_config = json!({
|
|
"template_id": template_id,
|
|
"module_id": module_id,
|
|
"symbol": symbol.as_str(),
|
|
"market": market
|
|
});
|
|
|
|
// Embed internal module config for Orchestrator use (Context Resolution)
|
|
if let Some(obj) = node_config.as_object_mut() {
|
|
obj.insert("_module_config".to_string(), serde_json::to_value(module_config).unwrap_or(serde_json::Value::Null));
|
|
}
|
|
|
|
dag.add_node(
|
|
task_id.clone(),
|
|
Some(module_config.name.clone()),
|
|
TaskType::Analysis,
|
|
"analysis.report".to_string(), // routing_key matches what report-generator consumes
|
|
node_config
|
|
);
|
|
|
|
// Dependencies
|
|
if module_config.dependencies.is_empty() {
|
|
// If no analysis dependencies, depend on Data Fetch
|
|
for fetch_task in &fetch_tasks {
|
|
dag.add_dependency(fetch_task, &task_id);
|
|
}
|
|
} else {
|
|
// Depend on other analysis modules
|
|
for dep_module_id in &module_config.dependencies {
|
|
let dep_task_id = format!("analysis:{}", dep_module_id);
|
|
dag.add_dependency(&dep_task_id, &task_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn publish_log(&self, req_id: uuid::Uuid, task_id: &str, level: &str, message: &str) {
|
|
let event = WorkflowEvent::TaskLog {
|
|
task_id: task_id.to_string(),
|
|
level: level.to_string(),
|
|
message: message.to_string(),
|
|
timestamp: chrono::Utc::now().timestamp_millis(),
|
|
};
|
|
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(req_id).to_string();
|
|
if let Ok(payload) = serde_json::to_vec(&event) {
|
|
// Fire and forget
|
|
let nats = self.nats.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = nats.publish(subject, payload.into()).await {
|
|
error!("Failed to publish TaskLog: {}", e);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|