FA3-Datafetch/backend/app/services/data_fetcher_service.py

584 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据获取服务层
负责从不同数据源获取财务数据并存储到数据库
"""
import os
from datetime import datetime
from typing import Dict, List, Optional
import pandas as pd
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text
from sqlalchemy.orm import selectinload
from app.models import Company, DataUpdate, DataSourceAvailability
from app.schemas import DataCheckResponse, FetchDataRequest
from app.fetchers.factory import FetcherFactory
from app.services.bloomberg_service import get_bloomberg_data
import logging
logger = logging.getLogger(__name__)
def update_progress_sync(update_id: int, message: str, percentage: int):
"""同步更新数据获取进度 - 使用psycopg2直连避免event loop冲突"""
import psycopg2
import os
logger.info(f"🔄 [进度更新] ID={update_id}, Msg={message}, Progress={percentage}%")
try:
conn = psycopg2.connect(
host=os.getenv('DB_HOST', '192.168.3.195'),
user=os.getenv('DB_USER', 'value'),
password=os.getenv('DB_PASSWORD', 'Value609!'),
dbname=os.getenv('DB_NAME', 'fa3'),
port=os.getenv('DB_PORT', '5432')
)
cur = conn.cursor()
cur.execute(
"""
UPDATE data_updates
SET progress_message = %s, progress_percentage = %s
WHERE id = %s
""",
(message, percentage, update_id)
)
if cur.rowcount == 0:
logger.warning(f"⚠️ [进度更新] ID={update_id} 未找到记录!")
conn.commit()
cur.close()
conn.close()
except Exception as e:
logger.error(f"❌ [进度更新失败] ID={update_id}: {e}", exc_info=True)
async def check_data_status(
market: str,
symbol: str,
data_source: str,
db: AsyncSession
) -> DataCheckResponse:
"""
检查指定公司和数据源的数据状态
Returns:
DataCheckResponse: 包含数据状态信息
"""
# 1. 查询或创建公司记录
result = await db.execute(
select(Company).where(
and_(Company.market == market, Company.symbol == symbol)
)
)
company = result.scalar_one_or_none()
if not company:
# 没有找到公司记录
return DataCheckResponse(
has_data=False,
company_id=None,
data_source=data_source,
message=f"该数据源暂无该公司数据"
)
# 2. 查询最近一次成功的数据更新
result = await db.execute(
select(DataUpdate)
.where(
and_(
DataUpdate.company_id == company.id,
DataUpdate.data_source == data_source,
DataUpdate.status == 'completed'
)
)
.order_by(DataUpdate.completed_at.desc())
.limit(1)
)
latest_update = result.scalar_one_or_none()
if not latest_update:
return DataCheckResponse(
has_data=False,
company_id=company.id,
data_source=data_source,
message="该数据源暂无该公司数据"
)
# 3. 构建响应
return DataCheckResponse(
has_data=True,
company_id=company.id,
data_source=data_source,
last_update={
"date": latest_update.completed_at.isoformat(),
"data_start_date": latest_update.data_start_date,
"data_end_date": latest_update.data_end_date,
"table_counts": latest_update.row_counts or {}
}
)
async def create_or_get_company(
market: str,
symbol: str,
company_name: str,
db: AsyncSession
) -> Company:
"""创建或获取公司记录"""
result = await db.execute(
select(Company).where(
and_(Company.market == market, Company.symbol == symbol)
)
)
company = result.scalar_one_or_none()
if not company:
company = Company(
market=market,
symbol=symbol,
company_name=company_name
)
db.add(company)
await db.commit()
await db.refresh(company)
return company
async def create_data_update_record(
company_id: int,
data_source: str,
update_type: str,
db: AsyncSession
) -> DataUpdate:
"""创建数据更新记录"""
data_update = DataUpdate(
company_id=company_id,
data_source=data_source,
update_type=update_type,
status='in_progress'
)
db.add(data_update)
await db.commit()
await db.refresh(data_update)
return data_update
async def update_data_update_record(
update_id: int,
status: str,
db: AsyncSession,
completed_at: Optional[datetime] = None,
error_message: Optional[str] = None,
data_start_date: Optional[str] = None,
data_end_date: Optional[str] = None,
fetched_tables: Optional[List[str]] = None,
row_counts: Optional[Dict[str, int]] = None
):
"""更新数据更新记录"""
result = await db.execute(
select(DataUpdate).where(DataUpdate.id == update_id)
)
data_update = result.scalar_one()
data_update.status = status
if completed_at:
data_update.completed_at = completed_at
if error_message:
data_update.error_message = error_message
if data_start_date:
data_update.data_start_date = data_start_date
if data_end_date:
data_update.data_end_date = data_end_date
if fetched_tables:
data_update.fetched_tables = fetched_tables
if row_counts:
data_update.row_counts = row_counts
await db.commit()
await db.refresh(data_update)
return data_update
def fetch_financial_data_sync(
company_id: int,
market: str,
symbol: str,
data_source: str,
update_id: int,
currency: Optional[str] = None
):
"""
同步方式获取财务数据(在后台任务中调用)
此函数实际执行数据获取,使用同步的 Fetcher
"""
try:
# 0. 初始化
update_progress_sync(update_id, "正在初始化数据获取...", 0)
# 格式化股票代码 - CN市场需要添加.SH或.SZ后缀
formatted_symbol = symbol
if market.upper() == 'CN' and data_source == 'Tushare':
if '.' not in symbol:
# 判断是上海还是深圳
if symbol.startswith('6'):
formatted_symbol = f"{symbol}.SH"
elif symbol.startswith(('0', '3', '2')):
formatted_symbol = f"{symbol}.SZ"
else:
formatted_symbol = f"{symbol}.SH" # 默认上海
import logging
logger = logging.getLogger(__name__)
logger.info(f"📝 [数据获取] 股票代码格式化: {symbol} -> {formatted_symbol}")
# 1. 获取对应的 Fetcher
# 1. 获取对应的 Fetcher
fetcher = FetcherFactory.get_fetcher(market, data_source)
update_progress_sync(update_id, "数据源连接成功", 10)
# 2. 触发数据同步
# 对于 Bloomberg我们只需要触发一次全量同步
if data_source == 'Bloomberg':
update_progress_sync(update_id, "正在连接 Bloomberg 终端...", 20)
# 定义进度回调
def progress_callback(msg, pct):
# 映射内部进度 (0-100) 到总体进度 (20-90)
mapped_pct = 20 + int(pct * 0.7)
update_progress_sync(update_id, msg, mapped_pct)
# 使用新的 sync_all_data 方法(将在 fetcher 中实现)
if hasattr(fetcher, 'sync_all_data'):
# 检查 sync_all_data 是否接受 progress_callback 参数
import inspect
sig = inspect.signature(fetcher.sync_all_data)
if 'progress_callback' in sig.parameters:
if 'force_currency' in sig.parameters:
fetcher.sync_all_data(formatted_symbol, progress_callback=progress_callback, force_currency=currency)
else:
fetcher.sync_all_data(formatted_symbol, progress_callback=progress_callback)
else:
fetcher.sync_all_data(formatted_symbol)
else:
# 兼容旧代码,虽然有了 sync_all_data 后这部分应该不需要了
fetcher.get_income_statement(formatted_symbol)
update_progress_sync(update_id, "Bloomberg 数据同步完成", 100)
else:
# 对于其他数据源,保持原有逻辑但简化日志
update_progress_sync(update_id, "正在获取财务报表数据...", 30)
fetcher.get_income_statement(formatted_symbol)
fetcher.get_balance_sheet(formatted_symbol)
fetcher.get_cash_flow(formatted_symbol)
fetcher.get_market_metrics(formatted_symbol)
# 尝试获取辅助数据
try:
fetcher.get_dividends(symbol)
fetcher.get_repurchases(symbol)
fetcher.get_employee_count(symbol)
except Exception:
pass # 忽略辅助数据获取错误
update_progress_sync(update_id, "数据同步完成", 100)
result_data = {
'status': 'completed',
'completed_at': datetime.now(),
'fetched_tables': ['unified_data'] if data_source == 'Bloomberg' else ['income_statement', 'balance_sheet', 'cash_flow', 'daily_basic'],
# 不再统计行数
'row_counts': {}
}
return result_data
except Exception as e:
update_progress_sync(update_id, f"数据获取失败: {str(e)}", 0)
return {
'status': 'failed',
'error_message': str(e)
}
async def get_financial_data_from_db(
company_id: int,
data_source: str,
db: AsyncSession
) -> Dict:
"""
从数据库读取财务数据
Args:
company_id: 公司ID
data_source: 数据源 (iFinD, Bloomberg, Tushare)
db: 数据库会话
Returns:
包含所有财务数据的字典
"""
# 获取公司信息
result = await db.execute(
select(Company).where(Company.id == company_id)
)
company = result.scalar_one()
# 准备返回数据结构
response_data = {
"company": {
"id": company.id,
"market": company.market,
"symbol": company.symbol,
"company_name": company.company_name,
"created_at": company.created_at,
"updated_at": company.updated_at
},
"data_source": data_source,
"income_statement": [],
"balance_sheet": [],
"cash_flow": [],
"daily_basic": [],
"dividend": [],
"repurchase": [],
"employee": [],
"unified_data": []
}
# 根据数据源读取数据
if data_source == 'Tushare' and company.market == 'CN':
# 定义表映射
tables = {
"income_statement": "tushare_income_statement",
"balance_sheet": "tushare_balance_sheet",
"cash_flow": "tushare_cash_flow",
"daily_basic": "tushare_daily_basic",
"dividend": "tushare_dividend",
"repurchase": "tushare_repurchase",
"employee": "tushare_stock_company" # Tushare 员工信息通常在公司信息表中
}
# 辅助函数:读取表数据
# from sqlalchemy import text (moved to top)
async def fetch_table_data(table_name, symbol_code, order_by=None):
try:
# 检查表是否存在
check_sql = text("SELECT to_regclass(:table_name)")
exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"})
if not exists.scalar():
return []
# 查询数据
query = f"SELECT * FROM {table_name} WHERE ts_code = :ts_code"
if order_by:
query += f" ORDER BY {order_by} DESC"
query += " LIMIT 100" # 限制返回数量
result = await db.execute(text(query), {"ts_code": symbol_code})
# 转换为字典列表,处理日期格式
data = []
for row in result.mappings():
item = dict(row)
# 处理日期对象为字符串
for k, v in item.items():
if isinstance(v, datetime):
item[k] = v.isoformat()
data.append(item)
return data
except Exception as e:
print(f"Error fetching from {table_name}: {e}")
return []
# Tushare 使用 ts_code (symbol)
ts_code = company.symbol # 假设 symbol 已经是 ts_code 格式 (e.g. 300750.SZ)
# 如果 symbol 只有数字,可能需要补充后缀,但这里先假设存储时已处理
# 实际上 TushareClient._get_ts_code 会处理,这里我们最好模糊匹配或确保一致性
# 在前端传递 symbol 时通常是 6 位数字,后端 TushareClient 用它去查询
# 这里尝试直接用 symbol 查询,如果为空,尝试添加后缀匹配(但这比较复杂)
# 简单起见,我们先用 symbol (假设数据库中存储的 ts_code 包含后缀,我们需要模糊查询或者假设 symbol 是一致的)
# 修正Tushare 存储时用的是 ts_code (如 000001.SZ)。
# 但我们这里的 company.symbol 可能是没有后缀的 (000001)。
# 我们应该用 LIKE 查询或者获取存储时的 ts_code。
# 为简单起见,先尝试模糊匹配
async def fetch_table_data_fuzzy(table_name, symbol_code, order_by=None):
try:
check_sql = text("SELECT to_regclass(:table_name)")
exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"})
if not exists.scalar(): return []
query = f"SELECT * FROM {table_name} WHERE ts_code LIKE :ts_code"
if order_by: query += f" ORDER BY {order_by} DESC"
query += " LIMIT 50"
# 尝试匹配 symbol% (如 300750%)
result = await db.execute(text(query), {"ts_code": f"{symbol_code}%"})
data = []
for row in result.mappings():
item = dict(row)
for k, v in item.items():
if hasattr(v, 'isoformat'): item[k] = v.isoformat()
data.append(item)
return data
except Exception as e:
print(f"Error fetching from {table_name}: {e}")
return []
# 并行或串行获取数据
response_data["income_statement"] = await fetch_table_data_fuzzy(tables["income_statement"], company.symbol, "end_date")
response_data["balance_sheet"] = await fetch_table_data_fuzzy(tables["balance_sheet"], company.symbol, "end_date")
response_data["cash_flow"] = await fetch_table_data_fuzzy(tables["cash_flow"], company.symbol, "end_date")
# 日频数据
response_data["daily_basic"] = await fetch_table_data_fuzzy(tables["daily_basic"], company.symbol, "trade_date")
# 其他数据
response_data["dividend"] = await fetch_table_data_fuzzy(tables["dividend"], company.symbol, "end_date")
response_data["repurchase"] = await fetch_table_data_fuzzy(tables["repurchase"], company.symbol, "ann_date")
elif data_source == 'Bloomberg':
try:
# 使用独立的 Bloomberg 服务读取数据
unified_data = await get_bloomberg_data(company, db)
response_data["unified_data"] = unified_data
response_data["income_statement"] = []
except Exception as e:
logger.error(f"Error fetching Bloomberg data from stockcard: {e}", exc_info=True)
return response_data
def get_available_data_sources(market: str) -> List[Dict]:
"""
获取指定市场的可用数据源
Returns:
数据源列表,每个包含 source, available, description, supported_markets
"""
all_sources = {
'iFinD': {
'description': 'iFinD (同花顺) 数据源',
'supported_markets': ['HK', 'JP', 'US', 'VN']
},
'Tushare': {
'description': 'Tushare 数据源',
'supported_markets': ['CN']
},
'Bloomberg': {
'description': 'Bloomberg 数据源',
'supported_markets': ['CN', 'HK', 'JP', 'US', 'VN']
}
}
sources = []
for source_name, info in all_sources.items():
sources.append({
'source': source_name,
'available': market in info['supported_markets'],
'description': info['description'],
'supported_markets': info['supported_markets']
})
return sources
async def get_recent_companies(
data_source: Optional[str],
db: AsyncSession,
limit: int = 20
) -> List[Dict]:
"""获取最近更新的公司列表
Args:
data_source: 可选的数据源过滤。如果为None则返回所有数据源的最近更新。
db: 数据库会话
limit: 返回数量限制
"""
query = (
select(Company, DataUpdate.completed_at, DataUpdate.data_source)
.join(DataUpdate, Company.id == DataUpdate.company_id)
.where(DataUpdate.status == 'completed')
.order_by(DataUpdate.completed_at.desc())
.limit(limit * 5) # Fetch more to handle duplicates
)
if data_source:
query = query.where(DataUpdate.data_source == data_source)
result = await db.execute(query)
rows = result.all()
seen_ids = set()
companies = []
for row in rows:
company, completed_at, source = row
if company.id not in seen_ids:
seen_ids.add(company.id)
companies.append({
"market": company.market,
"symbol": company.symbol,
"company_name": company.company_name,
"data_source": source,
"last_update": completed_at.strftime('%Y-%m-%d %H:%M') if completed_at else ""
})
if len(companies) >= limit:
break
return companies
async def search_companies(
query: str,
db: AsyncSession,
limit: int = 10
) -> List[Dict]:
"""
搜索本地公司库
Args:
query: 搜索关键词 (代码或名称)
db: 数据库会话
limit: 返回数量
Returns:
公司列表
"""
if not query:
return []
# 构建模糊查询
search_term = f"%{query}%"
# 查找 symbol 或 company_name 匹配的公司
stmt = (
select(Company)
.where(
(Company.symbol.ilike(search_term)) |
(Company.company_name.ilike(search_term))
)
.limit(limit)
)
result = await db.execute(stmt)
companies = result.scalars().all()
return [
{
"market": c.market,
"symbol": c.symbol,
"company_name": c.company_name,
"id": c.id
}
for c in companies
]