from .base import BaseDataProvider from typing import Any, Dict, List, Optional import httpx import logging import asyncio logger = logging.getLogger(__name__) TUSHARE_PRO_URL = "https://api.tushare.pro" class TushareProvider(BaseDataProvider): def _initialize(self): if not self.token: raise ValueError("Tushare API token not provided.") # Use httpx.AsyncClient directly self._client = httpx.AsyncClient(timeout=30) async def _query( self, api_name: str, params: Optional[Dict[str, Any]] = None, fields: Optional[str] = None, ) -> List[Dict[str, Any]]: payload = { "api_name": api_name, "token": self.token, "params": params or {}, } if "limit" not in payload["params"]: payload["params"]["limit"] = 5000 if fields: payload["fields"] = fields logger.info(f"Querying Tushare API '{api_name}' with params: {params}") try: resp = await self._client.post(TUSHARE_PRO_URL, json=payload) resp.raise_for_status() data = resp.json() if data.get("code") != 0: err_msg = data.get("msg") or "Unknown Tushare error" logger.error(f"Tushare API error for '{api_name}': {err_msg}") raise RuntimeError(f"{api_name}: {err_msg}") fields_def = data.get("data", {}).get("fields", []) items = data.get("data", {}).get("items", []) rows: List[Dict[str, Any]] = [] for it in items: row = {fields_def[i]: it[i] for i in range(len(fields_def))} rows.append(row) logger.info(f"Tushare API '{api_name}' returned {len(rows)} rows.") return rows except httpx.HTTPStatusError as e: logger.error(f"HTTP error calling Tushare API '{api_name}': {e.response.status_code} - {e.response.text}") raise except Exception as e: logger.error(f"Exception calling Tushare API '{api_name}': {e}") raise async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]: try: rows = await self._query( api_name="stock_basic", params={"ts_code": stock_code}, ) return rows[0] if rows else None except Exception as e: logger.error(f"Tushare get_stock_basic failed for {stock_code}: {e}") return None async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]: try: return await self._query( api_name="daily", params={ "ts_code": stock_code, "start_date": start_date, "end_date": end_date, }, ) except Exception as e: logger.error(f"Tushare get_daily_price failed for {stock_code}: {e}") return [] async def get_daily_basic_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]: """ 获取指定交易日列表的 daily_basic 数据(例如 total_mv、pe、pb)。 """ try: tasks = [ self._query( api_name="daily_basic", params={"ts_code": stock_code, "trade_date": d}, ) for d in trade_dates ] results = await asyncio.gather(*tasks, return_exceptions=True) rows: List[Dict[str, Any]] = [] for res in results: if isinstance(res, list) and res: rows.extend(res) logger.info(f"Tushare daily_basic returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates") return rows except Exception as e: logger.error(f"Tushare get_daily_basic_points failed for {stock_code}: {e}") return [] async def get_daily_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]: """ 获取指定交易日列表的日行情(例如 close)。 """ try: tasks = [ self._query( api_name="daily", params={"ts_code": stock_code, "trade_date": d}, ) for d in trade_dates ] results = await asyncio.gather(*tasks, return_exceptions=True) rows: List[Dict[str, Any]] = [] for res in results: if isinstance(res, list) and res: rows.extend(res) logger.info(f"Tushare daily returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates") return rows except Exception as e: logger.error(f"Tushare get_daily_points failed for {stock_code}: {e}") return [] def _calculate_derived_metrics(self, series: Dict[str, List[Dict]], years: List[str]) -> Dict[str, List[Dict]]: """ 在 Tushare provider 内部计算派生指标。 """ # --- Helper Functions --- def _get_value(key: str, year: str) -> Optional[float]: if key not in series: return None point = next((p for p in series[key] if p.get("year") == year), None) if point is None or point.get("value") is None: return None try: return float(point["value"]) except (ValueError, TypeError): return None def _get_avg_value(key: str, year: str) -> Optional[float]: current_val = _get_value(key, year) try: prev_year = str(int(year) - 1) prev_val = _get_value(key, prev_year) except (ValueError, TypeError): prev_val = None if current_val is None: return None if prev_val is None: return current_val return (current_val + prev_val) / 2 def _get_cogs(year: str) -> Optional[float]: revenue = _get_value('revenue', year) gp_margin_raw = _get_value('grossprofit_margin', year) if revenue is None or gp_margin_raw is None: return None gp_margin = gp_margin_raw / 100.0 if abs(gp_margin_raw) > 1 else gp_margin_raw return revenue * (1 - gp_margin) def add_series(key: str, data: List[Dict]): if data: series[key] = data # --- Calculations --- fcf_data = [] for year in years: op_cashflow = _get_value('n_cashflow_act', year) capex = _get_value('c_pay_acq_const_fiolta', year) if op_cashflow is not None and capex is not None: fcf_data.append({"year": year, "value": op_cashflow - capex}) add_series('__free_cash_flow', fcf_data) fee_calcs = [ ('__sell_rate', 'sell_exp', 'revenue'), ('__admin_rate', 'admin_exp', 'revenue'), ('__rd_rate', 'rd_exp', 'revenue'), ('__depr_ratio', 'depr_fa_coga_dpba', 'revenue'), ] for key, num_key, den_key in fee_calcs: data = [] for year in years: numerator = _get_value(num_key, year) denominator = _get_value(den_key, year) if numerator is not None and denominator is not None and denominator != 0: data.append({"year": year, "value": (numerator / denominator) * 100}) add_series(key, data) tax_rate_data = [] for year in years: tax_to_ebt = _get_value('tax_to_ebt', year) if tax_to_ebt is not None: rate = tax_to_ebt * 100 if abs(tax_to_ebt) <= 1 else tax_to_ebt tax_rate_data.append({"year": year, "value": rate}) add_series('__tax_rate', tax_rate_data) other_fee_data = [] for year in years: gp_raw = _get_value('grossprofit_margin', year) np_raw = _get_value('netprofit_margin', year) rev = _get_value('revenue', year) sell_exp = _get_value('sell_exp', year) admin_exp = _get_value('admin_exp', year) rd_exp = _get_value('rd_exp', year) if all(v is not None for v in [gp_raw, np_raw, rev, sell_exp, admin_exp, rd_exp]) and rev != 0: gp = gp_raw / 100 if abs(gp_raw) > 1 else gp_raw np = np_raw / 100 if abs(np_raw) > 1 else np_raw sell_rate = sell_exp / rev admin_rate = admin_exp / rev rd_rate = rd_exp / rev other_rate = (gp - np - sell_rate - admin_rate - rd_rate) * 100 other_fee_data.append({"year": year, "value": other_rate}) add_series('__other_fee_rate', other_fee_data) asset_ratio_keys = [ ('__money_cap_ratio', 'money_cap'), ('__inventories_ratio', 'inventories'), ('__ar_ratio', 'accounts_receiv_bill'), ('__prepay_ratio', 'prepayment'), ('__fix_assets_ratio', 'fix_assets'), ('__lt_invest_ratio', 'lt_eqt_invest'), ('__goodwill_ratio', 'goodwill'), ('__ap_ratio', 'accounts_pay'), ('__st_borr_ratio', 'st_borr'), ('__lt_borr_ratio', 'lt_borr'), ] for key, num_key in asset_ratio_keys: data = [] for year in years: numerator = _get_value(num_key, year) denominator = _get_value('total_assets', year) if numerator is not None and denominator is not None and denominator != 0: data.append({"year": year, "value": (numerator / denominator) * 100}) add_series(key, data) adv_data = [] for year in years: adv = _get_value('adv_receipts', year) or 0 contract = _get_value('contract_liab', year) or 0 total_assets = _get_value('total_assets', year) if total_assets is not None and total_assets != 0: adv_data.append({"year": year, "value": ((adv + contract) / total_assets) * 100}) add_series('__adv_ratio', adv_data) other_assets_data = [] known_assets_keys = ['money_cap', 'inventories', 'accounts_receiv_bill', 'prepayment', 'fix_assets', 'lt_eqt_invest', 'goodwill'] for year in years: total_assets = _get_value('total_assets', year) if total_assets is not None and total_assets != 0: sum_known = sum(_get_value(k, year) or 0 for k in known_assets_keys) other_assets_data.append({"year": year, "value": ((total_assets - sum_known) / total_assets) * 100}) add_series('__other_assets_ratio', other_assets_data) op_assets_data = [] for year in years: total_assets = _get_value('total_assets', year) if total_assets is not None and total_assets != 0: inv = _get_value('inventories', year) or 0 ar = _get_value('accounts_receiv_bill', year) or 0 pre = _get_value('prepayment', year) or 0 ap = _get_value('accounts_pay', year) or 0 adv = _get_value('adv_receipts', year) or 0 contract_liab = _get_value('contract_liab', year) or 0 operating_assets = inv + ar + pre - ap - adv - contract_liab op_assets_data.append({"year": year, "value": (operating_assets / total_assets) * 100}) add_series('__operating_assets_ratio', op_assets_data) debt_ratio_data = [] for year in years: total_assets = _get_value('total_assets', year) if total_assets is not None and total_assets != 0: st_borr = _get_value('st_borr', year) or 0 lt_borr = _get_value('lt_borr', year) or 0 debt_ratio_data.append({"year": year, "value": ((st_borr + lt_borr) / total_assets) * 100}) add_series('__interest_bearing_debt_ratio', debt_ratio_data) payturn_data = [] for year in years: avg_ap = _get_avg_value('accounts_pay', year) cogs = _get_cogs(year) if avg_ap is not None and cogs is not None and cogs != 0: payturn_data.append({"year": year, "value": (365 * avg_ap) / cogs}) add_series('payturn_days', payturn_data) per_capita_calcs = [ ('__rev_per_emp', 'revenue', 10000), ('__profit_per_emp', 'n_income', 10000), ('__salary_per_emp', 'c_paid_to_for_empl', 10000), ] for key, num_key, divisor in per_capita_calcs: data = [] for year in years: numerator = _get_value(num_key, year) employees = _get_value('employees', year) if numerator is not None and employees is not None and employees != 0: data.append({"year": year, "value": (numerator / employees) / divisor}) add_series(key, data) return series async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]: all_statements: List[Dict[str, Any]] = [] for date in report_dates: logger.info(f"Fetching financial statements for {stock_code}, report date: {date}") try: bs_rows, ic_rows, cf_rows, fi_rows = await asyncio.gather( self._query( api_name="balancesheet", params={"ts_code": stock_code, "period": date, "report_type": 1}, ), self._query( api_name="income", params={"ts_code": stock_code, "period": date, "report_type": 1}, ), self._query( api_name="cashflow", params={"ts_code": stock_code, "period": date, "report_type": 1}, ), # 补充关键财务比率(ROE/ROA/毛利率等) self._query( api_name="fina_indicator", params={"ts_code": stock_code, "period": date}, ), ) if not bs_rows and not ic_rows and not cf_rows and not fi_rows: logger.warning(f"No financial statements components found from Tushare for {stock_code} on {date}") continue merged: Dict[str, Any] = {"ts_code": stock_code, "end_date": date} bs_data = bs_rows[0] if bs_rows else {} ic_data = ic_rows[0] if ic_rows else {} cf_data = cf_rows[0] if cf_rows else {} fi_data = fi_rows[0] if fi_rows else {} merged.update(bs_data) merged.update(ic_data) merged.update(cf_data) merged.update(fi_data) merged["end_date"] = merged.get("end_date") or merged.get("period") or date logger.debug(f"Merged statement for {date} has keys: {list(merged.keys())}") all_statements.append(merged) except Exception as e: logger.error(f"Tushare get_financial_statement failed for {stock_code} on {date}: {e}") continue logger.info(f"Successfully fetched {len(all_statements)} statement(s) for {stock_code}.") # Transform to series format series: Dict[str, List[Dict]] = {} if all_statements: for report in all_statements: year = report.get("end_date", "")[:4] if not year: continue for key, value in report.items(): if key in ['ts_code', 'end_date', 'ann_date', 'f_ann_date', 'report_type', 'comp_type', 'end_type', 'update_flag', 'period']: continue if isinstance(value, (int, float)) and value is not None: if key not in series: series[key] = [] if not any(d['year'] == year for d in series[key]): series[key].append({"year": year, "value": value}) # Calculate derived metrics years = sorted(list(set(d['year'] for s in series.values() for d in s))) series = self._calculate_derived_metrics(series, years) return series