FA3-Datafetch/backend/app/api/data_routes.py
2026-01-12 09:33:52 +08:00

321 lines
9.4 KiB
Python

"""
数据管理 API 路由
处理财务数据的检查、获取和读取
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.database import get_db
from app.schemas import (
DataCheckRequest,
DataCheckResponse,
FetchDataRequest,
FetchDataResponse,
DataUpdateResponse,
FinancialDataResponse,
DataSourceListResponse,
DataSourceInfo
)
from app.services import data_fetcher_service
from app.models import DataUpdate
from sqlalchemy import select
router = APIRouter(prefix="/data", tags=["data"])
@router.get("/check", response_model=DataCheckResponse)
async def check_data_status(
market: str,
symbol: str,
data_source: str,
db: AsyncSession = Depends(get_db)
):
"""
检查公司数据状态
查询数据库中是否存在指定公司和数据源的数据,
如果存在则返回最近一次更新的时间和数据范围
"""
return await data_fetcher_service.check_data_status(
market=market,
symbol=symbol,
data_source=data_source,
db=db
)
@router.post("/fetch", response_model=FetchDataResponse)
async def fetch_data(
request: FetchDataRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
获取/更新财务数据
从指定数据源获取公司的财务数据并存储到数据库
此操作在后台异步执行
"""
# 1. 创建或获取公司记录
company = await data_fetcher_service.create_or_get_company(
market=request.market,
symbol=request.symbol,
company_name=request.company_name,
db=db
)
# 2. 创建数据更新记录
data_update = await data_fetcher_service.create_data_update_record(
company_id=company.id,
data_source=request.data_source,
update_type='full',
db=db
)
# 3. 启动后台任务
background_tasks.add_task(
fetch_data_background,
company_id=company.id,
market=request.market,
symbol=request.symbol,
data_source=request.data_source,
update_id=data_update.id,
currency=request.currency
)
return FetchDataResponse(
update_id=data_update.id,
data_source=request.data_source,
status='in_progress',
message=f'正在从 {request.data_source} 获取数据,请稍候...'
)
def fetch_data_background(
company_id: int,
market: str,
symbol: str,
data_source: str,
update_id: int,
currency: str = None
):
"""后台数据获取任务 - 完全同步执行,避免event loop冲突"""
import sys
import os
from pathlib import Path
import psycopg2
from datetime import datetime
# 确保项目根目录在 Python 路径中
project_root = Path(__file__).parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
try:
# 执行数据获取(同步)
result = data_fetcher_service.fetch_financial_data_sync(
company_id=company_id,
market=market,
symbol=symbol,
data_source=data_source,
update_id=update_id,
currency=currency
)
# 更新数据更新记录 - 使用psycopg2同步连接
conn = psycopg2.connect(
host=os.getenv('DB_HOST', '192.168.3.195'),
user=os.getenv('DB_USER', 'value'),
password=os.getenv('DB_PASSWORD', 'Value609!'),
dbname=os.getenv('DB_NAME', 'fa3'),
port=os.getenv('DB_PORT', '5432')
)
cur = conn.cursor()
# 更新记录
update_sql = """
UPDATE data_updates
SET status = %s,
completed_at = %s,
error_message = %s,
data_start_date = %s,
data_end_date = %s,
fetched_tables = %s,
row_counts = %s
WHERE id = %s
"""
import json
cur.execute(
update_sql,
(
result['status'],
datetime.now() if result['status'] == 'completed' else None,
result.get('error_message'),
result.get('data_start_date'),
result.get('data_end_date'),
json.dumps(result.get('fetched_tables')),
json.dumps(result.get('row_counts')),
update_id
)
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"❌ 后台任务执行错误: {e}")
import traceback
traceback.print_exc()
# 尝试更新错误状态
try:
conn = psycopg2.connect(
host=os.getenv('DB_HOST', '192.168.3.195'),
user=os.getenv('DB_USER', 'value'),
password=os.getenv('DB_PASSWORD', 'Value609!'),
dbname=os.getenv('DB_NAME', 'fa3'),
port=os.getenv('DB_PORT', '5432')
)
cur = conn.cursor()
cur.execute(
"UPDATE data_updates SET status = %s, error_message = %s, completed_at = %s WHERE id = %s",
('failed', str(e), datetime.now(), update_id)
)
conn.commit()
cur.close()
conn.close()
except:
pass
@router.get("/status/{update_id}", response_model=DataUpdateResponse)
async def get_fetch_status(
update_id: int,
db: AsyncSession = Depends(get_db)
):
"""
查询数据获取状态
通过 update_id 查询数据获取任务的当前状态
"""
result = await db.execute(
select(DataUpdate).where(DataUpdate.id == update_id)
)
data_update = result.scalar_one_or_none()
if not data_update:
raise HTTPException(status_code=404, detail="数据更新记录不存在")
# ⚠️ 特殊处理:由于 fetched_tables 和 row_counts 现在是 TEXT 类型存储 JSON 字符串
# 我们需要手动解析它们,否则 Pydantic 会校验失败
import json
fetched_tables = []
if data_update.fetched_tables:
if isinstance(data_update.fetched_tables, str):
try:
fetched_tables = json.loads(data_update.fetched_tables)
except:
fetched_tables = []
else:
fetched_tables = data_update.fetched_tables
row_counts = {}
if data_update.row_counts:
if isinstance(data_update.row_counts, str):
try:
row_counts = json.loads(data_update.row_counts)
except:
row_counts = {}
else:
row_counts = data_update.row_counts
# 手动构建响应对象
return DataUpdateResponse(
id=data_update.id,
company_id=data_update.company_id,
data_source=data_update.data_source,
update_type=data_update.update_type,
status=data_update.status,
progress_message=data_update.progress_message,
progress_percentage=data_update.progress_percentage,
started_at=data_update.started_at,
completed_at=data_update.completed_at,
error_message=data_update.error_message,
data_start_date=data_update.data_start_date,
data_end_date=data_update.data_end_date,
fetched_tables=fetched_tables, # 传入解析后的 List
row_counts=row_counts # 传入解析后的 Dict
)
@router.get("/financial", response_model=FinancialDataResponse)
async def get_financial_data(
company_id: int,
data_source: str,
db: AsyncSession = Depends(get_db)
):
"""
读取财务数据
从数据库中读取指定公司和数据源的所有财务数据
"""
try:
data = await data_fetcher_service.get_financial_data_from_db(
company_id=company_id,
data_source=data_source,
db=db
)
# 递归清理 NaN/Infinity 值,这会导致 JSON 序列化失败
import math
def clean_nan(obj):
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return None
return obj
elif isinstance(obj, dict):
return {k: clean_nan(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [clean_nan(v) for v in obj]
return obj
return clean_nan(data)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
return data
@router.get("/sources", response_model=DataSourceListResponse)
async def get_available_sources(market: str):
"""
获取可用数据源列表
返回指定市场支持的所有数据源
"""
sources = data_fetcher_service.get_available_data_sources(market)
return DataSourceListResponse(
market=market,
sources=[DataSourceInfo(**s) for s in sources]
)
@router.get("/recent", response_model=List[dict])
async def get_recent_companies(
data_source: str = "Bloomberg",
db: AsyncSession = Depends(get_db)
):
"""获取最近更新的公司列表"""
return await data_fetcher_service.get_recent_companies(
data_source=data_source,
db=db,
limit=20
)