Fundamental_Analysis/backend/app/data_manager.py

195 lines
8.4 KiB
Python

import yaml
import os
import json
from typing import Any, Dict, List, Optional
from numbers import Number
from app.data_providers.base import BaseDataProvider
from app.data_providers.tushare import TushareProvider
# from app.data_providers.ifind import TonghsProvider
from app.data_providers.yfinance import YfinanceProvider
from app.data_providers.finnhub import FinnhubProvider
import logging
logger = logging.getLogger(__name__)
class DataManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(DataManager, cls).__new__(cls)
return cls._instance
def __init__(self, config_path: str = None):
if hasattr(self, '_initialized') and self._initialized:
return
if config_path is None:
# Assume the config file is in the 'config' directory at the root of the repo
# Find the project root by looking for the config directory
current_dir = os.path.dirname(__file__)
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
REPO_ROOT = current_dir
break
current_dir = os.path.dirname(current_dir)
else:
# Fallback to the original calculation
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
config_path = os.path.join(REPO_ROOT, "config", "data_sources.yaml")
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.providers = {}
# Build provider base config ONLY from config/config.json (do not read env vars)
base_cfg: Dict[str, Any] = {"data_sources": {}}
try:
# Use the same REPO_ROOT calculation as data_sources.yaml
current_dir = os.path.dirname(__file__)
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
REPO_ROOT = current_dir
break
current_dir = os.path.dirname(current_dir)
else:
# Fallback to the original calculation
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
cfg_json_path = os.path.join(REPO_ROOT, "config", "config.json")
if os.path.exists(cfg_json_path):
with open(cfg_json_path, "r", encoding="utf-8") as jf:
cfg_json = json.load(jf)
ds_from_json = (cfg_json.get("data_sources") or {})
for name, node in ds_from_json.items():
if node.get("api_key"):
base_cfg["data_sources"][name] = {"api_key": node.get("api_key")}
logger.info(f"Loaded API key for provider '{name}' from config.json")
else:
logger.debug("config/config.json not found; skipping JSON token load.")
except Exception as e:
logger.warning(f"Failed to read tokens from config/config.json: {e}")
import traceback
traceback.print_exc()
try:
self._init_providers(base_cfg)
except Exception as e:
logger.error(f"Failed to initialize data providers: {e}")
self._initialized = True
def _init_providers(self, base_cfg: Dict[str, Any]) -> None:
"""
Initializes providers with the given base configuration.
This method should be called after the base config is loaded.
"""
provider_map = {
"tushare": TushareProvider,
# "ifind": TonghsProvider,
"yfinance": YfinanceProvider,
"finnhub": FinnhubProvider,
}
for name, provider_class in provider_map.items():
token = base_cfg.get("data_sources", {}).get(name, {}).get("api_key")
source_config = self.config['data_sources'].get(name, {})
# Initialize the provider if a token is found or not required
if token or not source_config.get('api_key_env'):
try:
self.providers[name] = provider_class(token=token)
except Exception as e:
logger.error(f"Failed to initialize provider '{name}': {e}")
else:
logger.warning(f"Provider '{name}' requires API key but none provided in config.json. Skipping.")
def _detect_market(self, stock_code: str) -> str:
if stock_code.endswith(('.SH', '.SZ')):
return 'CN'
elif stock_code.endswith('.HK'):
return 'HK'
elif stock_code.endswith('.T'): # Assuming .T for Tokyo
return 'JP'
else: # Default to US
return 'US'
async def get_data(self, method_name: str, stock_code: str, **kwargs):
market = self._detect_market(stock_code)
priority_list = self.config.get('markets', {}).get(market, {}).get('priority', [])
for provider_name in priority_list:
provider = self.providers.get(provider_name)
if not provider:
logger.warning(f"Provider '{provider_name}' not initialized.")
continue
try:
method = getattr(provider, method_name)
data = await method(stock_code=stock_code, **kwargs)
is_success = False
if data is None:
is_success = False
elif isinstance(data, list):
is_success = len(data) > 0
elif isinstance(data, dict):
is_success = len(data) > 0
else:
is_success = True
if is_success:
logger.info(f"Data successfully fetched from '{provider_name}' for '{stock_code}'.")
return data
except Exception as e:
logger.warning(f"Provider '{provider_name}' failed for '{stock_code}': {e}. Trying next provider.")
logger.error(f"All data providers failed for '{stock_code}' on method '{method_name}'.")
return None
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]:
data = await self.get_data('get_financial_statements', stock_code, report_dates=report_dates)
if data is None:
return {}
# Normalize to series format
if isinstance(data, dict):
# Already in series format (e.g., tushare)
return data
elif isinstance(data, list):
# Convert from flat format to series format
series: Dict[str, List[Dict[str, Any]]] = {}
for report in data:
year = str(report.get('year', report.get('end_date', '')[:4]))
if not year:
continue
for key, value in report.items():
if key in ['ts_code', 'stock_code', 'year', 'end_date', 'period', 'ann_date', 'f_ann_date', 'report_type']:
continue
# Accept numpy/pandas numeric types as well as builtin numbers
if value is not None and isinstance(value, Number):
if key not in series:
series[key] = []
if not any(d['year'] == year for d in series[key]):
# Store as builtin float to avoid JSON serialization issues
try:
numeric_value = float(value)
except Exception:
# Fallback: skip if cannot coerce to float
continue
series[key].append({"year": year, "value": numeric_value})
return series
else:
return {}
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
return await self.get_data('get_daily_price', stock_code, start_date=start_date, end_date=end_date)
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
return await self.get_data('get_stock_basic', stock_code)
data_manager = DataManager()