Fundamental_Analysis/backend/app/data_manager.py
xucheng 3475138419 feat(数据): 新增员工、股东及税务指标并生成日志
- 后端: Tushare provider 新增 get_employee_number, get_holder_number, get_tax_to_ebt 方法,并在 financial 路由中集成。
- 前端: report 页面新增对应图表展示,并更新相关类型与工具函数。
- 清理: 移除多个过时的测试脚本。
- 文档: 创建 2025-11-04 开发日志并更新用户手册。
2025-11-04 21:22:32 +08:00

190 lines
8.4 KiB
Python

import yaml
import os
import json
from typing import Any, Dict, List, Optional
from app.data_providers.base import BaseDataProvider
from app.data_providers.tushare import TushareProvider
# from app.data_providers.ifind import TonghsProvider
from app.data_providers.yfinance import YfinanceProvider
from app.data_providers.finnhub import FinnhubProvider
import logging
logger = logging.getLogger(__name__)
class DataManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(DataManager, cls).__new__(cls)
return cls._instance
def __init__(self, config_path: str = None):
if hasattr(self, '_initialized') and self._initialized:
return
if config_path is None:
# Assume the config file is in the 'config' directory at the root of the repo
# Find the project root by looking for the config directory
current_dir = os.path.dirname(__file__)
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
REPO_ROOT = current_dir
break
current_dir = os.path.dirname(current_dir)
else:
# Fallback to the original calculation
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
config_path = os.path.join(REPO_ROOT, "config", "data_sources.yaml")
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
self.providers = {}
# Build provider base config from environment variables and config/config.json, then initialize providers
base_cfg: Dict[str, Any] = {"data_sources": {}}
# 1) Prefer env vars when present
for name, source_config in (self.config.get('data_sources') or {}).items():
env_var = source_config.get('api_key_env')
if env_var:
api_key = os.getenv(env_var)
if api_key:
base_cfg["data_sources"][name] = {"api_key": api_key}
else:
logger.warning(f"Env var '{env_var}' for provider '{name}' not set; will try config.json.")
# 2) Fallback to config/config.json if tokens are provided there
try:
# Use the same REPO_ROOT calculation as data_sources.yaml
current_dir = os.path.dirname(__file__)
while current_dir != os.path.dirname(current_dir): # Not at filesystem root
if os.path.exists(os.path.join(current_dir, "config", "data_sources.yaml")):
REPO_ROOT = current_dir
break
current_dir = os.path.dirname(current_dir)
else:
# Fallback to the original calculation
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
cfg_json_path = os.path.join(REPO_ROOT, "config", "config.json")
if os.path.exists(cfg_json_path):
with open(cfg_json_path, "r", encoding="utf-8") as jf:
cfg_json = json.load(jf)
ds_from_json = (cfg_json.get("data_sources") or {})
for name, node in ds_from_json.items():
if name not in base_cfg["data_sources"] and node.get("api_key"):
base_cfg["data_sources"][name] = {"api_key": node.get("api_key")}
logger.info(f"Loaded API key for provider '{name}' from config.json")
else:
logger.debug("config/config.json not found; skipping JSON token load.")
except Exception as e:
logger.warning(f"Failed to read tokens from config/config.json: {e}")
import traceback
traceback.print_exc()
try:
self._init_providers(base_cfg)
except Exception as e:
logger.error(f"Failed to initialize data providers: {e}")
self._initialized = True
def _init_providers(self, base_cfg: Dict[str, Any]) -> None:
"""
Initializes providers with the given base configuration.
This method should be called after the base config is loaded.
"""
provider_map = {
"tushare": TushareProvider,
# "ifind": TonghsProvider,
"yfinance": YfinanceProvider,
"finnhub": FinnhubProvider,
}
for name, provider_class in provider_map.items():
token = None
source_config = self.config['data_sources'].get(name, {})
if source_config and source_config.get('api_key_env'):
token = base_cfg.get("data_sources", {}).get(name, {}).get("api_key")
# Initialize the provider if a token is found or not required
if token or not source_config.get('api_key_env'):
try:
self.providers[name] = provider_class(token=token)
except Exception as e:
logger.error(f"Failed to initialize provider '{name}': {e}")
else:
logger.warning(f"Provider '{name}' requires token env '{source_config.get('api_key_env')}', but none provided. Skipping.")
def _detect_market(self, stock_code: str) -> str:
if stock_code.endswith(('.SH', '.SZ')):
return 'CN'
elif stock_code.endswith('.HK'):
return 'HK'
elif stock_code.endswith('.T'): # Assuming .T for Tokyo
return 'JP'
else: # Default to US
return 'US'
async def get_data(self, method_name: str, stock_code: str, **kwargs):
market = self._detect_market(stock_code)
priority_list = self.config.get('markets', {}).get(market, {}).get('priority', [])
for provider_name in priority_list:
provider = self.providers.get(provider_name)
if not provider:
logger.warning(f"Provider '{provider_name}' not initialized.")
continue
try:
method = getattr(provider, method_name)
data = await method(stock_code=stock_code, **kwargs)
if data is not None and (not isinstance(data, list) or data):
logger.info(f"Data successfully fetched from '{provider_name}' for '{stock_code}'.")
return data
except Exception as e:
logger.warning(f"Provider '{provider_name}' failed for '{stock_code}': {e}. Trying next provider.")
logger.error(f"All data providers failed for '{stock_code}' on method '{method_name}'.")
return None
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]:
data = await self.get_data('get_financial_statements', stock_code, report_dates=report_dates)
if data is None:
return {}
# Normalize to series format
if isinstance(data, dict):
# Already in series format (e.g., tushare)
return data
elif isinstance(data, list):
# Convert from flat format to series format
series: Dict[str, List[Dict[str, Any]]] = {}
for report in data:
year = str(report.get('year', report.get('end_date', '')[:4]))
if not year:
continue
for key, value in report.items():
if key in ['ts_code', 'stock_code', 'year', 'end_date', 'period', 'ann_date', 'f_ann_date', 'report_type']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
if not any(d['year'] == year for d in series[key]):
series[key].append({"year": year, "value": value})
return series
else:
return {}
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
return await self.get_data('get_daily_price', stock_code, start_date=start_date, end_date=end_date)
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
return await self.get_data('get_stock_basic', stock_code)
data_manager = DataManager()