""" Tushare CN Client for Backend API Direct port from src/fetchers/tushare_cn_client.py Removed storage.file_io dependency, only uses PostgreSQL """ import tushare as ts import pandas as pd import os import psycopg2 from datetime import datetime from dotenv import load_dotenv from typing import Union, List, Dict, Any import numpy as np from pathlib import Path import logging import time # 获取logger (不再配置basicConfig,由main.py统一配置) logger = logging.getLogger(__name__) # Explicitly load .env from project root ROOT_DIR = Path(__file__).resolve().parent.parent.parent.parent load_dotenv(ROOT_DIR / ".env") class TushareCnClient: def __init__(self, api_key: str): ts.set_token(api_key) self.pro = ts.pro_api() self.api_key = api_key def _get_db_connection(self): """Create a database connection.""" db_host = os.getenv("DB_HOST", "192.168.3.195") db_user = os.getenv("DB_USER", "value") db_pass = os.getenv("DB_PASSWORD", "Value609!") db_name = os.getenv("DB_NAME", "fa3") db_port = os.getenv("DB_PORT", "5432") try: conn = psycopg2.connect( host=db_host, user=db_user, password=db_pass, dbname=db_name, port=db_port ) return conn except Exception as e: print(f"DB Connection Error: {e}") return None def _map_dtype_to_sql(self, dtype): """Map pandas dtype to PostgreSQL type.""" if pd.api.types.is_integer_dtype(dtype): return "BIGINT" elif pd.api.types.is_float_dtype(dtype): return "NUMERIC" elif pd.api.types.is_bool_dtype(dtype): return "BOOLEAN" elif pd.api.types.is_datetime64_any_dtype(dtype): return "TIMESTAMP" else: return "TEXT" def _save_df_to_wide_table(self, table_name: str, df: pd.DataFrame, pk_cols: List[str]): """ Save DataFrame to a specific wide table. Creates table if not exists using DF columns. Performs incremental save: checks existing records and only inserts new ones. """ if df is None or df.empty: return start_time = time.time() # 1. Clean Data df_clean = df.replace({np.nan: None}) # Convert date columns to YYYY-MM-DD format for col in df_clean.columns: if 'date' in col.lower() and df_clean[col].dtype == 'object': try: sample = df_clean[col].dropna().astype(str).iloc[0] if not df_clean[col].dropna().empty else "" if len(sample) == 8 and sample.isdigit(): df_clean[col] = df_clean[col].astype(str).apply( lambda x: f"{x[:4]}-{x[4:6]}-{x[6:]}" if x and len(str(x))==8 else x ) except: pass conn = self._get_db_connection() if not conn: return try: with conn.cursor() as cur: # 2. Check if table exists cur.execute("SELECT to_regclass(%s)", (f"public.{table_name}",)) table_exists = cur.fetchone()[0] is not None columns = list(df_clean.columns) if not table_exists: # Create table if not exists logger.info(f"🆕 [数据库] 表 {table_name} 不存在,正在创建...") col_defs = ['"id" SERIAL PRIMARY KEY'] for col in columns: sql_type = self._map_dtype_to_sql(df_clean[col].dtype) col_defs.append(f'"{col}" {sql_type}') col_defs.append('"update_date" TIMESTAMP DEFAULT CURRENT_TIMESTAMP') pk_str = ", ".join([f'"{c}"' for c in pk_cols]) constraint_name = f"uq_{table_name}" create_sql = f""" CREATE TABLE IF NOT EXISTS {table_name} ( {', '.join(col_defs)}, CONSTRAINT {constraint_name} UNIQUE ({pk_str}) ); """ cur.execute(create_sql) conn.commit() # No existing data, insert all df_to_insert = df_clean else: # 3. Incremental Logic: Filter out existing records # We assume 'ts_code' is always in pk_cols and present in df if 'ts_code' in df_clean.columns and 'ts_code' in pk_cols: ts_code_val = df_clean['ts_code'].iloc[0] # Build query to fetch existing PKs for this stock # We only select the PK columns to minimize data transfer pk_select = ", ".join([f'"{c}"' for c in pk_cols]) header_sql = f"SELECT {pk_select} FROM {table_name} WHERE \"ts_code\" = %s" cur.execute(header_sql, (ts_code_val,)) existing_rows = cur.fetchall() if existing_rows: # Create a set of existing PK tuples for fast lookup # Ensure types match (convert to string if necessary/consistent with df) existing_keys = set(existing_rows) # Filter df # We need to construct a tuple from df rows corresponding to pk_cols # Note: DB returns tuples in order of pk_colsQuery def row_to_key(row): return tuple(row[col] for col in pk_cols) # Identify rows to keep # This is a bit slow for very large DFs, but usually we have < 1000 rows is_new_list = [] for _, row in df_clean.iterrows(): key = row_to_key(row) # DB driver might return dates as datetime.date, df has strings or timestamps # Let's try simple comparison first, if issues arise we might need normalization # For now, assuming string matching for dates if they were converted above is_new = key not in existing_keys is_new_list.append(is_new) df_to_insert = df_clean[is_new_list] skipped_count = len(df_clean) - len(df_to_insert) if skipped_count > 0: logger.info(f"⏭️ [数据库] 表 {table_name}: 跳过 {skipped_count} 条已存在的数据, 准备插入 {len(df_to_insert)} 条") else: df_to_insert = df_clean else: # Fallback for tables without ts_code (rare in this context) # Identify existing PKs? Or just try insert all with ON CONFLICT DO NOTHING # User specifically asked for 'code and end date' logic df_to_insert = df_clean if df_to_insert.empty: logger.info(f"✅ [数据库] 表 {table_name}: 所有数据已存在,无需更新。") return # 4. Get existing columns from DB to ensure schema match cur.execute(""" SELECT column_name FROM information_schema.columns WHERE table_name = %s """, (table_name,)) db_cols = {row[0] for row in cur.fetchall()} # Add auto-columns (id) if missing if 'id' not in db_cols: try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "id" SERIAL') conn.commit() db_cols.add('id') except Exception as e: print(f"Error adding id column to {table_name}: {e}") conn.rollback() valid_cols = [c for c in columns if c in db_cols] # 5. Construct INSERT statement cols_str = ", ".join([f'"{c}"' for c in valid_cols]) vals_str = ", ".join(["%s"] * len(valid_cols)) # Since we filtered, we can use INSERT directly, # but ON CONFLICT DO NOTHING is safer for race conditions # The user wanted to avoid updating if exists. insert_sql = f""" INSERT INTO {table_name} ({cols_str}) VALUES ({vals_str}) ON CONFLICT ({", ".join([f'"{c}"' for c in pk_cols])}) DO NOTHING """ # 6. Execute Batch Insert logger.info(f"📥 [数据库] 正在保存 {len(df_to_insert)} 条新数据到 {table_name}...") data_tuples = [tuple(x) for x in df_to_insert[valid_cols].to_numpy()] cur.executemany(insert_sql, data_tuples) conn.commit() elapsed = time.time() - start_time logger.info(f"✅ [数据库] {table_name}: 成功插入 {len(data_tuples)} 条数据, 耗时: {elapsed:.2f}秒") except Exception as e: logger.error(f"❌ [数据库] 保存 {table_name} 失败: {e}") conn.rollback() finally: conn.close() def _get_ts_code(self, symbol: str) -> str: return symbol def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame: if df.empty or 'end_date' not in df.columns: return df df = df.sort_values(by='end_date', ascending=False) df = df.drop_duplicates(subset=['end_date'], keep='first') return df def get_income_statement(self, symbol: str) -> pd.DataFrame: logger.info(f"🐛 [DEBUG] get_income_statement called with symbol: '{symbol}'") ts_code = self._get_ts_code(symbol) logger.info(f"🔍 [Tushare API] 调用 income() 获取利润表, symbol={ts_code}") start_time = time.time() df = self.pro.income(ts_code=ts_code) elapsed = time.time() - start_time logger.info(f"✅ [Tushare API] income() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒") # Save to DB (Wide Table) if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns): self._save_df_to_wide_table('tushare_income_statement', df, ['ts_code', 'end_date', 'report_type']) # Legacy Rename for Frontend compatibility rename_map = { 'end_date': 'date', 'revenue': 'revenue', 'n_income_attr_p': 'net_income' } df_processed = self._filter_data(df) df_processed = df_processed.rename(columns=rename_map) return df_processed def get_balance_sheet(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) logger.info(f"🔍 [Tushare API] 调用 balancesheet() 获取资产负债表, symbol={ts_code}") start_time = time.time() df = self.pro.balancesheet(ts_code=ts_code) elapsed = time.time() - start_time logger.info(f"✅ [Tushare API] balancesheet() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒") # Save to DB if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns): self._save_df_to_wide_table('tushare_balance_sheet', df, ['ts_code', 'end_date', 'report_type']) rename_map = { 'end_date': 'date', 'total_hldr_eqy_exc_min_int': 'total_equity', 'total_liab': 'total_liabilities', 'total_cur_assets': 'current_assets', 'total_cur_liab': 'current_liabilities' } df_processed = self._filter_data(df) df_processed = df_processed.rename(columns=rename_map) return df_processed def get_cash_flow(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) logger.info(f"🔍 [Tushare API] 调用 cashflow() 获取现金流量表, symbol={ts_code}") start_time = time.time() df = self.pro.cashflow(ts_code=ts_code) elapsed = time.time() - start_time logger.info(f"✅ [Tushare API] cashflow() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒") # Save to DB if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns): self._save_df_to_wide_table('tushare_cash_flow', df, ['ts_code', 'end_date', 'report_type']) df_processed = self._filter_data(df) df_processed = df_processed.rename(columns={ 'end_date': 'date', 'n_cashflow_act': 'net_cash_flow', 'depr_fa_coga_dpba': 'depreciation' }) return df_processed def get_market_metrics(self, symbol: str) -> dict: ts_code = self._get_ts_code(symbol) metrics = { "price": 0.0, "market_cap": 0.0, "pe": 0.0, "pb": 0.0, "total_share_holders": 0, "employee_count": 0 } try: # 1. Daily Basic df_daily = self.pro.daily_basic(ts_code=ts_code, limit=1) if not df_daily.empty: row = df_daily.iloc[0] metrics["price"] = row.get('close', 0.0) metrics["pe"] = row.get('pe', 0.0) metrics["pb"] = row.get('pb', 0.0) metrics["market_cap"] = row.get('total_mv', 0.0) * 10000 metrics["dividend_yield"] = row.get('dv_ttm', 0.0) # Save to DB if 'trade_date' in df_daily.columns: self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date']) # 2. Stock Basic df_basic = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,symbol,name,area,industry,list_date') if not df_basic.empty: metrics['name'] = df_basic.iloc[0]['name'] metrics['list_date'] = df_basic.iloc[0]['list_date'] # Save to DB self._save_df_to_wide_table('tushare_stock_basic', df_basic, ['ts_code']) # 3. Company Info df_comp = self.pro.stock_company(ts_code=ts_code) if not df_comp.empty: metrics["employee_count"] = int(df_comp.iloc[0].get('employees', 0) or 0) # 4. Shareholder Number df_holder = self.pro.stk_holdernumber(ts_code=ts_code, limit=1) if not df_holder.empty: metrics["total_share_holders"] = int(df_holder.iloc[0].get('holder_num', 0) or 0) # Save to DB. if 'end_date' in df_holder.columns: self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date']) except Exception as e: print(f"Error fetching market metrics for {symbol}: {e}") import traceback traceback.print_exc() return metrics def get_historical_metrics(self, symbol: str, dates: list) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) results = [] if not dates: return pd.DataFrame() unique_dates = sorted(list(set([str(d).replace('-', '') for d in dates])), reverse=True) try: min_date = min(unique_dates) max_date = max(unique_dates) # Fetch and Save Daily Basic df_daily = self.pro.daily_basic(ts_code=ts_code, start_date=min_date, end_date=max_date) if not df_daily.empty: self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date']) df_daily = df_daily.sort_values('trade_date', ascending=False) # Fetch and Save Shareholder Number df_holder = self.pro.stk_holdernumber(ts_code=ts_code, start_date=min_date, end_date=max_date) if not df_holder.empty: self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date']) df_holder = df_holder.sort_values('end_date', ascending=False) # Build legacy results DataFrame for internal return for date_str in unique_dates: metrics = {'date': date_str} if not df_daily.empty: closest_daily = df_daily[df_daily['trade_date'] <= date_str] if not closest_daily.empty: row = closest_daily.iloc[0] metrics['Price'] = row.get('close') metrics['PE'] = row.get('pe') metrics['PB'] = row.get('pb') metrics['MarketCap'] = row.get('total_mv', 0) * 10000 if not df_holder.empty: closest_holder = df_holder[df_holder['end_date'] <= date_str] if not closest_holder.empty: metrics['Shareholders'] = closest_holder.iloc[0].get('holder_num') results.append(metrics) except Exception as e: print(f"Error fetching historical metrics for {symbol}: {e}") return pd.DataFrame(results) def get_dividends(self, symbol: str) -> pd.DataFrame: import time ts_code = self._get_ts_code(symbol) df_div = self.pro.dividend(ts_code=ts_code, fields='ts_code,end_date,ann_date,div_proc,stk_div,cash_div_tax,cash_div,record_date,ex_date,pay_date,div_listdate,imp_ann_date') if not df_div.empty: # Save to DB self._save_df_to_wide_table('tushare_dividend', df_div, ['ts_code', 'end_date', 'ann_date']) if df_div.empty: return pd.DataFrame() # Legacy processing for return df_div = df_div[(df_div['div_proc'] == '实施') & (df_div['cash_div'] > 0)] if df_div.empty: return pd.DataFrame() df_div['total_cash_div'] = 0.0 calced_divs = [] for index, row in df_div.iterrows(): ex_date = row['ex_date'] if not ex_date or pd.isna(ex_date): continue try: time.sleep(0.1) df_daily = self.pro.daily_basic(ts_code=ts_code, trade_date=ex_date, fields='total_share') if not df_daily.empty: total_share = df_daily.iloc[0]['total_share'] total_cash_dividend = (row['cash_div'] * total_share * 10000) calced_divs.append({ 'year': int(str(row['end_date'])[:4]), 'dividends': total_cash_dividend }) except: pass if calced_divs: df_calc = pd.DataFrame(calced_divs) dividends_by_year = df_calc.groupby('year')['dividends'].sum().reset_index() dividends_by_year['date'] = dividends_by_year['year'].astype(str) + '1231' return dividends_by_year[['date', 'dividends']] return pd.DataFrame() def get_repurchases(self, symbol: str) -> pd.DataFrame: ts_code = self._get_ts_code(symbol) df = self.pro.repurchase(ts_code=ts_code) if not df.empty: if 'ann_date' in df.columns and 'end_date' in df.columns: self._save_df_to_wide_table('tushare_repurchase', df, ['ts_code', 'ann_date', 'end_date']) if df.empty or 'ann_date' not in df.columns or 'amount' not in df.columns: return pd.DataFrame() # Legacy processing df = df[df['amount'] > 0] if df.empty: return pd.DataFrame() df['year'] = pd.to_datetime(df['ann_date']).dt.year repurchases_by_year = df.groupby('year')['amount'].sum().reset_index() repurchases_by_year['date_str'] = repurchases_by_year['year'].astype(str) + '1231' repurchases_by_year.rename(columns={'amount': 'repurchases', 'date_str': 'date'}, inplace=True) return repurchases_by_year[['date', 'repurchases']]