Fundamental_Analysis/backend/app/data_providers/tushare.py

373 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from .base import BaseDataProvider
from typing import Any, Dict, List, Optional
import httpx
import logging
import asyncio
logger = logging.getLogger(__name__)
TUSHARE_PRO_URL = "https://api.tushare.pro"
class TushareProvider(BaseDataProvider):
def _initialize(self):
if not self.token:
raise ValueError("Tushare API token not provided.")
# Use httpx.AsyncClient directly
self._client = httpx.AsyncClient(timeout=30)
async def _query(
self,
api_name: str,
params: Optional[Dict[str, Any]] = None,
fields: Optional[str] = None,
) -> List[Dict[str, Any]]:
payload = {
"api_name": api_name,
"token": self.token,
"params": params or {},
}
if "limit" not in payload["params"]:
payload["params"]["limit"] = 5000
if fields:
payload["fields"] = fields
logger.info(f"Querying Tushare API '{api_name}' with params: {params}")
try:
resp = await self._client.post(TUSHARE_PRO_URL, json=payload)
resp.raise_for_status()
data = resp.json()
if data.get("code") != 0:
err_msg = data.get("msg") or "Unknown Tushare error"
logger.error(f"Tushare API error for '{api_name}': {err_msg}")
raise RuntimeError(f"{api_name}: {err_msg}")
fields_def = data.get("data", {}).get("fields", [])
items = data.get("data", {}).get("items", [])
rows: List[Dict[str, Any]] = []
for it in items:
row = {fields_def[i]: it[i] for i in range(len(fields_def))}
rows.append(row)
logger.info(f"Tushare API '{api_name}' returned {len(rows)} rows.")
return rows
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling Tushare API '{api_name}': {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Exception calling Tushare API '{api_name}': {e}")
raise
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
try:
rows = await self._query(
api_name="stock_basic",
params={"ts_code": stock_code},
)
return rows[0] if rows else None
except Exception as e:
logger.error(f"Tushare get_stock_basic failed for {stock_code}: {e}")
return None
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
try:
return await self._query(
api_name="daily",
params={
"ts_code": stock_code,
"start_date": start_date,
"end_date": end_date,
},
)
except Exception as e:
logger.error(f"Tushare get_daily_price failed for {stock_code}: {e}")
return []
async def get_daily_basic_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]:
"""
获取指定交易日列表的 daily_basic 数据(例如 total_mv、pe、pb
"""
try:
tasks = [
self._query(
api_name="daily_basic",
params={"ts_code": stock_code, "trade_date": d},
) for d in trade_dates
]
results = await asyncio.gather(*tasks, return_exceptions=True)
rows: List[Dict[str, Any]] = []
for res in results:
if isinstance(res, list) and res:
rows.extend(res)
logger.info(f"Tushare daily_basic returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates")
return rows
except Exception as e:
logger.error(f"Tushare get_daily_basic_points failed for {stock_code}: {e}")
return []
async def get_daily_points(self, stock_code: str, trade_dates: List[str]) -> List[Dict[str, Any]]:
"""
获取指定交易日列表的日行情(例如 close
"""
try:
tasks = [
self._query(
api_name="daily",
params={"ts_code": stock_code, "trade_date": d},
) for d in trade_dates
]
results = await asyncio.gather(*tasks, return_exceptions=True)
rows: List[Dict[str, Any]] = []
for res in results:
if isinstance(res, list) and res:
rows.extend(res)
logger.info(f"Tushare daily returned {len(rows)} rows for {stock_code} on {len(trade_dates)} dates")
return rows
except Exception as e:
logger.error(f"Tushare get_daily_points failed for {stock_code}: {e}")
return []
def _calculate_derived_metrics(self, series: Dict[str, List[Dict]], years: List[str]) -> Dict[str, List[Dict]]:
"""
在 Tushare provider 内部计算派生指标。
"""
# --- Helper Functions ---
def _get_value(key: str, year: str) -> Optional[float]:
if key not in series:
return None
point = next((p for p in series[key] if p.get("year") == year), None)
if point is None or point.get("value") is None:
return None
try:
return float(point["value"])
except (ValueError, TypeError):
return None
def _get_avg_value(key: str, year: str) -> Optional[float]:
current_val = _get_value(key, year)
try:
prev_year = str(int(year) - 1)
prev_val = _get_value(key, prev_year)
except (ValueError, TypeError):
prev_val = None
if current_val is None: return None
if prev_val is None: return current_val
return (current_val + prev_val) / 2
def _get_cogs(year: str) -> Optional[float]:
revenue = _get_value('revenue', year)
gp_margin_raw = _get_value('grossprofit_margin', year)
if revenue is None or gp_margin_raw is None: return None
gp_margin = gp_margin_raw / 100.0 if abs(gp_margin_raw) > 1 else gp_margin_raw
return revenue * (1 - gp_margin)
def add_series(key: str, data: List[Dict]):
if data:
series[key] = data
# --- Calculations ---
fcf_data = []
for year in years:
op_cashflow = _get_value('n_cashflow_act', year)
capex = _get_value('c_pay_acq_const_fiolta', year)
if op_cashflow is not None and capex is not None:
fcf_data.append({"year": year, "value": op_cashflow - capex})
add_series('__free_cash_flow', fcf_data)
fee_calcs = [
('__sell_rate', 'sell_exp', 'revenue'),
('__admin_rate', 'admin_exp', 'revenue'),
('__rd_rate', 'rd_exp', 'revenue'),
('__depr_ratio', 'depr_fa_coga_dpba', 'revenue'),
]
for key, num_key, den_key in fee_calcs:
data = []
for year in years:
numerator = _get_value(num_key, year)
denominator = _get_value(den_key, year)
if numerator is not None and denominator is not None and denominator != 0:
data.append({"year": year, "value": (numerator / denominator) * 100})
add_series(key, data)
tax_rate_data = []
for year in years:
tax_to_ebt = _get_value('tax_to_ebt', year)
if tax_to_ebt is not None:
rate = tax_to_ebt * 100 if abs(tax_to_ebt) <= 1 else tax_to_ebt
tax_rate_data.append({"year": year, "value": rate})
add_series('__tax_rate', tax_rate_data)
other_fee_data = []
for year in years:
gp_raw = _get_value('grossprofit_margin', year)
np_raw = _get_value('netprofit_margin', year)
rev = _get_value('revenue', year)
sell_exp = _get_value('sell_exp', year)
admin_exp = _get_value('admin_exp', year)
rd_exp = _get_value('rd_exp', year)
if all(v is not None for v in [gp_raw, np_raw, rev, sell_exp, admin_exp, rd_exp]) and rev != 0:
gp = gp_raw / 100 if abs(gp_raw) > 1 else gp_raw
np = np_raw / 100 if abs(np_raw) > 1 else np_raw
sell_rate = sell_exp / rev
admin_rate = admin_exp / rev
rd_rate = rd_exp / rev
other_rate = (gp - np - sell_rate - admin_rate - rd_rate) * 100
other_fee_data.append({"year": year, "value": other_rate})
add_series('__other_fee_rate', other_fee_data)
asset_ratio_keys = [
('__money_cap_ratio', 'money_cap'), ('__inventories_ratio', 'inventories'),
('__ar_ratio', 'accounts_receiv_bill'), ('__prepay_ratio', 'prepayment'),
('__fix_assets_ratio', 'fix_assets'), ('__lt_invest_ratio', 'lt_eqt_invest'),
('__goodwill_ratio', 'goodwill'), ('__ap_ratio', 'accounts_pay'),
('__st_borr_ratio', 'st_borr'), ('__lt_borr_ratio', 'lt_borr'),
]
for key, num_key in asset_ratio_keys:
data = []
for year in years:
numerator = _get_value(num_key, year)
denominator = _get_value('total_assets', year)
if numerator is not None and denominator is not None and denominator != 0:
data.append({"year": year, "value": (numerator / denominator) * 100})
add_series(key, data)
adv_data = []
for year in years:
adv = _get_value('adv_receipts', year) or 0
contract = _get_value('contract_liab', year) or 0
total_assets = _get_value('total_assets', year)
if total_assets is not None and total_assets != 0:
adv_data.append({"year": year, "value": ((adv + contract) / total_assets) * 100})
add_series('__adv_ratio', adv_data)
other_assets_data = []
known_assets_keys = ['money_cap', 'inventories', 'accounts_receiv_bill', 'prepayment', 'fix_assets', 'lt_eqt_invest', 'goodwill']
for year in years:
total_assets = _get_value('total_assets', year)
if total_assets is not None and total_assets != 0:
sum_known = sum(_get_value(k, year) or 0 for k in known_assets_keys)
other_assets_data.append({"year": year, "value": ((total_assets - sum_known) / total_assets) * 100})
add_series('__other_assets_ratio', other_assets_data)
op_assets_data = []
for year in years:
total_assets = _get_value('total_assets', year)
if total_assets is not None and total_assets != 0:
inv = _get_value('inventories', year) or 0
ar = _get_value('accounts_receiv_bill', year) or 0
pre = _get_value('prepayment', year) or 0
ap = _get_value('accounts_pay', year) or 0
adv = _get_value('adv_receipts', year) or 0
contract_liab = _get_value('contract_liab', year) or 0
operating_assets = inv + ar + pre - ap - adv - contract_liab
op_assets_data.append({"year": year, "value": (operating_assets / total_assets) * 100})
add_series('__operating_assets_ratio', op_assets_data)
debt_ratio_data = []
for year in years:
total_assets = _get_value('total_assets', year)
if total_assets is not None and total_assets != 0:
st_borr = _get_value('st_borr', year) or 0
lt_borr = _get_value('lt_borr', year) or 0
debt_ratio_data.append({"year": year, "value": ((st_borr + lt_borr) / total_assets) * 100})
add_series('__interest_bearing_debt_ratio', debt_ratio_data)
payturn_data = []
for year in years:
avg_ap = _get_avg_value('accounts_pay', year)
cogs = _get_cogs(year)
if avg_ap is not None and cogs is not None and cogs != 0:
payturn_data.append({"year": year, "value": (365 * avg_ap) / cogs})
add_series('payturn_days', payturn_data)
per_capita_calcs = [
('__rev_per_emp', 'revenue', 10000),
('__profit_per_emp', 'n_income', 10000),
('__salary_per_emp', 'c_paid_to_for_empl', 10000),
]
for key, num_key, divisor in per_capita_calcs:
data = []
for year in years:
numerator = _get_value(num_key, year)
employees = _get_value('employees', year)
if numerator is not None and employees is not None and employees != 0:
data.append({"year": year, "value": (numerator / employees) / divisor})
add_series(key, data)
return series
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> Dict[str, List[Dict[str, Any]]]:
all_statements: List[Dict[str, Any]] = []
for date in report_dates:
logger.info(f"Fetching financial statements for {stock_code}, report date: {date}")
try:
bs_rows, ic_rows, cf_rows, fi_rows = await asyncio.gather(
self._query(
api_name="balancesheet",
params={"ts_code": stock_code, "period": date, "report_type": 1},
),
self._query(
api_name="income",
params={"ts_code": stock_code, "period": date, "report_type": 1},
),
self._query(
api_name="cashflow",
params={"ts_code": stock_code, "period": date, "report_type": 1},
),
# 补充关键财务比率ROE/ROA/毛利率等)
self._query(
api_name="fina_indicator",
params={"ts_code": stock_code, "period": date},
),
)
if not bs_rows and not ic_rows and not cf_rows and not fi_rows:
logger.warning(f"No financial statements components found from Tushare for {stock_code} on {date}")
continue
merged: Dict[str, Any] = {"ts_code": stock_code, "end_date": date}
bs_data = bs_rows[0] if bs_rows else {}
ic_data = ic_rows[0] if ic_rows else {}
cf_data = cf_rows[0] if cf_rows else {}
fi_data = fi_rows[0] if fi_rows else {}
merged.update(bs_data)
merged.update(ic_data)
merged.update(cf_data)
merged.update(fi_data)
merged["end_date"] = merged.get("end_date") or merged.get("period") or date
logger.debug(f"Merged statement for {date} has keys: {list(merged.keys())}")
all_statements.append(merged)
except Exception as e:
logger.error(f"Tushare get_financial_statement failed for {stock_code} on {date}: {e}")
continue
logger.info(f"Successfully fetched {len(all_statements)} statement(s) for {stock_code}.")
# Transform to series format
series: Dict[str, List[Dict]] = {}
if all_statements:
for report in all_statements:
year = report.get("end_date", "")[:4]
if not year: continue
for key, value in report.items():
if key in ['ts_code', 'end_date', 'ann_date', 'f_ann_date', 'report_type', 'comp_type', 'end_type', 'update_flag', 'period']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
if not any(d['year'] == year for d in series[key]):
series[key].append({"year": year, "value": value})
# Calculate derived metrics
years = sorted(list(set(d['year'] for s in series.values() for d in s)))
series = self._calculate_derived_metrics(series, years)
return series