""" 数据获取服务基础架构 处理外部数据源的数据获取 """ from typing import Dict, Any, Optional from abc import ABC, abstractmethod import httpx import asyncio from datetime import datetime from ..schemas.data import ( FinancialDataResponse, MarketDataResponse, SymbolValidationResponse, DataSourceStatus ) from ..core.exceptions import ( DataSourceError, APIError, SymbolNotFoundError, RateLimitError, AuthenticationError ) class DataFetcher(ABC): """数据获取服务基类""" def __init__(self, config: Dict[str, Any]): self.config = config self.name = config.get("name", "unknown") self.timeout = config.get("timeout", 30) self.max_retries = config.get("max_retries", 3) self.retry_delay = config.get("retry_delay", 1) @abstractmethod async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse: """获取财务数据""" pass @abstractmethod async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse: """获取市场数据""" pass @abstractmethod async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse: """验证证券代码""" pass async def check_status(self) -> DataSourceStatus: """检查数据源状态""" start_time = datetime.now() try: # 尝试进行简单的健康检查 await self._health_check() end_time = datetime.now() response_time = int((end_time - start_time).total_seconds() * 1000) return DataSourceStatus( name=self.name, is_available=True, last_check=end_time, response_time_ms=response_time ) except Exception as e: end_time = datetime.now() return DataSourceStatus( name=self.name, is_available=False, last_check=end_time, error_message=str(e) ) @abstractmethod async def _health_check(self): """健康检查实现""" pass async def _retry_request(self, func, *args, **kwargs): """重试机制""" last_exception = None for attempt in range(self.max_retries): try: return await func(*args, **kwargs) except (httpx.TimeoutException, httpx.ConnectError) as e: last_exception = e if attempt < self.max_retries - 1: await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避 continue except Exception as e: # 对于其他类型的异常,不重试 raise e # 所有重试都失败了 raise DataSourceError( f"数据源 {self.name} 请求失败,已重试 {self.max_retries} 次", self.name, {"last_error": str(last_exception)} ) class TushareDataFetcher(DataFetcher): """Tushare数据获取器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) self.token = config.get("api_key") or config.get("token") self.base_url = config.get("base_url", "http://api.tushare.pro") if not self.token: raise AuthenticationError("tushare", {"message": "Tushare API token未配置"}) async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse: """获取财务数据""" try: # 转换证券代码格式 ts_code = self._convert_symbol_format(symbol, market) # TODO: 实现实际的Tushare API调用 # 这里暂时返回模拟数据 financial_data = await self._retry_request(self._fetch_tushare_financial, ts_code) return FinancialDataResponse( symbol=symbol, market=market, data_source="tushare", last_updated=datetime.now(), balance_sheet=financial_data.get("balance_sheet"), income_statement=financial_data.get("income_statement"), cash_flow=financial_data.get("cash_flow"), key_metrics=financial_data.get("key_metrics") ) except Exception as e: if isinstance(e, (DataSourceError, APIError)): raise raise DataSourceError(f"获取财务数据失败: {str(e)}", "tushare") async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse: """获取市场数据""" try: ts_code = self._convert_symbol_format(symbol, market) # TODO: 实现实际的Tushare API调用 market_data = await self._retry_request(self._fetch_tushare_market, ts_code) return MarketDataResponse( symbol=symbol, market=market, data_source="tushare", last_updated=datetime.now(), price_data=market_data.get("price_data"), volume_data=market_data.get("volume_data"), technical_indicators=market_data.get("technical_indicators") ) except Exception as e: if isinstance(e, (DataSourceError, APIError)): raise raise DataSourceError(f"获取市场数据失败: {str(e)}", "tushare") async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse: """验证证券代码""" try: ts_code = self._convert_symbol_format(symbol, market) # TODO: 实现实际的证券代码验证 # 暂时模拟验证逻辑 is_valid = await self._retry_request(self._validate_tushare_symbol, ts_code) return SymbolValidationResponse( symbol=symbol, market=market, is_valid=is_valid, company_name="示例公司" if is_valid else None, message="证券代码有效" if is_valid else "证券代码无效" ) except Exception as e: return SymbolValidationResponse( symbol=symbol, market=market, is_valid=False, message=f"验证失败: {str(e)}" ) async def _health_check(self): """健康检查""" try: async with httpx.AsyncClient(timeout=5) as client: # 尝试调用一个简单的API来测试连通性 await self._call_tushare_api(client, "stock_basic", {"limit": 1}) except Exception as e: raise DataSourceError(f"Tushare健康检查失败: {str(e)}", "tushare") def _convert_symbol_format(self, symbol: str, market: str) -> str: """转换证券代码格式为Tushare格式""" if market.lower() == "china": # 中国股票代码格式转换 if symbol.startswith("6"): return f"{symbol}.SH" # 上海证券交易所 elif symbol.startswith(("0", "3")): return f"{symbol}.SZ" # 深圳证券交易所 return symbol async def _fetch_tushare_financial(self, ts_code: str) -> Dict[str, Any]: """获取Tushare财务数据""" async with httpx.AsyncClient(timeout=self.timeout) as client: # 获取资产负债表 balance_sheet_data = await self._call_tushare_api( client, "balancesheet", {"ts_code": ts_code, "period": "20231231"} ) # 获取利润表 income_data = await self._call_tushare_api( client, "income", {"ts_code": ts_code, "period": "20231231"} ) # 获取现金流量表 cashflow_data = await self._call_tushare_api( client, "cashflow", {"ts_code": ts_code, "period": "20231231"} ) # 获取基本财务指标 fina_indicator_data = await self._call_tushare_api( client, "fina_indicator", {"ts_code": ts_code, "period": "20231231"} ) return { "balance_sheet": self._process_balance_sheet(balance_sheet_data), "income_statement": self._process_income_statement(income_data), "cash_flow": self._process_cash_flow(cashflow_data), "key_metrics": self._process_key_metrics(fina_indicator_data) } async def _fetch_tushare_market(self, ts_code: str) -> Dict[str, Any]: """获取Tushare市场数据""" async with httpx.AsyncClient(timeout=self.timeout) as client: # 获取日线数据 daily_data = await self._call_tushare_api( client, "daily", {"ts_code": ts_code, "start_date": "20240101", "end_date": "20241231"} ) # 获取基本信息 stock_basic_data = await self._call_tushare_api( client, "stock_basic", {"ts_code": ts_code} ) return { "price_data": self._process_price_data(daily_data), "volume_data": self._process_volume_data(daily_data), "technical_indicators": self._calculate_technical_indicators(daily_data), "stock_info": self._process_stock_basic(stock_basic_data) } async def _validate_tushare_symbol(self, ts_code: str) -> bool: """验证Tushare证券代码""" try: async with httpx.AsyncClient(timeout=self.timeout) as client: result = await self._call_tushare_api( client, "stock_basic", {"ts_code": ts_code} ) return bool(result and len(result.get("items", [])) > 0) except Exception: return False async def _call_tushare_api(self, client: httpx.AsyncClient, api_name: str, params: Dict[str, Any]) -> Dict[str, Any]: """调用Tushare API""" request_data = { "api_name": api_name, "token": self.token, "params": params, "fields": "" } try: response = await client.post(self.base_url, json=request_data) response.raise_for_status() result = response.json() if result.get("code") != 0: error_msg = result.get("msg", "Unknown error") if "权限" in error_msg or "token" in error_msg.lower(): raise AuthenticationError("tushare", {"message": error_msg}) elif "频率" in error_msg or "limit" in error_msg.lower(): raise RateLimitError("tushare") else: raise APIError(f"Tushare API错误: {error_msg}", result.get("code"), "tushare") return result.get("data", {}) except httpx.HTTPStatusError as e: if e.response.status_code == 401: raise AuthenticationError("tushare", {"status_code": e.response.status_code}) elif e.response.status_code == 429: raise RateLimitError("tushare") else: raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "tushare") except httpx.RequestError as e: raise DataSourceError(f"网络请求失败: {str(e)}", "tushare") def _process_balance_sheet(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理资产负债表数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} # 取最新一期数据 latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: # 将列表数据转换为字典 balance_data = dict(zip(fields, latest)) else: balance_data = latest return { "total_assets": balance_data.get("total_assets", 0), "total_liab": balance_data.get("total_liab", 0), "total_hldr_eqy_exc_min_int": balance_data.get("total_hldr_eqy_exc_min_int", 0), "monetary_cap": balance_data.get("monetary_cap", 0), "accounts_receiv": balance_data.get("accounts_receiv", 0), "inventories": balance_data.get("inventories", 0) } def _process_income_statement(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理利润表数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: income_data = dict(zip(fields, latest)) else: income_data = latest return { "revenue": income_data.get("revenue", 0), "operate_profit": income_data.get("operate_profit", 0), "total_profit": income_data.get("total_profit", 0), "n_income": income_data.get("n_income", 0), "n_income_attr_p": income_data.get("n_income_attr_p", 0), "basic_eps": income_data.get("basic_eps", 0) } def _process_cash_flow(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理现金流量表数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: cashflow_data = dict(zip(fields, latest)) else: cashflow_data = latest return { "n_cashflow_act": cashflow_data.get("n_cashflow_act", 0), "n_cashflow_inv_act": cashflow_data.get("n_cashflow_inv_act", 0), "n_cashflow_fin_act": cashflow_data.get("n_cashflow_fin_act", 0), "c_cash_equ_end_period": cashflow_data.get("c_cash_equ_end_period", 0) } def _process_key_metrics(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理关键财务指标数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: metrics_data = dict(zip(fields, latest)) else: metrics_data = latest return { "pe": metrics_data.get("pe", 0), "pb": metrics_data.get("pb", 0), "ps": metrics_data.get("ps", 0), "roe": metrics_data.get("roe", 0), "roa": metrics_data.get("roa", 0), "gross_margin": metrics_data.get("gross_margin", 0), "debt_to_assets": metrics_data.get("debt_to_assets", 0) } def _process_price_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理价格数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} # 取最新一天的数据 latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: price_data = dict(zip(fields, latest)) else: price_data = latest return { "close": price_data.get("close", 0), "open": price_data.get("open", 0), "high": price_data.get("high", 0), "low": price_data.get("low", 0), "pre_close": price_data.get("pre_close", 0), "change": price_data.get("change", 0), "pct_chg": price_data.get("pct_chg", 0) } def _process_volume_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理成交量数据""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: volume_data = dict(zip(fields, latest)) else: volume_data = latest return { "vol": volume_data.get("vol", 0), "amount": volume_data.get("amount", 0) } def _calculate_technical_indicators(self, data: Dict[str, Any]) -> Dict[str, Any]: """计算技术指标""" if not data or not data.get("items"): return {} items = data["items"] if not items or len(items) < 20: return {} # 简单的移动平均计算 closes = [] for item in items[:20]: # 取最近20天 if isinstance(item, list): fields = data.get("fields", []) close_idx = fields.index("close") if "close" in fields else -1 if close_idx >= 0: closes.append(item[close_idx]) else: closes.append(item.get("close", 0)) if len(closes) >= 5: ma_5 = sum(closes[:5]) / 5 else: ma_5 = 0 if len(closes) >= 20: ma_20 = sum(closes) / 20 else: ma_20 = 0 return { "ma_5": ma_5, "ma_20": ma_20, "ma_60": 0 # 需要更多数据计算 } def _process_stock_basic(self, data: Dict[str, Any]) -> Dict[str, Any]: """处理股票基本信息""" if not data or not data.get("items"): return {} items = data["items"] if not items: return {} latest = items[0] if isinstance(items[0], list) else items fields = data.get("fields", []) if isinstance(latest, list) and fields: basic_data = dict(zip(fields, latest)) else: basic_data = latest return { "name": basic_data.get("name", ""), "industry": basic_data.get("industry", ""), "market": basic_data.get("market", ""), "list_date": basic_data.get("list_date", "") } class YahooDataFetcher(DataFetcher): """Yahoo Finance数据获取器""" def __init__(self, config: Dict[str, Any]): super().__init__(config) self.base_url = config.get("base_url", "https://query1.finance.yahoo.com") async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse: """获取财务数据""" try: yahoo_symbol = self._convert_symbol_format(symbol, market) # TODO: 实现实际的Yahoo Finance API调用 financial_data = await self._retry_request(self._fetch_yahoo_financial, yahoo_symbol) return FinancialDataResponse( symbol=symbol, market=market, data_source="yahoo", last_updated=datetime.now(), balance_sheet=financial_data.get("balance_sheet"), income_statement=financial_data.get("income_statement"), cash_flow=financial_data.get("cash_flow"), key_metrics=financial_data.get("key_metrics") ) except Exception as e: if isinstance(e, (DataSourceError, APIError)): raise raise DataSourceError(f"获取财务数据失败: {str(e)}", "yahoo") async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse: """获取市场数据""" try: yahoo_symbol = self._convert_symbol_format(symbol, market) # TODO: 实现实际的Yahoo Finance API调用 market_data = await self._retry_request(self._fetch_yahoo_market, yahoo_symbol) return MarketDataResponse( symbol=symbol, market=market, data_source="yahoo", last_updated=datetime.now(), price_data=market_data.get("price_data"), volume_data=market_data.get("volume_data"), technical_indicators=market_data.get("technical_indicators") ) except Exception as e: if isinstance(e, (DataSourceError, APIError)): raise raise DataSourceError(f"获取市场数据失败: {str(e)}", "yahoo") async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse: """验证证券代码""" try: yahoo_symbol = self._convert_symbol_format(symbol, market) # TODO: 实现实际的证券代码验证 is_valid = await self._retry_request(self._validate_yahoo_symbol, yahoo_symbol) return SymbolValidationResponse( symbol=symbol, market=market, is_valid=is_valid, company_name="Example Company" if is_valid else None, message="Symbol is valid" if is_valid else "Symbol not found" ) except Exception as e: return SymbolValidationResponse( symbol=symbol, market=market, is_valid=False, message=f"Validation failed: {str(e)}" ) async def _health_check(self): """健康检查""" async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.get(f"{self.base_url}/v1/finance/search?q=AAPL") if response.status_code != 200: raise APIError(f"Yahoo Finance API返回状态码: {response.status_code}", response.status_code, "yahoo") def _convert_symbol_format(self, symbol: str, market: str) -> str: """转换证券代码格式为Yahoo Finance格式""" if market.lower() == "hongkong": return f"{symbol}.HK" elif market.lower() == "japan": return f"{symbol}.T" elif market.lower() == "china": # 中国股票在Yahoo Finance中的格式 if symbol.startswith("6"): return f"{symbol}.SS" # 上海 elif symbol.startswith(("0", "3")): return f"{symbol}.SZ" # 深圳 return symbol async def _fetch_yahoo_financial(self, yahoo_symbol: str) -> Dict[str, Any]: """获取Yahoo Finance财务数据""" # TODO: 实现实际的API调用 return { "balance_sheet": {"totalAssets": 2000000}, "income_statement": {"totalRevenue": 800000}, "cash_flow": {"operatingCashflow": 300000}, "key_metrics": {"trailingPE": 18.5} } async def _fetch_yahoo_market(self, yahoo_symbol: str) -> Dict[str, Any]: """获取Yahoo Finance市场数据""" # TODO: 实现实际的API调用 return { "price_data": {"regularMarketPrice": 150.0, "dayHigh": 155.0, "dayLow": 145.0}, "volume_data": {"regularMarketVolume": 2000000}, "technical_indicators": {"fiftyDayAverage": 148.0, "twoHundredDayAverage": 152.0} } async def _validate_yahoo_symbol(self, yahoo_symbol: str) -> bool: """验证Yahoo Finance证券代码""" # TODO: 实现实际的验证逻辑 return True class DataFetcherFactory: """数据获取器工厂""" _fetchers = { "tushare": TushareDataFetcher, "yahoo": YahooDataFetcher, } @classmethod def create_fetcher(cls, data_source: str, config: Dict[str, Any]) -> DataFetcher: """创建数据获取器""" data_source_lower = data_source.lower() if data_source_lower not in cls._fetchers: raise DataSourceError( f"不支持的数据源: {data_source}", data_source, {"supported_sources": list(cls._fetchers.keys())} ) fetcher_class = cls._fetchers[data_source_lower] return fetcher_class(config) @classmethod def get_supported_sources(cls) -> list: """获取支持的数据源列表""" return list(cls._fetchers.keys()) @classmethod def register_fetcher(cls, name: str, fetcher_class: type): """注册新的数据获取器""" if not issubclass(fetcher_class, DataFetcher): raise ValueError("数据获取器必须继承自DataFetcher基类") cls._fetchers[name.lower()] = fetcher_class