from .base import BaseDataProvider from typing import Any, Dict, List, Optional, Callable import logging import asyncio import tushare as ts import math import datetime logger = logging.getLogger(__name__) class TushareProvider(BaseDataProvider): def _initialize(self): if not self.token: raise ValueError("Tushare API token not provided.") # 使用官方 SDK 客户端 self._pro = ts.pro_api(self.token) # 交易日历缓存:key=(exchange, start, end) -> List[Dict] self._trade_cal_cache: Dict[str, List[Dict[str, Any]]] = {} async def _resolve_trade_dates(self, dates: List[str], exchange: str = "SSE") -> Dict[str, str]: """ 将任意日期映射为“该日若非交易日,则取不晚于该日的最近一个交易日”。 返回映射:requested_date -> resolved_trade_date。 """ if not dates: return {} start_date = min(dates) end_date = max(dates) cache_key = f"{exchange}:{start_date}:{end_date}" if cache_key in self._trade_cal_cache: cal_rows = self._trade_cal_cache[cache_key] else: cal_rows = await self._query( api_name="trade_cal", params={ "exchange": exchange, "start_date": start_date, "end_date": end_date, }, fields=["cal_date", "is_open", "pretrade_date"], ) self._trade_cal_cache[cache_key] = cal_rows by_date: Dict[str, Dict[str, Any]] = {str(r.get("cal_date")): r for r in cal_rows} # 同时准备已开放的交易日期序列,便于兜底搜索 open_dates = sorted([d for d, r in by_date.items() if int(r.get("is_open", 0)) == 1]) def _prev_open(d: str) -> Optional[str]: # 找到 <= d 的最大开市日 lo, hi = 0, len(open_dates) - 1 ans = None while lo <= hi: mid = (lo + hi) // 2 if open_dates[mid] <= d: ans = open_dates[mid] lo = mid + 1 else: hi = mid - 1 return ans resolved: Dict[str, str] = {} for d in dates: row = by_date.get(d) if row is None: # 不在本段日历(极少数情况),做一次兜底:使用区间内最近开市日 prev_d = _prev_open(d) if prev_d: resolved[d] = prev_d else: # 最后兜底,仍找不到则原样返回 resolved[d] = d continue is_open = int(row.get("is_open", 0)) if is_open == 1: resolved[d] = d else: prev = str(row.get("pretrade_date") or "") if prev: resolved[d] = prev else: prev_d = _prev_open(d) resolved[d] = prev_d or d return resolved async def _query( self, api_name: str, params: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ 使用官方 tushare SDK 统一查询,返回字典列表。 为避免阻塞事件循环,内部通过 asyncio.to_thread 在线程中执行同步调用。 """ params = params or {} def _call() -> List[Dict[str, Any]]: # 将字段列表转换为逗号分隔的字符串(SDK 推荐方式) fields_arg: Optional[str] = ",".join(fields) if isinstance(fields, list) else None # 优先使用属性方式(pro.fina_indicator 等);若不存在则回退到通用 query func: Optional[Callable] = getattr(self._pro, api_name, None) try: if callable(func): df = func(**params, fields=fields_arg) if fields_arg else func(**params) else: # 通用回退:pro.query(name, params=..., fields=...) if fields_arg: df = self._pro.query(api_name, params=params, fields=fields_arg) else: df = self._pro.query(api_name, params=params) except Exception as exc: # 将 SDK 抛出的异常包装为统一日志 raise RuntimeError(f"tushare.{api_name} failed: {exc}") if df is None or df.empty: return [] # DataFrame -> List[Dict] return df.to_dict(orient="records") try: rows: List[Dict[str, Any]] = await asyncio.to_thread(_call) # 清洗 NaN/Inf,避免 JSON 序列化错误 DATE_KEYS = { "cal_date", "pretrade_date", "trade_date", "trade_dt", "date", "end_date", "ann_date", "f_ann_date", "period" } def _sanitize_value(key: str, v: Any) -> Any: if v is None: return None # 保持日期/期末字段为字符串(避免 20231231 -> 20231231.0 导致匹配失败) if key in DATE_KEYS: try: s = str(v) # 去除意外的小数点形式 if s.endswith(".0"): s = s[:-2] return s except Exception: return str(v) try: # 处理 numpy.nan / numpy.inf / Decimal / numpy 数值等,统一为 Python float fv = float(v) return fv if math.isfinite(fv) else None except Exception: # 利用自反性判断 NaN(NaN != NaN) try: if v != v: return None except Exception: pass return v for row in rows: for k, v in list(row.items()): row[k] = _sanitize_value(k, v) # logger.info(f"Tushare '{api_name}' returned {len(rows)} rows.") return rows except Exception as e: logger.error(f"Exception calling tushare '{api_name}': {e}") raise async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]: try: rows = await self._query( api_name="stock_basic", params={"ts_code": stock_code}, ) return rows[0] if rows else None except Exception as e: logger.error(f"Tushare get_stock_basic failed for {stock_code}: {e}") return None async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]: try: rows = await self._query( api_name="daily", params={ "ts_code": stock_code, "start_date": start_date, "end_date": end_date, }, ) return rows or [] except Exception as e: logger.error(f"Tushare get_daily_price failed for {stock_code}: {e}") return [] async def get_daily_basic_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]: """ 获取指定交易日列表的 daily_basic 数据(例如 total_mv、pe、pb)。 """ try: if not trade_dates: return [] # 将请求日期映射到不晚于该日的最近交易日 mapping = await self._resolve_trade_dates(trade_dates, exchange="SSE") resolved_dates = list(set(mapping.values())) start_date = min(resolved_dates) end_date = max(resolved_dates) # 一次性取区间内数据,再按解析后的交易日过滤 all_rows = await self._query( api_name="daily_basic", params={ "ts_code": stock_code, "start_date": start_date, "end_date": end_date, }, ) wanted = set(resolved_dates) rows = [r for r in all_rows if str(r.get("trade_date")) in wanted] logger.info(f"Tushare daily_basic returned {len(rows)} rows for {stock_code} on {len(trade_dates)} requested dates (resolved to {len(wanted)} trading dates)") return rows except Exception as e: logger.error(f"Tushare get_daily_basic_points failed for {stock_code}: {e}") return [] async def get_daily_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]: """ 获取指定交易日列表的日行情(例如 close)。 """ try: if not trade_dates: return [] mapping = await self._resolve_trade_dates(trade_dates, exchange="SSE") resolved_dates = list(set(mapping.values())) start_date = min(resolved_dates) end_date = max(resolved_dates) all_rows = await self._query( api_name="daily", params={ "ts_code": stock_code, "start_date": start_date, "end_date": end_date, }, ) wanted = set(resolved_dates) rows = [r for r in all_rows if str(r.get("trade_date")) in wanted] logger.info(f"Tushare daily returned {len(rows)} rows for {stock_code} on {len(trade_dates)} requested dates (resolved to {len(wanted)} trading dates)") return rows except Exception as e: logger.error(f"Tushare get_daily_points failed for {stock_code}: {e}") return [] def _calculate_derived_metrics(self, series: Dict[str, List[Dict]], periods: List[str]) -> Dict[str, List[Dict]]: """ 在 Tushare provider 内部计算派生指标。 """ # --- Helper Functions --- def _get_value(key: str, period: str) -> Optional[float]: if key not in series: return None point = next((p for p in series[key] if p.get("period") == period), None) if point is None or point.get("value") is None: return None try: return float(point["value"]) except (ValueError, TypeError): return None def _get_avg_value(key: str, period: str) -> Optional[float]: current_val = _get_value(key, period) try: # 总是和上一年度的年报值(如果存在)进行平均 current_year = int(period[:4]) prev_year_end_period = str(current_year - 1) + "1231" prev_val = _get_value(key, prev_year_end_period) except (ValueError, TypeError): prev_val = None if current_val is None: return None if prev_val is None: return current_val return (current_val + prev_val) / 2 def _get_cogs(period: str) -> Optional[float]: revenue = _get_value('revenue', period) gp_margin_raw = _get_value('grossprofit_margin', period) if revenue is None or gp_margin_raw is None: return None gp_margin = gp_margin_raw / 100.0 if abs(gp_margin_raw) > 1 else gp_margin_raw return revenue * (1 - gp_margin) def add_series(key: str, data: List[Dict]): if data: series[key] = data # --- Calculations --- fcf_data = [] for period in periods: op_cashflow = _get_value('n_cashflow_act', period) capex = _get_value('c_pay_acq_const_fiolta', period) if op_cashflow is not None and capex is not None: fcf_data.append({"period": period, "value": op_cashflow - capex}) add_series('__free_cash_flow', fcf_data) fee_calcs = [ ('__sell_rate', 'sell_exp', 'revenue'), ('__admin_rate', 'admin_exp', 'revenue'), ('__rd_rate', 'rd_exp', 'revenue'), ('__depr_ratio', 'depr_fa_coga_dpba', 'revenue'), ] for key, num_key, den_key in fee_calcs: data = [] for period in periods: numerator = _get_value(num_key, period) denominator = _get_value(den_key, period) if numerator is not None and denominator is not None and denominator != 0: data.append({"period": period, "value": (numerator / denominator) * 100}) add_series(key, data) tax_rate_data = [] for period in periods: tax_to_ebt = _get_value('tax_to_ebt', period) if tax_to_ebt is not None: rate = tax_to_ebt * 100 if abs(tax_to_ebt) <= 1 else tax_to_ebt tax_rate_data.append({"period": period, "value": rate}) add_series('__tax_rate', tax_rate_data) other_fee_data = [] for period in periods: gp_raw = _get_value('grossprofit_margin', period) np_raw = _get_value('netprofit_margin', period) rev = _get_value('revenue', period) sell_exp = _get_value('sell_exp', period) admin_exp = _get_value('admin_exp', period) rd_exp = _get_value('rd_exp', period) if all(v is not None for v in [gp_raw, np_raw, rev, sell_exp, admin_exp, rd_exp]) and rev != 0: gp = gp_raw / 100 if abs(gp_raw) > 1 else gp_raw np = np_raw / 100 if abs(np_raw) > 1 else np_raw sell_rate = sell_exp / rev admin_rate = admin_exp / rev rd_rate = rd_exp / rev other_rate = (gp - np - sell_rate - admin_rate - rd_rate) * 100 other_fee_data.append({"period": period, "value": other_rate}) add_series('__other_fee_rate', other_fee_data) asset_ratio_keys = [ ('__money_cap_ratio', 'money_cap'), ('__inventories_ratio', 'inventories'), ('__ar_ratio', 'accounts_receiv_bill'), ('__prepay_ratio', 'prepayment'), ('__fix_assets_ratio', 'fix_assets'), ('__lt_invest_ratio', 'lt_eqt_invest'), ('__goodwill_ratio', 'goodwill'), ('__ap_ratio', 'accounts_pay'), ('__st_borr_ratio', 'st_borr'), ('__lt_borr_ratio', 'lt_borr'), ] for key, num_key in asset_ratio_keys: data = [] for period in periods: numerator = _get_value(num_key, period) denominator = _get_value('total_assets', period) if numerator is not None and denominator is not None and denominator != 0: data.append({"period": period, "value": (numerator / denominator) * 100}) add_series(key, data) adv_data = [] for period in periods: adv = _get_value('adv_receipts', period) or 0 contract = _get_value('contract_liab', period) or 0 total_assets = _get_value('total_assets', period) if total_assets is not None and total_assets != 0: adv_data.append({"period": period, "value": ((adv + contract) / total_assets) * 100}) add_series('__adv_ratio', adv_data) other_assets_data = [] known_assets_keys = ['money_cap', 'inventories', 'accounts_receiv_bill', 'prepayment', 'fix_assets', 'lt_eqt_invest', 'goodwill'] for period in periods: total_assets = _get_value('total_assets', period) if total_assets is not None and total_assets != 0: sum_known = sum(_get_value(k, period) or 0 for k in known_assets_keys) other_assets_data.append({"period": period, "value": ((total_assets - sum_known) / total_assets) * 100}) add_series('__other_assets_ratio', other_assets_data) op_assets_data = [] for period in periods: total_assets = _get_value('total_assets', period) if total_assets is not None and total_assets != 0: inv = _get_value('inventories', period) or 0 ar = _get_value('accounts_receiv_bill', period) or 0 pre = _get_value('prepayment', period) or 0 ap = _get_value('accounts_pay', period) or 0 adv = _get_value('adv_receipts', period) or 0 contract_liab = _get_value('contract_liab', period) or 0 operating_assets = inv + ar + pre - ap - adv - contract_liab op_assets_data.append({"period": period, "value": (operating_assets / total_assets) * 100}) add_series('__operating_assets_ratio', op_assets_data) debt_ratio_data = [] for period in periods: total_assets = _get_value('total_assets', period) if total_assets is not None and total_assets != 0: st_borr = _get_value('st_borr', period) or 0 lt_borr = _get_value('lt_borr', period) or 0 debt_ratio_data.append({"period": period, "value": ((st_borr + lt_borr) / total_assets) * 100}) add_series('__interest_bearing_debt_ratio', debt_ratio_data) payturn_data = [] for period in periods: avg_ap = _get_avg_value('accounts_pay', period) cogs = _get_cogs(period) if avg_ap is not None and cogs is not None and cogs != 0: payturn_data.append({"period": period, "value": (365 * avg_ap) / cogs}) add_series('payturn_days', payturn_data) per_capita_calcs = [ ('__rev_per_emp', 'revenue', 10000), ('__profit_per_emp', 'n_income', 10000), ('__salary_per_emp', 'c_paid_to_for_empl', 10000), ] for key, num_key, divisor in per_capita_calcs: data = [] for period in periods: numerator = _get_value(num_key, period) employees = _get_value('employees', period) if numerator is not None and employees is not None and employees != 0: data.append({"period": period, "value": (numerator / employees) / divisor}) add_series(key, data) return series async def get_financial_statements(self, stock_code: str, report_dates: Optional[List[str]] = None) -> Dict[str, List[Dict[str, Any]]]: # 1) 一次性拉取所需四表(尽量齐全字段),再按指定 report_dates 过滤 # 字段列表基于官方示例,避免超量请求可按需精简 bs_fields = [ "ts_code","ann_date","f_ann_date","end_date","report_type","comp_type","end_type", "money_cap","inventories","prepayment","accounts_receiv","accounts_receiv_bill","goodwill", "lt_eqt_invest","fix_assets","total_assets","accounts_pay","adv_receipts","contract_liab", "st_borr","lt_borr","total_cur_assets","total_cur_liab","total_ncl","total_liab","total_hldr_eqy_exc_min_int", ] ic_fields = [ "ts_code","ann_date","f_ann_date","end_date","report_type","comp_type","end_type", "total_revenue","revenue","sell_exp","admin_exp","rd_exp","operate_profit","total_profit", "income_tax","n_income","n_income_attr_p","ebit","ebitda","netprofit_margin","grossprofit_margin", ] cf_fields = [ "ts_code","ann_date","f_ann_date","end_date","comp_type","report_type","end_type", "n_cashflow_act","c_pay_acq_const_fiolta","c_paid_to_for_empl","depr_fa_coga_dpba", ] fi_fields = [ "ts_code","end_date","ann_date","grossprofit_margin","netprofit_margin","tax_to_ebt","roe","roa","roic", "invturn_days","arturn_days","fa_turn","tr_yoy","dt_netprofit_yoy","assets_turn", ] try: bs_rows, ic_rows, cf_rows, fi_rows, rep_rows, div_rows, holder_rows, company_rows = await asyncio.gather( self._query("balancesheet", params={"ts_code": stock_code, "report_type": 1}, fields=bs_fields), self._query("income", params={"ts_code": stock_code, "report_type": 1}, fields=ic_fields), self._query("cashflow", params={"ts_code": stock_code, "report_type": 1}, fields=cf_fields), self._query("fina_indicator", params={"ts_code": stock_code}, fields=fi_fields), # 回购公告 self._query( "repurchase", params={"ts_code": stock_code}, fields=[ "ts_code","ann_date","end_date","proc","exp_date","vol","amount","high_limit","low_limit", ], ), # 分红公告(仅取必要字段) self._query( "dividend", params={"ts_code": stock_code}, fields=[ "ts_code","end_date","cash_div_tax","pay_date","base_share", ], ), # 股东户数(按报告期) self._query( "stk_holdernumber", params={"ts_code": stock_code}, fields=[ "ts_code","ann_date","end_date","holder_num", ], ), # 公司基本信息(包含员工数) self._query( "stock_company", params={"ts_code": stock_code}, fields=[ "ts_code","employees", ], ), ) try: logger.info(f"[Dividend] fetched {len(div_rows)} rows for {stock_code}") except Exception: pass except Exception as e: logger.error(f"Tushare bulk fetch failed for {stock_code}: {e}") bs_rows, ic_rows, cf_rows, fi_rows, rep_rows, div_rows, holder_rows, company_rows = [], [], [], [], [], [], [], [] # 2) 以 end_date 聚合合并四表 by_date: Dict[str, Dict[str, Any]] = {} def _merge_rows(rows: List[Dict[str, Any]]): for r in rows or []: end_date = str(r.get("end_date") or r.get("period") or "") if not end_date: continue if end_date not in by_date: by_date[end_date] = {"ts_code": stock_code, "end_date": end_date} by_date[end_date].update(r) _merge_rows(bs_rows) _merge_rows(ic_rows) _merge_rows(cf_rows) _merge_rows(fi_rows) # 3) 筛选报告期:今年的最新报告期 + 往年所有年报 current_year = str(datetime.date.today().year) all_available_dates = sorted(by_date.keys(), reverse=True) latest_current_year_report = None for d in all_available_dates: if d.startswith(current_year): latest_current_year_report = d break previous_years_annual_reports = [ d for d in all_available_dates if d.endswith("1231") and not d.startswith(current_year) ] wanted_dates = [] if latest_current_year_report: wanted_dates.append(latest_current_year_report) wanted_dates.extend(previous_years_annual_reports) all_statements = [by_date[d] for d in wanted_dates if d in by_date] logger.info(f"Successfully prepared {len(all_statements)} merged statement(s) for {stock_code} from {len(by_date)} available reports.") # Transform to series format series: Dict[str, List[Dict]] = {} if all_statements: for report in all_statements: period = report.get("end_date", "") if not period: continue for key, value in report.items(): if key in ['ts_code', 'end_date', 'ann_date', 'f_ann_date', 'report_type', 'comp_type', 'end_type', 'update_flag', 'period']: continue # 仅保留可转为有限 float 的数值,避免 JSON 序列化错误 try: fv = float(value) except (TypeError, ValueError): continue if value is not None and math.isfinite(fv): if key not in series: series[key] = [] if not any(d['period'] == period for d in series[key]): series[key].append({"period": period, "value": fv}) # 汇总回购信息为年度序列:按报告期 end_date 年份分组; # 其中 repurchase_amount 取该年内“最后一个 ann_date”的 amount 值。 if 'rep_rows' in locals() and rep_rows: rep_by_year: Dict[str, Dict[str, Any]] = {} for r in rep_rows: endd = str(r.get("end_date") or r.get("ann_date") or "") if not endd: continue y = endd[:4] bucket = rep_by_year.setdefault(y, { "amount_sum": 0.0, "vol": 0.0, "high_limit": None, "low_limit": None, "last_ann_date": None, "amount_last": None, }) amt = r.get("amount") vol = r.get("vol") hi = r.get("high_limit") lo = r.get("low_limit") ann = str(r.get("ann_date") or "") if isinstance(amt, (int, float)) and amt is not None: bucket["amount_sum"] += float(amt) if ann and ann[:4] == y: last = bucket["last_ann_date"] if last is None or ann > last: bucket["last_ann_date"] = ann bucket["amount_last"] = float(amt) if isinstance(vol, (int, float)) and vol is not None: bucket["vol"] += float(vol) if isinstance(hi, (int, float)) and hi is not None: bucket["high_limit"] = float(hi) if isinstance(lo, (int, float)) and lo is not None: bucket["low_limit"] = float(lo) if rep_by_year: amt_series = [] vol_series = [] hi_series = [] lo_series = [] for y, v in rep_by_year.items(): # 当年数据放在当前年最新报告期,否则放在年度报告期 if y == current_year and latest_current_year_report: period_key = latest_current_year_report else: period_key = f"{y}1231" if v.get("amount_last") is not None: amt_series.append({"period": period_key, "value": v["amount_last"]}) if v.get("vol"): vol_series.append({"period": period_key, "value": v["vol"]}) if v.get("high_limit") is not None: hi_series.append({"period": period_key, "value": v["high_limit"]}) if v.get("low_limit") is not None: lo_series.append({"period": period_key, "value": v["low_limit"]}) if amt_series: series["repurchase_amount"] = amt_series if vol_series: series["repurchase_vol"] = vol_series if hi_series: series["repurchase_high_limit"] = hi_series if lo_series: series["repurchase_low_limit"] = lo_series # 汇总分红信息为年度序列:以真实派息日 pay_date 的年份分组; # 每条记录金额= 每股分红(cash_div_tax) * 基准股本(base_share),其中 base_share 单位为“万股”, # 金额以“亿”为单位返回,因此需再除以 10000。 if 'div_rows' in locals() and div_rows: div_by_year: Dict[str, float] = {} for r in div_rows: pay = str(r.get("pay_date") or "") # 仅统计存在数字年份的真实派息日 if not pay or len(pay) < 4 or not any(ch.isdigit() for ch in pay): continue y = pay[:4] cash_div = r.get("cash_div_tax") base_share = r.get("base_share") if isinstance(cash_div, (int, float)) and isinstance(base_share, (int, float)): # 现金分红总额(万元)= 每股分红(元) * 基准股本(万股) # 转为“亿”需除以 10000 amount_billion = (float(cash_div) * float(base_share)) / 10000.0 div_by_year[y] = div_by_year.get(y, 0.0) + amount_billion if div_by_year: div_series = [] for y, v in sorted(div_by_year.items()): # 当年数据放在当前年最新报告期,否则放在年度报告期 if y == current_year and latest_current_year_report: period_key = latest_current_year_report else: period_key = f"{y}1231" div_series.append({"period": period_key, "value": v}) series["dividend_amount"] = div_series # try: # logger.info(f"[Dividend] Series dividend_amount(period) for {stock_code}: {div_series}") # except Exception: # pass # 汇总股东户数信息:按报告期 end_date 分组,取最新的 holder_num if 'holder_rows' in locals() and holder_rows: # 按 end_date 分组,取最新的 ann_date 的 holder_num holder_by_period: Dict[str, Dict[str, Any]] = {} for r in holder_rows: end_date = str(r.get("end_date") or "") if not end_date: continue ann_date = str(r.get("ann_date") or "") holder_num = r.get("holder_num") if end_date not in holder_by_period: holder_by_period[end_date] = { "holder_num": holder_num, "latest_ann_date": ann_date } else: # 比较 ann_date,取最新的 current_latest = holder_by_period[end_date]["latest_ann_date"] if ann_date and (not current_latest or ann_date > current_latest): holder_by_period[end_date] = { "holder_num": holder_num, "latest_ann_date": ann_date } # 使用与财务报表相同的报告期筛选逻辑 # 股东户数应该与财务报表的报告期时间点对应 holder_series = [] for end_date in wanted_dates: if end_date in holder_by_period: data = holder_by_period[end_date] holder_num = data["holder_num"] if isinstance(holder_num, (int, float)) and holder_num is not None: holder_series.append({"period": end_date, "value": float(holder_num)}) if holder_series: series["holder_num"] = holder_series # 汇总员工数信息:员工数放在去年的年末(上一年的12月31日) if 'company_rows' in locals() and company_rows: # 员工数通常是静态数据,取最新的一个值 latest_employees = None for r in company_rows: employees = r.get("employees") if isinstance(employees, (int, float)) and employees is not None: latest_employees = float(employees) break # 取第一个有效值 if latest_employees is not None: # 将员工数放在去年的年末(上一年的12月31日) previous_year = str(datetime.date.today().year - 1) period_key = f"{previous_year}1231" series["employees"] = [{"period": period_key, "value": latest_employees}] # Calculate derived metrics periods = sorted(list(set(d['period'] for s in series.values() for d in s))) series = self._calculate_derived_metrics(series, periods) return series