feat: Refactor Analysis Context Mechanism and Generic Worker

- Implemented Unified Context Mechanism (Task 20251127):
  - Decoupled intent (Module) from resolution (Orchestrator).
  - Added ContextResolver for resolving input bindings (Manual Glob/Auto LLM).
  - Added IOBinder for managing physical paths.
  - Updated GenerateReportCommand to support explicit input bindings and output paths.

- Refactored Report Worker to Generic Execution (Task 20251128):
  - Removed hardcoded financial DTOs and specific formatting logic.
  - Implemented Generic YAML-based context assembly for better LLM readability.
  - Added detailed execution tracing (Sidecar logs).
  - Fixed input data collision bug by using full paths as context keys.

- Updated Tushare Provider to support dynamic output paths.
- Updated Common Contracts with new configuration models.
This commit is contained in:
Lv, Qi 2025-11-28 20:11:17 +08:00
parent 91a6dfc4c1
commit 03b53aed71
32 changed files with 1886 additions and 880 deletions

File diff suppressed because one or more lines are too long

View File

@ -11,6 +11,10 @@ pub struct DirEntry {
pub name: String, pub name: String,
pub kind: EntryKind, pub kind: EntryKind,
pub object_id: String, pub object_id: String,
// New metadata fields
pub size: Option<u64>,
pub line_count: Option<usize>,
pub word_count: Option<usize>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -1,7 +1,6 @@
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::fs::{self, File}; use std::fs::{self, File};
use std::io::{Cursor, Read, Write}; use std::io::{Cursor, Read, Write, BufRead, BufReader};
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use git2::{Repository, Oid, ObjectType, Signature, Index, IndexEntry, IndexTime}; use git2::{Repository, Oid, ObjectType, Signature, Index, IndexEntry, IndexTime};
use sha2::{Sha256, Digest}; use sha2::{Sha256, Digest};
@ -97,7 +96,29 @@ impl ContextStore for Vgcs {
_ => EntryKind::File, _ => EntryKind::File,
}; };
let object_id = entry.id().to_string(); let object_id = entry.id().to_string();
entries.push(DirEntry { name, kind, object_id });
// Metadata extraction (Expensive but necessary for the prompt)
let mut size = None;
let mut line_count = None;
let mut word_count = None;
if kind == EntryKind::File {
if let Ok(object) = entry.to_object(&repo) {
if let Some(blob) = object.as_blob() {
let content = blob.content();
size = Some(content.len() as u64);
// Check for binary content or just use heuristic
if !content.contains(&0) {
let s = String::from_utf8_lossy(content);
line_count = Some(s.lines().count());
word_count = Some(s.split_whitespace().count());
}
}
}
}
entries.push(DirEntry { name, kind, object_id, size, line_count, word_count });
} }
Ok(entries) Ok(entries)
@ -338,4 +359,3 @@ fn create_index_entry(path: &str, mode: u32) -> IndexEntry {
path: path.as_bytes().to_vec(), path: path.as_bytes().to_vec(),
} }
} }

View File

@ -0,0 +1,160 @@
# 任务:重构分析模块上下文机制 (两阶段选择与统一 I/O 绑定的融合)
**状态**: 设计中 (Finalizing)
**日期**: 2025-11-27
**优先级**: 高
**负责人**: @User / @Assistant
## 1. 核心理念:意图与实现的解耦
我们经历了三个思维阶段,现在需要将其融合成一个完整的体系:
1. **Context Projection**: 模块需要从全局上下文中“投影”出自己需要的数据。
2. **Two-Stage Selection**: 这种投影过程分为“选择(我需要什么?)”和“分析(怎么处理它?)”两个阶段,且都需要 Prompt/Model 驱动。
3. **Unified I/O Binding**: 模块本身不应处理物理路径,应由 Orchestrator 负责 I/O 绑定。
**融合方案**:
* **Module 定义意图 (Intent)**: 模块通过 Configuration (Prompt/Rules) 描述“我需要什么样的输入”(例如:“我需要去年的财务数据” 或 “按此 Glob 规则匹配”)。
* **Orchestrator 负责解析 (Resolution)**: Orchestrator借助 IO Binder根据模块的意图和当前的全局上下文状态计算出具体的**物理路径**绑定。
* **Module 执行实现 (Execution)**: 模块接收 Orchestrator 传来的物理路径,执行读取、分析和写入。
## 2. 架构设计
### 2.1. 模块配置:描述“我需要什么”
`AnalysisModuleConfig` 依然保持两阶段结构但这里的“Input/Context Selector”描述的是**逻辑需求**。
```rust
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalysisModuleConfig {
pub id: String,
// Phase 1: Input Intent (我需要什么数据?)
pub context_selector: ContextSelectorConfig,
// Manual: 明确的规则 (e.g., "financials/*.json")
// Auto: 模糊的需求,交给 Orchestrator/Agent 自动推断
// Hybrid: 具体的 Prompt (e.g., "Find all news about 'Environment' from last year")
// Phase 2: Analysis Intent (怎么处理这些数据?)
pub analysis_prompt: String,
pub llm_config: Option<LlmConfig>,
// Output Intent (结果是什么?)
// 模块只定义它产生什么类型的结果,物理路径由 Orchestrator 分配
pub output_type: String, // e.g., "markdown_report", "json_summary"
}
```
### 2.2. Orchestrator 运行时:解析“在哪里”
Orchestrator 在调度任务前,会执行一个 **Resolution Step**
* **对于 Manual Selector**:
* Orchestrator 根据规则Glob在当前 VGCS Head Commit 中查找匹配的文件。
* 生成具体的 `InputBindings` (Map<FileName, PhysicalPath>)。
* **对于 Auto/Hybrid Selector**:
* **这里是关键融合点**Orchestrator (或专门的 Resolution Agent) 会运行一个轻量级的 LLM 任务。
* Input: 当前 VGCS 目录树 + 模块定义的 Selection Prompt (或 Auto 策略)。
* Output: 具体的 VGCS 文件路径列表。
* Orchestrator 将这些路径打包成 `InputBindings`
### 2.3. 模块执行:执行“转换”
当模块真正启动时Worker 接收到 Command它看到的是**已经被解析过**的确定的世界。
```rust
// 最终发给 Worker 的指令
pub struct GenerateReportCommand {
pub request_id: Uuid,
pub commit_hash: String, // 锁定的世界状态
// 具体的 I/O 绑定 (由 Orchestrator 解析完毕)
pub input_bindings: Vec<String>, // e.g., ["raw/tushare/AAPL/financials.json", ...]
pub output_path: String, // e.g., "analysis/financial_v1/report.md"
// 分析逻辑 (透传给 Worker)
pub analysis_prompt: String,
pub llm_config: Option<LlmConfig>,
}
```
**变化点**:
* **复杂的 Selection 逻辑上移**:原本打算放在 Worker 里的 `Select_Smart` 逻辑,现在看来更适合作为 Orchestrator 的预处理步骤(或者一个独立的微任务)。
* **Worker 变轻**Worker 变得非常“傻”,只负责 `Read(paths) -> Expand -> Prompt -> Write(output_path)`。这就实现了真正的“模块只关注核心任务”。
* **灵活性保留**:如果是 Auto/Hybrid 模式Orchestrator 会动态决定 Input Bindings如果是 Manual 模式,则是静态规则解析。对 Worker 来说,它收到的永远是确定的文件列表。
## 3. 实施路线图 (Revised)
### Phase 1: 协议与配置 (Contracts)
1. 定义 `AnalysisModuleConfig` (包含 Selector, Prompt, LlmConfig)。
2. 定义 `GenerateReportCommand` (包含 `input_bindings` 物理路径列表, `output_path`, `commit_hash`)。
### Phase 2: Orchestrator Resolution Logic
1. 实现 `ContextResolver` 组件:
* 支持 Glob 解析 (Manual)。
* (后续) 支持 LLM 目录树推理 (Auto/Hybrid)。
2. 在调度循环中,在生成 Command 之前调用 `ContextResolver`
### Phase 3: 模块改造 (Module Refactor)
1. **Provider**: 接收 `output_path` (由 Orchestrator 按约定生成,如 `raw/{provider}/{symbol}`) 并写入。
2. **Generator**:
* 移除所有选择逻辑。
* 直接读取 `cmd.input_bindings` 中的文件。
* 执行 Expander (JSON->Table 等)。
* 执行 Prompt。
* 写入 `cmd.output_path`
## 4. 总结
这个方案完美融合了我们的讨论:
* **Input/Output Symmetry**: 都在 Command 中明确绑定。
* **Two-Stage**:
* Stage 1 (Selection) 发生在 **Orchestration Time** (解析 Binding)。
* Stage 2 (Analysis) 发生在 **Execution Time** (Worker 运行)。
* **Module Focus**: 模块不需要知道“去哪找”,只知道“给我这些文件,我给你那个结果”。
## 5. 实施步骤清单 (Checklist)
### Phase 1: 协议与配置定义 (Contracts & Configs)
- [x] **Common Contracts**: 在 `services/common-contracts/src` 创建或更新 `configs.rs`
- [x] 定义 `SelectionMode` (Manual, Auto, Hybrid)。
- [x] 定义 `LlmConfig` (model_id, parameters)。
- [x] 定义 `ContextSelectorConfig` (mode, rules, prompt, llm_config)。
- [x] 定义 `AnalysisModuleConfig` (id, selector, analysis_prompt, llm_config, output_type)。
- [x] **Messages**: 更新 `services/common-contracts/src/messages.rs`
- [x] `GenerateReportCommand`: 添加 `commit_hash`, `input_bindings: Vec<String>`, `output_path: String`, `llm_config`.
- [x] `FetchCompanyDataCommand`: 添加 `output_path: Option<String>`.
- [x] **VGCS Types**: 确保 `workflow-context` crate 中的类型足以支持路径操作。(Confirmed: Vgcs struct has methods)
### Phase 2: Orchestrator 改造 (Resolution Logic)
- [x] **Context Resolver**: 在 `workflow-orchestrator-service` 中创建 `context_resolver.rs`
- [x] 实现 `resolve_input(selector, vgcs_client, commit_hash) -> Result<Vec<String>>`
- [x] 针对 `Manual` 模式:实现 Glob 匹配逻辑 (调用 VGCS `list_dir` 递归查找)。
- [x] 针对 `Auto/Hybrid` 模式:(暂留接口) 返回 Empty 或 NotImplemented后续接入 LLM。
- [x] **IO Binder**: 实现 `io_binder.rs`
- [x] 实现 `allocate_output_path(task_type, task_id) -> String` 约定生成逻辑。
- [x] **Scheduler**: 更新 `dag_scheduler.rs`
- [x] 在 dispatch 任务前,调用 `ContextResolver``IOBinder`
- [x] 将解析结果填入 Command。
### Phase 3: 写入端改造 (Provider Adaptation)
- [x] **Tushare Provider**: 更新 `services/tushare-provider-service/src/generic_worker.rs`
- [x] 读取 Command 中的 `output_path` (如果存在)。
- [x] 使用 `WorkerContext` 写入数据到指定路径 (不再硬编码 `raw/tushare/...`,而是信任 Command)。
- [x] 提交并返回 New Commit Hash。
### Phase 4: 读取端改造 (Report Generator Adaptation)
- [x] **Worker Refactor**: 重写 `services/report-generator-service/src/worker.rs`
- [x] **Remove**: 删除 `fetch_data_and_configs` (旧的 DB 读取逻辑)。
- [x] **Checkout**: 使用 `vgcs.checkout(cmd.commit_hash)`
- [x] **Read Input**: 遍历 `cmd.input_bindings`,使用 `vgcs.read_file` 读取内容。
- [x] **Expand**: 实现简单 `Expander` (JSON -> Markdown Table)。
- [x] **Prompt**: 渲染 `cmd.analysis_prompt`
- [x] **LLM Call**: 使用 `cmd.llm_config` 初始化 Client 并调用。
- [x] **Write Output**: 将结果写入 `cmd.output_path`
- [x] **Commit**: 提交更改并广播 Event。
### Phase 5: 集成与验证 (Integration)
- [x] **Config Migration**: 更新 `config/analysis-config.json` (或 DB 中的配置),适配新的 `AnalysisModuleConfig` 结构。
- [ ] **End-to-End Test**: 运行完整流程,验证:
1. Provider 写文件到 Git。
2. Orchestrator 解析路径。
3. Generator 读文件并生成报告。

View File

@ -0,0 +1,71 @@
# 任务:重构 Report Worker 为通用执行器 (Generic Execution)
**状态**: 规划中 -> 实施准备中
**优先级**: 高
**相关组件**: `report-generator-service`, `common-contracts`
## 1. 问题背景
当前的 `report-generator-service/src/worker.rs` 存在严重的设计缺陷:**业务逻辑泄露**。
Worker 代码中硬编码了对 `financials.json` 的特殊处理逻辑(反序列化 `TimeSeriesFinancialDto` 并转换为 Markdown Table。这导致 Worker 不再是一个通用的分析执行器,而是与特定的财务分析业务强耦合。这违背了系统设计的初衷,即 Worker 应该只负责通用的 `IO -> Context Assembly -> LLM` 流程。
## 2. 目标
将 Worker 彻底重构为 **Generic Analysis Worker**。它不应该知道什么是 "Financials",什么是 "Profile"。它只知道:
1. 我有输入文件JSON, Text, etc.)。
2. 我需要把它们转换成 Prompt Context优先对人类可读如 YAML
3. 我调用 LLM。
4. 我写入结果。
## 3. 核心变更点
### 3.1 移除硬编码的 DTO 解析
* **彻底删除** `worker.rs` 中所有关于 `TimeSeriesFinancialDto` 的引用。Worker 不应该知道任何业务特定的数据结构。
* 删除 `formatter.rs` 中专门针对 Financials 的表格生成逻辑。
### 3.2 通用格式化策略 (Generic Formatter) -> YAML First
我们需要一种通用的方式来将结构化数据展示给 LLM同时兼顾人类调试时的可读性。
**方案: YAML Pretty Print (首选)**
* **理由**YAML 相比 JSON 更干净没有大量的括号和引号人类阅读体验更好。LLM 对 YAML 的理解能力也很好。既然我们目前处于开发调试阶段,**人类可读性 (Human Readability)** 优于极致的 Token 效率。
* **策略**
* 尝试将输入文件内容解析为 JSON Value。
* 如果成功,将其转换为 **YAML** 字符串。
* 如果解析失败(非结构化文本),则保持原样 (Raw Text)。
* **Context 结构**:避免使用 XML Tags采用更直观的分隔符。
```yaml
---
# Data Source: financials.json (Original Size: 1.2MB)
data:
- date: 2023-12-31
revenue: 10000
...
```
### 3.3 增强的 Execution Trace 与截断策略
* **Sidecar Log**: 必须记录详细的执行过程。
* **截断策略 (Truncation)**:
* 保留字符级截断作为最后的安全防线 (Safety Net)。
* **Critical Logging**: 一旦发生截断,必须在 Log 中留下醒目的警告。
* **详细信息**: 必须记录“截断前大小” vs “截断后大小”(例如:`Original: 1MB, Truncated to: 64KB`),让开发者清楚意识到数据丢失的程度。
## 4. 实施步骤
1. **Cleanup**: 移除 `worker.rs``formatter.rs` 中所有特定业务 DTO 的代码。
2. **Generic Implementation**:
* 引入 `serde_yaml` 依赖。
* 实现通用的 `context_builder`
* Input -> `serde_json::Value` -> `serde_yaml::to_string`.
* Fallback: Raw Text.
* 组装 Context String。
3. **Safety & Logging**:
* 实现截断逻辑,计算 `original_size``truncated_size`
* 在 `execution_trace.md` 中记录详细的文件处理情况。
4. **Verify**: 运行测试,查看生成的 Context 是否清晰易读。
## 5. 预期效果
* **解耦**: 彻底切断 Worker 与 Financial Domain 的耦合。
* **直观**: Context 变得像配置文件一样易读,方便人工 Review LLM 的输入。
* **透明**: 明确知道哪些数据喂给了 LLM哪些被截断了。

