feat(frontend): integrate Prisma and reports API/pages chore(config): add data_sources.yaml; update analysis-config.json docs: add 2025-11-03 dev log; update user guide scripts: enhance dev.sh; add tushare_legacy_client deps: update backend and frontend dependencies
165 lines
7.3 KiB
Python
165 lines
7.3 KiB
Python
import yaml
|
|
import os
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
from app.data_providers.base import BaseDataProvider
|
|
from app.data_providers.tushare import TushareProvider
|
|
# from app.data_providers.ifind import TonghsProvider
|
|
from app.data_providers.yfinance import YfinanceProvider
|
|
from app.data_providers.finnhub import FinnhubProvider
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DataManager:
|
|
_instance = None
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if not cls._instance:
|
|
cls._instance = super(DataManager, cls).__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self, config_path: str = None):
|
|
if hasattr(self, '_initialized') and self._initialized:
|
|
return
|
|
|
|
if config_path is None:
|
|
# Assume the config file is in the 'config' directory at the root of the repo
|
|
# Find the project root by looking for the config directory
|
|
current_dir = os.path.dirname(__file__)
|
|
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
|
|
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
|
|
REPO_ROOT = current_dir
|
|
break
|
|
current_dir = os.path.dirname(current_dir)
|
|
else:
|
|
# Fallback to the original calculation
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
|
|
config_path = os.path.join(REPO_ROOT, "config", "data_sources.yaml")
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
self.config = yaml.safe_load(f)
|
|
|
|
self.providers = {}
|
|
|
|
# Build provider base config from environment variables and config/config.json, then initialize providers
|
|
base_cfg: Dict[str, Any] = {"data_sources": {}}
|
|
|
|
# 1) Prefer env vars when present
|
|
for name, source_config in (self.config.get('data_sources') or {}).items():
|
|
env_var = source_config.get('api_key_env')
|
|
if env_var:
|
|
api_key = os.getenv(env_var)
|
|
if api_key:
|
|
base_cfg["data_sources"][name] = {"api_key": api_key}
|
|
else:
|
|
logger.warning(f"Env var '{env_var}' for provider '{name}' not set; will try config.json.")
|
|
|
|
# 2) Fallback to config/config.json if tokens are provided there
|
|
try:
|
|
# Use the same REPO_ROOT calculation as data_sources.yaml
|
|
current_dir = os.path.dirname(__file__)
|
|
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
|
|
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
|
|
REPO_ROOT = current_dir
|
|
break
|
|
current_dir = os.path.dirname(current_dir)
|
|
else:
|
|
# Fallback to the original calculation
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
|
|
cfg_json_path = os.path.join(REPO_ROOT, "config", "config.json")
|
|
if os.path.exists(cfg_json_path):
|
|
with open(cfg_json_path, "r", encoding="utf-8") as jf:
|
|
cfg_json = json.load(jf)
|
|
ds_from_json = (cfg_json.get("data_sources") or {})
|
|
for name, node in ds_from_json.items():
|
|
if name not in base_cfg["data_sources"] and node.get("api_key"):
|
|
base_cfg["data_sources"][name] = {"api_key": node.get("api_key")}
|
|
logger.info(f"Loaded API key for provider '{name}' from config.json")
|
|
else:
|
|
logger.debug("config/config.json not found; skipping JSON token load.")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to read tokens from config/config.json: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
try:
|
|
self._init_providers(base_cfg)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize data providers: {e}")
|
|
|
|
self._initialized = True
|
|
|
|
def _init_providers(self, base_cfg: Dict[str, Any]) -> None:
|
|
"""
|
|
Initializes providers with the given base configuration.
|
|
This method should be called after the base config is loaded.
|
|
"""
|
|
provider_map = {
|
|
"tushare": TushareProvider,
|
|
# "ifind": TonghsProvider,
|
|
"yfinance": YfinanceProvider,
|
|
"finnhub": FinnhubProvider,
|
|
}
|
|
|
|
for name, provider_class in provider_map.items():
|
|
token = None
|
|
source_config = self.config['data_sources'].get(name, {})
|
|
if source_config and source_config.get('api_key_env'):
|
|
token = base_cfg.get("data_sources", {}).get(name, {}).get("api_key")
|
|
|
|
# Initialize the provider if a token is found or not required
|
|
if token or not source_config.get('api_key_env'):
|
|
try:
|
|
self.providers[name] = provider_class(token=token)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize provider '{name}': {e}")
|
|
else:
|
|
logger.warning(f"Provider '{name}' requires token env '{source_config.get('api_key_env')}', but none provided. Skipping.")
|
|
|
|
def _detect_market(self, stock_code: str) -> str:
|
|
if stock_code.endswith(('.SH', '.SZ')):
|
|
return 'CN'
|
|
elif stock_code.endswith('.HK'):
|
|
return 'HK'
|
|
elif stock_code.endswith('.T'): # Assuming .T for Tokyo
|
|
return 'JP'
|
|
else: # Default to US
|
|
return 'US'
|
|
|
|
async def get_data(self, method_name: str, stock_code: str, **kwargs):
|
|
market = self._detect_market(stock_code)
|
|
priority_list = self.config.get('markets', {}).get(market, {}).get('priority', [])
|
|
|
|
for provider_name in priority_list:
|
|
provider = self.providers.get(provider_name)
|
|
if not provider:
|
|
logger.warning(f"Provider '{provider_name}' not initialized.")
|
|
continue
|
|
|
|
try:
|
|
method = getattr(provider, method_name)
|
|
data = await method(stock_code=stock_code, **kwargs)
|
|
if data is not None and (not isinstance(data, list) or data):
|
|
logger.info(f"Data successfully fetched from '{provider_name}' for '{stock_code}'.")
|
|
return data
|
|
except Exception as e:
|
|
logger.warning(f"Provider '{provider_name}' failed for '{stock_code}': {e}. Trying next provider.")
|
|
|
|
logger.error(f"All data providers failed for '{stock_code}' on method '{method_name}'.")
|
|
return None
|
|
|
|
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> List[Dict[str, Any]]:
|
|
return await self.get_data('get_financial_statements', stock_code, report_dates=report_dates)
|
|
|
|
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
|
|
return await self.get_data('get_daily_price', stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
|
return await self.get_data('get_stock_basic', stock_code)
|
|
|
|
data_manager = DataManager()
|