""" 数据获取服务层 负责从不同数据源获取财务数据并存储到数据库 """ import os from datetime import datetime from typing import Dict, List, Optional import pandas as pd from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, text from sqlalchemy.orm import selectinload from app.models import Company, DataUpdate, DataSourceAvailability from app.schemas import DataCheckResponse, FetchDataRequest from app.fetchers.factory import FetcherFactory from app.services.bloomberg_service import get_bloomberg_data import logging logger = logging.getLogger(__name__) def update_progress_sync(update_id: int, message: str, percentage: int): """同步更新数据获取进度 - 使用psycopg2直连避免event loop冲突""" import psycopg2 import os logger.info(f"🔄 [进度更新] ID={update_id}, Msg={message}, Progress={percentage}%") try: conn = psycopg2.connect( host=os.getenv('DB_HOST', '192.168.3.195'), user=os.getenv('DB_USER', 'value'), password=os.getenv('DB_PASSWORD', 'Value609!'), dbname=os.getenv('DB_NAME', 'fa3'), port=os.getenv('DB_PORT', '5432') ) cur = conn.cursor() cur.execute( """ UPDATE data_updates SET progress_message = %s, progress_percentage = %s WHERE id = %s """, (message, percentage, update_id) ) if cur.rowcount == 0: logger.warning(f"⚠️ [进度更新] ID={update_id} 未找到记录!") conn.commit() cur.close() conn.close() except Exception as e: logger.error(f"❌ [进度更新失败] ID={update_id}: {e}", exc_info=True) async def check_data_status( market: str, symbol: str, data_source: str, db: AsyncSession ) -> DataCheckResponse: """ 检查指定公司和数据源的数据状态 Returns: DataCheckResponse: 包含数据状态信息 """ # 1. 查询或创建公司记录 result = await db.execute( select(Company).where( and_(Company.market == market, Company.symbol == symbol) ) ) company = result.scalar_one_or_none() if not company: # 没有找到公司记录 return DataCheckResponse( has_data=False, company_id=None, data_source=data_source, message=f"该数据源暂无该公司数据" ) # 2. 查询最近一次成功的数据更新 result = await db.execute( select(DataUpdate) .where( and_( DataUpdate.company_id == company.id, DataUpdate.data_source == data_source, DataUpdate.status == 'completed' ) ) .order_by(DataUpdate.completed_at.desc()) .limit(1) ) latest_update = result.scalar_one_or_none() if not latest_update: return DataCheckResponse( has_data=False, company_id=company.id, data_source=data_source, message="该数据源暂无该公司数据" ) # 3. 构建响应 return DataCheckResponse( has_data=True, company_id=company.id, data_source=data_source, last_update={ "date": latest_update.completed_at.isoformat(), "data_start_date": latest_update.data_start_date, "data_end_date": latest_update.data_end_date, "table_counts": latest_update.row_counts or {} } ) async def create_or_get_company( market: str, symbol: str, company_name: str, db: AsyncSession ) -> Company: """创建或获取公司记录""" result = await db.execute( select(Company).where( and_(Company.market == market, Company.symbol == symbol) ) ) company = result.scalar_one_or_none() if not company: company = Company( market=market, symbol=symbol, company_name=company_name ) db.add(company) await db.commit() await db.refresh(company) return company async def create_data_update_record( company_id: int, data_source: str, update_type: str, db: AsyncSession ) -> DataUpdate: """创建数据更新记录""" data_update = DataUpdate( company_id=company_id, data_source=data_source, update_type=update_type, status='in_progress' ) db.add(data_update) await db.commit() await db.refresh(data_update) return data_update async def update_data_update_record( update_id: int, status: str, db: AsyncSession, completed_at: Optional[datetime] = None, error_message: Optional[str] = None, data_start_date: Optional[str] = None, data_end_date: Optional[str] = None, fetched_tables: Optional[List[str]] = None, row_counts: Optional[Dict[str, int]] = None ): """更新数据更新记录""" result = await db.execute( select(DataUpdate).where(DataUpdate.id == update_id) ) data_update = result.scalar_one() data_update.status = status if completed_at: data_update.completed_at = completed_at if error_message: data_update.error_message = error_message if data_start_date: data_update.data_start_date = data_start_date if data_end_date: data_update.data_end_date = data_end_date if fetched_tables: data_update.fetched_tables = fetched_tables if row_counts: data_update.row_counts = row_counts await db.commit() await db.refresh(data_update) return data_update def fetch_financial_data_sync( company_id: int, market: str, symbol: str, data_source: str, update_id: int, currency: Optional[str] = None ): """ 同步方式获取财务数据(在后台任务中调用) 此函数实际执行数据获取,使用同步的 Fetcher """ try: # 0. 初始化 update_progress_sync(update_id, "正在初始化数据获取...", 0) # 格式化股票代码 - CN市场需要添加.SH或.SZ后缀 formatted_symbol = symbol if market.upper() == 'CN' and data_source == 'Tushare': if '.' not in symbol: # 判断是上海还是深圳 if symbol.startswith('6'): formatted_symbol = f"{symbol}.SH" elif symbol.startswith(('0', '3', '2')): formatted_symbol = f"{symbol}.SZ" else: formatted_symbol = f"{symbol}.SH" # 默认上海 import logging logger = logging.getLogger(__name__) logger.info(f"📝 [数据获取] 股票代码格式化: {symbol} -> {formatted_symbol}") # 1. 获取对应的 Fetcher # 1. 获取对应的 Fetcher fetcher = FetcherFactory.get_fetcher(market, data_source) update_progress_sync(update_id, "数据源连接成功", 10) # 2. 触发数据同步 # 对于 Bloomberg,我们只需要触发一次全量同步 if data_source == 'Bloomberg': update_progress_sync(update_id, "正在连接 Bloomberg 终端...", 20) # 定义进度回调 def progress_callback(msg, pct): # 映射内部进度 (0-100) 到总体进度 (20-90) mapped_pct = 20 + int(pct * 0.7) update_progress_sync(update_id, msg, mapped_pct) # 使用新的 sync_all_data 方法(将在 fetcher 中实现) if hasattr(fetcher, 'sync_all_data'): # 检查 sync_all_data 是否接受 progress_callback 参数 import inspect sig = inspect.signature(fetcher.sync_all_data) if 'progress_callback' in sig.parameters: if 'force_currency' in sig.parameters: fetcher.sync_all_data(formatted_symbol, progress_callback=progress_callback, force_currency=currency) else: fetcher.sync_all_data(formatted_symbol, progress_callback=progress_callback) else: fetcher.sync_all_data(formatted_symbol) else: # 兼容旧代码,虽然有了 sync_all_data 后这部分应该不需要了 fetcher.get_income_statement(formatted_symbol) update_progress_sync(update_id, "Bloomberg 数据同步完成", 100) else: # 对于其他数据源,保持原有逻辑但简化日志 update_progress_sync(update_id, "正在获取财务报表数据...", 30) fetcher.get_income_statement(formatted_symbol) fetcher.get_balance_sheet(formatted_symbol) fetcher.get_cash_flow(formatted_symbol) fetcher.get_market_metrics(formatted_symbol) # 尝试获取辅助数据 try: fetcher.get_dividends(symbol) fetcher.get_repurchases(symbol) fetcher.get_employee_count(symbol) except Exception: pass # 忽略辅助数据获取错误 update_progress_sync(update_id, "数据同步完成", 100) result_data = { 'status': 'completed', 'completed_at': datetime.now(), 'fetched_tables': ['unified_data'] if data_source == 'Bloomberg' else ['income_statement', 'balance_sheet', 'cash_flow', 'daily_basic'], # 不再统计行数 'row_counts': {} } return result_data except Exception as e: update_progress_sync(update_id, f"数据获取失败: {str(e)}", 0) return { 'status': 'failed', 'error_message': str(e) } async def get_financial_data_from_db( company_id: int, data_source: str, db: AsyncSession ) -> Dict: """ 从数据库读取财务数据 Args: company_id: 公司ID data_source: 数据源 (iFinD, Bloomberg, Tushare) db: 数据库会话 Returns: 包含所有财务数据的字典 """ # 获取公司信息 result = await db.execute( select(Company).where(Company.id == company_id) ) company = result.scalar_one() # 准备返回数据结构 response_data = { "company": { "id": company.id, "market": company.market, "symbol": company.symbol, "company_name": company.company_name, "created_at": company.created_at, "updated_at": company.updated_at }, "data_source": data_source, "income_statement": [], "balance_sheet": [], "cash_flow": [], "daily_basic": [], "dividend": [], "repurchase": [], "employee": [], "unified_data": [] } # 根据数据源读取数据 if data_source == 'Tushare' and company.market == 'CN': # 定义表映射 tables = { "income_statement": "tushare_income_statement", "balance_sheet": "tushare_balance_sheet", "cash_flow": "tushare_cash_flow", "daily_basic": "tushare_daily_basic", "dividend": "tushare_dividend", "repurchase": "tushare_repurchase", "employee": "tushare_stock_company" # Tushare 员工信息通常在公司信息表中 } # 辅助函数:读取表数据 # from sqlalchemy import text (moved to top) async def fetch_table_data(table_name, symbol_code, order_by=None): try: # 检查表是否存在 check_sql = text("SELECT to_regclass(:table_name)") exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"}) if not exists.scalar(): return [] # 查询数据 query = f"SELECT * FROM {table_name} WHERE ts_code = :ts_code" if order_by: query += f" ORDER BY {order_by} DESC" query += " LIMIT 100" # 限制返回数量 result = await db.execute(text(query), {"ts_code": symbol_code}) # 转换为字典列表,处理日期格式 data = [] for row in result.mappings(): item = dict(row) # 处理日期对象为字符串 for k, v in item.items(): if isinstance(v, datetime): item[k] = v.isoformat() data.append(item) return data except Exception as e: print(f"Error fetching from {table_name}: {e}") return [] # Tushare 使用 ts_code (symbol) ts_code = company.symbol # 假设 symbol 已经是 ts_code 格式 (e.g. 300750.SZ) # 如果 symbol 只有数字,可能需要补充后缀,但这里先假设存储时已处理 # 实际上 TushareClient._get_ts_code 会处理,这里我们最好模糊匹配或确保一致性 # 在前端传递 symbol 时通常是 6 位数字,后端 TushareClient 用它去查询 # 这里尝试直接用 symbol 查询,如果为空,尝试添加后缀匹配(但这比较复杂) # 简单起见,我们先用 symbol (假设数据库中存储的 ts_code 包含后缀,我们需要模糊查询或者假设 symbol 是一致的) # 修正:Tushare 存储时用的是 ts_code (如 000001.SZ)。 # 但我们这里的 company.symbol 可能是没有后缀的 (000001)。 # 我们应该用 LIKE 查询或者获取存储时的 ts_code。 # 为简单起见,先尝试模糊匹配 async def fetch_table_data_fuzzy(table_name, symbol_code, order_by=None): try: check_sql = text("SELECT to_regclass(:table_name)") exists = await db.execute(check_sql, {"table_name": f"public.{table_name}"}) if not exists.scalar(): return [] query = f"SELECT * FROM {table_name} WHERE ts_code LIKE :ts_code" if order_by: query += f" ORDER BY {order_by} DESC" query += " LIMIT 50" # 尝试匹配 symbol% (如 300750%) result = await db.execute(text(query), {"ts_code": f"{symbol_code}%"}) data = [] for row in result.mappings(): item = dict(row) for k, v in item.items(): if hasattr(v, 'isoformat'): item[k] = v.isoformat() data.append(item) return data except Exception as e: print(f"Error fetching from {table_name}: {e}") return [] # 并行或串行获取数据 response_data["income_statement"] = await fetch_table_data_fuzzy(tables["income_statement"], company.symbol, "end_date") response_data["balance_sheet"] = await fetch_table_data_fuzzy(tables["balance_sheet"], company.symbol, "end_date") response_data["cash_flow"] = await fetch_table_data_fuzzy(tables["cash_flow"], company.symbol, "end_date") # 日频数据 response_data["daily_basic"] = await fetch_table_data_fuzzy(tables["daily_basic"], company.symbol, "trade_date") # 其他数据 response_data["dividend"] = await fetch_table_data_fuzzy(tables["dividend"], company.symbol, "end_date") response_data["repurchase"] = await fetch_table_data_fuzzy(tables["repurchase"], company.symbol, "ann_date") elif data_source == 'Bloomberg': try: # 使用独立的 Bloomberg 服务读取数据 unified_data = await get_bloomberg_data(company, db) response_data["unified_data"] = unified_data response_data["income_statement"] = [] except Exception as e: logger.error(f"Error fetching Bloomberg data from stockcard: {e}", exc_info=True) return response_data def get_available_data_sources(market: str) -> List[Dict]: """ 获取指定市场的可用数据源 Returns: 数据源列表,每个包含 source, available, description, supported_markets """ all_sources = { 'iFinD': { 'description': 'iFinD (同花顺) 数据源', 'supported_markets': ['HK', 'JP', 'US', 'VN'] }, 'Tushare': { 'description': 'Tushare 数据源', 'supported_markets': ['CN'] }, 'Bloomberg': { 'description': 'Bloomberg 数据源', 'supported_markets': ['CN', 'HK', 'JP', 'US', 'VN'] } } sources = [] for source_name, info in all_sources.items(): sources.append({ 'source': source_name, 'available': market in info['supported_markets'], 'description': info['description'], 'supported_markets': info['supported_markets'] }) return sources async def get_recent_companies( data_source: str, db: AsyncSession, limit: int = 20 ) -> List[Dict]: """获取指定数据源下最近更新的公司""" query = ( select(Company, DataUpdate.completed_at) .join(DataUpdate, Company.id == DataUpdate.company_id) .where( and_( DataUpdate.data_source == data_source, DataUpdate.status == 'completed' ) ) .order_by(DataUpdate.completed_at.desc()) .limit(limit * 3) # Fetch more to handle duplicates ) result = await db.execute(query) rows = result.all() seen_ids = set() companies = [] for row in rows: company, completed_at = row if company.id not in seen_ids: seen_ids.add(company.id) companies.append({ "market": company.market, "symbol": company.symbol, "company_name": company.company_name, "last_update": completed_at.strftime('%Y-%m-%d %H:%M') if completed_at else "" }) if len(companies) >= limit: break return companies