FA3-Datafetch/backend/app/clients/tushare_cn_client.py
2026-01-11 21:33:47 +08:00

469 lines
20 KiB
Python

"""
Tushare CN Client for Backend API
Direct port from src/fetchers/tushare_cn_client.py
Removed storage.file_io dependency, only uses PostgreSQL
"""
import tushare as ts
import pandas as pd
import os
import psycopg2
from datetime import datetime
from dotenv import load_dotenv
from typing import Union, List, Dict, Any
import numpy as np
from pathlib import Path
import logging
import time
# 获取logger (不再配置basicConfig,由main.py统一配置)
logger = logging.getLogger(__name__)
# Explicitly load .env from project root
ROOT_DIR = Path(__file__).resolve().parent.parent.parent.parent
load_dotenv(ROOT_DIR / ".env")
class TushareCnClient:
def __init__(self, api_key: str):
ts.set_token(api_key)
self.pro = ts.pro_api()
self.api_key = api_key
def _get_db_connection(self):
"""Create a database connection."""
db_host = os.getenv("DB_HOST", "192.168.3.195")
db_user = os.getenv("DB_USER", "value")
db_pass = os.getenv("DB_PASSWORD", "Value609!")
db_name = os.getenv("DB_NAME", "fa3")
db_port = os.getenv("DB_PORT", "5432")
try:
conn = psycopg2.connect(
host=db_host, user=db_user, password=db_pass, dbname=db_name, port=db_port
)
return conn
except Exception as e:
print(f"DB Connection Error: {e}")
return None
def _map_dtype_to_sql(self, dtype):
"""Map pandas dtype to PostgreSQL type."""
if pd.api.types.is_integer_dtype(dtype):
return "BIGINT"
elif pd.api.types.is_float_dtype(dtype):
return "NUMERIC"
elif pd.api.types.is_bool_dtype(dtype):
return "BOOLEAN"
elif pd.api.types.is_datetime64_any_dtype(dtype):
return "TIMESTAMP"
else:
return "TEXT"
def _save_df_to_wide_table(self, table_name: str, df: pd.DataFrame, pk_cols: List[str]):
"""
Save DataFrame to a specific wide table.
Creates table if not exists using DF columns.
Performs incremental save: checks existing records and only inserts new ones.
"""
if df is None or df.empty:
return
start_time = time.time()
# 1. Clean Data
df_clean = df.replace({np.nan: None})
# Convert date columns to YYYY-MM-DD format
for col in df_clean.columns:
if 'date' in col.lower() and df_clean[col].dtype == 'object':
try:
sample = df_clean[col].dropna().astype(str).iloc[0] if not df_clean[col].dropna().empty else ""
if len(sample) == 8 and sample.isdigit():
df_clean[col] = df_clean[col].astype(str).apply(
lambda x: f"{x[:4]}-{x[4:6]}-{x[6:]}" if x and len(str(x))==8 else x
)
except:
pass
conn = self._get_db_connection()
if not conn: return
try:
with conn.cursor() as cur:
# 2. Check if table exists
cur.execute("SELECT to_regclass(%s)", (f"public.{table_name}",))
table_exists = cur.fetchone()[0] is not None
columns = list(df_clean.columns)
if not table_exists:
# Create table if not exists
logger.info(f"🆕 [数据库] 表 {table_name} 不存在,正在创建...")
col_defs = ['"id" SERIAL PRIMARY KEY']
for col in columns:
sql_type = self._map_dtype_to_sql(df_clean[col].dtype)
col_defs.append(f'"{col}" {sql_type}')
col_defs.append('"update_date" TIMESTAMP DEFAULT CURRENT_TIMESTAMP')
pk_str = ", ".join([f'"{c}"' for c in pk_cols])
constraint_name = f"uq_{table_name}"
create_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join(col_defs)},
CONSTRAINT {constraint_name} UNIQUE ({pk_str})
);
"""
cur.execute(create_sql)
conn.commit()
# No existing data, insert all
df_to_insert = df_clean
else:
# 3. Incremental Logic: Filter out existing records
# We assume 'ts_code' is always in pk_cols and present in df
if 'ts_code' in df_clean.columns and 'ts_code' in pk_cols:
ts_code_val = df_clean['ts_code'].iloc[0]
# Build query to fetch existing PKs for this stock
# We only select the PK columns to minimize data transfer
pk_select = ", ".join([f'"{c}"' for c in pk_cols])
header_sql = f"SELECT {pk_select} FROM {table_name} WHERE \"ts_code\" = %s"
cur.execute(header_sql, (ts_code_val,))
existing_rows = cur.fetchall()
if existing_rows:
# Create a set of existing PK tuples for fast lookup
# Ensure types match (convert to string if necessary/consistent with df)
existing_keys = set(existing_rows)
# Filter df
# We need to construct a tuple from df rows corresponding to pk_cols
# Note: DB returns tuples in order of pk_colsQuery
def row_to_key(row):
return tuple(row[col] for col in pk_cols)
# Identify rows to keep
# This is a bit slow for very large DFs, but usually we have < 1000 rows
is_new_list = []
for _, row in df_clean.iterrows():
key = row_to_key(row)
# DB driver might return dates as datetime.date, df has strings or timestamps
# Let's try simple comparison first, if issues arise we might need normalization
# For now, assuming string matching for dates if they were converted above
is_new = key not in existing_keys
is_new_list.append(is_new)
df_to_insert = df_clean[is_new_list]
skipped_count = len(df_clean) - len(df_to_insert)
if skipped_count > 0:
logger.info(f"⏭️ [数据库] 表 {table_name}: 跳过 {skipped_count} 条已存在的数据, 准备插入 {len(df_to_insert)}")
else:
df_to_insert = df_clean
else:
# Fallback for tables without ts_code (rare in this context)
# Identify existing PKs? Or just try insert all with ON CONFLICT DO NOTHING
# User specifically asked for 'code and end date' logic
df_to_insert = df_clean
if df_to_insert.empty:
logger.info(f"✅ [数据库] 表 {table_name}: 所有数据已存在,无需更新。")
return
# 4. Get existing columns from DB to ensure schema match
cur.execute("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = %s
""", (table_name,))
db_cols = {row[0] for row in cur.fetchall()}
# Add auto-columns (id) if missing
if 'id' not in db_cols:
try:
cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "id" SERIAL')
conn.commit()
db_cols.add('id')
except Exception as e:
print(f"Error adding id column to {table_name}: {e}")
conn.rollback()
valid_cols = [c for c in columns if c in db_cols]
# 5. Construct INSERT statement
cols_str = ", ".join([f'"{c}"' for c in valid_cols])
vals_str = ", ".join(["%s"] * len(valid_cols))
# Since we filtered, we can use INSERT directly,
# but ON CONFLICT DO NOTHING is safer for race conditions
# The user wanted to avoid updating if exists.
insert_sql = f"""
INSERT INTO {table_name} ({cols_str})
VALUES ({vals_str})
ON CONFLICT ({", ".join([f'"{c}"' for c in pk_cols])})
DO NOTHING
"""
# 6. Execute Batch Insert
logger.info(f"📥 [数据库] 正在保存 {len(df_to_insert)} 条新数据到 {table_name}...")
data_tuples = [tuple(x) for x in df_to_insert[valid_cols].to_numpy()]
cur.executemany(insert_sql, data_tuples)
conn.commit()
elapsed = time.time() - start_time
logger.info(f"✅ [数据库] {table_name}: 成功插入 {len(data_tuples)} 条数据, 耗时: {elapsed:.2f}")
except Exception as e:
logger.error(f"❌ [数据库] 保存 {table_name} 失败: {e}")
conn.rollback()
finally:
conn.close()
def _get_ts_code(self, symbol: str) -> str:
return symbol
def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
if df.empty or 'end_date' not in df.columns:
return df
df = df.sort_values(by='end_date', ascending=False)
df = df.drop_duplicates(subset=['end_date'], keep='first')
return df
def get_income_statement(self, symbol: str) -> pd.DataFrame:
logger.info(f"🐛 [DEBUG] get_income_statement called with symbol: '{symbol}'")
ts_code = self._get_ts_code(symbol)
logger.info(f"🔍 [Tushare API] 调用 income() 获取利润表, symbol={ts_code}")
start_time = time.time()
df = self.pro.income(ts_code=ts_code)
elapsed = time.time() - start_time
logger.info(f"✅ [Tushare API] income() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}")
# Save to DB (Wide Table)
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
self._save_df_to_wide_table('tushare_income_statement', df, ['ts_code', 'end_date', 'report_type'])
# Legacy Rename for Frontend compatibility
rename_map = {
'end_date': 'date',
'revenue': 'revenue',
'n_income_attr_p': 'net_income'
}
df_processed = self._filter_data(df)
df_processed = df_processed.rename(columns=rename_map)
return df_processed
def get_balance_sheet(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
logger.info(f"🔍 [Tushare API] 调用 balancesheet() 获取资产负债表, symbol={ts_code}")
start_time = time.time()
df = self.pro.balancesheet(ts_code=ts_code)
elapsed = time.time() - start_time
logger.info(f"✅ [Tushare API] balancesheet() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}")
# Save to DB
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
self._save_df_to_wide_table('tushare_balance_sheet', df, ['ts_code', 'end_date', 'report_type'])
rename_map = {
'end_date': 'date',
'total_hldr_eqy_exc_min_int': 'total_equity',
'total_liab': 'total_liabilities',
'total_cur_assets': 'current_assets',
'total_cur_liab': 'current_liabilities'
}
df_processed = self._filter_data(df)
df_processed = df_processed.rename(columns=rename_map)
return df_processed
def get_cash_flow(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
logger.info(f"🔍 [Tushare API] 调用 cashflow() 获取现金流量表, symbol={ts_code}")
start_time = time.time()
df = self.pro.cashflow(ts_code=ts_code)
elapsed = time.time() - start_time
logger.info(f"✅ [Tushare API] cashflow() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}")
# Save to DB
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
self._save_df_to_wide_table('tushare_cash_flow', df, ['ts_code', 'end_date', 'report_type'])
df_processed = self._filter_data(df)
df_processed = df_processed.rename(columns={
'end_date': 'date',
'n_cashflow_act': 'net_cash_flow',
'depr_fa_coga_dpba': 'depreciation'
})
return df_processed
def get_market_metrics(self, symbol: str) -> dict:
ts_code = self._get_ts_code(symbol)
metrics = {
"price": 0.0,
"market_cap": 0.0,
"pe": 0.0,
"pb": 0.0,
"total_share_holders": 0,
"employee_count": 0
}
try:
# 1. Daily Basic
df_daily = self.pro.daily_basic(ts_code=ts_code, limit=1)
if not df_daily.empty:
row = df_daily.iloc[0]
metrics["price"] = row.get('close', 0.0)
metrics["pe"] = row.get('pe', 0.0)
metrics["pb"] = row.get('pb', 0.0)
metrics["market_cap"] = row.get('total_mv', 0.0) * 10000
metrics["dividend_yield"] = row.get('dv_ttm', 0.0)
# Save to DB
if 'trade_date' in df_daily.columns:
self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date'])
# 2. Stock Basic
df_basic = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,symbol,name,area,industry,list_date')
if not df_basic.empty:
metrics['name'] = df_basic.iloc[0]['name']
metrics['list_date'] = df_basic.iloc[0]['list_date']
# Save to DB
self._save_df_to_wide_table('tushare_stock_basic', df_basic, ['ts_code'])
# 3. Company Info
df_comp = self.pro.stock_company(ts_code=ts_code)
if not df_comp.empty:
metrics["employee_count"] = int(df_comp.iloc[0].get('employees', 0) or 0)
# 4. Shareholder Number
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, limit=1)
if not df_holder.empty:
metrics["total_share_holders"] = int(df_holder.iloc[0].get('holder_num', 0) or 0)
# Save to DB.
if 'end_date' in df_holder.columns:
self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date'])
except Exception as e:
print(f"Error fetching market metrics for {symbol}: {e}")
import traceback
traceback.print_exc()
return metrics
def get_historical_metrics(self, symbol: str, dates: list) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
results = []
if not dates:
return pd.DataFrame()
unique_dates = sorted(list(set([str(d).replace('-', '') for d in dates])), reverse=True)
try:
min_date = min(unique_dates)
max_date = max(unique_dates)
# Fetch and Save Daily Basic
df_daily = self.pro.daily_basic(ts_code=ts_code, start_date=min_date, end_date=max_date)
if not df_daily.empty:
self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date'])
df_daily = df_daily.sort_values('trade_date', ascending=False)
# Fetch and Save Shareholder Number
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, start_date=min_date, end_date=max_date)
if not df_holder.empty:
self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date'])
df_holder = df_holder.sort_values('end_date', ascending=False)
# Build legacy results DataFrame for internal return
for date_str in unique_dates:
metrics = {'date': date_str}
if not df_daily.empty:
closest_daily = df_daily[df_daily['trade_date'] <= date_str]
if not closest_daily.empty:
row = closest_daily.iloc[0]
metrics['Price'] = row.get('close')
metrics['PE'] = row.get('pe')
metrics['PB'] = row.get('pb')
metrics['MarketCap'] = row.get('total_mv', 0) * 10000
if not df_holder.empty:
closest_holder = df_holder[df_holder['end_date'] <= date_str]
if not closest_holder.empty:
metrics['Shareholders'] = closest_holder.iloc[0].get('holder_num')
results.append(metrics)
except Exception as e:
print(f"Error fetching historical metrics for {symbol}: {e}")
return pd.DataFrame(results)
def get_dividends(self, symbol: str) -> pd.DataFrame:
import time
ts_code = self._get_ts_code(symbol)
df_div = self.pro.dividend(ts_code=ts_code, fields='ts_code,end_date,ann_date,div_proc,stk_div,cash_div_tax,cash_div,record_date,ex_date,pay_date,div_listdate,imp_ann_date')
if not df_div.empty:
# Save to DB
self._save_df_to_wide_table('tushare_dividend', df_div, ['ts_code', 'end_date', 'ann_date'])
if df_div.empty:
return pd.DataFrame()
# Legacy processing for return
df_div = df_div[(df_div['div_proc'] == '实施') & (df_div['cash_div'] > 0)]
if df_div.empty:
return pd.DataFrame()
df_div['total_cash_div'] = 0.0
calced_divs = []
for index, row in df_div.iterrows():
ex_date = row['ex_date']
if not ex_date or pd.isna(ex_date): continue
try:
time.sleep(0.1)
df_daily = self.pro.daily_basic(ts_code=ts_code, trade_date=ex_date, fields='total_share')
if not df_daily.empty:
total_share = df_daily.iloc[0]['total_share']
total_cash_dividend = (row['cash_div'] * total_share * 10000)
calced_divs.append({
'year': int(str(row['end_date'])[:4]),
'dividends': total_cash_dividend
})
except: pass
if calced_divs:
df_calc = pd.DataFrame(calced_divs)
dividends_by_year = df_calc.groupby('year')['dividends'].sum().reset_index()
dividends_by_year['date'] = dividends_by_year['year'].astype(str) + '1231'
return dividends_by_year[['date', 'dividends']]
return pd.DataFrame()
def get_repurchases(self, symbol: str) -> pd.DataFrame:
ts_code = self._get_ts_code(symbol)
df = self.pro.repurchase(ts_code=ts_code)
if not df.empty:
if 'ann_date' in df.columns and 'end_date' in df.columns:
self._save_df_to_wide_table('tushare_repurchase', df, ['ts_code', 'ann_date', 'end_date'])
if df.empty or 'ann_date' not in df.columns or 'amount' not in df.columns:
return pd.DataFrame()
# Legacy processing
df = df[df['amount'] > 0]
if df.empty: return pd.DataFrame()
df['year'] = pd.to_datetime(df['ann_date']).dt.year
repurchases_by_year = df.groupby('year')['amount'].sum().reset_index()
repurchases_by_year['date_str'] = repurchases_by_year['year'].astype(str) + '1231'
repurchases_by_year.rename(columns={'amount': 'repurchases', 'date_str': 'date'}, inplace=True)
return repurchases_by_year[['date', 'repurchases']]