View File

@ -1,18 +1,43 @@
import { makeApi, Zodios, type ZodiosOptions } from "@zodios/core"; import { makeApi, Zodios, type ZodiosOptions } from "@zodios/core";
import { z } from "zod"; import { z } from "zod";
export type AnalysisTemplateSet = {
modules: Record<string, AnalysisModuleConfig>;
name: string;
};
export type AnalysisModuleConfig = { export type AnalysisModuleConfig = {
analysis_prompt: string;
context_selector: ContextSelectorConfig;
dependencies: Array<string>; dependencies: Array<string>;
model_id: string; id?: (string | null) | undefined;
llm_config?: (null | LlmConfig) | undefined;
name: string; name: string;
prompt_template: string; output_type: string;
provider_id: string;
}; };
export type AnalysisTemplateSets = Record<string, AnalysisTemplateSet>; export type ContextSelectorConfig = SelectionMode;
export type SelectionMode =
| {
Manual: {
rules: Array<string>;
};
}
| {
Auto: Partial<{
llm_config: null | LlmConfig;
}>;
}
| {
Hybrid: {
llm_config?: (null | LlmConfig) | undefined;
selection_prompt: string;
};
};
export type LlmConfig = Partial<{
max_tokens: number | null;
model_id: string | null;
temperature: number | null;
}>;
export type AnalysisTemplateSet = {
modules: {};
name: string;
};
export type AnalysisTemplateSets = {};
export type ConfigFieldSchema = { export type ConfigFieldSchema = {
default_value?: (string | null) | undefined; default_value?: (string | null) | undefined;
description?: (string | null) | undefined; description?: (string | null) | undefined;
@ -41,9 +66,9 @@ export type DataSourceConfig = {
provider: DataSourceProvider; provider: DataSourceProvider;
}; };
export type DataSourceProvider = "Tushare" | "Finnhub" | "Alphavantage" | "Yfinance"; export type DataSourceProvider = "Tushare" | "Finnhub" | "Alphavantage" | "Yfinance";
export type DataSourcesConfig = Record<string, DataSourceConfig>; export type DataSourcesConfig = {};
export type HealthStatus = { export type HealthStatus = {
details: Record<string, string>; details: {};
module_id: string; module_id: string;
status: ServiceStatus; status: ServiceStatus;
version: string; version: string;
@ -60,7 +85,7 @@ export type LlmModel = {
model_id: string; model_id: string;
name?: (string | null) | undefined; name?: (string | null) | undefined;
}; };
export type LlmProvidersConfig = Record<string, LlmProvider>; export type LlmProvidersConfig = {};
export type ProviderMetadata = { export type ProviderMetadata = {
config_schema: Array<ConfigFieldSchema>; config_schema: Array<ConfigFieldSchema>;
description: string; description: string;
@ -119,7 +144,9 @@ export type WorkflowEvent =
} }
| { | {
payload: { payload: {
input_commit?: (string | null) | undefined;
message?: (string | null) | undefined; message?: (string | null) | undefined;
output_commit?: (string | null) | undefined;
progress?: (number | null) | undefined; progress?: (number | null) | undefined;
status: TaskStatus; status: TaskStatus;
task_id: string; task_id: string;
@ -148,7 +175,7 @@ export type WorkflowEvent =
| { | {
payload: { payload: {
end_timestamp: number; end_timestamp: number;
result_summary?: unknown; result_summary?: unknown | undefined;
}; };
type: "WorkflowCompleted"; type: "WorkflowCompleted";
} }
@ -163,25 +190,58 @@ export type WorkflowEvent =
| { | {
payload: { payload: {
task_graph: WorkflowDag; task_graph: WorkflowDag;
tasks_output: Record<string, string | null>; tasks_output: {};
tasks_status: Record<string, TaskStatus>; tasks_status: {};
timestamp: number; timestamp: number;
}; };
type: "WorkflowStateSnapshot"; type: "WorkflowStateSnapshot";
}; };
export const AnalysisModuleConfig: z.ZodType<AnalysisModuleConfig> = z.object({ export const LlmConfig = z
.object({
max_tokens: z.union([z.number(), z.null()]),
model_id: z.union([z.string(), z.null()]),
temperature: z.union([z.number(), z.null()]),
})
.partial();
export const SelectionMode = z.union([
z
.object({ Manual: z.object({ rules: z.array(z.string()) }).passthrough() })
.passthrough(),
z
.object({
Auto: z
.object({ llm_config: z.union([z.null(), LlmConfig]) })
.partial()
.passthrough(),
})
.passthrough(),
z
.object({
Hybrid: z
.object({
llm_config: z.union([z.null(), LlmConfig]).optional(),
selection_prompt: z.string(),
})
.passthrough(),
})
.passthrough(),
]);
export const ContextSelectorConfig = SelectionMode;
export const AnalysisModuleConfig = z.object({
analysis_prompt: z.string(),
context_selector: ContextSelectorConfig,
dependencies: z.array(z.string()), dependencies: z.array(z.string()),
model_id: z.string(), id: z.union([z.string(), z.null()]).optional(),
llm_config: z.union([z.null(), LlmConfig]).optional(),
name: z.string(), name: z.string(),
prompt_template: z.string(), output_type: z.string(),
provider_id: z.string(),
}); });
export const AnalysisTemplateSet: z.ZodType<AnalysisTemplateSet> = z.object({ export const AnalysisTemplateSet = z.object({
modules: z.record(AnalysisModuleConfig), modules: z.record(AnalysisModuleConfig),
name: z.string(), name: z.string(),
}); });
export const AnalysisTemplateSets: z.ZodType<AnalysisTemplateSets> = export const AnalysisTemplateSets =
z.record(AnalysisTemplateSet); z.record(AnalysisTemplateSet);
export const DataSourceProvider = z.enum([ export const DataSourceProvider = z.enum([
"Tushare", "Tushare",
@ -189,50 +249,36 @@ export const DataSourceProvider = z.enum([
"Alphavantage", "Alphavantage",
"Yfinance", "Yfinance",
]); ]);
export const DataSourceConfig: z.ZodType<DataSourceConfig> = z.object({ export const DataSourceConfig = z.object({
api_key: z.union([z.string(), z.null()]).optional(), api_key: z.union([z.string(), z.null()]).optional(),
api_url: z.union([z.string(), z.null()]).optional(), api_url: z.union([z.string(), z.null()]).optional(),
enabled: z.boolean(), enabled: z.boolean(),
provider: DataSourceProvider, provider: DataSourceProvider,
}); });
export const DataSourcesConfig: z.ZodType<DataSourcesConfig> = export const DataSourcesConfig =
z.record(DataSourceConfig); z.record(DataSourceConfig);
export type TestLlmConfigRequest = {
api_base_url: string;
api_key: string;
model_id: string;
};
export const TestLlmConfigRequest = z.object({ export const TestLlmConfigRequest = z.object({
api_base_url: z.string(), api_base_url: z.string(),
api_key: z.string(), api_key: z.string(),
model_id: z.string(), model_id: z.string(),
}); });
export const LlmModel: z.ZodType<LlmModel> = z.object({ export const LlmModel = z.object({
is_active: z.boolean(), is_active: z.boolean(),
model_id: z.string(), model_id: z.string(),
name: z.union([z.string(), z.null()]).optional(), name: z.union([z.string(), z.null()]).optional(),
}); });
export const LlmProvider: z.ZodType<LlmProvider> = z.object({ export const LlmProvider = z.object({
api_base_url: z.string(), api_base_url: z.string(),
api_key: z.string(), api_key: z.string(),
models: z.array(LlmModel), models: z.array(LlmModel),
name: z.string(), name: z.string(),
}); });
export const LlmProvidersConfig: z.ZodType<LlmProvidersConfig> = z.record(LlmProvider); export const LlmProvidersConfig = z.record(LlmProvider);
export type TestConfigRequest = { data: unknown; type: string };
export const TestConfigRequest = z.object({ data: z.unknown(), type: z.string() }); export const TestConfigRequest = z.object({ data: z.unknown(), type: z.string() });
export type TestConnectionResponse = {
message: string;
success: boolean;
};
export const TestConnectionResponse = z.object({ export const TestConnectionResponse = z.object({
message: z.string(), message: z.string(),
success: z.boolean(), success: z.boolean(),
}); });
export type DiscoverPreviewRequest = {
api_base_url: string;
api_key: string;
};
export const DiscoverPreviewRequest = z.object({ export const DiscoverPreviewRequest = z.object({
api_base_url: z.string(), api_base_url: z.string(),
api_key: z.string(), api_key: z.string(),
@ -249,7 +295,7 @@ export const ConfigKey = z.enum([
"SandboxMode", "SandboxMode",
"Region", "Region",
]); ]);
export const ConfigFieldSchema: z.ZodType<ConfigFieldSchema> = z.object({ export const ConfigFieldSchema = z.object({
default_value: z.union([z.string(), z.null()]).optional(), default_value: z.union([z.string(), z.null()]).optional(),
description: z.union([z.string(), z.null()]).optional(), description: z.union([z.string(), z.null()]).optional(),
field_type: FieldType, field_type: FieldType,
@ -259,7 +305,7 @@ export const ConfigFieldSchema: z.ZodType<ConfigFieldSchema> = z.object({
placeholder: z.union([z.string(), z.null()]).optional(), placeholder: z.union([z.string(), z.null()]).optional(),
required: z.boolean(), required: z.boolean(),
}); });
export const ProviderMetadata: z.ZodType<ProviderMetadata> = z.object({ export const ProviderMetadata = z.object({
config_schema: z.array(ConfigFieldSchema), config_schema: z.array(ConfigFieldSchema),
description: z.string(), description: z.string(),
icon_url: z.union([z.string(), z.null()]).optional(), icon_url: z.union([z.string(), z.null()]).optional(),
@ -268,37 +314,19 @@ export const ProviderMetadata: z.ZodType<ProviderMetadata> = z.object({
name_en: z.string(), name_en: z.string(),
supports_test_connection: z.boolean(), supports_test_connection: z.boolean(),
}); });
export type SymbolResolveRequest = {
market?: (string | null) | undefined;
symbol: string;
};
export const SymbolResolveRequest = z.object({ export const SymbolResolveRequest = z.object({
market: z.union([z.string(), z.null()]).optional(), market: z.union([z.string(), z.null()]).optional(),
symbol: z.string(), symbol: z.string(),
}); });
export type SymbolResolveResponse = {
market: string;
symbol: string;
};
export const SymbolResolveResponse = z.object({ export const SymbolResolveResponse = z.object({
market: z.string(), market: z.string(),
symbol: z.string(), symbol: z.string(),
}); });
export type DataRequest = {
market?: (string | null) | undefined;
symbol: string;
template_id: string;
};
export const DataRequest = z.object({ export const DataRequest = z.object({
market: z.union([z.string(), z.null()]).optional(), market: z.union([z.string(), z.null()]).optional(),
symbol: z.string(), symbol: z.string(),
template_id: z.string(), template_id: z.string(),
}); });
export type RequestAcceptedResponse = {
market: string;
request_id: string;
symbol: string;
};
export const RequestAcceptedResponse = z.object({ export const RequestAcceptedResponse = z.object({
market: z.string(), market: z.string(),
request_id: z.string().uuid(), request_id: z.string().uuid(),
@ -310,7 +338,7 @@ export const ObservabilityTaskStatus = z.enum([
"Completed", "Completed",
"Failed", "Failed",
]); ]);
export const TaskProgress: z.ZodType<TaskProgress> = z.object({ export const TaskProgress = z.object({
details: z.string(), details: z.string(),
progress_percent: z.number().int().gte(0), progress_percent: z.number().int().gte(0),
request_id: z.string().uuid(), request_id: z.string().uuid(),
@ -320,19 +348,19 @@ export const TaskProgress: z.ZodType<TaskProgress> = z.object({
}); });
export const CanonicalSymbol = z.string(); export const CanonicalSymbol = z.string();
export const ServiceStatus = z.enum(["Ok", "Degraded", "Unhealthy"]); export const ServiceStatus = z.enum(["Ok", "Degraded", "Unhealthy"]);
export const HealthStatus: z.ZodType<HealthStatus> = z.object({ export const HealthStatus = z.object({
details: z.record(z.string()), details: z.record(z.string()),
module_id: z.string(), module_id: z.string(),
status: ServiceStatus, status: ServiceStatus,
version: z.string(), version: z.string(),
}); });
export const StartWorkflowCommand: z.ZodType<StartWorkflowCommand> = z.object({ export const StartWorkflowCommand = z.object({
market: z.string(), market: z.string(),
request_id: z.string().uuid(), request_id: z.string().uuid(),
symbol: CanonicalSymbol, symbol: CanonicalSymbol,
template_id: z.string(), template_id: z.string(),
}); });
export const TaskDependency: z.ZodType<TaskDependency> = z.object({ export const TaskDependency = z.object({
from: z.string(), from: z.string(),
to: z.string(), to: z.string(),
}); });
@ -345,18 +373,18 @@ export const TaskStatus = z.enum([
"Skipped", "Skipped",
]); ]);
export const TaskType = z.enum(["DataFetch", "DataProcessing", "Analysis"]); export const TaskType = z.enum(["DataFetch", "DataProcessing", "Analysis"]);
export const TaskNode: z.ZodType<TaskNode> = z.object({ export const TaskNode = z.object({
display_name: z.union([z.string(), z.null()]).optional(), display_name: z.union([z.string(), z.null()]).optional(),
id: z.string(), id: z.string(),
initial_status: TaskStatus, initial_status: TaskStatus,
name: z.string(), name: z.string(),
type: TaskType, type: TaskType,
}); });
export const WorkflowDag: z.ZodType<WorkflowDag> = z.object({ export const WorkflowDag = z.object({
edges: z.array(TaskDependency), edges: z.array(TaskDependency),
nodes: z.array(TaskNode), nodes: z.array(TaskNode),
}); });
export const WorkflowEvent: z.ZodType<WorkflowEvent> = z.union([ export const WorkflowEvent = z.union([
z z
.object({ .object({
payload: z payload: z
@ -369,7 +397,9 @@ export const WorkflowEvent: z.ZodType<WorkflowEvent> = z.union([
.object({ .object({
payload: z payload: z
.object({ .object({
input_commit: z.union([z.string(), z.null()]).optional(),
message: z.union([z.string(), z.null()]).optional(), message: z.union([z.string(), z.null()]).optional(),
output_commit: z.union([z.string(), z.null()]).optional(),
progress: z.union([z.number(), z.null()]).optional(), progress: z.union([z.number(), z.null()]).optional(),
status: TaskStatus, status: TaskStatus,
task_id: z.string(), task_id: z.string(),
@ -410,7 +440,7 @@ export const WorkflowEvent: z.ZodType<WorkflowEvent> = z.union([
payload: z payload: z
.object({ .object({
end_timestamp: z.number().int(), end_timestamp: z.number().int(),
result_summary: z.unknown(), result_summary: z.unknown().optional(),
}) })
.passthrough(), .passthrough(),
type: z.literal("WorkflowCompleted"), type: z.literal("WorkflowCompleted"),
@ -444,6 +474,9 @@ export const WorkflowEvent: z.ZodType<WorkflowEvent> = z.union([
]); ]);
export const schemas = { export const schemas = {
LlmConfig,
SelectionMode,
ContextSelectorConfig,
AnalysisModuleConfig, AnalysisModuleConfig,
AnalysisTemplateSet, AnalysisTemplateSet,
AnalysisTemplateSets, AnalysisTemplateSets,
@ -479,7 +512,7 @@ export const schemas = {
WorkflowEvent, WorkflowEvent,
}; };
const endpoints = makeApi([ export const endpoints = makeApi([
{ {
method: "get", method: "get",
path: "/api/v1/configs/analysis_template_sets", path: "/api/v1/configs/analysis_template_sets",

View File

@ -45,13 +45,13 @@ export function Dashboard() {
const missingConfigs: string[] = []; const missingConfigs: string[] = [];
Object.values(selectedTemplate.modules).forEach(module => { Object.values(selectedTemplate.modules).forEach(module => {
if (!llmProviders[module.provider_id]) { const modelId = module.llm_config?.model_id;
missingConfigs.push(`Module '${module.name}': Provider '${module.provider_id}' not found`); if (modelId && llmProviders) {
} else { const modelExists = Object.values(llmProviders).some(provider =>
const provider = llmProviders[module.provider_id]; provider.models.some(m => m.model_id === modelId)
const modelExists = provider.models.some(m => m.model_id === module.model_id); );
if (!modelExists) { if (!modelExists) {
missingConfigs.push(`Module '${module.name}': Model '${module.model_id}' not found in provider '${provider.name}'`); missingConfigs.push(`Module '${module.name}': Model '${modelId}' not found in any active provider`);
} }
} }
}); });

View File

@ -1,6 +1,8 @@
import { useState, useEffect } from "react" import { useState, useEffect, useMemo } from "react"
import { useAnalysisTemplates, useUpdateAnalysisTemplates, useLlmProviders } from "@/hooks/useConfig" import { useAnalysisTemplates, useUpdateAnalysisTemplates, useLlmProviders } from "@/hooks/useConfig"
import { AnalysisTemplateSet, AnalysisModuleConfig } from "@/types/config" import { AnalysisTemplateSet, AnalysisModuleConfig } from "@/types/config"
import { schemas } from "@/api/schema.gen"
import { z } from "zod"
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button"
import { ScrollArea } from "@/components/ui/scroll-area" import { ScrollArea } from "@/components/ui/scroll-area"
@ -67,7 +69,7 @@ export function TemplateTab() {
}); });
} }
const activeTemplate = (templates && selectedId) ? templates[selectedId] : null; const activeTemplate = (templates && selectedId) ? (templates as Record<string, AnalysisTemplateSet>)[selectedId] : null;
return ( return (
<div className="flex h-[600px] border rounded-md overflow-hidden"> <div className="flex h-[600px] border rounded-md overflow-hidden">
@ -87,7 +89,7 @@ export function TemplateTab() {
selectedId === id ? "bg-accent text-accent-foreground font-medium" : "hover:bg-muted" selectedId === id ? "bg-accent text-accent-foreground font-medium" : "hover:bg-muted"
}`} }`}
> >
<span className="truncate pr-6">{t.name}</span> <span className="truncate pr-6">{(t as AnalysisTemplateSet).name}</span>
{selectedId === id && <ArrowRight className="h-3 w-3 opacity-50" />} {selectedId === id && <ArrowRight className="h-3 w-3 opacity-50" />}
</button> </button>
{/* Delete button visible on hover */} {/* Delete button visible on hover */}
@ -140,7 +142,7 @@ function TemplateDetailView({ template, onSave, isSaving }: { template: Analysis
}, [template]); }, [template]);
const handleRename = () => { const handleRename = () => {
setLocalTemplate(prev => ({ ...prev, name: newName })); setLocalTemplate((prev: AnalysisTemplateSet) => ({ ...prev, name: newName }));
setIsRenaming(false); setIsRenaming(false);
setIsDirty(true); setIsDirty(true);
} }
@ -149,13 +151,14 @@ function TemplateDetailView({ template, onSave, isSaving }: { template: Analysis
const newModuleId = "module_" + Math.random().toString(36).substring(2, 9); const newModuleId = "module_" + Math.random().toString(36).substring(2, 9);
const newModule: AnalysisModuleConfig = { const newModule: AnalysisModuleConfig = {
name: "New Analysis Module", name: "New Analysis Module",
model_id: "", // Empty to force selection analysis_prompt: "Analyze the following financial data:\n\n{{data}}\n\nProvide insights on...",
provider_id: "", dependencies: [],
prompt_template: "Analyze the following financial data:\n\n{{data}}\n\nProvide insights on...", context_selector: { Manual: { rules: [] } },
dependencies: [] llm_config: { model_id: "" },
output_type: "markdown"
}; };
setLocalTemplate(prev => ({ setLocalTemplate((prev: AnalysisTemplateSet) => ({
...prev, ...prev,
modules: { modules: {
...prev.modules, ...prev.modules,
@ -166,16 +169,16 @@ function TemplateDetailView({ template, onSave, isSaving }: { template: Analysis
} }
const handleDeleteModule = (moduleId: string) => { const handleDeleteModule = (moduleId: string) => {
setLocalTemplate(prev => { setLocalTemplate((prev: AnalysisTemplateSet) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars // eslint-disable-next-line @typescript-eslint/no-unused-vars
const { [moduleId]: removed, ...rest } = prev.modules; const { [moduleId]: removed, ...rest } = prev.modules as Record<string, AnalysisModuleConfig>;
return { ...prev, modules: rest }; return { ...prev, modules: rest };
}); });
setIsDirty(true); setIsDirty(true);
} }
const handleUpdateModule = (moduleId: string, updatedModule: AnalysisModuleConfig) => { const handleUpdateModule = (moduleId: string, updatedModule: AnalysisModuleConfig) => {
setLocalTemplate(prev => ({ setLocalTemplate((prev: AnalysisTemplateSet) => ({
...prev, ...prev,
modules: { modules: {
...prev.modules, ...prev.modules,
@ -269,9 +272,9 @@ function ModuleCard({ id, module, availableModules, allModules, onDelete, onUpda
const seenKeys = new Set<string>(); const seenKeys = new Set<string>();
if (providers) { if (providers) {
Object.entries(providers).forEach(([pid, p]) => { Object.entries(providers as Record<string, any>).forEach(([pid, p]) => {
p.models.forEach((m: { model_id: string, name?: string | null }) => { (p as { models: { model_id: string, name?: string | null }[] }).models.forEach((m) => {
const uniqueKey = `${pid}::${m.model_id}`; const uniqueKey = m.model_id;
if (!seenKeys.has(uniqueKey)) { if (!seenKeys.has(uniqueKey)) {
seenKeys.add(uniqueKey); seenKeys.add(uniqueKey);
allModels.push({ allModels.push({
@ -284,16 +287,130 @@ function ModuleCard({ id, module, availableModules, allModules, onDelete, onUpda
}); });
} }
const handleModelChange = (uniqueId: string) => { const handleModelChange = (mid: string) => {
const [pid, mid] = uniqueId.split('::'); onUpdate({
onUpdate({ ...module, provider_id: pid, model_id: mid }); ...module,
llm_config: { ...module.llm_config, model_id: mid }
});
} }
const currentModelUniqueId = module.provider_id && module.model_id ? `${module.provider_id}::${module.model_id}` : undefined; const currentModelId = module.llm_config?.model_id;
// Dynamically derive modes from Zod Schema
const availableSelectionModes = useMemo(() => {
if (schemas.SelectionMode instanceof z.ZodUnion) {
return schemas.SelectionMode.options.map((opt: any) => {
// Each option is ZodObject or similar, we extract the first key of the shape
// shape is available on ZodObject
const shape = opt.shape || (opt._def && opt._def.shape && opt._def.shape());
if (shape) {
return Object.keys(shape)[0];
}
return "Unknown";
}).filter(Boolean) as string[];
}
return ["Manual", "Auto", "Hybrid"]; // Fallback if schema introspection fails
}, []);
const currentMode = Object.keys(module.context_selector)[0];
const handleModeChange = (newMode: string) => {
// Initialize default structure based on mode
// Note: While modes are dynamic, default init still requires some knowledge or a more complex generator
// For now, we use a robust switch, but the keys come from the schema in the UI.
if (newMode === 'Manual') {
onUpdate({ ...module, context_selector: { Manual: { rules: [] } } });
} else if (newMode === 'Auto') {
onUpdate({ ...module, context_selector: { Auto: { llm_config: null } } });
} else if (newMode === 'Hybrid') {
onUpdate({ ...module, context_selector: { Hybrid: { selection_prompt: "", llm_config: null } } });
}
}
// Helper to get manual rules
const getManualRules = () => {
if ('Manual' in module.context_selector) {
return module.context_selector.Manual.rules.join('\n');
}
return "";
}
const handleRulesChange = (text: string) => {
const rules = text.split('\n').filter(line => line.trim() !== "");
onUpdate({
...module,
context_selector: { Manual: { rules } }
});
}
const getSelectionPrompt = () => {
if ('Hybrid' in module.context_selector) {
return module.context_selector.Hybrid.selection_prompt;
}
return "";
}
const handleSelectionPromptChange = (text: string) => {
if ('Hybrid' in module.context_selector) {
onUpdate({
...module,
context_selector: {
Hybrid: {
...module.context_selector.Hybrid,
selection_prompt: text
}
}
});
}
}
const getSelectorModelId = () => {
if ('Auto' in module.context_selector) {
return module.context_selector.Auto.llm_config?.model_id;
}
if ('Hybrid' in module.context_selector) {
return module.context_selector.Hybrid.llm_config?.model_id;
}
return undefined;
}
const handleSelectorModelChange = (mid: string) => {
const llmConfig = { model_id: mid };
if (currentMode === 'Auto') {
onUpdate({
...module,
context_selector: { Auto: { llm_config: llmConfig } }
});
} else if (currentMode === 'Hybrid') {
if ('Hybrid' in module.context_selector) {
onUpdate({
...module,
context_selector: {
Hybrid: {
...module.context_selector.Hybrid,
llm_config: llmConfig
}
}
});
} else {
// Handle edge case where currentMode is Hybrid but selector state not yet updated
// This usually shouldn't happen if useEffect syncs correctly, but safety first
onUpdate({
...module,
context_selector: {
Hybrid: {
selection_prompt: "",
llm_config: llmConfig
}
}
});
}
}
}
const toggleDependency = (depId: string) => { const toggleDependency = (depId: string) => {
const newDeps = module.dependencies.includes(depId) const newDeps = module.dependencies.includes(depId)
? module.dependencies.filter(d => d !== depId) ? module.dependencies.filter((d: string) => d !== depId)
: [...module.dependencies, depId]; : [...module.dependencies, depId];
onUpdate({ ...module, dependencies: newDeps }); onUpdate({ ...module, dependencies: newDeps });
} }
@ -314,7 +431,7 @@ function ModuleCard({ id, module, availableModules, allModules, onDelete, onUpda
<CardDescription className="hidden"></CardDescription> <CardDescription className="hidden"></CardDescription>
<div className="flex items-center gap-2 mt-1"> <div className="flex items-center gap-2 mt-1">
<Badge variant="outline" className="font-normal text-[10px] h-5"> <Badge variant="outline" className="font-normal text-[10px] h-5">
{module.provider_id || "?"} / {module.model_id || "?"} {currentModelId || "No Model"}
</Badge> </Badge>
{module.dependencies.length > 0 && ( {module.dependencies.length > 0 && (
<span className="text-[10px] text-muted-foreground"> <span className="text-[10px] text-muted-foreground">
@ -344,14 +461,14 @@ function ModuleCard({ id, module, availableModules, allModules, onDelete, onUpda
</div> </div>
<div className="space-y-2"> <div className="space-y-2">
<Label>Model</Label> <Label>Model</Label>
<Select value={currentModelUniqueId} onValueChange={handleModelChange}> <Select value={currentModelId || ""} onValueChange={handleModelChange}>
<SelectTrigger> <SelectTrigger>
<SelectValue placeholder="Select Model" /> <SelectValue placeholder="Select Model" />
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
{allModels.length > 0 ? ( {allModels.length > 0 ? (
allModels.map(m => ( allModels.map(m => (
<SelectItem key={`${m.providerId}::${m.modelId}`} value={`${m.providerId}::${m.modelId}`}> <SelectItem key={m.modelId} value={m.modelId}>
{m.name} {m.name}
</SelectItem> </SelectItem>
)) ))
@ -365,15 +482,99 @@ function ModuleCard({ id, module, availableModules, allModules, onDelete, onUpda
</div> </div>
</div> </div>
<div className="space-y-2"> <div className="space-y-2 border p-4 rounded-md bg-muted/5">
<Label>Prompt Template</Label> <div className="flex items-center justify-between">
<Label>Context Selection Mode</Label>
<Select value={currentMode} onValueChange={(v) => handleModeChange(v)}>
<SelectTrigger className="w-[180px]">
<SelectValue />
</SelectTrigger>
<SelectContent>
{availableSelectionModes.map(m => (
<SelectItem key={m} value={m}>{m}</SelectItem>
))}
</SelectContent>
</Select>
</div>
{currentMode === 'Manual' && (
<div className="space-y-2 mt-4">
<Label className="text-xs">Glob Patterns</Label>
<Textarea <Textarea
value={module.prompt_template} value={getManualRules()}
onChange={(e) => onUpdate({...module, prompt_template: e.target.value})} onChange={(e) => handleRulesChange(e.target.value)}
className="font-mono text-xs min-h-[100px]" className="font-mono text-xs min-h-[80px]"
placeholder="raw/tushare/{{symbol}}/*.json"
/> />
<p className="text-[10px] text-muted-foreground"> <p className="text-[10px] text-muted-foreground">
Use <code>{"{{data}}"}</code> to inject data from dependencies or data source. One rule per line. Use <code>{"{{symbol}}"}</code> as placeholder.
</p>
</div>
)}
{currentMode === 'Auto' && (
<div className="space-y-2 mt-4">
<Label className="text-xs">Selector Model</Label>
<Select value={getSelectorModelId() || ""} onValueChange={handleSelectorModelChange}>
<SelectTrigger>
<SelectValue placeholder="Select Model for Selection" />
</SelectTrigger>
<SelectContent>
{allModels.map(m => (
<SelectItem key={m.modelId} value={m.modelId}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
<p className="text-[10px] text-muted-foreground">
The LLM will automatically select relevant files based on the directory tree.
</p>
</div>
)}
{currentMode === 'Hybrid' && (
<div className="space-y-4 mt-4">
<div className="space-y-2">
<Label className="text-xs">Selector Model</Label>
<Select value={getSelectorModelId() || ""} onValueChange={handleSelectorModelChange}>
<SelectTrigger>
<SelectValue placeholder="Select Model for Selection" />
</SelectTrigger>
<SelectContent>
{allModels.map(m => (
<SelectItem key={m.modelId} value={m.modelId}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label className="text-xs">Selection Prompt</Label>
<Textarea
value={getSelectionPrompt()}
onChange={(e) => handleSelectionPromptChange(e.target.value)}
className="font-mono text-xs min-h-[80px]"
placeholder="Find all financial statements from 2024..."
/>
<p className="text-[10px] text-muted-foreground">
Describe what files you need. The LLM will interpret this and select files.
</p>
</div>
</div>
)}
</div>
<div className="space-y-2">
<Label>Analysis Prompt</Label>
<Textarea
value={module.analysis_prompt}
onChange={(e) => onUpdate({...module, analysis_prompt: e.target.value})}
className="font-mono text-xs min-h-[150px]"
/>
<p className="text-[10px] text-muted-foreground">
Use <code>{"{{data}}"}</code> to inject data.
</p> </p>
</div> </div>

View File

@ -415,32 +415,46 @@
"schemas": { "schemas": {
"AnalysisModuleConfig": { "AnalysisModuleConfig": {
"type": "object", "type": "object",
"description": "Configuration for a single analysis module.",
"required": [ "required": [
"name", "name",
"provider_id", "dependencies",
"model_id", "context_selector",
"prompt_template", "analysis_prompt",
"dependencies" "output_type"
], ],
"properties": { "properties": {
"analysis_prompt": {
"type": "string"
},
"context_selector": {
"$ref": "#/components/schemas/ContextSelectorConfig"
},
"dependencies": { "dependencies": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
}
}, },
"description": "List of dependencies. Each string must be a key in the parent `modules` HashMap." "id": {
"type": [
"string",
"null"
]
}, },
"model_id": { "llm_config": {
"type": "string" "oneOf": [
{
"type": "null"
},
{
"$ref": "#/components/schemas/LlmConfig"
}
]
}, },
"name": { "name": {
"type": "string" "type": "string"
}, },
"prompt_template": { "output_type": {
"type": "string"
},
"provider_id": {
"type": "string" "type": "string"
} }
}, },
@ -554,6 +568,13 @@
"Region" "Region"
] ]
}, },
"ContextSelectorConfig": {
"allOf": [
{
"$ref": "#/components/schemas/SelectionMode"
}
]
},
"DataRequest": { "DataRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@ -679,6 +700,33 @@
}, },
"additionalProperties": false "additionalProperties": false
}, },
"LlmConfig": {
"type": "object",
"properties": {
"max_tokens": {
"type": [
"integer",
"null"
],
"format": "int32",
"minimum": 0
},
"model_id": {
"type": [
"string",
"null"
]
},
"temperature": {
"type": [
"number",
"null"
],
"format": "float"
}
},
"additionalProperties": false
},
"LlmModel": { "LlmModel": {
"type": "object", "type": "object",
"required": [ "required": [
@ -811,6 +859,84 @@
}, },
"additionalProperties": false "additionalProperties": false
}, },
"SelectionMode": {
"oneOf": [
{
"type": "object",
"required": [
"Manual"
],
"properties": {
"Manual": {
"type": "object",
"required": [
"rules"
],
"properties": {
"rules": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
},
{
"type": "object",
"required": [
"Auto"
],
"properties": {
"Auto": {
"type": "object",
"properties": {
"llm_config": {
"oneOf": [
{
"type": "null"
},
{
"$ref": "#/components/schemas/LlmConfig"
}
]
}
}
}
}
},
{
"type": "object",
"required": [
"Hybrid"
],
"properties": {
"Hybrid": {
"type": "object",
"required": [
"selection_prompt"
],
"properties": {
"llm_config": {
"oneOf": [
{
"type": "null"
},
{
"$ref": "#/components/schemas/LlmConfig"
}
]
},
"selection_prompt": {
"type": "string"
}
}
}
}
}
]
},
"ServiceStatus": { "ServiceStatus": {
"type": "string", "type": "string",
"enum": [ "enum": [
@ -904,6 +1030,12 @@
"initial_status" "initial_status"
], ],
"properties": { "properties": {
"display_name": {
"type": [
"string",
"null"
]
},
"id": { "id": {
"type": "string" "type": "string"
}, },
@ -1095,12 +1227,24 @@
"timestamp" "timestamp"
], ],
"properties": { "properties": {
"input_commit": {
"type": [
"string",
"null"
]
},
"message": { "message": {
"type": [ "type": [
"string", "string",
"null" "null"
] ]
}, },
"output_commit": {
"type": [
"string",
"null"
]
},
"progress": { "progress": {
"type": [ "type": [
"integer", "integer",

View File

@ -500,6 +500,11 @@ async fn trigger_analysis_generation(
template_id: payload.template_id, template_id: payload.template_id,
task_id: None, task_id: None,
module_id: None, module_id: None,
commit_hash: None,
input_bindings: None,
output_path: None,
llm_config: None,
analysis_prompt: None,
}; };
info!(request_id = %request_id, "Publishing analysis generation command"); info!(request_id = %request_id, "Publishing analysis generation command");

View File

@ -89,16 +89,7 @@ pub struct AnalysisTemplateSet {
} }
/// Configuration for a single analysis module. /// Configuration for a single analysis module.
#[api_dto] pub use crate::configs::AnalysisModuleConfig;
#[derive(PartialEq)]
pub struct AnalysisModuleConfig {
pub name: String,
pub provider_id: String,
pub model_id: String,
pub prompt_template: String,
/// List of dependencies. Each string must be a key in the parent `modules` HashMap.
pub dependencies: Vec<String>,
}
// --- Analysis Module Config (OLD DEPRECATED STRUCTURE) --- // --- Analysis Module Config (OLD DEPRECATED STRUCTURE) ---

View File

@ -0,0 +1,46 @@
use std::collections::HashMap;
use service_kit::api_dto;
#[api_dto]
#[derive(PartialEq)]
pub struct LlmConfig {
pub model_id: Option<String>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
#[serde(flatten)]
pub extra_params: HashMap<String, serde_json::Value>,
}
#[api_dto]
#[derive(PartialEq)]
pub enum SelectionMode {
Manual {
rules: Vec<String>, // Glob patterns
},
Auto {
llm_config: Option<LlmConfig>,
},
Hybrid {
selection_prompt: String,
llm_config: Option<LlmConfig>,
},
}
#[api_dto]
#[derive(PartialEq)]
pub struct ContextSelectorConfig {
#[serde(flatten)]
pub mode: SelectionMode,
}
#[api_dto]
#[derive(PartialEq)]
pub struct AnalysisModuleConfig {
pub id: Option<String>, // Optional if key is ID
pub name: String,
pub dependencies: Vec<String>,
pub context_selector: ContextSelectorConfig,
pub analysis_prompt: String,
pub llm_config: Option<LlmConfig>,
pub output_type: String, // e.g. "markdown", "json"
}

View File

@ -13,3 +13,4 @@ pub mod persistence_client;
pub mod abstraction; pub mod abstraction;
pub mod workflow_harness; // Export the harness pub mod workflow_harness; // Export the harness
pub mod workflow_types; pub mod workflow_types;
pub mod configs;

View File

@ -3,6 +3,7 @@ use crate::symbol_utils::CanonicalSymbol;
use crate::subjects::{NatsSubject, SubjectMessage}; use crate::subjects::{NatsSubject, SubjectMessage};
use std::collections::HashMap; use std::collections::HashMap;
use service_kit::api_dto; use service_kit::api_dto;
use crate::configs::LlmConfig;
// --- Commands --- // --- Commands ---
@ -48,6 +49,7 @@ pub struct FetchCompanyDataCommand {
pub symbol: CanonicalSymbol, pub symbol: CanonicalSymbol,
pub market: String, pub market: String,
pub template_id: Option<String>, // Optional trigger for analysis pub template_id: Option<String>, // Optional trigger for analysis
pub output_path: Option<String>, // New: Unified I/O Binding
} }
impl SubjectMessage for FetchCompanyDataCommand { impl SubjectMessage for FetchCompanyDataCommand {
@ -68,6 +70,13 @@ pub struct GenerateReportCommand {
/// Used for reporting progress/content back to the specific node. /// Used for reporting progress/content back to the specific node.
pub task_id: Option<String>, pub task_id: Option<String>,
pub module_id: Option<String>, pub module_id: Option<String>,
// --- New Fields for Refactored Context Mechanism ---
pub commit_hash: Option<String>,
pub input_bindings: Option<Vec<String>>, // Resolved physical paths
pub output_path: Option<String>,
pub llm_config: Option<LlmConfig>,
pub analysis_prompt: Option<String>,
} }
impl SubjectMessage for GenerateReportCommand { impl SubjectMessage for GenerateReportCommand {
@ -254,5 +263,3 @@ impl SubjectMessage for ReportFailedEvent {
NatsSubject::AnalysisReportFailed NatsSubject::AnalysisReportFailed
} }
} }

View File

@ -17,6 +17,7 @@ ENV SQLX_OFFLINE=true
WORKDIR /app/services/data-persistence-service WORKDIR /app/services/data-persistence-service
COPY --from=planner /app/services/data-persistence-service/recipe.json /app/services/data-persistence-service/recipe.json COPY --from=planner /app/services/data-persistence-service/recipe.json /app/services/data-persistence-service/recipe.json
# 为了支持 path 依赖,先拷贝依赖源码再 cook # 为了支持 path 依赖,先拷贝依赖源码再 cook
ENV FORCE_REBUILD=2
COPY services/common-contracts /app/services/common-contracts COPY services/common-contracts /app/services/common-contracts
# Copy service_kit mirror again for build # Copy service_kit mirror again for build
COPY ref/service_kit_mirror /app/ref/service_kit_mirror COPY ref/service_kit_mirror /app/ref/service_kit_mirror

View File

@ -10,17 +10,7 @@ const CONFIG_KEY: &str = "analysis_template_sets";
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct RawAnalysisConfig { struct RawAnalysisConfig {
analysis_modules: HashMap<String, RawModule>, analysis_modules: HashMap<String, AnalysisModuleConfig>,
}
#[derive(serde::Deserialize)]
struct RawModule {
name: String,
#[serde(default)]
dependencies: Vec<String>,
#[serde(rename = "model")]
model_id: String,
prompt_template: String,
} }
/// Seeds the database with default configurations if they don't already exist. /// Seeds the database with default configurations if they don't already exist.
@ -39,24 +29,18 @@ pub async fn seed_data(pool: &PgPool) -> Result<(), sqlx::Error> {
info!("No 'analysis_template_sets' config found. Seeding default analysis templates..."); info!("No 'analysis_template_sets' config found. Seeding default analysis templates...");
// 解析当前仓库中的配置文件结构: // 解析当前仓库中的配置文件结构:
// { "analysis_modules": { "<module_id>": { name, model, prompt_template, dependencies? } } }
let raw: RawAnalysisConfig = serde_json::from_str(DEFAULT_ANALYSIS_CONFIG_JSON) let raw: RawAnalysisConfig = serde_json::from_str(DEFAULT_ANALYSIS_CONFIG_JSON)
.expect("Failed to parse embedded default analysis config JSON"); .expect("Failed to parse embedded default analysis config JSON");
let modules = raw let modules = raw
.analysis_modules .analysis_modules
.into_iter() .into_iter()
.map(|(k, v)| { .map(|(k, mut v)| {
( // Set ID if missing
k, if v.id.is_none() {
AnalysisModuleConfig { v.id = Some(k.clone());
name: v.name, }
provider_id: "".to_string(), // 由用户后续配置 (k, v)
model_id: v.model_id,
prompt_template: v.prompt_template,
dependencies: v.dependencies,
},
)
}) })
.collect(); .collect();

View File

@ -353,6 +353,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97463e1064cb1b1c1384ad0a0b9c8abd0988e2a91f52606c80ef14aadb63e36" checksum = "b97463e1064cb1b1c1384ad0a0b9c8abd0988e2a91f52606c80ef14aadb63e36"
dependencies = [ dependencies = [
"find-msvc-tools", "find-msvc-tools",
"jobserver",
"libc",
"shlex", "shlex",
] ]
@ -785,7 +787,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -997,6 +999,21 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "git2"
version = "0.18.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "232e6a7bfe35766bf715e55a88b39a700596c0ccfd88cd3680b4cdb40d66ef70"
dependencies = [
"bitflags",
"libc",
"libgit2-sys",
"log",
"openssl-probe",
"openssl-sys",
"url",
]
[[package]] [[package]]
name = "globset" name = "globset"
version = "0.4.18" version = "0.4.18"
@ -1085,6 +1102,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]] [[package]]
name = "http" name = "http"
version = "1.3.1" version = "1.3.1"
@ -1423,6 +1446,16 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jobserver"
version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33"
dependencies = [
"getrandom 0.3.4",
"libc",
]
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.82" version = "0.3.82"
@ -1456,12 +1489,52 @@ version = "0.2.177"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
[[package]]
name = "libgit2-sys"
version = "0.16.2+1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee4126d8b4ee5c9d9ea891dd875cfdc1e9d0950437179104b183d7d8a74d24e8"
dependencies = [
"cc",
"libc",
"libssh2-sys",
"libz-sys",
"openssl-sys",
"pkg-config",
]
[[package]] [[package]]
name = "libm" name = "libm"
version = "0.2.15" version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libssh2-sys"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "220e4f05ad4a218192533b300327f5150e809b54c4ec83b5a1d91833601811b9"
dependencies = [
"cc",
"libc",
"libz-sys",
"openssl-sys",
"pkg-config",
"vcpkg",
]
[[package]]
name = "libz-sys"
version = "1.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15d118bbf3771060e7311cc7bb0545b01d08a8b4a7de949198dec1fa0ca1c0f7"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.11.0" version = "0.11.0"
@ -1662,6 +1735,15 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-src"
version = "300.5.4+3.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.111" version = "0.9.111"
@ -1670,6 +1752,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
"openssl-src",
"pkg-config", "pkg-config",
"vcpkg", "vcpkg",
] ]
@ -2168,6 +2251,7 @@ dependencies = [
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"serde_yaml",
"tera", "tera",
"thiserror 2.0.17", "thiserror 2.0.17",
"tokio", "tokio",
@ -2175,6 +2259,7 @@ dependencies = [
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
"workflow-context",
] ]
[[package]] [[package]]
@ -2350,7 +2435,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -2670,6 +2755,19 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
"unsafe-libyaml",
]
[[package]] [[package]]
name = "service-kit-macros" name = "service-kit-macros"
version = "0.1.2" version = "0.1.2"
@ -2914,7 +3012,7 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys 0.52.0", "windows-sys 0.61.2",
] ]
[[package]] [[package]]
@ -3357,6 +3455,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
@ -3829,6 +3933,22 @@ version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59"
[[package]]
name = "workflow-context"
version = "0.1.0"
dependencies = [
"anyhow",
"git2",
"globset",
"hex",
"regex",
"serde",
"serde_json",
"sha2",
"thiserror 1.0.69",
"walkdir",
]
[[package]] [[package]]
name = "writeable" name = "writeable"
version = "0.6.2" version = "0.6.2"

View File

@ -11,6 +11,7 @@ tower-http = { version = "0.6.6", features = ["cors"] }
# Shared Contracts # Shared Contracts
common-contracts = { path = "../common-contracts", default-features = false } common-contracts = { path = "../common-contracts", default-features = false }
workflow-context = { path = "../../crates/workflow-context" }
# Message Queue (NATS) # Message Queue (NATS)
async-nats = "0.45.0" async-nats = "0.45.0"
@ -44,3 +45,4 @@ petgraph = "0.8.3"
async-openai = "0.30.1" async-openai = "0.30.1"
futures-util = "0.3" futures-util = "0.3"
async-stream = "0.3" async-stream = "0.3"
serde_yaml = "0.9.34"

View File

@ -5,6 +5,11 @@ pub struct AppConfig {
pub server_port: u16, pub server_port: u16,
pub nats_addr: String, pub nats_addr: String,
pub data_persistence_service_url: String, pub data_persistence_service_url: String,
pub workflow_data_path: String,
}
fn default_workflow_data_path() -> String {
"/app/data".to_string()
} }
impl AppConfig { impl AppConfig {

View File

@ -1,72 +0,0 @@
use std::collections::{BTreeMap, HashSet};
use common_contracts::dtos::TimeSeriesFinancialDto;
use chrono::Datelike;
/// Formats a list of TimeSeriesFinancialDto into a Markdown table.
/// The table columns are years (sorted descending), and rows are metrics.
pub fn format_financials_to_markdown(financials: &[TimeSeriesFinancialDto]) -> String {
if financials.is_empty() {
return "No financial data available.".to_string();
}
// 1. Group by Year and Metric
// Map<MetricName, Map<Year, Value>>
let mut data_map: BTreeMap<String, BTreeMap<i32, f64>> = BTreeMap::new();
let mut years_set: HashSet<i32> = HashSet::new();
for item in financials {
let year = item.period_date.year();
years_set.insert(year);
data_map
.entry(item.metric_name.clone())
.or_default()
.insert(year, item.value);
}
// 2. Sort years descending (recent first)
let mut sorted_years: Vec<i32> = years_set.into_iter().collect();
sorted_years.sort_by(|a, b| b.cmp(a));
// Limit to recent 5 years to keep table readable
let display_years = if sorted_years.len() > 5 {
&sorted_years[..5]
} else {
&sorted_years
};
// 3. Build Markdown Table
let mut markdown = String::new();
// Header
markdown.push_str("| Metric |");
for year in display_years {
markdown.push_str(&format!(" {} |", year));
}
markdown.push('\n');
// Separator
markdown.push_str("| :--- |");
for _ in display_years {
markdown.push_str(" ---: |");
}
markdown.push('\n');
// Rows
for (metric, year_values) in data_map {
markdown.push_str(&format!("| {} |", metric));
for year in display_years {
if let Some(value) = year_values.get(year) {
// Format large numbers or percentages intelligently if needed.
// For now, simple float formatting.
markdown.push_str(&format!(" {:.2} |", value));
} else {
markdown.push_str(" - |");
}
}
markdown.push('\n');
}
markdown
}

View File

@ -7,7 +7,6 @@ mod persistence;
mod state; mod state;
mod templates; mod templates;
mod worker; mod worker;
mod formatter;
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::error::{ProviderError, Result}; use crate::error::{ProviderError, Result};

View File

@ -73,12 +73,31 @@ pub async fn subscribe_to_commands(
let market = Market::from(m); let market = Market::from(m);
let symbol = CanonicalSymbol::new(s, &market); let symbol = CanonicalSymbol::new(s, &market);
let commit_hash = task_cmd.context.base_commit.clone();
let input_bindings = task_cmd.config.get("input_bindings").and_then(|v| serde_json::from_value(v.clone()).ok());
let output_path = task_cmd.config.get("output_path").and_then(|v| v.as_str()).map(|s| s.to_string());
let mut llm_config = None;
let mut analysis_prompt = None;
if let Some(mc_val) = task_cmd.config.get("_module_config") {
if let Ok(mc) = serde_json::from_value::<common_contracts::config_models::AnalysisModuleConfig>(mc_val.clone()) {
llm_config = mc.llm_config;
analysis_prompt = Some(mc.analysis_prompt);
}
}
let report_cmd = GenerateReportCommand { let report_cmd = GenerateReportCommand {
request_id: task_cmd.request_id, request_id: task_cmd.request_id,
symbol: symbol.clone(), symbol: symbol.clone(),
template_id: t.to_string(), template_id: t.to_string(),
task_id: Some(task_cmd.task_id.clone()), task_id: Some(task_cmd.task_id.clone()),
module_id: module_id.map(|v| v.to_string()), module_id: module_id.map(|v| v.to_string()),
commit_hash,
input_bindings,
output_path,
llm_config,
analysis_prompt,
}; };
// 2. Send TaskStatus::Running // 2. Send TaskStatus::Running
@ -98,13 +117,17 @@ pub async fn subscribe_to_commands(
// 3. Run Logic // 3. Run Logic
match run_report_generation_workflow(Arc::new(state), report_cmd).await { match run_report_generation_workflow(Arc::new(state), report_cmd).await {
Ok(_) => { Ok(new_commit_opt) => {
// 4. Send TaskStatus::Completed // 4. Send TaskStatus::Completed
let completed_evt = WorkflowTaskEvent { let completed_evt = WorkflowTaskEvent {
request_id: task_cmd.request_id, request_id: task_cmd.request_id,
task_id: task_cmd.task_id.clone(), task_id: task_cmd.task_id.clone(),
status: TaskStatus::Completed, status: TaskStatus::Completed,
result: None, // Future: Add commit hash here if we used VGCS result: Some(common_contracts::workflow_types::TaskResult {
new_commit: new_commit_opt, // Pass the commit hash back
error: None,
summary: None,
}),
}; };
if let Ok(payload) = serde_json::to_vec(&completed_evt) { if let Ok(payload) = serde_json::to_vec(&completed_evt) {
let subject = common_contracts::subjects::NatsSubject::WorkflowEventTaskCompleted.to_string(); let subject = common_contracts::subjects::NatsSubject::WorkflowEventTaskCompleted.to_string();

View File

@ -1,611 +1,244 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use common_contracts::config_models::{ use common_contracts::messages::{GenerateReportCommand, WorkflowEvent};
AnalysisModuleConfig, AnalysisTemplateSets, LlmProvidersConfig,
};
use common_contracts::dtos::{CompanyProfileDto, NewAnalysisResult, TimeSeriesFinancialDto};
use common_contracts::messages::{GenerateReportCommand, ReportGeneratedEvent, ReportFailedEvent, WorkflowEvent};
use common_contracts::subjects::SubjectMessage;
use futures_util::StreamExt; use futures_util::StreamExt;
use petgraph::algo::toposort; use tracing::{info, instrument, error};
use petgraph::graph::DiGraph; use workflow_context::WorkerContext;
use tera::{Context, Tera};
use tracing::{info, instrument, error, warn};
use crate::error::{ProviderError, Result}; use crate::error::{ProviderError, Result};
use crate::llm_client::LlmClient; use crate::llm_client::LlmClient;
use crate::persistence::PersistenceClient; use crate::persistence::PersistenceClient;
use crate::state::AppState; use crate::state::AppState;
use crate::formatter::format_financials_to_markdown;
#[instrument(skip_all, fields(request_id = %command.request_id, symbol = %command.symbol, template_id = %command.template_id))] #[instrument(skip_all, fields(request_id = %command.request_id, symbol = %command.symbol, template_id = %command.template_id))]
pub async fn run_report_generation_workflow( pub async fn run_report_generation_workflow(
state: Arc<AppState>, state: Arc<AppState>,
command: GenerateReportCommand, command: GenerateReportCommand,
) -> Result<()> { ) -> Result<Option<String>> {
info!("Starting report generation workflow."); info!("Starting report generation workflow.");
// Get channel immediately to ensure we can broadcast errors // NEW MODE: Check if we are in VGCS mode
let stream_tx = state.streams.get_or_create_sender(command.request_id); let has_new_mode_params = command.input_bindings.is_some() && command.output_path.is_some() && command.commit_hash.is_some();
let result: Result<()> = async { if has_new_mode_params {
let persistence_client = let input_bindings = command.input_bindings.clone().unwrap();
PersistenceClient::new(state.config.data_persistence_service_url.clone()); let output_path = command.output_path.clone().unwrap();
let commit_hash = command.commit_hash.clone().unwrap();
// 1. Fetch all necessary data AND configurations in parallel return run_vgcs_based_generation(state, command, &input_bindings, &output_path, &commit_hash).await.map(Some);
info!("Fetching initial data and configurations from persistence service...");
let (profile, financials, llm_providers, template_sets) =
fetch_data_and_configs(&persistence_client, command.symbol.as_str(), command.request_id).await
.map_err(|e| {
error!("Failed to fetch initial data and configs: {}", e);
e
})?;
info!("Successfully fetched data. Profile: {}, Financials count: {}", profile.name, financials.len());
// 2. Select the correct template set
let template_set = template_sets.get(&command.template_id).ok_or_else(|| {
let err_msg = format!("Analysis template set with ID '{}' not found.", command.template_id);
error!("{}", err_msg);
ProviderError::Configuration(err_msg)
})?;
// 3. Determine Execution Plan (Single Module vs Full Workflow)
let sorted_modules = if let Some(target_module) = &command.module_id {
info!("Targeting single module execution: {}", target_module);
if !template_set.modules.contains_key(target_module) {
let err_msg = format!("Module '{}' not found in template '{}'", target_module, command.template_id);
error!("{}", err_msg);
return Err(ProviderError::Configuration(err_msg));
}
vec![target_module.clone()]
} else { } else {
info!("Targeting full workflow execution."); // Deprecated Flow
sort_modules_by_dependency(&template_set.modules) let err_msg = "Old flow is deprecated. Please provide input_bindings, output_path, and commit_hash.";
.map_err(|e| { error!("{}", err_msg);
error!("Failed to sort analysis modules: {}", e); return Err(ProviderError::Configuration(err_msg.to_string()));
e
})?
};
info!(execution_order = ?sorted_modules, "Successfully determined module execution order.");
// 4. Execute modules in order
let mut generated_results: HashMap<String, String> = HashMap::new();
// If single module, preload context from persistence (previous steps)
if command.module_id.is_some() {
info!("Preloading context from previous analysis results...");
match persistence_client.get_analysis_results(command.symbol.as_str()).await {
Ok(results) => {
for r in results {
if r.request_id == command.request_id {
generated_results.insert(r.module_id, r.content);
} }
}
async fn run_vgcs_based_generation(
state: Arc<AppState>,
command: GenerateReportCommand,
input_bindings: &[String],
output_path: &str,
commit_hash: &str
) -> Result<String> {
info!("Running VGCS based generation for task {:?}", command.task_id);
let persistence_client = PersistenceClient::new(state.config.data_persistence_service_url.clone());
let llm_providers = persistence_client.get_llm_providers_config().await.map_err(|e| ProviderError::Configuration(e.to_string()))?;
let llm_config = command.llm_config.clone().ok_or_else(|| ProviderError::Configuration("Missing llm_config".into()))?;
let analysis_prompt = command.analysis_prompt.clone().ok_or_else(|| ProviderError::Configuration("Missing analysis_prompt".into()))?;
// 1. Read Inputs from VGCS
let root_path = state.config.workflow_data_path.clone();
let req_id = command.request_id.to_string();
let commit_clone = commit_hash.to_string();
let bindings_clone = input_bindings.to_vec();
// Clones for input reading task
let root_path_in = root_path.clone();
let req_id_in = req_id.clone();
let commit_in = commit_clone.clone();
// Execution Trace Log
let mut execution_log = String::new();
execution_log.push_str(&format!("# Analysis Execution Trace: {}\n\n", command.task_id.clone().unwrap_or_default()));
execution_log.push_str("## 1. Input Retrieval\n\n");
let inputs_res = tokio::task::spawn_blocking(move || -> Result<(HashMap<String, String>, String)> {
let ctx = WorkerContext::new(&root_path_in, &req_id_in, &commit_in);
let mut data = HashMap::new();
let mut log = String::new();
for path in bindings_clone {
match ctx.read_text(&path) {
Ok(content) => {
let size = content.len();
log.push_str(&format!("- **Read Success**: `{}` (Size: {} bytes)\n", path, size));
// Use the full path as the key to avoid collisions between files with same name from different folders
let key = path.clone();
// Check for overwrite
if data.contains_key(&key) {
log.push_str(&format!("- **Warn**: Duplicate input key '{}' detected. Overwriting.\n", key));
} }
info!("Preloaded {} context items.", generated_results.len());
data.insert(key, content);
}, },
Err(e) => { Err(e) => {
warn!("Failed to preload analysis results: {}", e); log.push_str(&format!("- **Read Failed**: `{}` (Error: {})\n", path, e));
// Non-fatal, but might cause dependency error later
} }
} }
} }
Ok((data, log))
}).await.map_err(|e| ProviderError::Internal(anyhow::anyhow!("Join Error: {}", e)))??;
for module_id in sorted_modules { let (inputs, input_log) = inputs_res;
let module_config = template_set.modules.get(&module_id).unwrap(); execution_log.push_str(&input_log);
execution_log.push_str("\n## 2. Data Parsing & Formatting\n\n");
info!(module_id = %module_id, "All dependencies met. Generating report for module."); // 2. Prepare Context (Generic YAML Strategy)
let mut context_builder = String::new();
context_builder.push_str("\n\n# Data Context\n\n");
// Broadcast Module Start // Sort keys for deterministic output
let _ = stream_tx.send(serde_json::json!({ let mut sorted_keys: Vec<_> = inputs.keys().collect();
"type": "module_start", sorted_keys.sort();
"module_id": module_id
}).to_string());
let llm_client = match create_llm_client_for_module(&llm_providers, module_config) { for key in sorted_keys {
Ok(client) => client, let content = &inputs[key];
execution_log.push_str(&format!("### Processing Source: `{}`\n", key));
// Generic: Try parsing as JSON -> YAML
if let Ok(json_val) = serde_json::from_str::<serde_json::Value>(content) {
match serde_yaml::to_string(&json_val) {
Ok(yaml_str) => {
execution_log.push_str("- **Success**: Parsed JSON and converted to YAML.\n");
context_builder.push_str(&format!("---\n# Source: {}\n", key));
context_builder.push_str(&yaml_str);
context_builder.push_str("\n");
},
Err(e) => { Err(e) => {
error!(module_id = %module_id, "Failed to create LLM client: {}. Aborting workflow.", e); execution_log.push_str(&format!("- **Warn**: JSON parsed but YAML conversion failed: {}. Fallback to Raw JSON.\n", e));
let err_msg = format!("Error: Failed to create LLM client: {}", e); context_builder.push_str(&format!("---\n# Source: {}\n", key));
generated_results.insert(module_id.clone(), err_msg.clone()); context_builder.push_str("```json\n");
// Pretty print JSON as fallback
// Broadcast Error context_builder.push_str(&serde_json::to_string_pretty(&json_val).unwrap_or_else(|_| content.to_string()));
let _ = stream_tx.send(serde_json::json!({ context_builder.push_str("\n```\n");
"type": "error", }
"module_id": module_id, }
"payload": err_msg } else {
}).to_string()); // Plain text
execution_log.push_str("- **Note**: Not valid JSON, treating as plain text.\n");
// Publish Failed Event context_builder.push_str(&format!("---\n# Source: {}\n", key));
let fail_event = ReportFailedEvent { context_builder.push_str(content);
request_id: command.request_id, context_builder.push_str("\n");
symbol: command.symbol.clone(),
module_id: module_id.clone(),
error: err_msg.clone(),
};
if let Ok(payload) = serde_json::to_vec(&fail_event) {
if let Err(e) = state.nats.publish(fail_event.subject().to_string(), payload.into()).await {
error!("Failed to publish ReportFailedEvent: {}", e);
} }
} }
return Err(e); // 3. Construct Final Prompt
} // Prompt Strategy: <System/Analysis Prompt> \n\n <Context>
}; let final_prompt = format!("{}\n{}", analysis_prompt, context_builder);
let prompt_len = final_prompt.len();
let mut context = Context::new(); execution_log.push_str("\n## 3. Context & Prompt Assembly\n\n");
context.insert("company_name", &profile.name); execution_log.push_str(&format!("- **Context Size**: {} chars\n", context_builder.len()));
context.insert("ts_code", &command.symbol); execution_log.push_str(&format!("- **Total Prompt Size**: {} chars\n", prompt_len));
for dep in &module_config.dependencies { // Hard Context Limit Check (Safety Net)
if let Some(content) = generated_results.get(dep) { const MAX_CONTEXT_CHARS: usize = 128_000; // 128k chars
context.insert(dep, content); let (final_prompt_to_send, truncated) = if prompt_len > MAX_CONTEXT_CHARS {
} let trunc_msg = "\n\n[SYSTEM WARNING: Input data truncated to fit context limits.]";
}
// Format financial data into a markdown table
let formatted_financials = format_financials_to_markdown(&financials);
context.insert("financial_data", &formatted_financials);
info!(module_id = %module_id, "Rendering prompt template...");
let prompt = match Tera::one_off(&module_config.prompt_template, &context, true) {
Ok(p) => {
let p_len = p.len();
info!(module_id = %module_id, "Prompt rendered successfully. Length: {} chars", p_len);
// Hard Context Limit: 64K chars (~16K tokens)
// This is a temporary protection until we have a Deep Research / Summarization module.
const MAX_CONTEXT_CHARS: usize = 64_000;
if p_len > MAX_CONTEXT_CHARS {
let trunc_msg = "\n\n[SYSTEM WARNING: Input data truncated to fit context limits. Analysis may be partial.]";
let safe_len = MAX_CONTEXT_CHARS.saturating_sub(trunc_msg.len()); let safe_len = MAX_CONTEXT_CHARS.saturating_sub(trunc_msg.len());
let truncated_prompt = format!("{}{}", &p[..safe_len], trunc_msg); (format!("{}{}", &final_prompt[..safe_len], trunc_msg), true)
tracing::warn!(
module_id = %module_id,
"Prompt size ({} chars) exceeded limit ({} chars). Truncated.",
p_len, MAX_CONTEXT_CHARS
);
truncated_prompt
} else { } else {
p (final_prompt.clone(), false)
}
},
Err(e) => {
let err_msg = format!("Prompt rendering failed: {}", e);
error!(module_id = %module_id, "{}", err_msg);
generated_results.insert(module_id.clone(), format!("Error: {}", err_msg));
// Broadcast Error
let _ = stream_tx.send(serde_json::json!({
"type": "error",
"module_id": module_id,
"payload": err_msg
}).to_string());
// Publish Failed Event
let fail_event = ReportFailedEvent {
request_id: command.request_id,
symbol: command.symbol.clone(),
module_id: module_id.clone(),
error: err_msg.clone(),
};
if let Ok(payload) = serde_json::to_vec(&fail_event) {
let _ = state.nats.publish(fail_event.subject().to_string(), payload.into()).await;
}
return Err(ProviderError::Configuration(err_msg));
}
}; };
// Streaming Generation if truncated {
info!(module_id = %module_id, "Initiating LLM stream..."); execution_log.push_str(&format!("- **WARNING**: Prompt exceeded limit (Original: {}, Limit: {}). **TRUNCATED** to {} chars.\n", prompt_len, MAX_CONTEXT_CHARS, final_prompt_to_send.len()));
let mut stream = match llm_client.stream_text(prompt).await { } else {
Ok(s) => s, execution_log.push_str("- Prompt fits within limits.\n");
Err(e) => {
let err_msg = format!("LLM stream init failed: {}", e);
error!(module_id = %module_id, "{}", err_msg);
generated_results.insert(module_id.clone(), format!("Error: {}", err_msg));
// Broadcast Error
let _ = stream_tx.send(serde_json::json!({
"type": "error",
"module_id": module_id,
"payload": err_msg
}).to_string());
// Publish Failed Event
let fail_event = ReportFailedEvent {
request_id: command.request_id,
symbol: command.symbol.clone(),
module_id: module_id.clone(),
error: err_msg.clone(),
};
if let Ok(payload) = serde_json::to_vec(&fail_event) {
let _ = state.nats.publish(fail_event.subject().to_string(), payload.into()).await;
} }
return Err(ProviderError::LlmApi(err_msg)); info!("Final Prompt Length: {} chars (Context: {} chars)", final_prompt_to_send.len(), context_builder.len());
// 4. Init LLM
let mut target_provider: Option<&common_contracts::config_models::LlmProvider> = None;
let model_id = llm_config.model_id.as_deref().unwrap_or("default");
for provider in llm_providers.values() {
if provider.models.iter().any(|m| m.model_id == model_id) {
target_provider = Some(provider);
break;
} }
}; }
let provider = target_provider.ok_or_else(|| ProviderError::Configuration(format!("No provider found for model {}", model_id)))?;
let llm_client = LlmClient::new(
provider.api_base_url.clone(),
provider.api_key.clone().into(),
model_id.to_string(),
None,
);
// 5. Stream Generation
let mut stream = llm_client.stream_text(final_prompt_to_send).await.map_err(|e| ProviderError::LlmApi(e.to_string()))?;
let mut full_content = String::new(); let mut full_content = String::new();
let mut first_chunk_received = false;
while let Some(chunk_res) = stream.next().await { while let Some(chunk_res) = stream.next().await {
match chunk_res { if let Ok(chunk) = chunk_res {
Ok(chunk) => {
if !chunk.is_empty() { if !chunk.is_empty() {
if !first_chunk_received {
info!(module_id = %module_id, "Received first chunk from LLM stream.");
first_chunk_received = true;
}
full_content.push_str(&chunk); full_content.push_str(&chunk);
// Broadcast Content Chunk
// Fire and forget - if no subscribers (frontend disconnected), we continue
let _ = stream_tx.send(serde_json::json!({
"type": "content",
"module_id": module_id,
"payload": chunk
}).to_string());
// Publish TaskStreamUpdate (NATS) // Publish Stream Update
if let Some(task_id) = &command.task_id { if let Some(task_id) = &command.task_id {
let stream_evt = WorkflowEvent::TaskStreamUpdate { let stream_evt = WorkflowEvent::TaskStreamUpdate {
task_id: task_id.clone(), task_id: task_id.clone(),
content_delta: chunk.clone(), content_delta: chunk,
index: 0, // Index tracking might be hard with current stream logic, frontend usually appends index: 0,
}; };
if let Ok(payload) = serde_json::to_vec(&stream_evt) { if let Ok(payload) = serde_json::to_vec(&stream_evt) {
let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(command.request_id).to_string(); let subject = common_contracts::subjects::NatsSubject::WorkflowProgress(command.request_id).to_string();
// Fire and forget via NATS too let _ = state.nats.publish(subject, payload.into()).await;
let nats = state.nats.clone();
tokio::spawn(async move {
let _ = nats.publish(subject, payload.into()).await;
});
} }
} }
} }
},
Err(e) => {
error!(module_id = %module_id, "Stream error: {}", e);
let _ = stream_tx.send(serde_json::json!({
"type": "error",
"module_id": module_id,
"payload": format!("Stream error: {}", e)
}).to_string());
// We might want to break or continue?
// Let's continue to try to finish
}
}
}
info!(module_id = %module_id, "Successfully generated content (Length: {}).", full_content.len());
// Broadcast Module Done
let _ = stream_tx.send(serde_json::json!({
"type": "module_done",
"module_id": module_id
}).to_string());
// PERSISTENCE: Only write to DB after full generation
let result_to_persist = NewAnalysisResult {
request_id: command.request_id,
symbol: command.symbol.to_string(),
template_id: command.template_id.clone(),
module_id: module_id.clone(),
content: full_content.clone(),
meta_data: serde_json::json!({
"model_id": module_config.model_id,
"status": "success"
}),
};
if let Err(e) = persistence_client.create_analysis_result(result_to_persist).await {
error!(module_id = %module_id, "Failed to persist analysis result: {}", e);
}
// Publish Module Completed Event (ReportGenerated)
// This tells the orchestrator that this module is done.
let event = ReportGeneratedEvent {
request_id: command.request_id,
symbol: command.symbol.clone(),
module_id: module_id.clone(),
content_snapshot: Some(full_content.chars().take(200).collect()), // Preview
model_id: Some(module_config.model_id.clone()),
};
if let Ok(payload) = serde_json::to_vec(&event) {
info!(module_id = %module_id, "Publishing ReportGeneratedEvent");
if let Err(e) = state.nats.publish(event.subject().to_string(), payload.into()).await {
error!("Failed to publish ReportGeneratedEvent: {}", e);
} }
} }
generated_results.insert(module_id.clone(), full_content); execution_log.push_str("\n## 4. Output Generation\n\n");
} execution_log.push_str(&format!("- **Output Size**: {} chars\n", full_content.len()));
Ok(())
}.await;
if let Err(e) = result { // 6. Write Output and Commit
error!("Workflow failed: {}", e); let output_path_clone = output_path.to_string();
let _ = stream_tx.send(serde_json::json!({ let full_content_clone = full_content.clone();
"type": "error", let req_id_clone = req_id.clone();
"module_id": "workflow", let task_id_clone = command.task_id.clone().unwrap_or_default();
"payload": format!("Analysis workflow failed: {}", e)
}).to_string());
// Ensure we propagate the error so the message consumer sends TaskFailed // Prepare Execution Log Path (Sidecar)
return Err(e); // e.g. analysis/fundamental_analysis/AAPL_execution.md
} let log_path = if output_path.ends_with(".md") {
output_path.replace(".md", "_execution.md")
// Broadcast Workflow Done
let _ = stream_tx.send(serde_json::json!({
"type": "done"
}).to_string());
// Clean up channel with delay to allow clients to connect and receive the final messages
let streams = state.streams.clone();
let req_id = command.request_id;
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(300)).await; // Keep for 5 minutes to be safe
streams.remove_channel(&req_id);
});
info!("Report generation workflow finished.");
Ok(())
}
fn sort_modules_by_dependency(
modules: &HashMap<String, AnalysisModuleConfig>,
) -> Result<Vec<String>> {
let mut graph = DiGraph::<String, ()>::new();
let mut node_map = HashMap::new();
for module_id in modules.keys() {
let index = graph.add_node(module_id.clone());
node_map.insert(module_id.clone(), index);
}
for (module_id, module_config) in modules {
if let Some(&module_index) = node_map.get(module_id) {
for dep in &module_config.dependencies {
if let Some(&dep_index) = node_map.get(dep) {
graph.add_edge(dep_index, module_index, ());
} else { } else {
return Err(ProviderError::Configuration(format!( format!("{}_execution.md", output_path)
"Module '{}' has a missing dependency: '{}'",
module_id, dep
)));
}
}
}
}
match toposort(&graph, None) {
Ok(sorted_nodes) => {
let sorted_ids = sorted_nodes
.into_iter()
.map(|node_index| graph[node_index].clone())
.collect();
Ok(sorted_ids)
}
Err(cycle) => {
let cycle_id = graph[cycle.node_id()].clone();
Err(ProviderError::Configuration(format!(
"Circular dependency detected in analysis modules. Cycle involves: '{}'",
cycle_id
)))
}
}
}
fn create_llm_client_for_module(
llm_providers: &LlmProvidersConfig,
module_config: &AnalysisModuleConfig,
) -> Result<LlmClient> {
if module_config.provider_id.is_empty() {
return Err(ProviderError::Configuration(format!(
"Module '{}' has empty provider_id",
module_config.name
)));
}
let provider_id = &module_config.provider_id;
let provider = llm_providers.get(provider_id).ok_or_else(|| {
ProviderError::Configuration(format!(
"Provider '{}' not found for module '{}'",
provider_id, module_config.name
))
})?;
let api_url = provider.api_base_url.clone();
info!("Creating LLM client for module '{}' using provider '{}' with URL: '{}'", module_config.name, provider_id, api_url);
Ok(LlmClient::new(
api_url,
provider.api_key.clone().into(),
module_config.model_id.clone(),
None, // Default timeout
))
}
async fn fetch_data_and_configs(
client: &PersistenceClient,
symbol: &str,
request_id: uuid::Uuid,
) -> Result<(
CompanyProfileDto,
Vec<TimeSeriesFinancialDto>,
LlmProvidersConfig,
AnalysisTemplateSets,
)> {
info!("fetch_data_and_configs: Starting fetch for request_id: {}", request_id);
// 1. Fetch Configuration (Parallel)
info!("fetch_data_and_configs: Fetching LLM providers and template sets...");
let (llm_providers, template_sets) = tokio::try_join!(
client.get_llm_providers_config(),
client.get_analysis_template_sets(),
).map_err(|e| {
error!("fetch_data_and_configs: Failed to fetch configs: {}", e);
e
})?;
info!("fetch_data_and_configs: Configs fetched successfully.");
// 2. Fetch Session Data (Observational Data Snapshot)
info!("fetch_data_and_configs: Fetching session data for request_id: {}", request_id);
let session_data = match client.get_session_data(request_id).await {
Ok(data) => {
info!("fetch_data_and_configs: Fetched {} session data items.", data.len());
data
},
Err(e) => {
error!("Failed to fetch session data for {}: {}", request_id, e);
vec![] // Treat as empty if failed, though ideally we should fail?
}
}; };
let execution_log_clone = execution_log.clone();
let mut profile: Option<CompanyProfileDto> = None; // We need to commit on top of base commit
let mut financials: Vec<TimeSeriesFinancialDto> = Vec::new(); let commit_res = tokio::task::spawn_blocking(move || -> Result<String> {
let mut ctx = WorkerContext::new(&root_path, &req_id_clone, &commit_clone);
ctx.write_file(&output_path_clone, &full_content_clone).map_err(|e| ProviderError::Internal(e))?;
// Write the sidecar log
ctx.write_file(&log_path, &execution_log_clone).map_err(|e| ProviderError::Internal(e))?;
for item in session_data { ctx.commit(&format!("Analysis Result for {}", task_id_clone)).map_err(|e| ProviderError::Internal(e))
if item.data_type == "company_profile" { }).await.map_err(|e| ProviderError::Internal(anyhow::anyhow!("Join Error: {}", e)))??;
if let Ok(p) = serde_json::from_value::<CompanyProfileDto>(item.data_payload.clone()) {
if profile.is_none() {
profile = Some(p);
}
// If we have multiple profiles, we currently just take the first one found.
// Future: We could merge them or provide all to context.
}
} else if item.data_type == "financial_statements" {
if let Ok(mut f_list) = serde_json::from_value::<Vec<TimeSeriesFinancialDto>>(item.data_payload.clone()) {
// Tag the source for clarity in the report if not already present
for f in &mut f_list {
if f.source.is_none() {
f.source = Some(item.provider.clone());
}
}
financials.extend(f_list);
}
}
}
// 3. Fallback for Profile (Global Reference) info!("Generated report committed: {}", commit_res);
if profile.is_none() {
info!("Profile not found in session data, fetching global profile for {}", symbol);
if let Ok(p) = client.get_company_profile(symbol).await {
profile = Some(p);
}
}
// Ensure we have a profile to avoid crashes, or create a dummy one Ok(commit_res)
let final_profile = profile.unwrap_or_else(|| CompanyProfileDto {
symbol: symbol.to_string(),
name: symbol.to_string(),
industry: None,
list_date: None,
additional_info: None,
updated_at: None,
});
Ok((final_profile, financials, llm_providers, template_sets))
}
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::config::AppConfig;
use crate::state::AppState;
use common_contracts::dtos::SessionDataDto;
use common_contracts::config_models::{LlmProvider, LlmModel, AnalysisTemplateSet, AnalysisModuleConfig};
use common_contracts::symbol_utils::{CanonicalSymbol, Market};
use uuid::Uuid;
#[tokio::test]
async fn test_report_generation_flow() {
if std::env::var("NATS_ADDR").is_err() {
println!("Skipping integration test (no environment)");
return;
}
// 1. Env & Config
let config = AppConfig::load().expect("Failed to load config");
let nats_client = async_nats::connect(config.nats_addr.clone()).await.expect("Failed to connect to NATS in test");
let state = Arc::new(AppState::new(config.clone(), nats_client));
let persistence_client = PersistenceClient::new(config.data_persistence_service_url.clone());
// 2. Setup: LlmProvider
let api_key = std::env::var("OPENROUTER_API_KEY")
.expect("OPENROUTER_API_KEY must be set");
let api_url = std::env::var("OPENROUTER_API_URL")
.unwrap_or_else(|_| "https://openrouter.ai/api/v1".to_string());
let mut llm_config = persistence_client.get_llm_providers_config().await.unwrap_or_default();
llm_config.insert("openrouter".to_string(), LlmProvider {
api_base_url: api_url,
api_key: api_key,
name: "OpenRouter".to_string(),
models: vec![
LlmModel {
model_id: "google/gemini-flash-1.5".to_string(),
name: Some("Gemini Flash".to_string()),
is_active: true,
}
],
});
persistence_client.update_llm_providers_config(&llm_config).await.expect("Failed to set LLM config");
// 3. Setup: Template Set
let mut templates = persistence_client.get_analysis_template_sets().await.unwrap_or_default();
let mut modules = HashMap::new();
modules.insert("swot_analysis".to_string(), AnalysisModuleConfig {
name: "SWOT Analysis".to_string(),
provider_id: "openrouter".to_string(),
model_id: "google/gemini-flash-1.5".to_string(),
prompt_template: "Analyze SWOT for {{ company_name }} ({{ ts_code }}). Keep it brief.".to_string(),
dependencies: vec![],
});
templates.insert("test_template".to_string(), AnalysisTemplateSet {
name: "Test Template".to_string(),
modules,
});
persistence_client.update_analysis_template_sets(&templates).await.expect("Failed to set Template config");
// 4. Setup: Session Data (Mocked Financials)
let request_id = Uuid::new_v4();
let symbol = CanonicalSymbol::new("AAPL", &Market::US);
// Mock Profile
let profile_dto = CompanyProfileDto {
symbol: symbol.to_string(),
name: "Apple Inc.".to_string(),
industry: Some("Technology".to_string()),
list_date: None,
additional_info: None,
updated_at: None,
};
persistence_client.insert_session_data(&SessionDataDto {
request_id,
symbol: symbol.to_string(),
provider: "mock".to_string(),
data_type: "company_profile".to_string(),
data_payload: serde_json::to_value(&profile_dto).unwrap(),
created_at: None,
}).await.unwrap();
// 5. Construct Command
let cmd = GenerateReportCommand {
request_id,
symbol,
template_id: "test_template".to_string(),
task_id: None,
module_id: None,
};
// 6. Run Workflow
let result = run_report_generation_workflow(state.clone(), cmd).await;
// 7. Assert
assert!(result.is_ok(), "Report generation failed: {:?}", result.err());
}
} }

View File

@ -37,10 +37,14 @@ pub async fn handle_workflow_command(state: AppState, nats: async_nats::Client,
let financials_clone = financials.clone(); let financials_clone = financials.clone();
let symbol_code_clone = symbol_code.clone(); let symbol_code_clone = symbol_code.clone();
// Check for output path override from Orchestrator
let output_path_override = cmd.config.get("output_path").and_then(|s| s.as_str()).map(|s| s.to_string());
let commit_result = tokio::task::spawn_blocking(move || -> Result<String> { let commit_result = tokio::task::spawn_blocking(move || -> Result<String> {
let mut ctx = WorkerContext::new(&root_path, &req_id, &base_commit); let mut ctx = WorkerContext::new(&root_path, &req_id, &base_commit);
let base_dir = format!("raw/tushare/{}", symbol_code_clone); // Use resolved output path or fallback to default convention
let base_dir = output_path_override.unwrap_or_else(|| format!("raw/tushare/{}", symbol_code_clone));
let profile_json = serde_json::to_string_pretty(&profile_clone) let profile_json = serde_json::to_string_pretty(&profile_clone)
.context("Failed to serialize profile")?; .context("Failed to serialize profile")?;

View File

@ -3073,6 +3073,7 @@ dependencies = [
"common-contracts", "common-contracts",
"dashmap", "dashmap",
"futures", "futures",
"globset",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -18,6 +18,7 @@ futures = "0.3"
reqwest = { version = "0.12", features = ["json"] } reqwest = { version = "0.12", features = ["json"] }
dashmap = "6.1.0" dashmap = "6.1.0"
axum = "0.8.7" axum = "0.8.7"
globset = "0.4"
# Internal dependencies # Internal dependencies
common-contracts = { path = "../common-contracts", default-features = false } common-contracts = { path = "../common-contracts", default-features = false }

View File

@ -0,0 +1,279 @@
use anyhow::{Result, anyhow};
use common_contracts::configs::{ContextSelectorConfig, SelectionMode, LlmConfig};
use common_contracts::config_models::LlmProvidersConfig;
use workflow_context::{Vgcs, EntryKind, ContextStore};
use globset::{GlobBuilder, GlobMatcher};
use std::sync::Arc;
use std::collections::HashMap;
use tracing::info;
use crate::llm_client::LlmClient;
pub struct ResolutionResult {
pub paths: Vec<String>,
pub trace: String,
}
pub struct ContextResolver {
vgcs: Arc<Vgcs>,
}
impl ContextResolver {
pub fn new(vgcs: Arc<Vgcs>) -> Self {
Self { vgcs }
}
pub async fn resolve_input(
&self,
selector: &ContextSelectorConfig,
req_id: &str,
commit_hash: &str,
variables: &HashMap<String, String>,
llm_providers: &LlmProvidersConfig,
analysis_prompt: &str,
) -> Result<ResolutionResult> {
match &selector.mode {
SelectionMode::Manual { rules } => {
let resolved_rules = rules.iter().map(|r| {
let mut rule = r.clone();
for (k, v) in variables {
rule = rule.replace(&format!("{{{{{}}}}}", k), v);
}
rule
}).collect::<Vec<_>>();
self.resolve_manual(&resolved_rules, req_id, commit_hash).await
}
SelectionMode::Auto { llm_config } => {
let system_prompt = "You are an intelligent file selector for a financial analysis system. \
Your goal is to select the specific files from the repository that are necessary to fulfill the user's analysis request.\n\
Return ONLY a JSON array of string file paths (e.g. [\"path/to/file1\", \"path/to/file2\"]). \
Do not include any explanation, markdown formatting, or code blocks.";
let user_prompt = format!(
"I need to perform the following analysis task:\n\n\"{}\"\n\n\
Below is the current file structure of the repository. Please select the files that contain the data needed for this task.\n\
(Note: 'financials.json' usually contains financial statements, 'profile.json' contains company info, 'news.json' contains news).",
analysis_prompt
);
self.resolve_with_llm(system_prompt, &user_prompt, llm_config.as_ref(), llm_providers, req_id, commit_hash, "Auto").await
}
SelectionMode::Hybrid { selection_prompt, llm_config } => {
let mut user_req = selection_prompt.clone();
// Replace variables in prompt
for (k, v) in variables {
user_req = user_req.replace(&format!("{{{{{}}}}}", k), v);
}
let system_prompt = "You are an intelligent file selector. Return ONLY a JSON array of string file paths.";
let user_prompt = format!("Request: {}\n\nTask Context: {}\n\nSelect relevant files from the tree below.", user_req, analysis_prompt);
self.resolve_with_llm(system_prompt, &user_prompt, llm_config.as_ref(), llm_providers, req_id, commit_hash, "Hybrid").await
}
}
}
async fn resolve_manual(
&self,
rules: &[String],
req_id: &str,
commit_hash: &str,
) -> Result<ResolutionResult> {
let mut matchers = Vec::new();
for rule in rules {
let glob = GlobBuilder::new(rule)
.literal_separator(true)
.build()
.map_err(|e| anyhow!("Invalid glob pattern '{}': {}", rule, e))?;
matchers.push(glob.compile_matcher());
}
let mut files = Vec::new();
self.traverse(req_id, commit_hash, "", &matchers, &mut files, true)?;
let paths = files.into_iter().map(|(p, _)| p).collect::<Vec<_>>();
// Generate Trace
let mut trace = String::new();
trace.push_str("# Context Selection Trace: Manual\n\n");
trace.push_str("## Rules\n");
for rule in rules {
trace.push_str(&format!("- `{}`\n", rule));
}
trace.push_str("\n## Matched Files\n");
for file in &paths {
trace.push_str(&format!("- `{}`\n", file));
}
Ok(ResolutionResult {
paths,
trace,
})
}
#[allow(clippy::too_many_arguments)]
async fn resolve_with_llm(
&self,
system_prompt: &str,
user_instruction: &str,
llm_config: Option<&LlmConfig>,
providers: &LlmProvidersConfig,
req_id: &str,
commit_hash: &str,
mode_name: &str,
) -> Result<ResolutionResult> {
// 1. List all files with metadata
let mut all_files = Vec::new();
// Pass empty matchers to list everything
self.traverse(req_id, commit_hash, "", &[], &mut all_files, false)?;
if all_files.is_empty() {
return Ok(ResolutionResult {
paths: vec![],
trace: format!("# Context Selection Trace: {}\n\nNo files found in repository to select from.", mode_name),
});
}
// 2. Construct Tree String with Metadata
let tree_str = all_files.iter().map(|(path, meta)| {
format!("{} ({})", path, meta)
}).collect::<Vec<_>>().join("\n");
// 3. Create Client
let client = self.create_llm_client(llm_config, providers)?;
// 4. Call LLM
let full_user_prompt = format!("{}\n\n---\nFile Tree (Format: Path (Size | Lines | Words)):\n{}", user_instruction, tree_str);
info!("Calling LLM for file selection...");
let response = client.chat_completion(system_prompt, &full_user_prompt).await?;
// 5. Parse Response
// Clean up markdown code blocks if present
let cleaned = response.trim();
let cleaned = if cleaned.starts_with("```json") {
cleaned.trim_start_matches("```json").trim_end_matches("```").trim()
} else if cleaned.starts_with("```") {
cleaned.trim_start_matches("```").trim_end_matches("```").trim()
} else {
cleaned
};
let paths: Vec<String> = serde_json::from_str(cleaned)
.map_err(|e| anyhow!("Failed to parse LLM response as JSON list: {}. Response: {}", e, response))?;
// Validate paths exist in all_files to prevent hallucination
let all_paths_set: std::collections::HashSet<&String> = all_files.iter().map(|(p, _)| p).collect();
let valid_paths: Vec<String> = paths.into_iter()
.filter(|p| all_paths_set.contains(p))
.collect();
info!("LLM selected {} valid files out of {} total.", valid_paths.len(), all_files.len());
// Generate Trace
let mut trace = String::new();
trace.push_str(&format!("# Context Selection Trace: {}\n\n", mode_name));
trace.push_str("## System Prompt\n");
trace.push_str("```\n");
trace.push_str(system_prompt);
trace.push_str("\n```\n\n");
trace.push_str("## Full User Prompt (Intent + Context)\n");
trace.push_str("```\n");
trace.push_str(&full_user_prompt);
trace.push_str("\n```\n\n");
trace.push_str("## Model Output\n");
trace.push_str(cleaned);
trace.push_str("\n\n");
trace.push_str("## Final Selected Files\n");
for path in &valid_paths {
trace.push_str(&format!("- `{}`\n", path));
}
Ok(ResolutionResult {
paths: valid_paths,
trace,
})
}
fn traverse(
&self,
req_id: &str,
commit_hash: &str,
path: &str,
matchers: &[GlobMatcher],
results: &mut Vec<(String, String)>, // (Path, MetadataString)
filter_mode: bool, // If true, use matchers. If false, list all.
) -> Result<()> {
// list_dir might fail if path is not dir, but we only recurse on dirs.
// root "" is dir.
let entries = self.vgcs.list_dir(req_id, commit_hash, path)?;
for entry in entries {
// Skip internal files
if entry.name.starts_with(".") || entry.name == "index.md" || entry.name == "_meta.json" {
continue;
}
let full_path = if path.is_empty() || path == "/" {
entry.name.clone()
} else {
format!("{}/{}", path, entry.name)
};
match entry.kind {
EntryKind::File => {
let should_include = if !filter_mode {
true
} else {
matchers.iter().any(|m| m.is_match(&full_path))
};
if should_include {
let meta_str = format!(
"Size: {}, Lines: {}, Words: {}",
entry.size.unwrap_or(0),
entry.line_count.unwrap_or(0),
entry.word_count.unwrap_or(0)
);
results.push((full_path.clone(), meta_str));
}
}
EntryKind::Dir => {
self.traverse(req_id, commit_hash, &full_path, matchers, results, filter_mode)?;
}
}
}
Ok(())
}
fn create_llm_client(&self, llm_config: Option<&LlmConfig>, providers: &LlmProvidersConfig) -> Result<LlmClient> {
let model_id = llm_config.and_then(|c| c.model_id.as_deref()).unwrap_or("default");
let mut target_provider = None;
for provider in providers.values() {
if provider.models.iter().any(|m| m.model_id == model_id) {
target_provider = Some(provider);
break;
}
}
// Fallback to first provider if default
if target_provider.is_none() && model_id == "default" {
target_provider = providers.values().next();
}
let provider = target_provider.ok_or_else(|| anyhow!("No provider found for model {}", model_id))?;
Ok(LlmClient::new(
provider.api_base_url.clone(),
provider.api_key.clone(),
model_id.to_string(),
))
}
}

View File

@ -0,0 +1,43 @@
use common_contracts::messages::TaskType;
pub struct IOBinder;
impl IOBinder {
pub fn new() -> Self {
Self
}
pub fn allocate_output_path(
&self,
task_type: TaskType,
symbol: &str,
task_id: &str,
) -> String {
// Convention based paths:
// DataFetch: raw/{provider_id}/{symbol}
// DataProcessing: processed/{processor_id}/{symbol}
// Analysis: analysis/{module_id}/{symbol}.md
let clean_task_id = task_id.split(':').last().unwrap_or(task_id);
match task_type {
TaskType::DataFetch => format!("raw/{}/{}", clean_task_id, symbol),
TaskType::DataProcessing => format!("processed/{}/{}", clean_task_id, symbol),
TaskType::Analysis => format!("analysis/{}/{}.md", clean_task_id, symbol),
}
}
pub fn allocate_trace_path(
&self,
task_type: TaskType,
symbol: &str,
task_id: &str,
) -> String {
let clean_task_id = task_id.split(':').last().unwrap_or(task_id);
match task_type {
TaskType::Analysis => format!("analysis/{}/{}_trace.md", clean_task_id, symbol),
_ => format!("debug/{}/{}_trace.md", clean_task_id, symbol),
}
}
}

View File

@ -5,3 +5,6 @@ pub mod persistence;
pub mod state; pub mod state;
pub mod workflow; pub mod workflow;
pub mod dag_scheduler; pub mod dag_scheduler;
pub mod context_resolver;
pub mod io_binder;
pub mod llm_client;

View File

@ -0,0 +1,76 @@
use anyhow::{Result, anyhow};
use serde_json::{json, Value};
use std::time::Duration;
use tracing::{debug, error, info};
pub struct LlmClient {
http_client: reqwest::Client,
api_base_url: String,
api_key: String,
model: String,
}
impl LlmClient {
pub fn new(api_url: String, api_key: String, model: String) -> Self {
let api_url = api_url.trim();
// Normalize base URL (handling /v1, /chat/completions etc is tricky, keeping it simple for now)
// Assuming api_url is the base (e.g., https://api.openai.com/v1)
let base_url = api_url.trim_end_matches('/').to_string();
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(60))
.build()
.unwrap_or_default();
Self {
http_client,
api_base_url: base_url,
api_key,
model,
}
}
pub async fn chat_completion(&self, system_prompt: &str, user_prompt: &str) -> Result<String> {
let url = if self.api_base_url.ends_with("/chat/completions") {
self.api_base_url.clone()
} else {
format!("{}/chat/completions", self.api_base_url)
};
debug!("Sending request to LLM: {} ({})", self.model, url);
let body = json!({
"model": self.model,
"messages": [
{ "role": "system", "content": system_prompt },
{ "role": "user", "content": user_prompt }
],
"temperature": 0.1 // Low temperature for deterministic selection
});
let res = self.http_client.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| anyhow!("LLM request failed: {}", e))?;
if !res.status().is_success() {
let status = res.status();
let text = res.text().await.unwrap_or_default();
return Err(anyhow!("LLM API error {}: {}", status, text));
}
let json: Value = res.json().await
.map_err(|e| anyhow!("Failed to parse LLM response: {}", e))?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| anyhow!("Invalid LLM response format"))?;
Ok(content.to_string())
}
}

View File

@ -14,7 +14,10 @@ use tokio::sync::Mutex;
use crate::dag_scheduler::DagScheduler; use crate::dag_scheduler::DagScheduler;
use crate::state::AppState; use crate::state::AppState;
use workflow_context::{Vgcs, ContextStore, traits::Transaction}; // Added Transaction use crate::context_resolver::ContextResolver;
use crate::io_binder::IOBinder;
use common_contracts::configs::AnalysisModuleConfig;
use workflow_context::{Vgcs, ContextStore};
pub struct WorkflowEngine { pub struct WorkflowEngine {
state: Arc<AppState>, state: Arc<AppState>,
@ -255,7 +258,7 @@ impl WorkflowEngine {
async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> { async fn dispatch_task(&self, dag: &mut DagScheduler, task_id: &str, vgcs: &Vgcs) -> Result<()> {
// 1. Resolve Context (Merge if needed) // 1. Resolve Context (Merge if needed)
let context = dag.resolve_context(task_id, vgcs)?; let mut context = dag.resolve_context(task_id, vgcs)?;
// Store the input commit in the node for observability // Store the input commit in the node for observability
if let Some(base_commit) = &context.base_commit { if let Some(base_commit) = &context.base_commit {
@ -267,13 +270,102 @@ impl WorkflowEngine {
self.publish_log(dag.request_id, task_id, "INFO", "Task scheduled and dispatched.").await; self.publish_log(dag.request_id, task_id, "INFO", "Task scheduled and dispatched.").await;
// 3. Construct Command // 3. Construct Command
let (routing_key, task_type, mut config) = {
let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?; let node = dag.nodes.get(task_id).ok_or_else(|| anyhow::anyhow!("Node not found"))?;
(node.routing_key.clone(), node.task_type, node.config.clone())
};
// --- Resolution Phase ---
let symbol = self.get_symbol_from_config(&config);
// 3.1 IO Binding
let io_binder = IOBinder::new();
let output_path = io_binder.allocate_output_path(task_type, &symbol, task_id);
if let Some(obj) = config.as_object_mut() {
obj.insert("output_path".to_string(), serde_json::Value::String(output_path.clone()));
}
// 3.2 Context Resolution (for Analysis)
if task_type == TaskType::Analysis {
// We might update the base_commit if we write a trace file
let mut current_base_commit = context.base_commit.clone().unwrap_or_default();
if let Some(module_config_val) = config.get("_module_config") {
if let Ok(module_config) = serde_json::from_value::<AnalysisModuleConfig>(module_config_val.clone()) {
let mut variables = std::collections::HashMap::new();
variables.insert("symbol".to_string(), symbol.clone());
if let Some(market) = config.get("market").and_then(|v| v.as_str()) {
variables.insert("market".to_string(), market.to_string());
}
let resolver = ContextResolver::new(self.state.vgcs.clone());
// Fetch LLM providers for context resolution (Hybrid/Auto modes)
let llm_providers = match self.state.persistence_client.get_llm_providers_config().await {
Ok(p) => p,
Err(e) => {
warn!("Failed to fetch LLM providers config (non-fatal for Manual mode): {}", e);
common_contracts::config_models::LlmProvidersConfig::default()
}
};
match resolver.resolve_input(&module_config.context_selector, &dag.request_id.to_string(), &current_base_commit, &variables, &llm_providers, &module_config.analysis_prompt).await {
Ok(resolution) => {
// 1. Inject Input Bindings
if let Some(obj) = config.as_object_mut() {
obj.insert("input_bindings".to_string(), serde_json::to_value(&resolution.paths)?);
}
// 2. Write Trace Sidecar to VGCS
let trace_path = io_binder.allocate_trace_path(task_type, &symbol, task_id);
// Use a blocking task for VGCS write/commit to avoid async issues with standard IO
let vgcs = self.state.vgcs.clone();
let req_id_str = dag.request_id.to_string();
let base_commit_for_write = current_base_commit.clone();
let trace_content = resolution.trace.clone();
let task_id_str = task_id.to_string();
let trace_path_clone = trace_path.clone();
let trace_commit_res = tokio::task::spawn_blocking(move || -> Result<String> {
let mut tx = vgcs.begin_transaction(&req_id_str, &base_commit_for_write)?;
tx.write(&trace_path_clone, trace_content.as_bytes())?;
let new_commit = Box::new(tx).commit(&format!("Context Resolution Trace for {}", task_id_str), "Orchestrator")?;
Ok(new_commit)
}).await;
match trace_commit_res {
Ok(Ok(new_commit)) => {
info!("Written context resolution trace to {} (Commit: {})", trace_path, new_commit);
// Update the base commit for the worker, so it sees the trace file (linear history)
current_base_commit = new_commit;
// Update the context passed to the worker
context.base_commit = Some(current_base_commit.clone());
// Also update the DAG node's input commit for observability
// Note: dag is locked in this scope, we can modify it but we need to handle scope issues if we were using dag inside closure.
// We are outside closure here.
dag.set_input_commit(task_id, current_base_commit);
},
Ok(Err(e)) => error!("Failed to write trace file: {}", e),
Err(e) => error!("Failed to join trace write task: {}", e),
}
},
Err(e) => {
error!("Context resolution failed for task {}: {}", task_id, e);
// We proceed, but the worker might fail if it relies on inputs
}
}
}
}
}
let cmd = WorkflowTaskCommand { let cmd = WorkflowTaskCommand {
request_id: dag.request_id, request_id: dag.request_id,
task_id: task_id.to_string(), task_id: task_id.to_string(),
routing_key: node.routing_key.clone(), routing_key: routing_key.clone(),
config: node.config.clone(), config, // Use modified config
context, context,
storage: StorageConfig { storage: StorageConfig {
root_path: self.state.config.workflow_data_path.clone(), root_path: self.state.config.workflow_data_path.clone(),
@ -299,6 +391,10 @@ impl WorkflowEngine {
Ok(()) Ok(())
} }
fn get_symbol_from_config(&self, config: &serde_json::Value) -> String {
config.get("symbol").and_then(|v| v.as_str()).unwrap_or("unknown").to_string()
}
// Helper to build DAG // Helper to build DAG
fn build_dag( fn build_dag(
&self, &self,
@ -321,6 +417,11 @@ impl WorkflowEngine {
for key in source_keys { for key in source_keys {
let config = &data_sources[key]; let config = &data_sources[key];
if config.enabled { if config.enabled {
// Special handling for MOCK market: skip real providers
if market == "MOCK" && key.to_lowercase() != "mock" {
continue;
}
let provider_key = key.to_lowercase(); let provider_key = key.to_lowercase();
let task_id = format!("fetch:{}", provider_key); let task_id = format!("fetch:{}", provider_key);
fetch_tasks.push(task_id.clone()); fetch_tasks.push(task_id.clone());
@ -369,17 +470,25 @@ impl WorkflowEngine {
// We pass the FULL module config here if we want the worker to be stateless, // We pass the FULL module config here if we want the worker to be stateless,
// BUT existing worker logic fetches template again. // BUT existing worker logic fetches template again.
// To support "Single Module Execution", we should probably pass the module_id. // To support "Single Module Execution", we should probably pass the module_id.
let mut node_config = json!({
"template_id": template_id,
"module_id": module_id,
"symbol": symbol.as_str(),
"market": market
});
// Embed internal module config for Orchestrator use (Context Resolution)
if let Some(obj) = node_config.as_object_mut() {
obj.insert("_module_config".to_string(), serde_json::to_value(module_config).unwrap_or(serde_json::Value::Null));
}
dag.add_node( dag.add_node(
task_id.clone(), task_id.clone(),
Some(module_config.name.clone()), Some(module_config.name.clone()),
TaskType::Analysis, TaskType::Analysis,
"analysis.report".to_string(), // routing_key matches what report-generator consumes "analysis.report".to_string(), // routing_key matches what report-generator consumes
json!({ node_config
"template_id": template_id,
"module_id": module_id,
"symbol": symbol.as_str(),
"market": market
})
); );
// Dependencies // Dependencies

View File

@ -86,9 +86,35 @@ impl YFinanceDataProvider {
Ok(crumb_text) Ok(crumb_text)
} }
/// Tests connectivity by ensuring we can get a crumb from Yahoo Finance. /// Tests connectivity by ensuring we can get a crumb AND fetch a small piece of data.
/// This bypasses the crumb cache check implicitly because the fetch will fail if network is down.
/// However, to be absolutely sure we are testing connectivity, we should force a fetch.
pub async fn ping(&self) -> Result<(), AppError> { pub async fn ping(&self) -> Result<(), AppError> {
self.ensure_crumb().await.map(|_| ()) // 1. Ensure we have a crumb (this might use cache)
let crumb = self.ensure_crumb().await?;
// 2. Perform a real request to verify connectivity and crumb validity.
// We fetch a lightweight endpoint, e.g., quoteType for a stable symbol like "SPY".
let url = format!(
"https://query2.finance.yahoo.com/v10/finance/quoteSummary/SPY?modules=quoteType&crumb={}",
crumb
);
info!("Ping: sending verification request to Yahoo Finance...");
let resp = self.client
.get(&url)
.send()
.await
.map_err(|e| AppError::ServiceRequest(e))?;
if !resp.status().is_success() {
warn!("Ping failed: Yahoo Finance returned status {}", resp.status());
// If unauthorized, maybe crumb is stale? We could invalidate cache here,
// but for now let's just report failure.
return Err(AppError::ServiceRequest(resp.error_for_status().unwrap_err()));
}
Ok(())
} }
pub async fn fetch_all_data( pub async fn fetch_all_data(