373 lines
16 KiB
Python
373 lines
16 KiB
Python
from .base import BaseDataProvider
|
||
from typing import Any, Dict, List, Optional
|
||
import httpx
|
||
import logging
|
||
import asyncio
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
TUSHARE_PRO_URL = "https://api.tushare.pro"
|
||
|
||
class TushareProvider(BaseDataProvider):
|
||
|
||
def _initialize(self):
|
||
if not self.token:
|
||
raise ValueError("Tushare API token not provided.")
|
||
# Use httpx.AsyncClient directly
|
||
self._client = httpx.AsyncClient(timeout=30)
|
||
|
||
async def _query(
|
||
self,
|
||
api_name: str,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
fields: Optional[str] = None,
|
||
) -> List[Dict[str, Any]]:
|
||
payload = {
|
||
"api_name": api_name,
|
||
"token": self.token,
|
||
"params": params or {},
|
||
}
|
||
if "limit" not in payload["params"]:
|
||
payload["params"]["limit"] = 5000
|
||
if fields:
|
||
payload["fields"] = fields
|
||
|
||
logger.info(f"Querying Tushare API '{api_name}' with params: {params}")
|
||
|
||
try:
|
||
resp = await self._client.post(TUSHARE_PRO_URL, json=payload)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
|
||
if data.get("code") != 0:
|
||
err_msg = data.get("msg") or "Unknown Tushare error"
|
||
logger.error(f"Tushare API error for '{api_name}': {err_msg}")
|
||
raise RuntimeError(f"{api_name}: {err_msg}")
|
||
|
||
fields_def = data.get("data", {}).get("fields", [])
|
||
items = data.get("data", {}).get("items", [])
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
for it in items:
|
||
row = {fields_def[i]: it[i] for i in range(len(fields_def))}
|
||
rows.append(row)
|
||
|
||
logger.info(f"Tushare API '{api_name}' returned {len(rows)} rows.")
|
||
return rows
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
logger.error(f"HTTP error calling Tushare API '{api_name}': {e.response.status_code} - {e.response.text}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Exception calling Tushare API '{api_name}': {e}")
|
||
raise
|
||
|
||
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
||
try:
|
||
rows = await self._query(
|
||
api_name="stock_basic",
|
||
params={"ts_code": stock_code},
|
||
)
|
||
return rows[0] if rows else None
|
||
except Exception as e:
|
||
logger.error(f"Tushare get_stock_basic failed for {stock_code}: {e}")
|
||
return None
|
||
|
||
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
|
||
try:
|
||
return await self._query(
|
||
api_name="daily",
|
||
params={
|
||
"ts_code": stock_code,
|
||
"start_date": start_date,
|
||
"end_date": end_date,
|
||
},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Tushare get_daily_price failed for {stock_code}: {e}")
|
||
return []
|
||
|
||
async def get_daily_basic_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定交易日列表的 daily_basic 数据(例如 total_mv、pe、pb)。
|
||
"""
|
||
try:
|
||
tasks = [
|
||
self._query(
|
||
api_name="daily_basic",
|
||
params={"ts_code": stock_code, "trade_date": d},
|
||
) for d in trade_dates
|
||
]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
rows: List[Dict[str, Any]] = []
|
||
for res in results:
|
||
if isinstance(res, list) and res:
|
||
rows.extend(res)
|
||
logger.info(f"Tushare daily_basic returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates")
|
||
return rows
|
||
except Exception as e:
|
||
logger.error(f"Tushare get_daily_basic_points failed for {stock_code}: {e}")
|
||
return []
|
||
|
||
async def get_daily_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定交易日列表的日行情(例如 close)。
|
||
"""
|
||
try:
|
||
tasks = [
|
||
self._query(
|
||
api_name="daily",
|
||
params={"ts_code": stock_code, "trade_date": d},
|
||
) for d in trade_dates
|
||
]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
rows: List[Dict[str, Any]] = []
|
||
for res in results:
|
||
if isinstance(res, list) and res:
|
||
rows.extend(res)
|
||
logger.info(f"Tushare daily returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates")
|
||
return rows
|
||
except Exception as e:
|
||
logger.error(f"Tushare get_daily_points failed for {stock_code}: {e}")
|
||
return []
|
||
|
||
def _calculate_derived_metrics(self, series: Dict[str, List[Dict]], years: List[str]) -> Dict[str, List[Dict]]:
|
||
"""
|
||
在 Tushare provider 内部计算派生指标。
|
||
"""
|
||
# --- Helper Functions ---
|
||
def _get_value(key: str, year: str) -> Optional[float]:
|
||
if key not in series:
|
||
return None
|
||
point = next((p for p in series[key] if p.get("year") == year), None)
|
||
if point is None or point.get("value") is None:
|
||
return None
|
||
try:
|
||
return float(point["value"])
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
def _get_avg_value(key: str, year: str) -> Optional[float]:
|
||
current_val = _get_value(key, year)
|
||
try:
|
||
prev_year = str(int(year) - 1)
|
||
prev_val = _get_value(key, prev_year)
|
||
except (ValueError, TypeError):
|
||
prev_val = None
|
||
if current_val is None: return None
|
||
if prev_val is None: return current_val
|
||
return (current_val + prev_val) / 2
|
||
|
||
def _get_cogs(year: str) -> Optional[float]:
|
||
revenue = _get_value('revenue', year)
|
||
gp_margin_raw = _get_value('grossprofit_margin', year)
|
||
if revenue is None or gp_margin_raw is None: return None
|
||
gp_margin = gp_margin_raw / 100.0 if abs(gp_margin_raw) > 1 else gp_margin_raw
|
||
return revenue * (1 - gp_margin)
|
||
|
||
def add_series(key: str, data: List[Dict]):
|
||
if data:
|
||
series[key] = data
|
||
|
||
# --- Calculations ---
|
||
fcf_data = []
|
||
for year in years:
|
||
op_cashflow = _get_value('n_cashflow_act', year)
|
||
capex = _get_value('c_pay_acq_const_fiolta', year)
|
||
if op_cashflow is not None and capex is not None:
|
||
fcf_data.append({"year": year, "value": op_cashflow - capex})
|
||
add_series('__free_cash_flow', fcf_data)
|
||
|
||
fee_calcs = [
|
||
('__sell_rate', 'sell_exp', 'revenue'),
|
||
('__admin_rate', 'admin_exp', 'revenue'),
|
||
('__rd_rate', 'rd_exp', 'revenue'),
|
||
('__depr_ratio', 'depr_fa_coga_dpba', 'revenue'),
|
||
]
|
||
for key, num_key, den_key in fee_calcs:
|
||
data = []
|
||
for year in years:
|
||
numerator = _get_value(num_key, year)
|
||
denominator = _get_value(den_key, year)
|
||
if numerator is not None and denominator is not None and denominator != 0:
|
||
data.append({"year": year, "value": (numerator / denominator) * 100})
|
||
add_series(key, data)
|
||
|
||
tax_rate_data = []
|
||
for year in years:
|
||
tax_to_ebt = _get_value('tax_to_ebt', year)
|
||
if tax_to_ebt is not None:
|
||
rate = tax_to_ebt * 100 if abs(tax_to_ebt) <= 1 else tax_to_ebt
|
||
tax_rate_data.append({"year": year, "value": rate})
|
||
add_series('__tax_rate', tax_rate_data)
|
||
|
||
other_fee_data = []
|
||
for year in years:
|
||
gp_raw = _get_value('grossprofit_margin', year)
|
||
np_raw = _get_value('netprofit_margin', year)
|
||
rev = _get_value('revenue', year)
|
||
sell_exp = _get_value('sell_exp', year)
|
||
admin_exp = _get_value('admin_exp', year)
|
||
rd_exp = _get_value('rd_exp', year)
|
||
if all(v is not None for v in [gp_raw, np_raw, rev, sell_exp, admin_exp, rd_exp]) and rev != 0:
|
||
gp = gp_raw / 100 if abs(gp_raw) > 1 else gp_raw
|
||
np = np_raw / 100 if abs(np_raw) > 1 else np_raw
|
||
sell_rate = sell_exp / rev
|
||
admin_rate = admin_exp / rev
|
||
rd_rate = rd_exp / rev
|
||
other_rate = (gp - np - sell_rate - admin_rate - rd_rate) * 100
|
||
other_fee_data.append({"year": year, "value": other_rate})
|
||
add_series('__other_fee_rate', other_fee_data)
|
||
|
||
asset_ratio_keys = [
|
||
('__money_cap_ratio', 'money_cap'), ('__inventories_ratio', 'inventories'),
|
||
('__ar_ratio', 'accounts_receiv_bill'), ('__prepay_ratio', 'prepayment'),
|
||
('__fix_assets_ratio', 'fix_assets'), ('__lt_invest_ratio', 'lt_eqt_invest'),
|
||
('__goodwill_ratio', 'goodwill'), ('__ap_ratio', 'accounts_pay'),
|
||
('__st_borr_ratio', 'st_borr'), ('__lt_borr_ratio', 'lt_borr'),
|
||
]
|
||
for key, num_key in asset_ratio_keys:
|
||
data = []
|
||
for year in years:
|
||
numerator = _get_value(num_key, year)
|
||
denominator = _get_value('total_assets', year)
|
||
if numerator is not None and denominator is not None and denominator != 0:
|
||
data.append({"year": year, "value": (numerator / denominator) * 100})
|
||
add_series(key, data)
|
||
|
||
adv_data = []
|
||
for year in years:
|
||
adv = _get_value('adv_receipts', year) or 0
|
||
contract = _get_value('contract_liab', year) or 0
|
||
total_assets = _get_value('total_assets', year)
|
||
if total_assets is not None and total_assets != 0:
|
||
adv_data.append({"year": year, "value": ((adv + contract) / total_assets) * 100})
|
||
add_series('__adv_ratio', adv_data)
|
||
|
||
other_assets_data = []
|
||
known_assets_keys = ['money_cap', 'inventories', 'accounts_receiv_bill', 'prepayment', 'fix_assets', 'lt_eqt_invest', 'goodwill']
|
||
for year in years:
|
||
total_assets = _get_value('total_assets', year)
|
||
if total_assets is not None and total_assets != 0:
|
||
sum_known = sum(_get_value(k, year) or 0 for k in known_assets_keys)
|
||
other_assets_data.append({"year": year, "value": ((total_assets - sum_known) / total_assets) * 100})
|
||
add_series('__other_assets_ratio', other_assets_data)
|
||
|
||
op_assets_data = []
|
||
for year in years:
|
||
total_assets = _get_value('total_assets', year)
|
||
if total_assets is not None and total_assets != 0:
|
||
inv = _get_value('inventories', year) or 0
|
||
ar = _get_value('accounts_receiv_bill', year) or 0
|
||
pre = _get_value('prepayment', year) or 0
|
||
ap = _get_value('accounts_pay', year) or 0
|
||
adv = _get_value('adv_receipts', year) or 0
|
||
contract_liab = _get_value('contract_liab', year) or 0
|
||
operating_assets = inv + ar + pre - ap - adv - contract_liab
|
||
op_assets_data.append({"year": year, "value": (operating_assets / total_assets) * 100})
|
||
add_series('__operating_assets_ratio', op_assets_data)
|
||
|
||
debt_ratio_data = []
|
||
for year in years:
|
||
total_assets = _get_value('total_assets', year)
|
||
if total_assets is not None and total_assets != 0:
|
||
st_borr = _get_value('st_borr', year) or 0
|
||
lt_borr = _get_value('lt_borr', year) or 0
|
||
debt_ratio_data.append({"year": year, "value": ((st_borr + lt_borr) / total_assets) * 100})
|
||
add_series('__interest_bearing_debt_ratio', debt_ratio_data)
|
||
|
||
payturn_data = []
|
||
for year in years:
|
||
avg_ap = _get_avg_value('accounts_pay', year)
|
||
cogs = _get_cogs(year)
|
||
if avg_ap is not None and cogs is not None and cogs != 0:
|
||
payturn_data.append({"year": year, "value": (365 * avg_ap) / cogs})
|
||
add_series('payturn_days', payturn_data)
|
||
|
||
per_capita_calcs = [
|
||
('__rev_per_emp', 'revenue', 10000),
|
||
('__profit_per_emp', 'n_income', 10000),
|
||
('__salary_per_emp', 'c_paid_to_for_empl', 10000),
|
||
]
|
||
for key, num_key, divisor in per_capita_calcs:
|
||
data = []
|
||
for year in years:
|
||
numerator = _get_value(num_key, year)
|
||
employees = _get_value('employees', year)
|
||
if numerator is not None and employees is not None and employees != 0:
|
||
data.append({"year": year, "value": (numerator / employees) / divisor})
|
||
add_series(key, data)
|
||
|
||
return series
|
||
|
||
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]:
|
||
all_statements: List[Dict[str, Any]] = []
|
||
for date in report_dates:
|
||
logger.info(f"Fetching financial statements for {stock_code}, report date: {date}")
|
||
try:
|
||
bs_rows, ic_rows, cf_rows, fi_rows = await asyncio.gather(
|
||
self._query(
|
||
api_name="balancesheet",
|
||
params={"ts_code": stock_code, "period": date, "report_type": 1},
|
||
),
|
||
self._query(
|
||
api_name="income",
|
||
params={"ts_code": stock_code, "period": date, "report_type": 1},
|
||
),
|
||
self._query(
|
||
api_name="cashflow",
|
||
params={"ts_code": stock_code, "period": date, "report_type": 1},
|
||
),
|
||
# 补充关键财务比率(ROE/ROA/毛利率等)
|
||
self._query(
|
||
api_name="fina_indicator",
|
||
params={"ts_code": stock_code, "period": date},
|
||
),
|
||
)
|
||
|
||
if not bs_rows and not ic_rows and not cf_rows and not fi_rows:
|
||
logger.warning(f"No financial statements components found from Tushare for {stock_code} on {date}")
|
||
continue
|
||
|
||
merged: Dict[str, Any] = {"ts_code": stock_code, "end_date": date}
|
||
bs_data = bs_rows[0] if bs_rows else {}
|
||
ic_data = ic_rows[0] if ic_rows else {}
|
||
cf_data = cf_rows[0] if cf_rows else {}
|
||
fi_data = fi_rows[0] if fi_rows else {}
|
||
|
||
merged.update(bs_data)
|
||
merged.update(ic_data)
|
||
merged.update(cf_data)
|
||
merged.update(fi_data)
|
||
|
||
merged["end_date"] = merged.get("end_date") or merged.get("period") or date
|
||
logger.debug(f"Merged statement for {date} has keys: {list(merged.keys())}")
|
||
|
||
all_statements.append(merged)
|
||
except Exception as e:
|
||
logger.error(f"Tushare get_financial_statement failed for {stock_code} on {date}: {e}")
|
||
continue
|
||
|
||
logger.info(f"Successfully fetched {len(all_statements)} statement(s) for {stock_code}.")
|
||
|
||
# Transform to series format
|
||
series: Dict[str, List[Dict]] = {}
|
||
if all_statements:
|
||
for report in all_statements:
|
||
year = report.get("end_date", "")[:4]
|
||
if not year: continue
|
||
for key, value in report.items():
|
||
if key in ['ts_code', 'end_date', 'ann_date', 'f_ann_date', 'report_type', 'comp_type', 'end_type', 'update_flag', 'period']:
|
||
continue
|
||
if isinstance(value, (int, float)) and value is not None:
|
||
if key not in series:
|
||
series[key] = []
|
||
if not any(d['year'] == year for d in series[key]):
|
||
series[key].append({"year": year, "value": value})
|
||
|
||
# Calculate derived metrics
|
||
years = sorted(list(set(d['year'] for s in series.values() for d in s)))
|
||
series = self._calculate_derived_metrics(series, years)
|
||
|
||
return series
|