592 lines
20 KiB
Python
592 lines
20 KiB
Python
"""
|
||
数据获取服务层
|
||
负责从不同数据源获取财务数据并存储到数据库
|
||
"""
|
||
import os
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
import pandas as pd
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, and_, text
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.models import Company, DataUpdate, DataSourceAvailability
|
||
from app.schemas import DataCheckResponse, FetchDataRequest
|
||
from app.fetchers.factory import FetcherFactory
|
||
from app.services.bloomberg_service import get_bloomberg_data
|
||
|
||
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def update_progress_sync(update_id: int, message: str, percentage: int):
|
||
"""同步更新数据获取进度 - 使用psycopg2直连避免event loop冲突"""
|
||
import psycopg2
|
||
import os
|
||
|
||
logger.info(f"🔄 [进度更新] ID={update_id}, Msg={message}, Progress={percentage}%")
|
||
|
||
try:
|
||
conn = psycopg2.connect(
|
||
host=os.getenv('DB_HOST', '192.168.3.195'),
|
||
user=os.getenv('DB_USER', 'value'),
|
||
password=os.getenv('DB_PASSWORD', 'Value609!'),
|
||
dbname=os.getenv('DB_NAME', 'fa3'),
|
||
port=os.getenv('DB_PORT', '5432')
|
||
)
|
||
|
||
cur = conn.cursor()
|
||
cur.execute(
|
||
"""
|
||
UPDATE data_updates
|
||
SET progress_message = %s, progress_percentage = %s
|
||
WHERE id = %s
|
||
""",
|
||
(message, percentage, update_id)
|
||
)
|
||
if cur.rowcount == 0:
|
||
logger.warning(f"⚠️ [进度更新] ID={update_id} 未找到记录!")
|
||
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
except Exception as e:
|
||
logger.error(f"❌ [进度更新失败] ID={update_id}: {e}", exc_info=True)
|
||
|
||
|
||
|
||
async def check_data_status(
|
||
market: str,
|
||
symbol: str,
|
||
data_source: str,
|
||
db: AsyncSession
|
||
) -> DataCheckResponse:
|
||
"""
|
||
检查指定公司和数据源的数据状态
|
||
|
||
Returns:
|
||
DataCheckResponse: 包含数据状态信息
|
||
"""
|
||
# 1. 查询或创建公司记录
|
||
result = await db.execute(
|
||
select(Company).where(
|
||
and_(Company.market == market, Company.symbol == symbol)
|
||
)
|
||
)
|
||
company = result.scalar_one_or_none()
|
||
|
||
if not company:
|
||
# 没有找到公司记录
|
||
return DataCheckResponse(
|
||
has_data=False,
|
||
company_id=None,
|
||
data_source=data_source,
|
||
message=f"该数据源暂无该公司数据"
|
||
)
|
||
|
||
# 2. 查询最近一次成功的数据更新
|
||
result = await db.execute(
|
||
select(DataUpdate)
|
||
.where(
|
||
and_(
|
||
DataUpdate.company_id == company.id,
|
||
DataUpdate.data_source == data_source,
|
||
DataUpdate.status == 'completed'
|
||
)
|
||
)
|
||
.order_by(DataUpdate.completed_at.desc())
|
||
.limit(1)
|
||
)
|
||
latest_update = result.scalar_one_or_none()
|
||
|
||
if not latest_update:
|
||
return DataCheckResponse(
|
||
has_data=False,
|
||
company_id=company.id,
|
||
data_source=data_source,
|
||
message="该数据源暂无该公司数据"
|
||
)
|
||
|
||
# 3. 构建响应
|
||
return DataCheckResponse(
|
||
has_data=True,
|
||
company_id=company.id,
|
||
data_source=data_source,
|
||
last_update={
|
||
"date": latest_update.completed_at.isoformat(),
|
||
"data_start_date": latest_update.data_start_date,
|
||
"data_end_date": latest_update.data_end_date,
|
||
"table_counts": latest_update.row_counts or {}
|
||
}
|
||
)
|
||
|
||
|
||
async def create_or_get_company(
|
||
market: str,
|
||
symbol: str,
|
||
company_name: str,
|
||
db: AsyncSession
|
||
) -> Company:
|
||
"""创建或获取公司记录"""
|
||
result = await db.execute(
|
||
select(Company).where(
|
||
and_(Company.market == market, Company.symbol == symbol)
|
||
)
|
||
)
|
||
company = result.scalar_one_or_none()
|
||
|
||
if not company:
|
||
company = Company(
|
||
market=market,
|
||
symbol=symbol,
|
||
company_name=company_name
|
||
)
|
||
db.add(company)
|
||
await db.commit()
|
||
await db.refresh(company)
|
||
|
||
return company
|
||
|
||
|
||
async def create_data_update_record(
|
||
company_id: int,
|
||
data_source: str,
|
||
update_type: str,
|
||
db: AsyncSession
|
||
) -> DataUpdate:
|
||
"""创建数据更新记录"""
|
||
data_update = DataUpdate(
|
||
company_id=company_id,
|
||
data_source=data_source,
|
||
update_type=update_type,
|
||
status='in_progress'
|
||
)
|
||
db.add(data_update)
|
||
await db.commit()
|
||
await db.refresh(data_update)
|
||
return data_update
|
||
|
||
|
||
async def update_data_update_record(
|
||
update_id: int,
|
||
status: str,
|
||
db: AsyncSession,
|
||
completed_at: Optional[datetime] = None,
|
||
error_message: Optional[str] = None,
|
||
data_start_date: Optional[str] = None,
|
||
data_end_date: Optional[str] = None,
|
||
fetched_tables: Optional[List[str]] = None,
|
||
row_counts: Optional[Dict[str, int]] = None
|
||
):
|
||
"""更新数据更新记录"""
|
||
result = await db.execute(
|
||
select(DataUpdate).where(DataUpdate.id == update_id)
|
||
)
|
||
data_update = result.scalar_one()
|
||
|
||
data_update.status = status
|
||
if completed_at:
|
||
data_update.completed_at = completed_at
|
||
if error_message:
|
||
data_update.error_message = error_message
|
||
if data_start_date:
|
||
data_update.data_start_date = data_start_date
|
||
if data_end_date:
|
||
data_update.data_end_date = data_end_date
|
||
if fetched_tables:
|
||
data_update.fetched_tables = fetched_tables
|
||
if row_counts:
|
||
data_update.row_counts = row_counts
|
||
|
||
await db.commit()
|
||
await db.refresh(data_update)
|
||
return data_update
|
||
|
||
|
||
def fetch_financial_data_sync(
|
||
company_id: int,
|
||
market: str,
|
||
symbol: str,
|
||
data_source: str,
|
||
update_id: int,
|
||
currency: Optional[str] = None,
|
||
frequency: Optional[str] = "Annual"
|
||
):
|
||
"""
|
||
同步方式获取财务数据(在后台任务中调用)
|
||
|
||
此函数实际执行数据获取,使用同步的 Fetcher
|
||
"""
|
||
try:
|
||
# 0. 初始化
|
||
display_freq = "季度" if frequency == "Quarterly" or frequency == "Quarter" else "年度"
|
||
update_progress_sync(update_id, f"正在初始化数据获取 ({display_freq})...", 0)
|
||
|
||
# 格式化股票代码 - CN市场需要添加.SH或.SZ后缀
|
||
formatted_symbol = symbol
|
||
if market.upper() == 'CN' and data_source == 'Tushare':
|
||
if '.' not in symbol:
|
||
# 判断是上海还是深圳
|
||
if symbol.startswith('6'):
|
||
formatted_symbol = f"{symbol}.SH"
|
||
elif symbol.startswith(('0', '3', '2')):
|
||
formatted_symbol = f"{symbol}.SZ"
|
||
else:
|
||
formatted_symbol = f"{symbol}.SH" # 默认上海
|
||
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"📝 [数据获取] 股票代码格式化: {symbol} -> {formatted_symbol}")
|
||
|
||
# 1. 获取对应的 Fetcher
|
||
# 1. 获取对应的 Fetcher
|
||
fetcher = FetcherFactory.get_fetcher(market, data_source)
|
||
update_progress_sync(update_id, "数据源连接成功", 10)
|
||
|
||
# 2. 触发数据同步
|
||
# 对于 Bloomberg,我们只需要触发一次全量同步
|
||
if data_source == 'Bloomberg':
|
||
update_progress_sync(update_id, "正在连接 Bloomberg 终端...", 20)
|
||
|
||
# 定义进度回调
|
||
def progress_callback(msg, pct):
|
||
# 映射内部进度 (0-100) 到总体进度 (20-90)
|
||
mapped_pct = 20 + int(pct * 0.7)
|
||
update_progress_sync(update_id, msg, mapped_pct)
|
||
|
||
# 使用新的 sync_all_data 方法(将在 fetcher 中实现)
|
||
if hasattr(fetcher, 'sync_all_data'):
|
||
# 检查 sync_all_data 是否接受 progress_callback 参数
|
||
import inspect
|
||
sig = inspect.signature(fetcher.sync_all_data)
|
||
|
||
kwargs = {}
|
||
if 'progress_callback' in sig.parameters:
|
||
kwargs['progress_callback'] = progress_callback
|
||
if 'force_currency' in sig.parameters:
|
||
kwargs['force_currency'] = currency
|
||
if 'frequency' in sig.parameters:
|
||
kwargs['frequency'] = frequency
|
||
|
||
fetcher.sync_all_data(formatted_symbol, **kwargs)
|
||
|
||
else:
|
||
# 兼容旧代码,虽然有了 sync_all_data 后这部分应该不需要了
|
||
fetcher.get_income_statement(formatted_symbol)
|
||
|
||
update_progress_sync(update_id, "Bloomberg 数据同步完成", 100)
|
||
|
||
else:
|
||
# 对于其他数据源,保持原有逻辑但简化日志
|
||
update_progress_sync(update_id, "正在获取财务报表数据...", 30)
|
||
fetcher.get_income_statement(formatted_symbol)
|
||
fetcher.get_balance_sheet(formatted_symbol)
|
||
fetcher.get_cash_flow(formatted_symbol)
|
||
fetcher.get_market_metrics(formatted_symbol)
|
||
|
||
# 尝试获取辅助数据
|
||
try:
|
||
fetcher.get_dividends(symbol)
|
||
fetcher.get_repurchases(symbol)
|
||
fetcher.get_employee_count(symbol)
|
||
except Exception:
|
||
pass # 忽略辅助数据获取错误
|
||
|
||
update_progress_sync(update_id, "数据同步完成", 100)
|
||
|
||
result_data = {
|
||
'status': 'completed',
|
||
'completed_at': datetime.now(),
|
||
'fetched_tables': ['unified_data'] if data_source == 'Bloomberg' else ['income_statement', 'balance_sheet', 'cash_flow', 'daily_basic'],
|
||
# 不再统计行数
|
||
'row_counts': {}
|
||
}
|
||
return result_data
|
||
|
||
except Exception as e:
|
||
update_progress_sync(update_id, f"数据获取失败: {str(e)}", 0)
|
||
return {
|
||
'status': 'failed',
|
||
'error_message': str(e)
|
||
}
|
||
|
||
|
||
async def get_financial_data_from_db(
|
||
company_id: int,
|
||
data_source: str,
|
||
db: AsyncSession,
|
||
frequency: str = "Annual"
|
||
) -> Dict:
|
||
"""
|
||
从数据库读取财务数据
|
||
|
||
Args:
|
||
company_id: 公司ID
|
||
data_source: 数据源 (iFinD, Bloomberg, Tushare)
|
||
db: 数据库会话
|
||
frequency: 'Annual' or 'Quarterly'
|
||
|
||
Returns:
|
||
包含所有财务数据的字典
|
||
"""
|
||
# 获取公司信息
|
||
result = await db.execute(
|
||
select(Company).where(Company.id == company_id)
|
||
)
|
||
company = result.scalar_one()
|
||
|
||
# 准备返回数据结构
|
||
response_data = {
|
||
"company": {
|
||
"id": company.id,
|
||
"market": company.market,
|
||
"symbol": company.symbol,
|
||
"company_name": company.company_name,
|
||
"created_at": company.created_at,
|
||
"updated_at": company.updated_at
|
||
},
|
||
"data_source": data_source,
|
||
"income_statement": [],
|
||
"balance_sheet": [],
|
||
"cash_flow": [],
|
||
"daily_basic": [],
|
||
"dividend": [],
|
||
"repurchase": [],
|
||
"employee": [],
|
||
"unified_data": []
|
||
}
|
||
|
||
# 根据数据源读取数据
|
||
if data_source == 'Tushare' and company.market == 'CN':
|
||
# 定义表映射
|
||
tables = {
|
||
"income_statement": "tushare_income_statement",
|
||
"balance_sheet": "tushare_balance_sheet",
|
||
"cash_flow": "tushare_cash_flow",
|
||
"daily_basic": "tushare_daily_basic",
|
||
"dividend": "tushare_dividend",
|
||
"repurchase": "tushare_repurchase",
|
||
"employee": "tushare_stock_company" # Tushare 员工信息通常在公司信息表中
|
||
}
|
||
|
||
# 辅助函数:读取表数据
|
||
# from sqlalchemy import text (moved to top)
|
||
|
||
async def fetch_table_data(table_name, symbol_code, order_by=None):
|
||
try:
|
||
# 检查表是否存在
|
||
check_sql = text("SELECT to_regclass(:table_name)")
|
||
exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"})
|
||
if not exists.scalar():
|
||
return []
|
||
|
||
# 查询数据
|
||
query = f"SELECT * FROM {table_name} WHERE ts_code = :ts_code"
|
||
if order_by:
|
||
query += f" ORDER BY {order_by} DESC"
|
||
query += " LIMIT 100" # 限制返回数量
|
||
|
||
result = await db.execute(text(query), {"ts_code": symbol_code})
|
||
|
||
# 转换为字典列表,处理日期格式
|
||
data = []
|
||
for row in result.mappings():
|
||
item = dict(row)
|
||
# 处理日期对象为字符串
|
||
for k, v in item.items():
|
||
if isinstance(v, datetime):
|
||
item[k] = v.isoformat()
|
||
data.append(item)
|
||
return data
|
||
except Exception as e:
|
||
print(f"Error fetching from {table_name}: {e}")
|
||
return []
|
||
|
||
# Tushare 使用 ts_code (symbol)
|
||
ts_code = company.symbol # 假设 symbol 已经是 ts_code 格式 (e.g. 300750.SZ)
|
||
# 如果 symbol 只有数字,可能需要补充后缀,但这里先假设存储时已处理
|
||
# 实际上 TushareClient._get_ts_code 会处理,这里我们最好模糊匹配或确保一致性
|
||
# 在前端传递 symbol 时通常是 6 位数字,后端 TushareClient 用它去查询
|
||
# 这里尝试直接用 symbol 查询,如果为空,尝试添加后缀匹配(但这比较复杂)
|
||
# 简单起见,我们先用 symbol (假设数据库中存储的 ts_code 包含后缀,我们需要模糊查询或者假设 symbol 是一致的)
|
||
|
||
# 修正:Tushare 存储时用的是 ts_code (如 000001.SZ)。
|
||
# 但我们这里的 company.symbol 可能是没有后缀的 (000001)。
|
||
# 我们应该用 LIKE 查询或者获取存储时的 ts_code。
|
||
# 为简单起见,先尝试模糊匹配
|
||
|
||
async def fetch_table_data_fuzzy(table_name, symbol_code, order_by=None):
|
||
try:
|
||
check_sql = text("SELECT to_regclass(:table_name)")
|
||
exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"})
|
||
if not exists.scalar(): return []
|
||
|
||
query = f"SELECT * FROM {table_name} WHERE ts_code LIKE :ts_code"
|
||
if order_by: query += f" ORDER BY {order_by} DESC"
|
||
query += " LIMIT 50"
|
||
|
||
# 尝试匹配 symbol% (如 300750%)
|
||
result = await db.execute(text(query), {"ts_code": f"{symbol_code}%"})
|
||
|
||
data = []
|
||
for row in result.mappings():
|
||
item = dict(row)
|
||
for k, v in item.items():
|
||
if hasattr(v, 'isoformat'): item[k] = v.isoformat()
|
||
data.append(item)
|
||
return data
|
||
except Exception as e:
|
||
print(f"Error fetching from {table_name}: {e}")
|
||
return []
|
||
|
||
# 并行或串行获取数据
|
||
response_data["income_statement"] = await fetch_table_data_fuzzy(tables["income_statement"], company.symbol, "end_date")
|
||
response_data["balance_sheet"] = await fetch_table_data_fuzzy(tables["balance_sheet"], company.symbol, "end_date")
|
||
response_data["cash_flow"] = await fetch_table_data_fuzzy(tables["cash_flow"], company.symbol, "end_date")
|
||
|
||
# 日频数据
|
||
response_data["daily_basic"] = await fetch_table_data_fuzzy(tables["daily_basic"], company.symbol, "trade_date")
|
||
|
||
# 其他数据
|
||
response_data["dividend"] = await fetch_table_data_fuzzy(tables["dividend"], company.symbol, "end_date")
|
||
response_data["repurchase"] = await fetch_table_data_fuzzy(tables["repurchase"], company.symbol, "ann_date")
|
||
|
||
elif data_source == 'Bloomberg':
|
||
try:
|
||
# 使用独立的 Bloomberg 服务读取数据
|
||
unified_data = await get_bloomberg_data(company, db, frequency=frequency)
|
||
|
||
response_data["unified_data"] = unified_data
|
||
response_data["income_statement"] = []
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error fetching Bloomberg data from stockcard: {e}", exc_info=True)
|
||
|
||
return response_data
|
||
|
||
|
||
def get_available_data_sources(market: str) -> List[Dict]:
|
||
"""
|
||
获取指定市场的可用数据源
|
||
|
||
Returns:
|
||
数据源列表,每个包含 source, available, description, supported_markets
|
||
"""
|
||
all_sources = {
|
||
'iFinD': {
|
||
'description': 'iFinD (同花顺) 数据源',
|
||
'supported_markets': ['HK', 'JP', 'US', 'VN']
|
||
},
|
||
'Tushare': {
|
||
'description': 'Tushare 数据源',
|
||
'supported_markets': ['CN']
|
||
},
|
||
'Bloomberg': {
|
||
'description': 'Bloomberg 数据源',
|
||
'supported_markets': ['CN', 'HK', 'JP', 'US', 'VN']
|
||
}
|
||
}
|
||
|
||
sources = []
|
||
for source_name, info in all_sources.items():
|
||
sources.append({
|
||
'source': source_name,
|
||
'available': market in info['supported_markets'],
|
||
'description': info['description'],
|
||
'supported_markets': info['supported_markets']
|
||
})
|
||
|
||
return sources
|
||
|
||
async def get_recent_companies(
|
||
data_source: Optional[str],
|
||
db: AsyncSession,
|
||
limit: int = 20
|
||
) -> List[Dict]:
|
||
"""获取最近更新的公司列表
|
||
|
||
Args:
|
||
data_source: 可选的数据源过滤。如果为None,则返回所有数据源的最近更新。
|
||
db: 数据库会话
|
||
limit: 返回数量限制
|
||
"""
|
||
query = (
|
||
select(Company, DataUpdate.completed_at, DataUpdate.data_source)
|
||
.join(DataUpdate, Company.id == DataUpdate.company_id)
|
||
.where(DataUpdate.status == 'completed')
|
||
.order_by(DataUpdate.completed_at.desc())
|
||
.limit(limit * 5) # Fetch more to handle duplicates
|
||
)
|
||
|
||
if data_source:
|
||
query = query.where(DataUpdate.data_source == data_source)
|
||
|
||
result = await db.execute(query)
|
||
rows = result.all()
|
||
|
||
seen_ids = set()
|
||
companies = []
|
||
|
||
for row in rows:
|
||
company, completed_at, source = row
|
||
if company.id not in seen_ids:
|
||
seen_ids.add(company.id)
|
||
companies.append({
|
||
"market": company.market,
|
||
"symbol": company.symbol,
|
||
"company_name": company.company_name,
|
||
"data_source": source,
|
||
"last_update": completed_at.strftime('%Y-%m-%d %H:%M') if completed_at else ""
|
||
})
|
||
if len(companies) >= limit:
|
||
break
|
||
|
||
|
||
return companies
|
||
|
||
async def search_companies(
|
||
query: str,
|
||
db: AsyncSession,
|
||
limit: int = 10
|
||
) -> List[Dict]:
|
||
"""
|
||
搜索本地公司库
|
||
|
||
Args:
|
||
query: 搜索关键词 (代码或名称)
|
||
db: 数据库会话
|
||
limit: 返回数量
|
||
|
||
Returns:
|
||
公司列表
|
||
"""
|
||
if not query:
|
||
return []
|
||
|
||
# 构建模糊查询
|
||
search_term = f"%{query}%"
|
||
|
||
# 查找 symbol 或 company_name 匹配的公司
|
||
stmt = (
|
||
select(Company)
|
||
.where(
|
||
(Company.symbol.ilike(search_term)) |
|
||
(Company.company_name.ilike(search_term))
|
||
)
|
||
.limit(limit)
|
||
)
|
||
|
||
result = await db.execute(stmt)
|
||
companies = result.scalars().all()
|
||
|
||
return [
|
||
{
|
||
"market": c.market,
|
||
"symbol": c.symbol,
|
||
"company_name": c.company_name,
|
||
"id": c.id
|
||
}
|
||
for c in companies
|
||
]
|