469 lines
20 KiB
Python
469 lines
20 KiB
Python
"""
|
|
Tushare CN Client for Backend API
|
|
Direct port from src/fetchers/tushare_cn_client.py
|
|
Removed storage.file_io dependency, only uses PostgreSQL
|
|
"""
|
|
import tushare as ts
|
|
import pandas as pd
|
|
import os
|
|
import psycopg2
|
|
from datetime import datetime
|
|
from dotenv import load_dotenv
|
|
from typing import Union, List, Dict, Any
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import logging
|
|
import time
|
|
|
|
# 获取logger (不再配置basicConfig,由main.py统一配置)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Explicitly load .env from project root
|
|
ROOT_DIR = Path(__file__).resolve().parent.parent.parent.parent
|
|
load_dotenv(ROOT_DIR / ".env")
|
|
|
|
class TushareCnClient:
|
|
def __init__(self, api_key: str):
|
|
ts.set_token(api_key)
|
|
self.pro = ts.pro_api()
|
|
self.api_key = api_key
|
|
|
|
def _get_db_connection(self):
|
|
"""Create a database connection."""
|
|
db_host = os.getenv("DB_HOST", "192.168.3.195")
|
|
db_user = os.getenv("DB_USER", "value")
|
|
db_pass = os.getenv("DB_PASSWORD", "Value609!")
|
|
db_name = os.getenv("DB_NAME", "fa3")
|
|
db_port = os.getenv("DB_PORT", "5432")
|
|
|
|
try:
|
|
conn = psycopg2.connect(
|
|
host=db_host, user=db_user, password=db_pass, dbname=db_name, port=db_port
|
|
)
|
|
return conn
|
|
except Exception as e:
|
|
print(f"DB Connection Error: {e}")
|
|
return None
|
|
|
|
def _map_dtype_to_sql(self, dtype):
|
|
"""Map pandas dtype to PostgreSQL type."""
|
|
if pd.api.types.is_integer_dtype(dtype):
|
|
return "BIGINT"
|
|
elif pd.api.types.is_float_dtype(dtype):
|
|
return "NUMERIC"
|
|
elif pd.api.types.is_bool_dtype(dtype):
|
|
return "BOOLEAN"
|
|
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
|
return "TIMESTAMP"
|
|
else:
|
|
return "TEXT"
|
|
|
|
def _save_df_to_wide_table(self, table_name: str, df: pd.DataFrame, pk_cols: List[str]):
|
|
"""
|
|
Save DataFrame to a specific wide table.
|
|
Creates table if not exists using DF columns.
|
|
Performs incremental save: checks existing records and only inserts new ones.
|
|
"""
|
|
if df is None or df.empty:
|
|
return
|
|
|
|
start_time = time.time()
|
|
# 1. Clean Data
|
|
df_clean = df.replace({np.nan: None})
|
|
|
|
# Convert date columns to YYYY-MM-DD format
|
|
for col in df_clean.columns:
|
|
if 'date' in col.lower() and df_clean[col].dtype == 'object':
|
|
try:
|
|
sample = df_clean[col].dropna().astype(str).iloc[0] if not df_clean[col].dropna().empty else ""
|
|
if len(sample) == 8 and sample.isdigit():
|
|
df_clean[col] = df_clean[col].astype(str).apply(
|
|
lambda x: f"{x[:4]}-{x[4:6]}-{x[6:]}" if x and len(str(x))==8 else x
|
|
)
|
|
except:
|
|
pass
|
|
|
|
conn = self._get_db_connection()
|
|
if not conn: return
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 2. Check if table exists
|
|
cur.execute("SELECT to_regclass(%s)", (f"public.{table_name}",))
|
|
table_exists = cur.fetchone()[0] is not None
|
|
|
|
columns = list(df_clean.columns)
|
|
|
|
if not table_exists:
|
|
# Create table if not exists
|
|
logger.info(f"🆕 [数据库] 表 {table_name} 不存在,正在创建...")
|
|
col_defs = ['"id" SERIAL PRIMARY KEY']
|
|
for col in columns:
|
|
sql_type = self._map_dtype_to_sql(df_clean[col].dtype)
|
|
col_defs.append(f'"{col}" {sql_type}')
|
|
|
|
col_defs.append('"update_date" TIMESTAMP DEFAULT CURRENT_TIMESTAMP')
|
|
|
|
pk_str = ", ".join([f'"{c}"' for c in pk_cols])
|
|
constraint_name = f"uq_{table_name}"
|
|
create_sql = f"""
|
|
CREATE TABLE IF NOT EXISTS {table_name} (
|
|
{', '.join(col_defs)},
|
|
CONSTRAINT {constraint_name} UNIQUE ({pk_str})
|
|
);
|
|
"""
|
|
cur.execute(create_sql)
|
|
conn.commit()
|
|
# No existing data, insert all
|
|
df_to_insert = df_clean
|
|
else:
|
|
# 3. Incremental Logic: Filter out existing records
|
|
# We assume 'ts_code' is always in pk_cols and present in df
|
|
if 'ts_code' in df_clean.columns and 'ts_code' in pk_cols:
|
|
ts_code_val = df_clean['ts_code'].iloc[0]
|
|
|
|
# Build query to fetch existing PKs for this stock
|
|
# We only select the PK columns to minimize data transfer
|
|
pk_select = ", ".join([f'"{c}"' for c in pk_cols])
|
|
|
|
header_sql = f"SELECT {pk_select} FROM {table_name} WHERE \"ts_code\" = %s"
|
|
cur.execute(header_sql, (ts_code_val,))
|
|
existing_rows = cur.fetchall()
|
|
|
|
if existing_rows:
|
|
# Create a set of existing PK tuples for fast lookup
|
|
# Ensure types match (convert to string if necessary/consistent with df)
|
|
existing_keys = set(existing_rows)
|
|
|
|
# Filter df
|
|
# We need to construct a tuple from df rows corresponding to pk_cols
|
|
# Note: DB returns tuples in order of pk_colsQuery
|
|
|
|
def row_to_key(row):
|
|
return tuple(row[col] for col in pk_cols)
|
|
|
|
# Identify rows to keep
|
|
# This is a bit slow for very large DFs, but usually we have < 1000 rows
|
|
is_new_list = []
|
|
for _, row in df_clean.iterrows():
|
|
key = row_to_key(row)
|
|
# DB driver might return dates as datetime.date, df has strings or timestamps
|
|
# Let's try simple comparison first, if issues arise we might need normalization
|
|
# For now, assuming string matching for dates if they were converted above
|
|
is_new = key not in existing_keys
|
|
is_new_list.append(is_new)
|
|
|
|
df_to_insert = df_clean[is_new_list]
|
|
skipped_count = len(df_clean) - len(df_to_insert)
|
|
if skipped_count > 0:
|
|
logger.info(f"⏭️ [数据库] 表 {table_name}: 跳过 {skipped_count} 条已存在的数据, 准备插入 {len(df_to_insert)} 条")
|
|
else:
|
|
df_to_insert = df_clean
|
|
else:
|
|
# Fallback for tables without ts_code (rare in this context)
|
|
# Identify existing PKs? Or just try insert all with ON CONFLICT DO NOTHING
|
|
# User specifically asked for 'code and end date' logic
|
|
df_to_insert = df_clean
|
|
|
|
if df_to_insert.empty:
|
|
logger.info(f"✅ [数据库] 表 {table_name}: 所有数据已存在,无需更新。")
|
|
return
|
|
|
|
# 4. Get existing columns from DB to ensure schema match
|
|
cur.execute("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_name = %s
|
|
""", (table_name,))
|
|
db_cols = {row[0] for row in cur.fetchall()}
|
|
|
|
# Add auto-columns (id) if missing
|
|
if 'id' not in db_cols:
|
|
try:
|
|
cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "id" SERIAL')
|
|
conn.commit()
|
|
db_cols.add('id')
|
|
except Exception as e:
|
|
print(f"Error adding id column to {table_name}: {e}")
|
|
conn.rollback()
|
|
|
|
valid_cols = [c for c in columns if c in db_cols]
|
|
|
|
# 5. Construct INSERT statement
|
|
cols_str = ", ".join([f'"{c}"' for c in valid_cols])
|
|
vals_str = ", ".join(["%s"] * len(valid_cols))
|
|
|
|
# Since we filtered, we can use INSERT directly,
|
|
# but ON CONFLICT DO NOTHING is safer for race conditions
|
|
# The user wanted to avoid updating if exists.
|
|
|
|
insert_sql = f"""
|
|
INSERT INTO {table_name} ({cols_str})
|
|
VALUES ({vals_str})
|
|
ON CONFLICT ({", ".join([f'"{c}"' for c in pk_cols])})
|
|
DO NOTHING
|
|
"""
|
|
|
|
# 6. Execute Batch Insert
|
|
logger.info(f"📥 [数据库] 正在保存 {len(df_to_insert)} 条新数据到 {table_name}...")
|
|
data_tuples = [tuple(x) for x in df_to_insert[valid_cols].to_numpy()]
|
|
cur.executemany(insert_sql, data_tuples)
|
|
conn.commit()
|
|
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"✅ [数据库] {table_name}: 成功插入 {len(data_tuples)} 条数据, 耗时: {elapsed:.2f}秒")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ [数据库] 保存 {table_name} 失败: {e}")
|
|
conn.rollback()
|
|
finally:
|
|
conn.close()
|
|
|
|
def _get_ts_code(self, symbol: str) -> str:
|
|
return symbol
|
|
|
|
def _filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
if df.empty or 'end_date' not in df.columns:
|
|
return df
|
|
df = df.sort_values(by='end_date', ascending=False)
|
|
df = df.drop_duplicates(subset=['end_date'], keep='first')
|
|
return df
|
|
|
|
def get_income_statement(self, symbol: str) -> pd.DataFrame:
|
|
logger.info(f"🐛 [DEBUG] get_income_statement called with symbol: '{symbol}'")
|
|
ts_code = self._get_ts_code(symbol)
|
|
logger.info(f"🔍 [Tushare API] 调用 income() 获取利润表, symbol={ts_code}")
|
|
start_time = time.time()
|
|
df = self.pro.income(ts_code=ts_code)
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"✅ [Tushare API] income() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒")
|
|
|
|
# Save to DB (Wide Table)
|
|
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
|
|
self._save_df_to_wide_table('tushare_income_statement', df, ['ts_code', 'end_date', 'report_type'])
|
|
|
|
# Legacy Rename for Frontend compatibility
|
|
rename_map = {
|
|
'end_date': 'date',
|
|
'revenue': 'revenue',
|
|
'n_income_attr_p': 'net_income'
|
|
}
|
|
df_processed = self._filter_data(df)
|
|
df_processed = df_processed.rename(columns=rename_map)
|
|
|
|
return df_processed
|
|
|
|
def get_balance_sheet(self, symbol: str) -> pd.DataFrame:
|
|
ts_code = self._get_ts_code(symbol)
|
|
logger.info(f"🔍 [Tushare API] 调用 balancesheet() 获取资产负债表, symbol={ts_code}")
|
|
start_time = time.time()
|
|
df = self.pro.balancesheet(ts_code=ts_code)
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"✅ [Tushare API] balancesheet() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒")
|
|
|
|
# Save to DB
|
|
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
|
|
self._save_df_to_wide_table('tushare_balance_sheet', df, ['ts_code', 'end_date', 'report_type'])
|
|
|
|
rename_map = {
|
|
'end_date': 'date',
|
|
'total_hldr_eqy_exc_min_int': 'total_equity',
|
|
'total_liab': 'total_liabilities',
|
|
'total_cur_assets': 'current_assets',
|
|
'total_cur_liab': 'current_liabilities'
|
|
}
|
|
df_processed = self._filter_data(df)
|
|
df_processed = df_processed.rename(columns=rename_map)
|
|
|
|
return df_processed
|
|
|
|
def get_cash_flow(self, symbol: str) -> pd.DataFrame:
|
|
ts_code = self._get_ts_code(symbol)
|
|
logger.info(f"🔍 [Tushare API] 调用 cashflow() 获取现金流量表, symbol={ts_code}")
|
|
start_time = time.time()
|
|
df = self.pro.cashflow(ts_code=ts_code)
|
|
elapsed = time.time() - start_time
|
|
logger.info(f"✅ [Tushare API] cashflow() 完成, 返回 {len(df)} 行, 耗时: {elapsed:.2f}秒")
|
|
|
|
# Save to DB
|
|
if not df.empty and {'ts_code', 'end_date', 'report_type'}.issubset(df.columns):
|
|
self._save_df_to_wide_table('tushare_cash_flow', df, ['ts_code', 'end_date', 'report_type'])
|
|
|
|
df_processed = self._filter_data(df)
|
|
df_processed = df_processed.rename(columns={
|
|
'end_date': 'date',
|
|
'n_cashflow_act': 'net_cash_flow',
|
|
'depr_fa_coga_dpba': 'depreciation'
|
|
})
|
|
|
|
return df_processed
|
|
|
|
def get_market_metrics(self, symbol: str) -> dict:
|
|
ts_code = self._get_ts_code(symbol)
|
|
metrics = {
|
|
"price": 0.0,
|
|
"market_cap": 0.0,
|
|
"pe": 0.0,
|
|
"pb": 0.0,
|
|
"total_share_holders": 0,
|
|
"employee_count": 0
|
|
}
|
|
|
|
try:
|
|
# 1. Daily Basic
|
|
df_daily = self.pro.daily_basic(ts_code=ts_code, limit=1)
|
|
if not df_daily.empty:
|
|
row = df_daily.iloc[0]
|
|
metrics["price"] = row.get('close', 0.0)
|
|
metrics["pe"] = row.get('pe', 0.0)
|
|
metrics["pb"] = row.get('pb', 0.0)
|
|
metrics["market_cap"] = row.get('total_mv', 0.0) * 10000
|
|
metrics["dividend_yield"] = row.get('dv_ttm', 0.0)
|
|
# Save to DB
|
|
if 'trade_date' in df_daily.columns:
|
|
self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date'])
|
|
|
|
# 2. Stock Basic
|
|
df_basic = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,symbol,name,area,industry,list_date')
|
|
if not df_basic.empty:
|
|
metrics['name'] = df_basic.iloc[0]['name']
|
|
metrics['list_date'] = df_basic.iloc[0]['list_date']
|
|
# Save to DB
|
|
self._save_df_to_wide_table('tushare_stock_basic', df_basic, ['ts_code'])
|
|
|
|
# 3. Company Info
|
|
df_comp = self.pro.stock_company(ts_code=ts_code)
|
|
if not df_comp.empty:
|
|
metrics["employee_count"] = int(df_comp.iloc[0].get('employees', 0) or 0)
|
|
|
|
# 4. Shareholder Number
|
|
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, limit=1)
|
|
if not df_holder.empty:
|
|
metrics["total_share_holders"] = int(df_holder.iloc[0].get('holder_num', 0) or 0)
|
|
# Save to DB.
|
|
if 'end_date' in df_holder.columns:
|
|
self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date'])
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching market metrics for {symbol}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
return metrics
|
|
|
|
def get_historical_metrics(self, symbol: str, dates: list) -> pd.DataFrame:
|
|
ts_code = self._get_ts_code(symbol)
|
|
results = []
|
|
|
|
if not dates:
|
|
return pd.DataFrame()
|
|
|
|
unique_dates = sorted(list(set([str(d).replace('-', '') for d in dates])), reverse=True)
|
|
|
|
try:
|
|
min_date = min(unique_dates)
|
|
max_date = max(unique_dates)
|
|
|
|
# Fetch and Save Daily Basic
|
|
df_daily = self.pro.daily_basic(ts_code=ts_code, start_date=min_date, end_date=max_date)
|
|
if not df_daily.empty:
|
|
self._save_df_to_wide_table('tushare_daily_basic', df_daily, ['ts_code', 'trade_date'])
|
|
df_daily = df_daily.sort_values('trade_date', ascending=False)
|
|
|
|
# Fetch and Save Shareholder Number
|
|
df_holder = self.pro.stk_holdernumber(ts_code=ts_code, start_date=min_date, end_date=max_date)
|
|
if not df_holder.empty:
|
|
self._save_df_to_wide_table('tushare_stk_holdernumber', df_holder, ['ts_code', 'end_date'])
|
|
df_holder = df_holder.sort_values('end_date', ascending=False)
|
|
|
|
# Build legacy results DataFrame for internal return
|
|
for date_str in unique_dates:
|
|
metrics = {'date': date_str}
|
|
|
|
if not df_daily.empty:
|
|
closest_daily = df_daily[df_daily['trade_date'] <= date_str]
|
|
if not closest_daily.empty:
|
|
row = closest_daily.iloc[0]
|
|
metrics['Price'] = row.get('close')
|
|
metrics['PE'] = row.get('pe')
|
|
metrics['PB'] = row.get('pb')
|
|
metrics['MarketCap'] = row.get('total_mv', 0) * 10000
|
|
|
|
if not df_holder.empty:
|
|
closest_holder = df_holder[df_holder['end_date'] <= date_str]
|
|
if not closest_holder.empty:
|
|
metrics['Shareholders'] = closest_holder.iloc[0].get('holder_num')
|
|
|
|
results.append(metrics)
|
|
|
|
except Exception as e:
|
|
print(f"Error fetching historical metrics for {symbol}: {e}")
|
|
|
|
return pd.DataFrame(results)
|
|
|
|
def get_dividends(self, symbol: str) -> pd.DataFrame:
|
|
import time
|
|
ts_code = self._get_ts_code(symbol)
|
|
df_div = self.pro.dividend(ts_code=ts_code, fields='ts_code,end_date,ann_date,div_proc,stk_div,cash_div_tax,cash_div,record_date,ex_date,pay_date,div_listdate,imp_ann_date')
|
|
|
|
if not df_div.empty:
|
|
# Save to DB
|
|
self._save_df_to_wide_table('tushare_dividend', df_div, ['ts_code', 'end_date', 'ann_date'])
|
|
|
|
if df_div.empty:
|
|
return pd.DataFrame()
|
|
|
|
# Legacy processing for return
|
|
df_div = df_div[(df_div['div_proc'] == '实施') & (df_div['cash_div'] > 0)]
|
|
if df_div.empty:
|
|
return pd.DataFrame()
|
|
|
|
df_div['total_cash_div'] = 0.0
|
|
|
|
calced_divs = []
|
|
for index, row in df_div.iterrows():
|
|
ex_date = row['ex_date']
|
|
if not ex_date or pd.isna(ex_date): continue
|
|
|
|
try:
|
|
time.sleep(0.1)
|
|
df_daily = self.pro.daily_basic(ts_code=ts_code, trade_date=ex_date, fields='total_share')
|
|
if not df_daily.empty:
|
|
total_share = df_daily.iloc[0]['total_share']
|
|
total_cash_dividend = (row['cash_div'] * total_share * 10000)
|
|
calced_divs.append({
|
|
'year': int(str(row['end_date'])[:4]),
|
|
'dividends': total_cash_dividend
|
|
})
|
|
except: pass
|
|
|
|
if calced_divs:
|
|
df_calc = pd.DataFrame(calced_divs)
|
|
dividends_by_year = df_calc.groupby('year')['dividends'].sum().reset_index()
|
|
dividends_by_year['date'] = dividends_by_year['year'].astype(str) + '1231'
|
|
return dividends_by_year[['date', 'dividends']]
|
|
|
|
return pd.DataFrame()
|
|
|
|
def get_repurchases(self, symbol: str) -> pd.DataFrame:
|
|
ts_code = self._get_ts_code(symbol)
|
|
df = self.pro.repurchase(ts_code=ts_code)
|
|
|
|
if not df.empty:
|
|
if 'ann_date' in df.columns and 'end_date' in df.columns:
|
|
self._save_df_to_wide_table('tushare_repurchase', df, ['ts_code', 'ann_date', 'end_date'])
|
|
|
|
if df.empty or 'ann_date' not in df.columns or 'amount' not in df.columns:
|
|
return pd.DataFrame()
|
|
|
|
# Legacy processing
|
|
df = df[df['amount'] > 0]
|
|
if df.empty: return pd.DataFrame()
|
|
|
|
df['year'] = pd.to_datetime(df['ann_date']).dt.year
|
|
repurchases_by_year = df.groupby('year')['amount'].sum().reset_index()
|
|
repurchases_by_year['date_str'] = repurchases_by_year['year'].astype(str) + '1231'
|
|
repurchases_by_year.rename(columns={'amount': 'repurchases', 'date_str': 'date'}, inplace=True)
|
|
|
|
return repurchases_by_year[['date', 'repurchases']]
|