Fundamental_Analysis/services/data-persistence-service/tests/db_tests.rs
Lv, Qi 0cb31e363e Refactor E2E tests and improve error handling in Orchestrator
- Fix `simple_test_analysis` template in E2E test setup to align with Orchestrator's data fetch logic.
- Implement and verify additional E2E scenarios:
    - Scenario C: Partial Provider Failure (verified error propagation fix in Orchestrator).
    - Scenario D: Invalid Symbol input.
    - Scenario E: Analysis Module failure.
- Update `WorkflowStateMachine::handle_report_failed` to correctly scope error broadcasting to the specific task instead of failing effectively silently or broadly.
- Update testing strategy documentation to reflect completed Phase 4 testing.
- Skip Scenario B (Orchestrator Restart) as persistence is not yet implemented (decision made to defer persistence).
2025-11-21 20:44:32 +08:00

188 lines
7.6 KiB
Rust

// This file will contain integration tests for the database functions in `src/db.rs`.
// We will use the `#[sqlx::test]` macro to run tests in a transaction that is rolled back at the end.
// This ensures that our tests are isolated and do not leave any data in the database.
// Silence unused_imports warning for now, as we will add tests here shortly.
#![allow(unused_imports)]
use data_persistence_service::{
db,
dtos::{CompanyProfileDto, TimeSeriesFinancialDto, DailyMarketDataDto, NewAnalysisResult},
models,
};
use sqlx::{postgres::PgPoolOptions, PgPool};
use std::time::Duration;
use uuid::Uuid;
async fn setup() -> PgPool {
dotenvy::dotenv().ok();
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set for tests");
PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(3))
.connect(&db_url)
.await
.expect("Failed to create pool.")
}
#[sqlx::test]
async fn test_upsert_and_get_company(pool: PgPool) {
// 1. Setup: Create a test company DTO
let new_company = CompanyProfileDto {
symbol: "TEST.SYM".to_string(),
name: "Test Company Inc.".to_string(),
industry: Some("Testing".to_string()),
list_date: Some(chrono::NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()),
additional_info: Some(serde_json::json!({ "ceo": "John Doe" })),
updated_at: None,
};
// 2. Act: Call the upsert function
let upsert_result = db::upsert_company(&pool, &new_company).await;
assert!(upsert_result.is_ok());
// 3. Assert: Call the get function and verify the data
let fetched_company = db::get_company_by_symbol(&pool, "TEST.SYM").await.unwrap().unwrap();
assert_eq!(fetched_company.symbol, new_company.symbol);
assert_eq!(fetched_company.name, new_company.name);
assert_eq!(fetched_company.industry, new_company.industry);
assert_eq!(fetched_company.list_date, new_company.list_date);
assert_eq!(fetched_company.additional_info, new_company.additional_info);
// 4. Act (Update): Create a modified DTO and upsert again
let updated_company = CompanyProfileDto {
symbol: "TEST.SYM".to_string(),
name: "Test Company LLC".to_string(), // Name changed
industry: Some("Advanced Testing".to_string()), // Industry changed
list_date: new_company.list_date,
additional_info: new_company.additional_info,
updated_at: None,
};
let update_result = db::upsert_company(&pool, &updated_company).await;
assert!(update_result.is_ok());
// 5. Assert (Update): Fetch again and verify the updated data
let fetched_updated_company = db::get_company_by_symbol(&pool, "TEST.SYM").await.unwrap().unwrap();
assert_eq!(fetched_updated_company.name, "Test Company LLC");
assert_eq!(fetched_updated_company.industry, Some("Advanced Testing".to_string()));
}
#[sqlx::test]
async fn test_batch_insert_and_get_financials(pool: PgPool) {
// 1. Setup: Create some test financial DTOs
let financials = vec![
TimeSeriesFinancialDto {
symbol: "TEST.FIN".to_string(),
metric_name: "revenue".to_string(),
period_date: chrono::NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(),
value: 1000.0,
source: Some("test".to_string()),
},
TimeSeriesFinancialDto {
symbol: "TEST.FIN".to_string(),
metric_name: "roe".to_string(),
period_date: chrono::NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(),
value: 15.5,
source: Some("test".to_string()),
},
];
// 2. Act: Call the batch insert function
let insert_result = db::batch_insert_financials(&pool, &financials).await;
assert!(insert_result.is_ok());
// 3. Assert: Get all financials and verify
let fetched_all = db::get_financials_by_symbol(&pool, "TEST.FIN", None).await.unwrap();
assert_eq!(fetched_all.len(), 2);
// 4. Assert: Get specific metric and verify
let fetched_roe = db::get_financials_by_symbol(&pool, "TEST.FIN", Some(vec!["roe".to_string()])).await.unwrap();
assert_eq!(fetched_roe.len(), 1);
assert_eq!(fetched_roe[0].metric_name, "roe");
// Note: Comparing decimals requires conversion or a tolerance-based approach
assert_eq!(fetched_roe[0].value.to_string(), "15.5");
}
#[sqlx::test]
async fn test_batch_insert_and_get_daily_data(pool: PgPool) {
// 1. Setup: Create some test daily market data DTOs
let daily_data = vec![
DailyMarketDataDto {
symbol: "TEST.MKT".to_string(),
trade_date: chrono::NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
open_price: Some(100.0),
high_price: Some(102.5),
low_price: Some(99.5),
close_price: Some(101.0),
volume: Some(10000),
pe: Some(20.5),
pb: Some(2.1),
total_mv: Some(1000000.0),
},
DailyMarketDataDto {
symbol: "TEST.MKT".to_string(),
trade_date: chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap(),
open_price: Some(101.0),
high_price: Some(103.5),
low_price: Some(100.5),
close_price: Some(102.0),
volume: Some(12000),
pe: Some(20.7),
pb: Some(2.2),
total_mv: Some(1020000.0),
},
];
// 2. Act: Call the batch insert function
let insert_result = db::batch_insert_daily_data(&pool, &daily_data).await;
assert!(insert_result.is_ok());
// 3. Assert: Get all daily data and verify
let fetched_all = db::get_daily_data_by_symbol(&pool, "TEST.MKT", None, None).await.unwrap();
assert_eq!(fetched_all.len(), 2);
assert_eq!(fetched_all[0].trade_date, chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()); // Desc order
// 4. Assert: Get data within a date range
let start_date = chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap();
let fetched_one = db::get_daily_data_by_symbol(&pool, "TEST.MKT", Some(start_date), None).await.unwrap();
assert_eq!(fetched_one.len(), 1);
assert_eq!(fetched_one[0].trade_date, start_date);
let close_str = fetched_one[0].close_price.unwrap().to_string();
assert!(close_str == "102.0" || close_str == "102");
}
#[sqlx::test]
async fn test_create_and_get_analysis_results(pool: PgPool) {
// 1. Setup: Create a test analysis result DTO
let new_analysis = NewAnalysisResult {
request_id: Uuid::new_v4(),
symbol: "TEST.AI".to_string(),
template_id: "default_template".to_string(),
module_id: "bull_case".to_string(),
content: "This is a bullish analysis.".to_string(),
meta_data: serde_json::json!({ "tokens": 123 }),
};
// 2. Act: Call the create function
let created_result = db::create_analysis_result(&pool, &new_analysis).await.unwrap();
assert_eq!(created_result.symbol, "TEST.AI");
assert_eq!(created_result.module_id, "bull_case");
// 3. Assert: Get by symbol and module_id
let fetched_by_symbol = db::get_analysis_results(&pool, "TEST.AI", Some("bull_case")).await.unwrap();
assert_eq!(fetched_by_symbol.len(), 1);
assert_eq!(fetched_by_symbol[0].id, created_result.id);
assert_eq!(fetched_by_symbol[0].content, new_analysis.content);
// 4. Assert: Get by ID
let fetched_by_id = db::get_analysis_result_by_id(&pool, created_result.id).await.unwrap().unwrap();
assert_eq!(fetched_by_id.symbol, "TEST.AI");
assert_eq!(fetched_by_id.content, new_analysis.content);
// 5. Assert: Get by symbol only
let fetched_all_for_symbol = db::get_analysis_results(&pool, "TEST.AI", None).await.unwrap();
assert_eq!(fetched_all_for_symbol.len(), 1);
}