Fundamental_Analysis/backend/app/services/data_fetcher.py

673 lines
25 KiB
Python

"""
数据获取服务基础架构
处理外部数据源的数据获取
"""
from typing import Dict, Any, Optional
from abc import ABC, abstractmethod
import httpx
import asyncio
from datetime import datetime
from ..schemas.data import (
FinancialDataResponse,
MarketDataResponse,
SymbolValidationResponse,
DataSourceStatus
)
from ..core.exceptions import (
DataSourceError,
APIError,
SymbolNotFoundError,
RateLimitError,
AuthenticationError
)
class DataFetcher(ABC):
"""数据获取服务基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.name = config.get("name", "unknown")
self.timeout = config.get("timeout", 30)
self.max_retries = config.get("max_retries", 3)
self.retry_delay = config.get("retry_delay", 1)
@abstractmethod
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
pass
@abstractmethod
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
pass
@abstractmethod
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
pass
async def check_status(self) -> DataSourceStatus:
"""检查数据源状态"""
start_time = datetime.now()
try:
# 尝试进行简单的健康检查
await self._health_check()
end_time = datetime.now()
response_time = int((end_time - start_time).total_seconds() * 1000)
return DataSourceStatus(
name=self.name,
is_available=True,
last_check=end_time,
response_time_ms=response_time
)
except Exception as e:
end_time = datetime.now()
return DataSourceStatus(
name=self.name,
is_available=False,
last_check=end_time,
error_message=str(e)
)
@abstractmethod
async def _health_check(self):
"""健康检查实现"""
pass
async def _retry_request(self, func, *args, **kwargs):
"""重试机制"""
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except (httpx.TimeoutException, httpx.ConnectError) as e:
last_exception = e
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
continue
except Exception as e:
# 对于其他类型的异常,不重试
raise e
# 所有重试都失败了
raise DataSourceError(
f"数据源 {self.name} 请求失败,已重试 {self.max_retries}",
self.name,
{"last_error": str(last_exception)}
)
class TushareDataFetcher(DataFetcher):
"""Tushare数据获取器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.token = config.get("api_key") or config.get("token")
self.base_url = config.get("base_url", "http://api.tushare.pro")
if not self.token:
raise AuthenticationError("tushare", {"message": "Tushare API token未配置"})
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
try:
# 转换证券代码格式
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Tushare API调用
# 这里暂时返回模拟数据
financial_data = await self._retry_request(self._fetch_tushare_financial, ts_code)
return FinancialDataResponse(
symbol=symbol,
market=market,
data_source="tushare",
last_updated=datetime.now(),
balance_sheet=financial_data.get("balance_sheet"),
income_statement=financial_data.get("income_statement"),
cash_flow=financial_data.get("cash_flow"),
key_metrics=financial_data.get("key_metrics")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取财务数据失败: {str(e)}", "tushare")
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
try:
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Tushare API调用
market_data = await self._retry_request(self._fetch_tushare_market, ts_code)
return MarketDataResponse(
symbol=symbol,
market=market,
data_source="tushare",
last_updated=datetime.now(),
price_data=market_data.get("price_data"),
volume_data=market_data.get("volume_data"),
technical_indicators=market_data.get("technical_indicators")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取市场数据失败: {str(e)}", "tushare")
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
try:
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的证券代码验证
# 暂时模拟验证逻辑
is_valid = await self._retry_request(self._validate_tushare_symbol, ts_code)
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=is_valid,
company_name="示例公司" if is_valid else None,
message="证券代码有效" if is_valid else "证券代码无效"
)
except Exception as e:
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=False,
message=f"验证失败: {str(e)}"
)
async def _health_check(self):
"""健康检查"""
try:
async with httpx.AsyncClient(timeout=5) as client:
# 尝试调用一个简单的API来测试连通性
await self._call_tushare_api(client, "stock_basic", {"limit": 1})
except Exception as e:
raise DataSourceError(f"Tushare健康检查失败: {str(e)}", "tushare")
def _convert_symbol_format(self, symbol: str, market: str) -> str:
"""转换证券代码格式为Tushare格式"""
if market.lower() == "china":
# 中国股票代码格式转换
if symbol.startswith("6"):
return f"{symbol}.SH" # 上海证券交易所
elif symbol.startswith(("0", "3")):
return f"{symbol}.SZ" # 深圳证券交易所
return symbol
async def _fetch_tushare_financial(self, ts_code: str) -> Dict[str, Any]:
"""获取Tushare财务数据"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 获取资产负债表
balance_sheet_data = await self._call_tushare_api(
client, "balancesheet", {"ts_code": ts_code, "period": "20231231"}
)
# 获取利润表
income_data = await self._call_tushare_api(
client, "income", {"ts_code": ts_code, "period": "20231231"}
)
# 获取现金流量表
cashflow_data = await self._call_tushare_api(
client, "cashflow", {"ts_code": ts_code, "period": "20231231"}
)
# 获取基本财务指标
fina_indicator_data = await self._call_tushare_api(
client, "fina_indicator", {"ts_code": ts_code, "period": "20231231"}
)
return {
"balance_sheet": self._process_balance_sheet(balance_sheet_data),
"income_statement": self._process_income_statement(income_data),
"cash_flow": self._process_cash_flow(cashflow_data),
"key_metrics": self._process_key_metrics(fina_indicator_data)
}
async def _fetch_tushare_market(self, ts_code: str) -> Dict[str, Any]:
"""获取Tushare市场数据"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 获取日线数据
daily_data = await self._call_tushare_api(
client, "daily", {"ts_code": ts_code, "start_date": "20240101", "end_date": "20241231"}
)
# 获取基本信息
stock_basic_data = await self._call_tushare_api(
client, "stock_basic", {"ts_code": ts_code}
)
return {
"price_data": self._process_price_data(daily_data),
"volume_data": self._process_volume_data(daily_data),
"technical_indicators": self._calculate_technical_indicators(daily_data),
"stock_info": self._process_stock_basic(stock_basic_data)
}
async def _validate_tushare_symbol(self, ts_code: str) -> bool:
"""验证Tushare证券代码"""
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
result = await self._call_tushare_api(
client, "stock_basic", {"ts_code": ts_code}
)
return bool(result and len(result.get("items", [])) > 0)
except Exception:
return False
async def _call_tushare_api(self, client: httpx.AsyncClient, api_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""调用Tushare API"""
request_data = {
"api_name": api_name,
"token": self.token,
"params": params,
"fields": ""
}
try:
response = await client.post(self.base_url, json=request_data)
response.raise_for_status()
result = response.json()
if result.get("code") != 0:
error_msg = result.get("msg", "Unknown error")
if "权限" in error_msg or "token" in error_msg.lower():
raise AuthenticationError("tushare", {"message": error_msg})
elif "频率" in error_msg or "limit" in error_msg.lower():
raise RateLimitError("tushare")
else:
raise APIError(f"Tushare API错误: {error_msg}", result.get("code"), "tushare")
return result.get("data", {})
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError("tushare", {"status_code": e.response.status_code})
elif e.response.status_code == 429:
raise RateLimitError("tushare")
else:
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "tushare")
except httpx.RequestError as e:
raise DataSourceError(f"网络请求失败: {str(e)}", "tushare")
def _process_balance_sheet(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理资产负债表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
# 取最新一期数据
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
# 将列表数据转换为字典
balance_data = dict(zip(fields, latest))
else:
balance_data = latest
return {
"total_assets": balance_data.get("total_assets", 0),
"total_liab": balance_data.get("total_liab", 0),
"total_hldr_eqy_exc_min_int": balance_data.get("total_hldr_eqy_exc_min_int", 0),
"monetary_cap": balance_data.get("monetary_cap", 0),
"accounts_receiv": balance_data.get("accounts_receiv", 0),
"inventories": balance_data.get("inventories", 0)
}
def _process_income_statement(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理利润表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
income_data = dict(zip(fields, latest))
else:
income_data = latest
return {
"revenue": income_data.get("revenue", 0),
"operate_profit": income_data.get("operate_profit", 0),
"total_profit": income_data.get("total_profit", 0),
"n_income": income_data.get("n_income", 0),
"n_income_attr_p": income_data.get("n_income_attr_p", 0),
"basic_eps": income_data.get("basic_eps", 0)
}
def _process_cash_flow(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理现金流量表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
cashflow_data = dict(zip(fields, latest))
else:
cashflow_data = latest
return {
"n_cashflow_act": cashflow_data.get("n_cashflow_act", 0),
"n_cashflow_inv_act": cashflow_data.get("n_cashflow_inv_act", 0),
"n_cashflow_fin_act": cashflow_data.get("n_cashflow_fin_act", 0),
"c_cash_equ_end_period": cashflow_data.get("c_cash_equ_end_period", 0)
}
def _process_key_metrics(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理关键财务指标数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
metrics_data = dict(zip(fields, latest))
else:
metrics_data = latest
return {
"pe": metrics_data.get("pe", 0),
"pb": metrics_data.get("pb", 0),
"ps": metrics_data.get("ps", 0),
"roe": metrics_data.get("roe", 0),
"roa": metrics_data.get("roa", 0),
"gross_margin": metrics_data.get("gross_margin", 0),
"debt_to_assets": metrics_data.get("debt_to_assets", 0)
}
def _process_price_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理价格数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
# 取最新一天的数据
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
price_data = dict(zip(fields, latest))
else:
price_data = latest
return {
"close": price_data.get("close", 0),
"open": price_data.get("open", 0),
"high": price_data.get("high", 0),
"low": price_data.get("low", 0),
"pre_close": price_data.get("pre_close", 0),
"change": price_data.get("change", 0),
"pct_chg": price_data.get("pct_chg", 0)
}
def _process_volume_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理成交量数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
volume_data = dict(zip(fields, latest))
else:
volume_data = latest
return {
"vol": volume_data.get("vol", 0),
"amount": volume_data.get("amount", 0)
}
def _calculate_technical_indicators(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""计算技术指标"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items or len(items) < 20:
return {}
# 简单的移动平均计算
closes = []
for item in items[:20]: # 取最近20天
if isinstance(item, list):
fields = data.get("fields", [])
close_idx = fields.index("close") if "close" in fields else -1
if close_idx >= 0:
closes.append(item[close_idx])
else:
closes.append(item.get("close", 0))
if len(closes) >= 5:
ma_5 = sum(closes[:5]) / 5
else:
ma_5 = 0
if len(closes) >= 20:
ma_20 = sum(closes) / 20
else:
ma_20 = 0
return {
"ma_5": ma_5,
"ma_20": ma_20,
"ma_60": 0 # 需要更多数据计算
}
def _process_stock_basic(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理股票基本信息"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
basic_data = dict(zip(fields, latest))
else:
basic_data = latest
return {
"name": basic_data.get("name", ""),
"industry": basic_data.get("industry", ""),
"market": basic_data.get("market", ""),
"list_date": basic_data.get("list_date", "")
}
class YahooDataFetcher(DataFetcher):
"""Yahoo Finance数据获取器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.base_url = config.get("base_url", "https://query1.finance.yahoo.com")
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Yahoo Finance API调用
financial_data = await self._retry_request(self._fetch_yahoo_financial, yahoo_symbol)
return FinancialDataResponse(
symbol=symbol,
market=market,
data_source="yahoo",
last_updated=datetime.now(),
balance_sheet=financial_data.get("balance_sheet"),
income_statement=financial_data.get("income_statement"),
cash_flow=financial_data.get("cash_flow"),
key_metrics=financial_data.get("key_metrics")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取财务数据失败: {str(e)}", "yahoo")
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Yahoo Finance API调用
market_data = await self._retry_request(self._fetch_yahoo_market, yahoo_symbol)
return MarketDataResponse(
symbol=symbol,
market=market,
data_source="yahoo",
last_updated=datetime.now(),
price_data=market_data.get("price_data"),
volume_data=market_data.get("volume_data"),
technical_indicators=market_data.get("technical_indicators")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取市场数据失败: {str(e)}", "yahoo")
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的证券代码验证
is_valid = await self._retry_request(self._validate_yahoo_symbol, yahoo_symbol)
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=is_valid,
company_name="Example Company" if is_valid else None,
message="Symbol is valid" if is_valid else "Symbol not found"
)
except Exception as e:
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=False,
message=f"Validation failed: {str(e)}"
)
async def _health_check(self):
"""健康检查"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(f"{self.base_url}/v1/finance/search?q=AAPL")
if response.status_code != 200:
raise APIError(f"Yahoo Finance API返回状态码: {response.status_code}", response.status_code, "yahoo")
def _convert_symbol_format(self, symbol: str, market: str) -> str:
"""转换证券代码格式为Yahoo Finance格式"""
if market.lower() == "hongkong":
return f"{symbol}.HK"
elif market.lower() == "japan":
return f"{symbol}.T"
elif market.lower() == "china":
# 中国股票在Yahoo Finance中的格式
if symbol.startswith("6"):
return f"{symbol}.SS" # 上海
elif symbol.startswith(("0", "3")):
return f"{symbol}.SZ" # 深圳
return symbol
async def _fetch_yahoo_financial(self, yahoo_symbol: str) -> Dict[str, Any]:
"""获取Yahoo Finance财务数据"""
# TODO: 实现实际的API调用
return {
"balance_sheet": {"totalAssets": 2000000},
"income_statement": {"totalRevenue": 800000},
"cash_flow": {"operatingCashflow": 300000},
"key_metrics": {"trailingPE": 18.5}
}
async def _fetch_yahoo_market(self, yahoo_symbol: str) -> Dict[str, Any]:
"""获取Yahoo Finance市场数据"""
# TODO: 实现实际的API调用
return {
"price_data": {"regularMarketPrice": 150.0, "dayHigh": 155.0, "dayLow": 145.0},
"volume_data": {"regularMarketVolume": 2000000},
"technical_indicators": {"fiftyDayAverage": 148.0, "twoHundredDayAverage": 152.0}
}
async def _validate_yahoo_symbol(self, yahoo_symbol: str) -> bool:
"""验证Yahoo Finance证券代码"""
# TODO: 实现实际的验证逻辑
return True
class DataFetcherFactory:
"""数据获取器工厂"""
_fetchers = {
"tushare": TushareDataFetcher,
"yahoo": YahooDataFetcher,
}
@classmethod
def create_fetcher(cls, data_source: str, config: Dict[str, Any]) -> DataFetcher:
"""创建数据获取器"""
data_source_lower = data_source.lower()
if data_source_lower not in cls._fetchers:
raise DataSourceError(
f"不支持的数据源: {data_source}",
data_source,
{"supported_sources": list(cls._fetchers.keys())}
)
fetcher_class = cls._fetchers[data_source_lower]
return fetcher_class(config)
@classmethod
def get_supported_sources(cls) -> list:
"""获取支持的数据源列表"""
return list(cls._fetchers.keys())
@classmethod
def register_fetcher(cls, name: str, fetcher_class: type):
"""注册新的数据获取器"""
if not issubclass(fetcher_class, DataFetcher):
raise ValueError("数据获取器必须继承自DataFetcher基类")
cls._fetchers[name.lower()] = fetcher_class