673 lines
25 KiB
Python
673 lines
25 KiB
Python
"""
|
|
数据获取服务基础架构
|
|
处理外部数据源的数据获取
|
|
"""
|
|
|
|
from typing import Dict, Any, Optional
|
|
from abc import ABC, abstractmethod
|
|
import httpx
|
|
import asyncio
|
|
from datetime import datetime
|
|
|
|
from ..schemas.data import (
|
|
FinancialDataResponse,
|
|
MarketDataResponse,
|
|
SymbolValidationResponse,
|
|
DataSourceStatus
|
|
)
|
|
from ..core.exceptions import (
|
|
DataSourceError,
|
|
APIError,
|
|
SymbolNotFoundError,
|
|
RateLimitError,
|
|
AuthenticationError
|
|
)
|
|
|
|
|
|
class DataFetcher(ABC):
|
|
"""数据获取服务基类"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
self.name = config.get("name", "unknown")
|
|
self.timeout = config.get("timeout", 30)
|
|
self.max_retries = config.get("max_retries", 3)
|
|
self.retry_delay = config.get("retry_delay", 1)
|
|
|
|
@abstractmethod
|
|
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
|
"""获取财务数据"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
|
"""获取市场数据"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
|
"""验证证券代码"""
|
|
pass
|
|
|
|
async def check_status(self) -> DataSourceStatus:
|
|
"""检查数据源状态"""
|
|
start_time = datetime.now()
|
|
try:
|
|
# 尝试进行简单的健康检查
|
|
await self._health_check()
|
|
end_time = datetime.now()
|
|
response_time = int((end_time - start_time).total_seconds() * 1000)
|
|
|
|
return DataSourceStatus(
|
|
name=self.name,
|
|
is_available=True,
|
|
last_check=end_time,
|
|
response_time_ms=response_time
|
|
)
|
|
except Exception as e:
|
|
end_time = datetime.now()
|
|
return DataSourceStatus(
|
|
name=self.name,
|
|
is_available=False,
|
|
last_check=end_time,
|
|
error_message=str(e)
|
|
)
|
|
|
|
@abstractmethod
|
|
async def _health_check(self):
|
|
"""健康检查实现"""
|
|
pass
|
|
|
|
async def _retry_request(self, func, *args, **kwargs):
|
|
"""重试机制"""
|
|
last_exception = None
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except (httpx.TimeoutException, httpx.ConnectError) as e:
|
|
last_exception = e
|
|
if attempt < self.max_retries - 1:
|
|
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
|
|
continue
|
|
except Exception as e:
|
|
# 对于其他类型的异常,不重试
|
|
raise e
|
|
|
|
# 所有重试都失败了
|
|
raise DataSourceError(
|
|
f"数据源 {self.name} 请求失败,已重试 {self.max_retries} 次",
|
|
self.name,
|
|
{"last_error": str(last_exception)}
|
|
)
|
|
|
|
|
|
class TushareDataFetcher(DataFetcher):
|
|
"""Tushare数据获取器"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.token = config.get("api_key") or config.get("token")
|
|
self.base_url = config.get("base_url", "http://api.tushare.pro")
|
|
|
|
if not self.token:
|
|
raise AuthenticationError("tushare", {"message": "Tushare API token未配置"})
|
|
|
|
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
|
"""获取财务数据"""
|
|
try:
|
|
# 转换证券代码格式
|
|
ts_code = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的Tushare API调用
|
|
# 这里暂时返回模拟数据
|
|
financial_data = await self._retry_request(self._fetch_tushare_financial, ts_code)
|
|
|
|
return FinancialDataResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
data_source="tushare",
|
|
last_updated=datetime.now(),
|
|
balance_sheet=financial_data.get("balance_sheet"),
|
|
income_statement=financial_data.get("income_statement"),
|
|
cash_flow=financial_data.get("cash_flow"),
|
|
key_metrics=financial_data.get("key_metrics")
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, (DataSourceError, APIError)):
|
|
raise
|
|
raise DataSourceError(f"获取财务数据失败: {str(e)}", "tushare")
|
|
|
|
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
|
"""获取市场数据"""
|
|
try:
|
|
ts_code = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的Tushare API调用
|
|
market_data = await self._retry_request(self._fetch_tushare_market, ts_code)
|
|
|
|
return MarketDataResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
data_source="tushare",
|
|
last_updated=datetime.now(),
|
|
price_data=market_data.get("price_data"),
|
|
volume_data=market_data.get("volume_data"),
|
|
technical_indicators=market_data.get("technical_indicators")
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, (DataSourceError, APIError)):
|
|
raise
|
|
raise DataSourceError(f"获取市场数据失败: {str(e)}", "tushare")
|
|
|
|
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
|
"""验证证券代码"""
|
|
try:
|
|
ts_code = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的证券代码验证
|
|
# 暂时模拟验证逻辑
|
|
is_valid = await self._retry_request(self._validate_tushare_symbol, ts_code)
|
|
|
|
return SymbolValidationResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
is_valid=is_valid,
|
|
company_name="示例公司" if is_valid else None,
|
|
message="证券代码有效" if is_valid else "证券代码无效"
|
|
)
|
|
except Exception as e:
|
|
return SymbolValidationResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
is_valid=False,
|
|
message=f"验证失败: {str(e)}"
|
|
)
|
|
|
|
async def _health_check(self):
|
|
"""健康检查"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
# 尝试调用一个简单的API来测试连通性
|
|
await self._call_tushare_api(client, "stock_basic", {"limit": 1})
|
|
except Exception as e:
|
|
raise DataSourceError(f"Tushare健康检查失败: {str(e)}", "tushare")
|
|
|
|
def _convert_symbol_format(self, symbol: str, market: str) -> str:
|
|
"""转换证券代码格式为Tushare格式"""
|
|
if market.lower() == "china":
|
|
# 中国股票代码格式转换
|
|
if symbol.startswith("6"):
|
|
return f"{symbol}.SH" # 上海证券交易所
|
|
elif symbol.startswith(("0", "3")):
|
|
return f"{symbol}.SZ" # 深圳证券交易所
|
|
return symbol
|
|
|
|
async def _fetch_tushare_financial(self, ts_code: str) -> Dict[str, Any]:
|
|
"""获取Tushare财务数据"""
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
# 获取资产负债表
|
|
balance_sheet_data = await self._call_tushare_api(
|
|
client, "balancesheet", {"ts_code": ts_code, "period": "20231231"}
|
|
)
|
|
|
|
# 获取利润表
|
|
income_data = await self._call_tushare_api(
|
|
client, "income", {"ts_code": ts_code, "period": "20231231"}
|
|
)
|
|
|
|
# 获取现金流量表
|
|
cashflow_data = await self._call_tushare_api(
|
|
client, "cashflow", {"ts_code": ts_code, "period": "20231231"}
|
|
)
|
|
|
|
# 获取基本财务指标
|
|
fina_indicator_data = await self._call_tushare_api(
|
|
client, "fina_indicator", {"ts_code": ts_code, "period": "20231231"}
|
|
)
|
|
|
|
return {
|
|
"balance_sheet": self._process_balance_sheet(balance_sheet_data),
|
|
"income_statement": self._process_income_statement(income_data),
|
|
"cash_flow": self._process_cash_flow(cashflow_data),
|
|
"key_metrics": self._process_key_metrics(fina_indicator_data)
|
|
}
|
|
|
|
async def _fetch_tushare_market(self, ts_code: str) -> Dict[str, Any]:
|
|
"""获取Tushare市场数据"""
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
# 获取日线数据
|
|
daily_data = await self._call_tushare_api(
|
|
client, "daily", {"ts_code": ts_code, "start_date": "20240101", "end_date": "20241231"}
|
|
)
|
|
|
|
# 获取基本信息
|
|
stock_basic_data = await self._call_tushare_api(
|
|
client, "stock_basic", {"ts_code": ts_code}
|
|
)
|
|
|
|
return {
|
|
"price_data": self._process_price_data(daily_data),
|
|
"volume_data": self._process_volume_data(daily_data),
|
|
"technical_indicators": self._calculate_technical_indicators(daily_data),
|
|
"stock_info": self._process_stock_basic(stock_basic_data)
|
|
}
|
|
|
|
async def _validate_tushare_symbol(self, ts_code: str) -> bool:
|
|
"""验证Tushare证券代码"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
result = await self._call_tushare_api(
|
|
client, "stock_basic", {"ts_code": ts_code}
|
|
)
|
|
return bool(result and len(result.get("items", [])) > 0)
|
|
except Exception:
|
|
return False
|
|
|
|
async def _call_tushare_api(self, client: httpx.AsyncClient, api_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""调用Tushare API"""
|
|
request_data = {
|
|
"api_name": api_name,
|
|
"token": self.token,
|
|
"params": params,
|
|
"fields": ""
|
|
}
|
|
|
|
try:
|
|
response = await client.post(self.base_url, json=request_data)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
if result.get("code") != 0:
|
|
error_msg = result.get("msg", "Unknown error")
|
|
if "权限" in error_msg or "token" in error_msg.lower():
|
|
raise AuthenticationError("tushare", {"message": error_msg})
|
|
elif "频率" in error_msg or "limit" in error_msg.lower():
|
|
raise RateLimitError("tushare")
|
|
else:
|
|
raise APIError(f"Tushare API错误: {error_msg}", result.get("code"), "tushare")
|
|
|
|
return result.get("data", {})
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code == 401:
|
|
raise AuthenticationError("tushare", {"status_code": e.response.status_code})
|
|
elif e.response.status_code == 429:
|
|
raise RateLimitError("tushare")
|
|
else:
|
|
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "tushare")
|
|
except httpx.RequestError as e:
|
|
raise DataSourceError(f"网络请求失败: {str(e)}", "tushare")
|
|
|
|
def _process_balance_sheet(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理资产负债表数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
# 取最新一期数据
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
# 将列表数据转换为字典
|
|
balance_data = dict(zip(fields, latest))
|
|
else:
|
|
balance_data = latest
|
|
|
|
return {
|
|
"total_assets": balance_data.get("total_assets", 0),
|
|
"total_liab": balance_data.get("total_liab", 0),
|
|
"total_hldr_eqy_exc_min_int": balance_data.get("total_hldr_eqy_exc_min_int", 0),
|
|
"monetary_cap": balance_data.get("monetary_cap", 0),
|
|
"accounts_receiv": balance_data.get("accounts_receiv", 0),
|
|
"inventories": balance_data.get("inventories", 0)
|
|
}
|
|
|
|
def _process_income_statement(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理利润表数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
income_data = dict(zip(fields, latest))
|
|
else:
|
|
income_data = latest
|
|
|
|
return {
|
|
"revenue": income_data.get("revenue", 0),
|
|
"operate_profit": income_data.get("operate_profit", 0),
|
|
"total_profit": income_data.get("total_profit", 0),
|
|
"n_income": income_data.get("n_income", 0),
|
|
"n_income_attr_p": income_data.get("n_income_attr_p", 0),
|
|
"basic_eps": income_data.get("basic_eps", 0)
|
|
}
|
|
|
|
def _process_cash_flow(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理现金流量表数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
cashflow_data = dict(zip(fields, latest))
|
|
else:
|
|
cashflow_data = latest
|
|
|
|
return {
|
|
"n_cashflow_act": cashflow_data.get("n_cashflow_act", 0),
|
|
"n_cashflow_inv_act": cashflow_data.get("n_cashflow_inv_act", 0),
|
|
"n_cashflow_fin_act": cashflow_data.get("n_cashflow_fin_act", 0),
|
|
"c_cash_equ_end_period": cashflow_data.get("c_cash_equ_end_period", 0)
|
|
}
|
|
|
|
def _process_key_metrics(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理关键财务指标数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
metrics_data = dict(zip(fields, latest))
|
|
else:
|
|
metrics_data = latest
|
|
|
|
return {
|
|
"pe": metrics_data.get("pe", 0),
|
|
"pb": metrics_data.get("pb", 0),
|
|
"ps": metrics_data.get("ps", 0),
|
|
"roe": metrics_data.get("roe", 0),
|
|
"roa": metrics_data.get("roa", 0),
|
|
"gross_margin": metrics_data.get("gross_margin", 0),
|
|
"debt_to_assets": metrics_data.get("debt_to_assets", 0)
|
|
}
|
|
|
|
def _process_price_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理价格数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
# 取最新一天的数据
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
price_data = dict(zip(fields, latest))
|
|
else:
|
|
price_data = latest
|
|
|
|
return {
|
|
"close": price_data.get("close", 0),
|
|
"open": price_data.get("open", 0),
|
|
"high": price_data.get("high", 0),
|
|
"low": price_data.get("low", 0),
|
|
"pre_close": price_data.get("pre_close", 0),
|
|
"change": price_data.get("change", 0),
|
|
"pct_chg": price_data.get("pct_chg", 0)
|
|
}
|
|
|
|
def _process_volume_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理成交量数据"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
volume_data = dict(zip(fields, latest))
|
|
else:
|
|
volume_data = latest
|
|
|
|
return {
|
|
"vol": volume_data.get("vol", 0),
|
|
"amount": volume_data.get("amount", 0)
|
|
}
|
|
|
|
def _calculate_technical_indicators(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""计算技术指标"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items or len(items) < 20:
|
|
return {}
|
|
|
|
# 简单的移动平均计算
|
|
closes = []
|
|
for item in items[:20]: # 取最近20天
|
|
if isinstance(item, list):
|
|
fields = data.get("fields", [])
|
|
close_idx = fields.index("close") if "close" in fields else -1
|
|
if close_idx >= 0:
|
|
closes.append(item[close_idx])
|
|
else:
|
|
closes.append(item.get("close", 0))
|
|
|
|
if len(closes) >= 5:
|
|
ma_5 = sum(closes[:5]) / 5
|
|
else:
|
|
ma_5 = 0
|
|
|
|
if len(closes) >= 20:
|
|
ma_20 = sum(closes) / 20
|
|
else:
|
|
ma_20 = 0
|
|
|
|
return {
|
|
"ma_5": ma_5,
|
|
"ma_20": ma_20,
|
|
"ma_60": 0 # 需要更多数据计算
|
|
}
|
|
|
|
def _process_stock_basic(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""处理股票基本信息"""
|
|
if not data or not data.get("items"):
|
|
return {}
|
|
|
|
items = data["items"]
|
|
if not items:
|
|
return {}
|
|
|
|
latest = items[0] if isinstance(items[0], list) else items
|
|
fields = data.get("fields", [])
|
|
|
|
if isinstance(latest, list) and fields:
|
|
basic_data = dict(zip(fields, latest))
|
|
else:
|
|
basic_data = latest
|
|
|
|
return {
|
|
"name": basic_data.get("name", ""),
|
|
"industry": basic_data.get("industry", ""),
|
|
"market": basic_data.get("market", ""),
|
|
"list_date": basic_data.get("list_date", "")
|
|
}
|
|
|
|
|
|
class YahooDataFetcher(DataFetcher):
|
|
"""Yahoo Finance数据获取器"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.base_url = config.get("base_url", "https://query1.finance.yahoo.com")
|
|
|
|
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
|
"""获取财务数据"""
|
|
try:
|
|
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的Yahoo Finance API调用
|
|
financial_data = await self._retry_request(self._fetch_yahoo_financial, yahoo_symbol)
|
|
|
|
return FinancialDataResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
data_source="yahoo",
|
|
last_updated=datetime.now(),
|
|
balance_sheet=financial_data.get("balance_sheet"),
|
|
income_statement=financial_data.get("income_statement"),
|
|
cash_flow=financial_data.get("cash_flow"),
|
|
key_metrics=financial_data.get("key_metrics")
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, (DataSourceError, APIError)):
|
|
raise
|
|
raise DataSourceError(f"获取财务数据失败: {str(e)}", "yahoo")
|
|
|
|
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
|
"""获取市场数据"""
|
|
try:
|
|
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的Yahoo Finance API调用
|
|
market_data = await self._retry_request(self._fetch_yahoo_market, yahoo_symbol)
|
|
|
|
return MarketDataResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
data_source="yahoo",
|
|
last_updated=datetime.now(),
|
|
price_data=market_data.get("price_data"),
|
|
volume_data=market_data.get("volume_data"),
|
|
technical_indicators=market_data.get("technical_indicators")
|
|
)
|
|
except Exception as e:
|
|
if isinstance(e, (DataSourceError, APIError)):
|
|
raise
|
|
raise DataSourceError(f"获取市场数据失败: {str(e)}", "yahoo")
|
|
|
|
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
|
"""验证证券代码"""
|
|
try:
|
|
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
|
|
|
# TODO: 实现实际的证券代码验证
|
|
is_valid = await self._retry_request(self._validate_yahoo_symbol, yahoo_symbol)
|
|
|
|
return SymbolValidationResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
is_valid=is_valid,
|
|
company_name="Example Company" if is_valid else None,
|
|
message="Symbol is valid" if is_valid else "Symbol not found"
|
|
)
|
|
except Exception as e:
|
|
return SymbolValidationResponse(
|
|
symbol=symbol,
|
|
market=market,
|
|
is_valid=False,
|
|
message=f"Validation failed: {str(e)}"
|
|
)
|
|
|
|
async def _health_check(self):
|
|
"""健康检查"""
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.get(f"{self.base_url}/v1/finance/search?q=AAPL")
|
|
if response.status_code != 200:
|
|
raise APIError(f"Yahoo Finance API返回状态码: {response.status_code}", response.status_code, "yahoo")
|
|
|
|
def _convert_symbol_format(self, symbol: str, market: str) -> str:
|
|
"""转换证券代码格式为Yahoo Finance格式"""
|
|
if market.lower() == "hongkong":
|
|
return f"{symbol}.HK"
|
|
elif market.lower() == "japan":
|
|
return f"{symbol}.T"
|
|
elif market.lower() == "china":
|
|
# 中国股票在Yahoo Finance中的格式
|
|
if symbol.startswith("6"):
|
|
return f"{symbol}.SS" # 上海
|
|
elif symbol.startswith(("0", "3")):
|
|
return f"{symbol}.SZ" # 深圳
|
|
return symbol
|
|
|
|
async def _fetch_yahoo_financial(self, yahoo_symbol: str) -> Dict[str, Any]:
|
|
"""获取Yahoo Finance财务数据"""
|
|
# TODO: 实现实际的API调用
|
|
return {
|
|
"balance_sheet": {"totalAssets": 2000000},
|
|
"income_statement": {"totalRevenue": 800000},
|
|
"cash_flow": {"operatingCashflow": 300000},
|
|
"key_metrics": {"trailingPE": 18.5}
|
|
}
|
|
|
|
async def _fetch_yahoo_market(self, yahoo_symbol: str) -> Dict[str, Any]:
|
|
"""获取Yahoo Finance市场数据"""
|
|
# TODO: 实现实际的API调用
|
|
return {
|
|
"price_data": {"regularMarketPrice": 150.0, "dayHigh": 155.0, "dayLow": 145.0},
|
|
"volume_data": {"regularMarketVolume": 2000000},
|
|
"technical_indicators": {"fiftyDayAverage": 148.0, "twoHundredDayAverage": 152.0}
|
|
}
|
|
|
|
async def _validate_yahoo_symbol(self, yahoo_symbol: str) -> bool:
|
|
"""验证Yahoo Finance证券代码"""
|
|
# TODO: 实现实际的验证逻辑
|
|
return True
|
|
|
|
|
|
class DataFetcherFactory:
|
|
"""数据获取器工厂"""
|
|
|
|
_fetchers = {
|
|
"tushare": TushareDataFetcher,
|
|
"yahoo": YahooDataFetcher,
|
|
}
|
|
|
|
@classmethod
|
|
def create_fetcher(cls, data_source: str, config: Dict[str, Any]) -> DataFetcher:
|
|
"""创建数据获取器"""
|
|
data_source_lower = data_source.lower()
|
|
|
|
if data_source_lower not in cls._fetchers:
|
|
raise DataSourceError(
|
|
f"不支持的数据源: {data_source}",
|
|
data_source,
|
|
{"supported_sources": list(cls._fetchers.keys())}
|
|
)
|
|
|
|
fetcher_class = cls._fetchers[data_source_lower]
|
|
return fetcher_class(config)
|
|
|
|
@classmethod
|
|
def get_supported_sources(cls) -> list:
|
|
"""获取支持的数据源列表"""
|
|
return list(cls._fetchers.keys())
|
|
|
|
@classmethod
|
|
def register_fetcher(cls, name: str, fetcher_class: type):
|
|
"""注册新的数据获取器"""
|
|
if not issubclass(fetcher_class, DataFetcher):
|
|
raise ValueError("数据获取器必须继承自DataFetcher基类")
|
|
cls._fetchers[name.lower()] = fetcher_class |