FA3-Datafetch/src/fetchers/cn_fetcher.py

244 lines
9.9 KiB
Python

import tushare as ts
import pandas as pd
from .base import DataFetcher
import time
from storage.file_io import DataStorage
class CnFetcher(DataFetcher):
def __init__(self, api_key: str):
super().__init__(api_key)
ts.set_token(self.api_key)
self.pro = ts.pro_api()
self.storage = DataStorage()
def _save_raw_data(self, df: pd.DataFrame, symbol: str, name: str):
if df is None or df.empty:
return
market = 'CN'
self.storage.save_data(df, market, symbol, f"raw_{name}")
def _get_ts_code(self, symbol: str) -> str:
return symbol
def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
if df.empty or 'end_date' not in df.columns:
return df
df = df.sort_values(by='end_date', ascending=False)
df = df.drop_duplicates(subset=['end_date'], keep='first')
if df.empty:
return df
latest_record = df.iloc[[0]]
try:
latest_date_str = str(latest_record['end_date'].values[0])
last_year_date_str = str(int(latest_date_str) - 10000)
comparable_record = df[df['end_date'].astype(str) == last_year_date_str]
except:
comparable_record = pd.DataFrame()
is_annual = df['end_date'].astype(str).str.endswith('1231')
annual_records = df[is_annual]
combined = pd.concat([latest_record, comparable_record, annual_records])
combined = combined.drop_duplicates(subset=['end_date'])
combined = combined.sort_values(by='end_date', ascending=False)
return combined
def get_income_statement(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df = self.pro.income(ts_code=ts_code)
self._save_raw_data(df, ts_code, "income_statement")
rename_map = {
'end_date': 'date',
'revenue': 'revenue',
'n_income_attr_p': 'net_income'
}
df = self._filter_data(df)
df = df.rename(columns=rename_map)
return df
def get_balance_sheet(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df = self.pro.balancesheet(ts_code=ts_code)
self._save_raw_data(df, ts_code, "balance_sheet")
rename_map = {
'end_date': 'date',
'total_hldr_eqy_exc_min_int': 'total_equity',
'total_liab': 'total_liabilities',
'total_cur_assets': 'current_assets',
'total_cur_liab': 'current_liabilities'
}
df = self._filter_data(df)
df = df.rename(columns=rename_map)
return df
def get_cash_flow(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df = self.pro.cashflow(ts_code=ts_code)
self._save_raw_data(df, ts_code, "cash_flow")
df = self._filter_data(df)
df = df.rename(columns={
'end_date': 'date',
'n_cashflow_act': 'net_cash_flow',
'depr_fa_coga_dpba': 'depreciation'
})
return df
def get_market_metrics(self, symbol: str) -> dict:
ts_code = self._get_ts_code(symbol)
metrics = {
"price": 0.0,
"market_cap": 0.0,
"pe": 0.0,
"pb": 0.0,
"total_share_holders": 0,
"employee_count": 0
}
try:
df_daily = self.pro.daily_basic(ts_code=ts_code, limit=1)
self._save_raw_data(df_daily, ts_code, "market_metrics_daily_basic")
if not df_daily.empty:
row = df_daily.iloc[0]
metrics["price"] = row.get('close', 0.0)
metrics["pe"] = row.get('pe', 0.0)
metrics["pb"] = row.get('pb', 0.0)
metrics["market_cap"] = row.get('total_mv', 0.0) * 10000
metrics["dividend_yield"] = row.get('dv_ttm', 0.0)
df_basic = self.pro.stock_basic(ts_code=ts_code, fields='name,list_date')
self._save_raw_data(df_basic, ts_code, "market_metrics_stock_basic")
if not df_basic.empty:
metrics['name'] = df_basic.iloc[0]['name']
metrics['list_date'] = df_basic.iloc[0]['list_date']
df_comp = self.pro.stock_company(ts_code=ts_code)
if not df_comp.empty:
metrics["employee_count"] = int(df_comp.iloc[0].get('employees', 0) or 0)
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, limit=1)
self._save_raw_data(df_holder, ts_code, "market_metrics_shareholder_number")
if not df_holder.empty:
metrics["total_share_holders"] = int(df_holder.iloc[0].get('holder_num', 0) or 0)
except Exception as e:
print(f"Error fetching market metrics for {symbol}: {e}")
return metrics
def get_historical_metrics(self, symbol: str, dates: list) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
results = []
if not dates:
return pd.DataFrame()
unique_dates = sorted(list(set([str(d).replace('-', '') for d in dates])), reverse=True)
try:
import datetime
min_date = min(unique_dates)
max_date = max(unique_dates)
df_daily = self.pro.daily_basic(ts_code=ts_code, start_date=min_date, end_date=max_date)
self._save_raw_data(df_daily, ts_code, "historical_metrics_daily_basic")
if not df_daily.empty:
df_daily = df_daily.sort_values('trade_date', ascending=False)
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, start_date=min_date, end_date=max_date)
self._save_raw_data(df_holder, ts_code, "historical_metrics_shareholder_number")
if not df_holder.empty:
df_holder = df_holder.sort_values('end_date', ascending=False)
for date_str in unique_dates:
metrics = {'date_str': date_str}
if not df_daily.empty:
closest_daily = df_daily[df_daily['trade_date'] <= date_str]
if not closest_daily.empty:
row = closest_daily.iloc[0]
metrics['Price'] = row.get('close')
metrics['PE'] = row.get('pe')
metrics['PB'] = row.get('pb')
metrics['MarketCap'] = row.get('total_mv', 0) * 10000
if not df_holder.empty:
closest_holder = df_holder[df_holder['end_date'] <= date_str]
if not closest_holder.empty:
metrics['Shareholders'] = closest_holder.iloc[0].get('holder_num')
results.append(metrics)
except Exception as e:
print(f"Error fetching historical metrics for {symbol}: {e}")
return pd.DataFrame(results)
def get_dividends(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df_div = self.pro.dividend(ts_code=ts_code, fields='end_date,ex_date,div_proc,cash_div')
self._save_raw_data(df_div, ts_code, "dividends_raw")
if df_div.empty:
return pd.DataFrame()
# Filter for implemented cash dividends
df_div = df_div[(df_div['div_proc'] == '实施') & (df_div['cash_div'] > 0)]
if df_div.empty:
return pd.DataFrame()
df_div['total_cash_div'] = 0.0
# Get total shares for each ex_date
for index, row in df_div.iterrows():
ex_date = row['ex_date']
if not ex_date or pd.isna(ex_date):
continue
try:
time.sleep(0.2) # Sleep for 200ms to avoid hitting API limits
df_daily = self.pro.daily_basic(ts_code=ts_code, trade_date=ex_date, fields='total_share')
if not df_daily.empty and not df_daily['total_share'].empty:
total_share = df_daily.iloc[0]['total_share'] # total_share is in 万股 (10k shares)
cash_div_per_share = row['cash_div'] # This is per-share
# Total dividend in Yuan
total_cash_dividend = (cash_div_per_share * total_share * 10000)
df_div.loc[index, 'total_cash_div'] = total_cash_dividend
except Exception as e:
print(f"Could not fetch daily basic for {ts_code} on {ex_date}: {e}")
df_div['year'] = pd.to_datetime(df_div['end_date']).dt.year
dividends_by_year = df_div.groupby('year')['total_cash_div'].sum().reset_index()
dividends_by_year['date_str'] = dividends_by_year['year'].astype(str) + '1231'
dividends_by_year.rename(columns={'total_cash_div': 'dividends'}, inplace=True)
return dividends_by_year[['date_str', 'dividends']]
def get_repurchases(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df = self.pro.repurchase(ts_code=ts_code)
self._save_raw_data(df, ts_code, "repurchases")
if df.empty or 'ann_date' not in df.columns or 'amount' not in df.columns:
return pd.DataFrame()
# Filter for repurchases with a valid amount
df = df[df['amount'] > 0]
if df.empty:
return pd.DataFrame()
# Extract year and group by it
df['year'] = pd.to_datetime(df['ann_date']).dt.year
repurchases_by_year = df.groupby('year')['amount'].sum().reset_index()
# Create date_str for merging (YYYY1231)
repurchases_by_year['date_str'] = repurchases_by_year['year'].astype(str) + '1231'
# Rename for merging.
# Based on user feedback, it appears the unit from the API is Yuan, so no conversion is needed.
repurchases_by_year.rename(columns={'amount': 'repurchases'}, inplace=True)
return repurchases_by_year[['date_str', 'repurchases']]