import tushare as ts import pandas as pd from .base import DataFetcher import time from storage.file_io import DataStorage class CnFetcher(DataFetcher): def __init__(self, api_key: str): super().__init__(api_key) ts.set_token(self.api_key) self.pro = ts.pro_api() self.storage = DataStorage() def _save_raw_data(self, df: pd.DataFrame, symbol: str, name: str): if df is None or df.empty: return market = 'CN' self.storage.save_data(df, market, symbol, f"raw_{name}") def _get_ts_code(self, symbol: str) -> str: return symbol def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame: if df.empty or 'end_date' not in df.columns: return df df = df.sort_values(by='end_date', ascending=False) df = df.drop_duplicates(subset=['end_date'], keep='first') if df.empty: return df latest_record = df.iloc[[0]] try: latest_date_str = str(latest_record['end_date'].values[0]) last_year_date_str = str(int(latest_date_str) - 10000) comparable_record = df[df['end_date'].astype(str) == last_year_date_str] except: comparable_record = pd.DataFrame() is_annual = df['end_date'].astype(str).str.endswith('1231') annual_records = df[is_annual] combined = pd.concat([latest_record, comparable_record, annual_records]) combined = combined.drop_duplicates(subset=['end_date']) combined = combined.sort_values(by='end_date', ascending=False) return combined def get_income_statement(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df = self.pro.income(ts_code=ts_code) self._save_raw_data(df, ts_code, "income_statement") rename_map = { 'end_date': 'date', 'revenue': 'revenue', 'n_income_attr_p': 'net_income' } df = self._filter_data(df) df = df.rename(columns=rename_map) return df def get_balance_sheet(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df = self.pro.balancesheet(ts_code=ts_code) self._save_raw_data(df, ts_code, "balance_sheet") rename_map = { 'end_date': 'date', 'total_hldr_eqy_exc_min_int': 'total_equity', 'total_liab': 'total_liabilities', 'total_cur_assets': 'current_assets', 'total_cur_liab': 'current_liabilities' } df = self._filter_data(df) df = df.rename(columns=rename_map) return df def get_cash_flow(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df = self.pro.cashflow(ts_code=ts_code) self._save_raw_data(df, ts_code, "cash_flow") df = self._filter_data(df) df = df.rename(columns={ 'end_date': 'date', 'n_cashflow_act': 'net_cash_flow', 'depr_fa_coga_dpba': 'depreciation' }) return df def get_market_metrics(self, symbol: str) -> dict: ts_code = self._get_ts_code(symbol) metrics = { "price": 0.0, "market_cap": 0.0, "pe": 0.0, "pb": 0.0, "total_share_holders": 0, "employee_count": 0 } try: df_daily = self.pro.daily_basic(ts_code=ts_code, limit=1) self._save_raw_data(df_daily, ts_code, "market_metrics_daily_basic") if not df_daily.empty: row = df_daily.iloc[0] metrics["price"] = row.get('close', 0.0) metrics["pe"] = row.get('pe', 0.0) metrics["pb"] = row.get('pb', 0.0) metrics["market_cap"] = row.get('total_mv', 0.0) * 10000 metrics["dividend_yield"] = row.get('dv_ttm', 0.0) df_basic = self.pro.stock_basic(ts_code=ts_code, fields='name,list_date') self._save_raw_data(df_basic, ts_code, "market_metrics_stock_basic") if not df_basic.empty: metrics['name'] = df_basic.iloc[0]['name'] metrics['list_date'] = df_basic.iloc[0]['list_date'] df_comp = self.pro.stock_company(ts_code=ts_code) if not df_comp.empty: metrics["employee_count"] = int(df_comp.iloc[0].get('employees', 0) or 0) df_holder = self.pro.stk_holdernumber(ts_code=ts_code, limit=1) self._save_raw_data(df_holder, ts_code, "market_metrics_shareholder_number") if not df_holder.empty: metrics["total_share_holders"] = int(df_holder.iloc[0].get('holder_num', 0) or 0) except Exception as e: print(f"Error fetching market metrics for {symbol}: {e}") return metrics def get_historical_metrics(self, symbol: str, dates: list) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) results = [] if not dates: return pd.DataFrame() unique_dates = sorted(list(set([str(d).replace('-', '') for d in dates])), reverse=True) try: import datetime min_date = min(unique_dates) max_date = max(unique_dates) df_daily = self.pro.daily_basic(ts_code=ts_code, start_date=min_date, end_date=max_date) self._save_raw_data(df_daily, ts_code, "historical_metrics_daily_basic") if not df_daily.empty: df_daily = df_daily.sort_values('trade_date', ascending=False) df_holder = self.pro.stk_holdernumber(ts_code=ts_code, start_date=min_date, end_date=max_date) self._save_raw_data(df_holder, ts_code, "historical_metrics_shareholder_number") if not df_holder.empty: df_holder = df_holder.sort_values('end_date', ascending=False) for date_str in unique_dates: metrics = {'date_str': date_str} if not df_daily.empty: closest_daily = df_daily[df_daily['trade_date'] <= date_str] if not closest_daily.empty: row = closest_daily.iloc[0] metrics['Price'] = row.get('close') metrics['PE'] = row.get('pe') metrics['PB'] = row.get('pb') metrics['MarketCap'] = row.get('total_mv', 0) * 10000 if not df_holder.empty: closest_holder = df_holder[df_holder['end_date'] <= date_str] if not closest_holder.empty: metrics['Shareholders'] = closest_holder.iloc[0].get('holder_num') results.append(metrics) except Exception as e: print(f"Error fetching historical metrics for {symbol}: {e}") return pd.DataFrame(results) def get_dividends(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df_div = self.pro.dividend(ts_code=ts_code, fields='end_date,ex_date,div_proc,cash_div') self._save_raw_data(df_div, ts_code, "dividends_raw") if df_div.empty: return pd.DataFrame() # Filter for implemented cash dividends df_div = df_div[(df_div['div_proc'] == '实施') & (df_div['cash_div'] > 0)] if df_div.empty: return pd.DataFrame() df_div['total_cash_div'] = 0.0 # Get total shares for each ex_date for index, row in df_div.iterrows(): ex_date = row['ex_date'] if not ex_date or pd.isna(ex_date): continue try: time.sleep(0.2) # Sleep for 200ms to avoid hitting API limits df_daily = self.pro.daily_basic(ts_code=ts_code, trade_date=ex_date, fields='total_share') if not df_daily.empty and not df_daily['total_share'].empty: total_share = df_daily.iloc[0]['total_share'] # total_share is in 万股 (10k shares) cash_div_per_share = row['cash_div'] # This is per-share # Total dividend in Yuan total_cash_dividend = (cash_div_per_share * total_share * 10000) df_div.loc[index, 'total_cash_div'] = total_cash_dividend except Exception as e: print(f"Could not fetch daily basic for {ts_code} on {ex_date}: {e}") df_div['year'] = pd.to_datetime(df_div['end_date']).dt.year dividends_by_year = df_div.groupby('year')['total_cash_div'].sum().reset_index() dividends_by_year['date_str'] = dividends_by_year['year'].astype(str) + '1231' dividends_by_year.rename(columns={'total_cash_div': 'dividends'}, inplace=True) return dividends_by_year[['date_str', 'dividends']] def get_repurchases(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df = self.pro.repurchase(ts_code=ts_code) self._save_raw_data(df, ts_code, "repurchases") if df.empty or 'ann_date' not in df.columns or 'amount' not in df.columns: return pd.DataFrame() # Filter for repurchases with a valid amount df = df[df['amount'] > 0] if df.empty: return pd.DataFrame() # Extract year and group by it df['year'] = pd.to_datetime(df['ann_date']).dt.year repurchases_by_year = df.groupby('year')['amount'].sum().reset_index() # Create date_str for merging (YYYY1231) repurchases_by_year['date_str'] = repurchases_by_year['year'].astype(str) + '1231' # Rename for merging. # Based on user feedback, it appears the unit from the API is Yuan, so no conversion is needed. repurchases_by_year.rename(columns={'amount': 'repurchases'}, inplace=True) return repurchases_by_year[['date_str', 'repurchases']]