""" 数据管理 API 路由 处理财务数据的检查、获取和读取 """ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from typing import List from app.database import get_db from app.schemas import ( DataCheckRequest, DataCheckResponse, FetchDataRequest, FetchDataResponse, DataUpdateResponse, FinancialDataResponse, DataSourceListResponse, DataSourceInfo ) from app.services import data_fetcher_service from app.models import DataUpdate from sqlalchemy import select router = APIRouter(prefix="/data", tags=["data"]) @router.get("/check", response_model=DataCheckResponse) async def check_data_status( market: str, symbol: str, data_source: str, db: AsyncSession = Depends(get_db) ): """ 检查公司数据状态 查询数据库中是否存在指定公司和数据源的数据, 如果存在则返回最近一次更新的时间和数据范围 """ return await data_fetcher_service.check_data_status( market=market, symbol=symbol, data_source=data_source, db=db ) @router.post("/fetch", response_model=FetchDataResponse) async def fetch_data( request: FetchDataRequest, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db) ): """ 获取/更新财务数据 从指定数据源获取公司的财务数据并存储到数据库 此操作在后台异步执行 """ # 1. 创建或获取公司记录 company = await data_fetcher_service.create_or_get_company( market=request.market, symbol=request.symbol, company_name=request.company_name, db=db ) # 2. 创建数据更新记录 data_update = await data_fetcher_service.create_data_update_record( company_id=company.id, data_source=request.data_source, update_type='full', db=db ) # 3. 启动后台任务 background_tasks.add_task( fetch_data_background, company_id=company.id, market=request.market, symbol=request.symbol, data_source=request.data_source, update_id=data_update.id, currency=request.currency, frequency=request.frequency ) return FetchDataResponse( update_id=data_update.id, data_source=request.data_source, status='in_progress', message=f'正在从 {request.data_source} 获取数据,请稍候...' ) def fetch_data_background( company_id: int, market: str, symbol: str, data_source: str, update_id: int, currency: str = None, frequency: str = "Annual" ): """后台数据获取任务 - 完全同步执行,避免event loop冲突""" import sys import os from pathlib import Path import psycopg2 from datetime import datetime # 确保项目根目录在 Python 路径中 project_root = Path(__file__).parent.parent.parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) try: # 执行数据获取(同步) result = data_fetcher_service.fetch_financial_data_sync( company_id=company_id, market=market, symbol=symbol, data_source=data_source, update_id=update_id, currency=currency, frequency=frequency ) # 更新数据更新记录 - 使用psycopg2同步连接 conn = psycopg2.connect( host=os.getenv('DB_HOST', '192.168.3.195'), user=os.getenv('DB_USER', 'value'), password=os.getenv('DB_PASSWORD', 'Value609!'), dbname=os.getenv('DB_NAME', 'fa3'), port=os.getenv('DB_PORT', '5432') ) cur = conn.cursor() # 更新记录 update_sql = """ UPDATE data_updates SET status = %s, completed_at = %s, error_message = %s, data_start_date = %s, data_end_date = %s, fetched_tables = %s, row_counts = %s WHERE id = %s """ import json cur.execute( update_sql, ( result['status'], datetime.now() if result['status'] == 'completed' else None, result.get('error_message'), result.get('data_start_date'), result.get('data_end_date'), json.dumps(result.get('fetched_tables')), json.dumps(result.get('row_counts')), update_id ) ) conn.commit() cur.close() conn.close() except Exception as e: print(f"❌ 后台任务执行错误: {e}") import traceback traceback.print_exc() # 尝试更新错误状态 try: conn = psycopg2.connect( host=os.getenv('DB_HOST', '192.168.3.195'), user=os.getenv('DB_USER', 'value'), password=os.getenv('DB_PASSWORD', 'Value609!'), dbname=os.getenv('DB_NAME', 'fa3'), port=os.getenv('DB_PORT', '5432') ) cur = conn.cursor() cur.execute( "UPDATE data_updates SET status = %s, error_message = %s, completed_at = %s WHERE id = %s", ('failed', str(e), datetime.now(), update_id) ) conn.commit() cur.close() conn.close() except: pass @router.get("/status/{update_id}", response_model=DataUpdateResponse) async def get_fetch_status( update_id: int, db: AsyncSession = Depends(get_db) ): """ 查询数据获取状态 通过 update_id 查询数据获取任务的当前状态 """ result = await db.execute( select(DataUpdate).where(DataUpdate.id == update_id) ) data_update = result.scalar_one_or_none() if not data_update: raise HTTPException(status_code=404, detail="数据更新记录不存在") # ⚠️ 特殊处理:由于 fetched_tables 和 row_counts 现在是 TEXT 类型存储 JSON 字符串 # 我们需要手动解析它们,否则 Pydantic 会校验失败 import json fetched_tables = [] if data_update.fetched_tables: if isinstance(data_update.fetched_tables, str): try: fetched_tables = json.loads(data_update.fetched_tables) except: fetched_tables = [] else: fetched_tables = data_update.fetched_tables row_counts = {} if data_update.row_counts: if isinstance(data_update.row_counts, str): try: row_counts = json.loads(data_update.row_counts) except: row_counts = {} else: row_counts = data_update.row_counts # 手动构建响应对象 return DataUpdateResponse( id=data_update.id, company_id=data_update.company_id, data_source=data_update.data_source, update_type=data_update.update_type, status=data_update.status, progress_message=data_update.progress_message, progress_percentage=data_update.progress_percentage, started_at=data_update.started_at, completed_at=data_update.completed_at, error_message=data_update.error_message, data_start_date=data_update.data_start_date, data_end_date=data_update.data_end_date, fetched_tables=fetched_tables, # 传入解析后的 List row_counts=row_counts # 传入解析后的 Dict ) @router.get("/financial", response_model=FinancialDataResponse) async def get_financial_data( company_id: int, data_source: str, frequency: str = "Annual", db: AsyncSession = Depends(get_db) ): """ 读取财务数据 从数据库中读取指定公司和数据源的所有财务数据 """ try: data = await data_fetcher_service.get_financial_data_from_db( company_id=company_id, data_source=data_source, frequency=frequency, db=db ) # 递归清理 NaN/Infinity 值,这会导致 JSON 序列化失败 import math def clean_nan(obj): if isinstance(obj, float): if math.isnan(obj) or math.isinf(obj): return None return obj elif isinstance(obj, dict): return {k: clean_nan(v) for k, v in obj.items()} elif isinstance(obj, list): return [clean_nan(v) for v in obj] return obj return clean_nan(data) except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) return data @router.get("/sources", response_model=DataSourceListResponse) async def get_available_sources(market: str): """ 获取可用数据源列表 返回指定市场支持的所有数据源 """ sources = data_fetcher_service.get_available_data_sources(market) return DataSourceListResponse( market=market, sources=[DataSourceInfo(**s) for s in sources] ) @router.get("/recent", response_model=List[dict]) async def get_recent_companies( data_source: str = None, db: AsyncSession = Depends(get_db) ): """获取最近更新的公司列表 如果不指定 data_source,将返回所有数据源的最近更新。 """ return await data_fetcher_service.get_recent_companies( data_source=data_source, db=db, limit=20 ) @router.get("/search_local", response_model=List[dict]) async def search_local_companies( q: str, db: AsyncSession = Depends(get_db) ): """ 搜索本地数据库中的公司 """ return await data_fetcher_service.search_companies( query=q, db=db, limit=10 )