// This file will contain integration tests for the database functions in `src/db.rs`. // We will use the `#[sqlx::test]` macro to run tests in a transaction that is rolled back at the end. // This ensures that our tests are isolated and do not leave any data in the database. // Silence unused_imports warning for now, as we will add tests here shortly. #![allow(unused_imports)] use data_persistence_service::{ db, dtos::{CompanyProfileDto, TimeSeriesFinancialDto, DailyMarketDataDto, NewAnalysisResultDto}, models, }; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::time::Duration; async fn setup() -> PgPool { dotenvy::dotenv().ok(); let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set for tests"); PgPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_secs(3)) .connect(&db_url) .await .expect("Failed to create pool.") } // Example test structure // #[sqlx::test] // async fn test_some_db_function(pool: PgPool) { // // 1. Setup: Insert some test data // // 2. Act: Call the database function // // 3. Assert: Check the result // assert!(true); // } #[sqlx::test] async fn test_upsert_and_get_company(pool: PgPool) { // 1. Setup: Create a test company DTO let new_company = CompanyProfileDto { symbol: "TEST.SYM".to_string(), name: "Test Company Inc.".to_string(), industry: Some("Testing".to_string()), list_date: Some(chrono::NaiveDate::from_ymd_opt(2024, 1, 1).unwrap()), additional_info: Some(serde_json::json!({ "ceo": "John Doe" })), }; // 2. Act: Call the upsert function let upsert_result = db::upsert_company(&pool, &new_company).await; assert!(upsert_result.is_ok()); // 3. Assert: Call the get function and verify the data let fetched_company = db::get_company_by_symbol(&pool, "TEST.SYM").await.unwrap().unwrap(); assert_eq!(fetched_company.symbol, new_company.symbol); assert_eq!(fetched_company.name, new_company.name); assert_eq!(fetched_company.industry, new_company.industry); assert_eq!(fetched_company.list_date, new_company.list_date); assert_eq!(fetched_company.additional_info, new_company.additional_info); // 4. Act (Update): Create a modified DTO and upsert again let updated_company = CompanyProfileDto { symbol: "TEST.SYM".to_string(), name: "Test Company LLC".to_string(), // Name changed industry: Some("Advanced Testing".to_string()), // Industry changed list_date: new_company.list_date, additional_info: new_company.additional_info, }; let update_result = db::upsert_company(&pool, &updated_company).await; assert!(update_result.is_ok()); // 5. Assert (Update): Fetch again and verify the updated data let fetched_updated_company = db::get_company_by_symbol(&pool, "TEST.SYM").await.unwrap().unwrap(); assert_eq!(fetched_updated_company.name, "Test Company LLC"); assert_eq!(fetched_updated_company.industry, Some("Advanced Testing".to_string())); } #[sqlx::test] async fn test_batch_insert_and_get_financials(pool: PgPool) { // 1. Setup: Create some test financial DTOs let financials = vec![ TimeSeriesFinancialDto { symbol: "TEST.FIN".to_string(), metric_name: "revenue".to_string(), period_date: chrono::NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(), value: 1000.0, source: Some("test".to_string()), }, TimeSeriesFinancialDto { symbol: "TEST.FIN".to_string(), metric_name: "roe".to_string(), period_date: chrono::NaiveDate::from_ymd_opt(2023, 12, 31).unwrap(), value: 15.5, source: Some("test".to_string()), }, ]; // 2. Act: Call the batch insert function let insert_result = db::batch_insert_financials(&pool, &financials).await; assert!(insert_result.is_ok()); // 3. Assert: Get all financials and verify let fetched_all = db::get_financials_by_symbol(&pool, "TEST.FIN", None).await.unwrap(); assert_eq!(fetched_all.len(), 2); // 4. Assert: Get specific metric and verify let fetched_roe = db::get_financials_by_symbol(&pool, "TEST.FIN", Some(vec!["roe".to_string()])).await.unwrap(); assert_eq!(fetched_roe.len(), 1); assert_eq!(fetched_roe[0].metric_name, "roe"); // Note: Comparing decimals requires conversion or a tolerance-based approach assert_eq!(fetched_roe[0].value.to_string(), "15.5"); } #[sqlx::test] async fn test_batch_insert_and_get_daily_data(pool: PgPool) { // 1. Setup: Create some test daily market data DTOs let daily_data = vec![ DailyMarketDataDto { symbol: "TEST.MKT".to_string(), trade_date: chrono::NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(), open_price: Some(100.0), high_price: Some(102.5), low_price: Some(99.5), close_price: Some(101.0), volume: Some(10000), pe: Some(20.5), pb: Some(2.1), total_mv: Some(1000000.0), }, DailyMarketDataDto { symbol: "TEST.MKT".to_string(), trade_date: chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap(), open_price: Some(101.0), high_price: Some(103.5), low_price: Some(100.5), close_price: Some(102.0), volume: Some(12000), pe: Some(20.7), pb: Some(2.2), total_mv: Some(1020000.0), }, ]; // 2. Act: Call the batch insert function let insert_result = db::batch_insert_daily_data(&pool, &daily_data).await; assert!(insert_result.is_ok()); // 3. Assert: Get all daily data and verify let fetched_all = db::get_daily_data_by_symbol(&pool, "TEST.MKT", None, None).await.unwrap(); assert_eq!(fetched_all.len(), 2); assert_eq!(fetched_all[0].trade_date, chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()); // Desc order // 4. Assert: Get data within a date range let start_date = chrono::NaiveDate::from_ymd_opt(2024, 1, 2).unwrap(); let fetched_one = db::get_daily_data_by_symbol(&pool, "TEST.MKT", Some(start_date), None).await.unwrap(); assert_eq!(fetched_one.len(), 1); assert_eq!(fetched_one[0].trade_date, start_date); let close_str = fetched_one[0].close_price.unwrap().to_string(); assert!(close_str == "102.0" || close_str == "102"); } #[sqlx::test] async fn test_create_and_get_analysis_results(pool: PgPool) { // 1. Setup: Create a test analysis result DTO let new_analysis = NewAnalysisResultDto { symbol: "TEST.AI".to_string(), module_id: "bull_case".to_string(), model_name: Some("test-model-v1".to_string()), content: "This is a bullish analysis.".to_string(), meta_data: Some(serde_json::json!({ "tokens": 123 })), }; // 2. Act: Call the create function let created_result = db::create_analysis_result(&pool, &new_analysis).await.unwrap(); assert_eq!(created_result.symbol, "TEST.AI"); assert_eq!(created_result.module_id, "bull_case"); // 3. Assert: Get by symbol and module_id let fetched_by_symbol = db::get_analysis_results(&pool, "TEST.AI", Some("bull_case")).await.unwrap(); assert_eq!(fetched_by_symbol.len(), 1); assert_eq!(fetched_by_symbol[0].id, created_result.id); assert_eq!(fetched_by_symbol[0].content, new_analysis.content); // 4. Assert: Get by ID let fetched_by_id = db::get_analysis_result_by_id(&pool, created_result.id).await.unwrap().unwrap(); assert_eq!(fetched_by_id.symbol, "TEST.AI"); assert_eq!(fetched_by_id.content, new_analysis.content); // 5. Assert: Get by symbol only let fetched_all_for_symbol = db::get_analysis_results(&pool, "TEST.AI", None).await.unwrap(); assert_eq!(fetched_all_for_symbol.len(), 1); }