115 lines
5.2 KiB
Python
115 lines
5.2 KiB
Python
from .base import BaseDataProvider
|
|
from typing import Any, Dict, List, Optional
|
|
import yfinance as yf
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import asyncio
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class YfinanceProvider(BaseDataProvider):
|
|
|
|
def _map_stock_code(self, stock_code: str) -> str:
|
|
# yfinance uses different tickers for CN market
|
|
if stock_code.endswith('.SH'):
|
|
return stock_code.replace('.SH', '.SS')
|
|
elif stock_code.endswith('.SZ'):
|
|
# For Shenzhen stocks, try without suffix first, then with .SZ
|
|
base_code = stock_code.replace('.SZ', '')
|
|
return base_code # Try without suffix first
|
|
return stock_code
|
|
|
|
async def get_stock_basic(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
|
def _fetch():
|
|
try:
|
|
ticker = yf.Ticker(self._map_stock_code(stock_code))
|
|
info = ticker.info
|
|
|
|
# Normalize data to match expected format
|
|
return {
|
|
"ts_code": stock_code,
|
|
"name": info.get("longName"),
|
|
"area": info.get("country"),
|
|
"industry": info.get("industry"),
|
|
"market": info.get("market"),
|
|
"exchange": info.get("exchange"),
|
|
"list_date": datetime.fromtimestamp(info.get("firstTradeDateEpoch", 0)).strftime('%Y%m%d') if info.get("firstTradeDateEpoch") else None,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"yfinance get_stock_basic failed for {stock_code}: {e}")
|
|
return None
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _fetch)
|
|
|
|
async def get_daily_price(self, stock_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
|
|
def _fetch():
|
|
try:
|
|
# yfinance date format is YYYY-MM-DD
|
|
start_fmt = datetime.strptime(start_date, '%Y%m%d').strftime('%Y-%m-%d')
|
|
end_fmt = datetime.strptime(end_date, '%Y%m%d').strftime('%Y-%m-%d')
|
|
|
|
ticker = yf.Ticker(self._map_stock_code(stock_code))
|
|
df = ticker.history(start=start_fmt, end=end_fmt)
|
|
|
|
df.reset_index(inplace=True)
|
|
# Normalize column names
|
|
df.rename(columns={
|
|
"Date": "trade_date",
|
|
"Open": "open", "High": "high", "Low": "low", "Close": "close",
|
|
"Volume": "vol"
|
|
}, inplace=True)
|
|
df['trade_date'] = df['trade_date'].dt.strftime('%Y%m%d')
|
|
return df.to_dict('records')
|
|
except Exception as e:
|
|
logger.error(f"yfinance get_daily_price failed for {stock_code}: {e}")
|
|
return []
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _fetch)
|
|
|
|
async def get_financial_statements(self, stock_code: str, report_dates: List[str]) -> List[Dict[str, Any]]:
|
|
def _fetch():
|
|
try:
|
|
ticker = yf.Ticker(self._map_stock_code(stock_code))
|
|
|
|
# yfinance provides financials quarterly or annually. We'll fetch annually and try to match the dates.
|
|
# Note: This is an approximation as yfinance does not allow fetching by specific end-of-year dates.
|
|
df_financials = ticker.financials.transpose()
|
|
df_balance = ticker.balance_sheet.transpose()
|
|
df_cashflow = ticker.cash_flow.transpose()
|
|
|
|
if df_financials.empty and df_balance.empty and df_cashflow.empty:
|
|
return []
|
|
|
|
# Combine the data
|
|
df_combined = pd.concat([df_financials, df_balance, df_cashflow], axis=1)
|
|
df_combined.index.name = 'end_date'
|
|
df_combined.reset_index(inplace=True)
|
|
df_combined['end_date_str'] = df_combined['end_date'].dt.strftime('%Y%m%d')
|
|
|
|
# Filter by requested dates (allowing for some flexibility if exact match not found)
|
|
# This simplistic filtering might need to be more robust.
|
|
# For now, we assume the yearly data maps to the year in report_dates.
|
|
years_to_fetch = {date[:4] for date in report_dates}
|
|
df_combined = df_combined[df_combined['end_date'].dt.year.astype(str).isin(years_to_fetch)]
|
|
|
|
# Data Normalization (yfinance columns are different from Tushare)
|
|
# This is a sample, a more comprehensive mapping would be required.
|
|
df_combined.rename(columns={
|
|
"Total Revenue": "revenue",
|
|
"Net Income": "net_income",
|
|
"Total Assets": "total_assets",
|
|
"Total Liab": "total_liabilities",
|
|
}, inplace=True, errors='ignore')
|
|
|
|
return df_combined.to_dict('records')
|
|
|
|
except Exception as e:
|
|
logger.error(f"yfinance get_financial_statements failed for {stock_code}: {e}")
|
|
return []
|
|
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, _fetch)
|