from .base import BaseDataProvider from typing import Any, Dict, List, Optional import httpx import logging import asyncio logger = logging.getLogger(__name__) TUSHARE_PRO_URL = "https://api.tushare.pro" class TushareProvider(BaseDataProvider): def _initialize(self): if not self.token: raise ValueError("Tushare API token not provided.") # Use httpx.AsyncClient directly self._client = httpx.AsyncClient(timeout=30) async def _query( self, api_name: str, params: Optional[Dict[str, Any]] = None, fields: Optional[str] = None, ) -> List[Dict[str, Any]]: payload = { "api_name": api_name, "token": self.token, "params": params or {}, } if "limit" not in payload["params"]: payload["params"]["limit"] = 5000 if fields: payload["fields"] = fields logger.info(f"Querying Tushare API '{api_name}' with params: {params}") try: resp = await self._client.post(TUSHARE_PRO_URL, json=payload) resp.raise_for_status() data = resp.json() if data.get("code") != 0: err_msg = data.get("msg") or "Unknown Tushare error" logger.error(f"Tushare API error for '{api_name}': {err_msg}") raise RuntimeError(f"{api_name}: {err_msg}") fields_def = data.get("data", {}).get("fields", []) items = data.get("data", {}).get("items", []) rows: List[Dict[str, Any]] = [] for it in items: row = {fields_def[i]: it[i] for i in range(len(fields_def))} rows.append(row) logger.info(f"Tushare API '{api_name}' returned {len(rows)} rows.") return rows except httpx.HTTPStatusError as e: logger.error(f"HTTP error calling Tushare API '{api_name}': {e.response.status_code} - {e.response.text}") raise except Exception as e: logger.error(f"Exception calling Tushare API '{api_name}': {e}") raise async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]: try: rows = await self._query( api_name="stock_basic", params={"ts_code": stock_code}, ) return rows[0] if rows else None except Exception as e: logger.error(f"Tushare get_stock_basic failed for {stock_code}: {e}") return None async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]: try: return await self._query( api_name="daily", params={ "ts_code": stock_code, "start_date": start_date, "end_date": end_date, }, ) except Exception as e: logger.error(f"Tushare get_daily_price failed for {stock_code}: {e}") return [] async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> List[Dict[str, Any]]: all_statements: List[Dict[str, Any]] = [] for date in report_dates: logger.info(f"Fetching financial statements for {stock_code}, report date: {date}") try: bs_rows, ic_rows, cf_rows = await asyncio.gather( self._query( api_name="balancesheet", params={"ts_code": stock_code, "period": date, "report_type": 1}, ), self._query( api_name="income", params={"ts_code": stock_code, "period": date, "report_type": 1}, ), self._query( api_name="cashflow", params={"ts_code": stock_code, "period": date, "report_type": 1}, ) ) if not bs_rows and not ic_rows and not cf_rows: logger.warning(f"No financial statements components found from Tushare for {stock_code} on {date}") continue merged: Dict[str, Any] = {"ts_code": stock_code, "end_date": date} bs_data = bs_rows[0] if bs_rows else {} ic_data = ic_rows[0] if ic_rows else {} cf_data = cf_rows[0] if cf_rows else {} merged.update(bs_data) merged.update(ic_data) merged.update(cf_data) merged["end_date"] = merged.get("end_date") or merged.get("period") or date logger.debug(f"Merged statement for {date} has keys: {list(merged.keys())}") all_statements.append(merged) except Exception as e: logger.error(f"Tushare get_financial_statement failed for {stock_code} on {date}: {e}") continue logger.info(f"Successfully fetched {len(all_statements)} statement(s) for {stock_code}.") return all_statements