import yaml import os import json from typing import Any, Dict, List, Optional from numbers import Number from app.data_providers.base import BaseDataProvider from app.data_providers.tushare import TushareProvider # from app.data_providers.ifind import TonghsProvider from app.data_providers.yfinance import YfinanceProvider from app.data_providers.finnhub import FinnhubProvider import logging logger = logging.getLogger(__name__) class DataManager: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(DataManager, cls).__new__(cls) return cls._instance def __init__(self, config_path: str = None): if hasattr(self, '_initialized') and self._initialized: return if config_path is None: # Assume the config file is in the 'config' directory at the root of the repo # Find the project root by looking for the config directory current_dir = os.path.dirname(__file__) while current_dir != os.path.dirname(current_dir): # Not at filesystem root if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")): REPO_ROOT = current_dir break current_dir = os.path.dirname(current_dir) else: # Fallback to the original calculation REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) config_path = os.path.join(REPO_ROOT, "config", "data_sources.yaml") with open(config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) self.providers = {} # Build provider base config ONLY from config/config.json (do not read env vars) base_cfg: Dict[str, Any] = {"data_sources": {}} try: # Use the same REPO_ROOT calculation as data_sources.yaml current_dir = os.path.dirname(__file__) while current_dir != os.path.dirname(current_dir): # Not at filesystem root if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")): REPO_ROOT = current_dir break current_dir = os.path.dirname(current_dir) else: # Fallback to the original calculation REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) cfg_json_path = os.path.join(REPO_ROOT, "config", "config.json") if os.path.exists(cfg_json_path): with open(cfg_json_path, "r", encoding="utf-8") as jf: cfg_json = json.load(jf) ds_from_json = (cfg_json.get("data_sources") or {}) for name, node in ds_from_json.items(): if node.get("api_key"): base_cfg["data_sources"][name] = {"api_key": node.get("api_key")} logger.info(f"Loaded API key for provider '{name}' from config.json") else: logger.debug("config/config.json not found; skipping JSON token load.") except Exception as e: logger.warning(f"Failed to read tokens from config/config.json: {e}") import traceback traceback.print_exc() try: self._init_providers(base_cfg) except Exception as e: logger.error(f"Failed to initialize data providers: {e}") self._initialized = True def _init_providers(self, base_cfg: Dict[str, Any]) -> None: """ Initializes providers with the given base configuration. This method should be called after the base config is loaded. """ provider_map = { "tushare": TushareProvider, # "ifind": TonghsProvider, "yfinance": YfinanceProvider, "finnhub": FinnhubProvider, } for name, provider_class in provider_map.items(): token = base_cfg.get("data_sources", {}).get(name, {}).get("api_key") source_config = self.config['data_sources'].get(name, {}) # Initialize the provider if a token is found or not required if token or not source_config.get('api_key_env'): try: self.providers[name] = provider_class(token=token) except Exception as e: logger.error(f"Failed to initialize provider '{name}': {e}") else: logger.warning(f"Provider '{name}' requires API key but none provided in config.json. Skipping.") def _detect_market(self, stock_code: str) -> str: if stock_code.endswith(('.SH', '.SZ')): return 'CN' elif stock_code.endswith('.HK'): return 'HK' elif stock_code.endswith('.T'): # Assuming .T for Tokyo return 'JP' else: # Default to US return 'US' async def get_data(self, method_name: str, stock_code: str, **kwargs): market = self._detect_market(stock_code) priority_list = self.config.get('markets', {}).get(market, {}).get('priority', []) for provider_name in priority_list: provider = self.providers.get(provider_name) if not provider: logger.warning(f"Provider '{provider_name}' not initialized.") continue try: method = getattr(provider, method_name) data = await method(stock_code=stock_code, **kwargs) is_success = False if data is None: is_success = False elif isinstance(data, list): is_success = len(data) > 0 elif isinstance(data, dict): is_success = len(data) > 0 else: is_success = True if is_success: logger.info(f"Data successfully fetched from '{provider_name}' for '{stock_code}'.") return data except Exception as e: logger.warning(f"Provider '{provider_name}' failed for '{stock_code}': {e}. Trying next provider.") logger.error(f"All data providers failed for '{stock_code}' on method '{method_name}'.") return None async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]: data = await self.get_data('get_financial_statements', stock_code, report_dates=report_dates) if data is None: return {} # Normalize to series format if isinstance(data, dict): # Already in series format (e.g., tushare) return data elif isinstance(data, list): # Convert from flat format to series format series: Dict[str, List[Dict[str, Any]]] = {} for report in data: year = str(report.get('year', report.get('end_date', '')[:4])) if not year: continue for key, value in report.items(): if key in ['ts_code', 'stock_code', 'year', 'end_date', 'period', 'ann_date', 'f_ann_date', 'report_type']: continue # Accept numpy/pandas numeric types as well as builtin numbers if value is not None and isinstance(value, Number): if key not in series: series[key] = [] if not any(d['year'] == year for d in series[key]): # Store as builtin float to avoid JSON serialization issues try: numeric_value = float(value) except Exception: # Fallback: skip if cannot coerce to float continue series[key].append({"year": year, "value": numeric_value}) return series else: return {} async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]: return await self.get_data('get_daily_price', stock_code, start_date=start_date, end_date=end_date) async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]: return await self.get_data('get_stock_basic', stock_code) data_manager = DataManager()