321 lines
9.4 KiB
Python
321 lines
9.4 KiB
Python
"""
|
|
数据管理 API 路由
|
|
处理财务数据的检查、获取和读取
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List
|
|
|
|
from app.database import get_db
|
|
from app.schemas import (
|
|
DataCheckRequest,
|
|
DataCheckResponse,
|
|
FetchDataRequest,
|
|
FetchDataResponse,
|
|
DataUpdateResponse,
|
|
FinancialDataResponse,
|
|
DataSourceListResponse,
|
|
DataSourceInfo
|
|
)
|
|
from app.services import data_fetcher_service
|
|
from app.models import DataUpdate
|
|
from sqlalchemy import select
|
|
|
|
router = APIRouter(prefix="/data", tags=["data"])
|
|
|
|
|
|
@router.get("/check", response_model=DataCheckResponse)
|
|
async def check_data_status(
|
|
market: str,
|
|
symbol: str,
|
|
data_source: str,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
检查公司数据状态
|
|
|
|
查询数据库中是否存在指定公司和数据源的数据,
|
|
如果存在则返回最近一次更新的时间和数据范围
|
|
"""
|
|
return await data_fetcher_service.check_data_status(
|
|
market=market,
|
|
symbol=symbol,
|
|
data_source=data_source,
|
|
db=db
|
|
)
|
|
|
|
|
|
@router.post("/fetch", response_model=FetchDataResponse)
|
|
async def fetch_data(
|
|
request: FetchDataRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
获取/更新财务数据
|
|
|
|
从指定数据源获取公司的财务数据并存储到数据库
|
|
此操作在后台异步执行
|
|
"""
|
|
# 1. 创建或获取公司记录
|
|
company = await data_fetcher_service.create_or_get_company(
|
|
market=request.market,
|
|
symbol=request.symbol,
|
|
company_name=request.company_name,
|
|
db=db
|
|
)
|
|
|
|
# 2. 创建数据更新记录
|
|
data_update = await data_fetcher_service.create_data_update_record(
|
|
company_id=company.id,
|
|
data_source=request.data_source,
|
|
update_type='full',
|
|
db=db
|
|
)
|
|
|
|
# 3. 启动后台任务
|
|
background_tasks.add_task(
|
|
fetch_data_background,
|
|
company_id=company.id,
|
|
market=request.market,
|
|
symbol=request.symbol,
|
|
data_source=request.data_source,
|
|
update_id=data_update.id,
|
|
currency=request.currency
|
|
)
|
|
|
|
return FetchDataResponse(
|
|
update_id=data_update.id,
|
|
data_source=request.data_source,
|
|
status='in_progress',
|
|
message=f'正在从 {request.data_source} 获取数据,请稍候...'
|
|
)
|
|
|
|
|
|
def fetch_data_background(
|
|
company_id: int,
|
|
market: str,
|
|
symbol: str,
|
|
data_source: str,
|
|
update_id: int,
|
|
currency: str = None
|
|
):
|
|
"""后台数据获取任务 - 完全同步执行,避免event loop冲突"""
|
|
import sys
|
|
import os
|
|
from pathlib import Path
|
|
import psycopg2
|
|
from datetime import datetime
|
|
|
|
# 确保项目根目录在 Python 路径中
|
|
project_root = Path(__file__).parent.parent.parent.parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
try:
|
|
# 执行数据获取(同步)
|
|
result = data_fetcher_service.fetch_financial_data_sync(
|
|
company_id=company_id,
|
|
market=market,
|
|
symbol=symbol,
|
|
data_source=data_source,
|
|
update_id=update_id,
|
|
currency=currency
|
|
)
|
|
|
|
# 更新数据更新记录 - 使用psycopg2同步连接
|
|
conn = psycopg2.connect(
|
|
host=os.getenv('DB_HOST', '192.168.3.195'),
|
|
user=os.getenv('DB_USER', 'value'),
|
|
password=os.getenv('DB_PASSWORD', 'Value609!'),
|
|
dbname=os.getenv('DB_NAME', 'fa3'),
|
|
port=os.getenv('DB_PORT', '5432')
|
|
)
|
|
|
|
cur = conn.cursor()
|
|
|
|
# 更新记录
|
|
update_sql = """
|
|
UPDATE data_updates
|
|
SET status = %s,
|
|
completed_at = %s,
|
|
error_message = %s,
|
|
data_start_date = %s,
|
|
data_end_date = %s,
|
|
fetched_tables = %s,
|
|
row_counts = %s
|
|
WHERE id = %s
|
|
"""
|
|
|
|
import json
|
|
cur.execute(
|
|
update_sql,
|
|
(
|
|
result['status'],
|
|
datetime.now() if result['status'] == 'completed' else None,
|
|
result.get('error_message'),
|
|
result.get('data_start_date'),
|
|
result.get('data_end_date'),
|
|
json.dumps(result.get('fetched_tables')),
|
|
json.dumps(result.get('row_counts')),
|
|
update_id
|
|
)
|
|
)
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
print(f"❌ 后台任务执行错误: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# 尝试更新错误状态
|
|
try:
|
|
conn = psycopg2.connect(
|
|
host=os.getenv('DB_HOST', '192.168.3.195'),
|
|
user=os.getenv('DB_USER', 'value'),
|
|
password=os.getenv('DB_PASSWORD', 'Value609!'),
|
|
dbname=os.getenv('DB_NAME', 'fa3'),
|
|
port=os.getenv('DB_PORT', '5432')
|
|
)
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"UPDATE data_updates SET status = %s, error_message = %s, completed_at = %s WHERE id = %s",
|
|
('failed', str(e), datetime.now(), update_id)
|
|
)
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
except:
|
|
pass
|
|
|
|
|
|
|
|
@router.get("/status/{update_id}", response_model=DataUpdateResponse)
|
|
async def get_fetch_status(
|
|
update_id: int,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
查询数据获取状态
|
|
|
|
通过 update_id 查询数据获取任务的当前状态
|
|
"""
|
|
result = await db.execute(
|
|
select(DataUpdate).where(DataUpdate.id == update_id)
|
|
)
|
|
data_update = result.scalar_one_or_none()
|
|
|
|
if not data_update:
|
|
raise HTTPException(status_code=404, detail="数据更新记录不存在")
|
|
|
|
# ⚠️ 特殊处理:由于 fetched_tables 和 row_counts 现在是 TEXT 类型存储 JSON 字符串
|
|
# 我们需要手动解析它们,否则 Pydantic 会校验失败
|
|
import json
|
|
|
|
fetched_tables = []
|
|
if data_update.fetched_tables:
|
|
if isinstance(data_update.fetched_tables, str):
|
|
try:
|
|
fetched_tables = json.loads(data_update.fetched_tables)
|
|
except:
|
|
fetched_tables = []
|
|
else:
|
|
fetched_tables = data_update.fetched_tables
|
|
|
|
row_counts = {}
|
|
if data_update.row_counts:
|
|
if isinstance(data_update.row_counts, str):
|
|
try:
|
|
row_counts = json.loads(data_update.row_counts)
|
|
except:
|
|
row_counts = {}
|
|
else:
|
|
row_counts = data_update.row_counts
|
|
|
|
# 手动构建响应对象
|
|
return DataUpdateResponse(
|
|
id=data_update.id,
|
|
company_id=data_update.company_id,
|
|
data_source=data_update.data_source,
|
|
update_type=data_update.update_type,
|
|
status=data_update.status,
|
|
progress_message=data_update.progress_message,
|
|
progress_percentage=data_update.progress_percentage,
|
|
started_at=data_update.started_at,
|
|
completed_at=data_update.completed_at,
|
|
error_message=data_update.error_message,
|
|
data_start_date=data_update.data_start_date,
|
|
data_end_date=data_update.data_end_date,
|
|
fetched_tables=fetched_tables, # 传入解析后的 List
|
|
row_counts=row_counts # 传入解析后的 Dict
|
|
)
|
|
|
|
|
|
@router.get("/financial", response_model=FinancialDataResponse)
|
|
async def get_financial_data(
|
|
company_id: int,
|
|
data_source: str,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
读取财务数据
|
|
|
|
从数据库中读取指定公司和数据源的所有财务数据
|
|
"""
|
|
try:
|
|
data = await data_fetcher_service.get_financial_data_from_db(
|
|
company_id=company_id,
|
|
data_source=data_source,
|
|
db=db
|
|
)
|
|
|
|
# 递归清理 NaN/Infinity 值,这会导致 JSON 序列化失败
|
|
import math
|
|
def clean_nan(obj):
|
|
if isinstance(obj, float):
|
|
if math.isnan(obj) or math.isinf(obj):
|
|
return None
|
|
return obj
|
|
elif isinstance(obj, dict):
|
|
return {k: clean_nan(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [clean_nan(v) for v in obj]
|
|
return obj
|
|
|
|
return clean_nan(data)
|
|
except Exception as e:
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
return data
|
|
|
|
|
|
@router.get("/sources", response_model=DataSourceListResponse)
|
|
async def get_available_sources(market: str):
|
|
"""
|
|
获取可用数据源列表
|
|
|
|
返回指定市场支持的所有数据源
|
|
"""
|
|
sources = data_fetcher_service.get_available_data_sources(market)
|
|
|
|
return DataSourceListResponse(
|
|
market=market,
|
|
sources=[DataSourceInfo(**s) for s in sources]
|
|
)
|
|
|
|
@router.get("/recent", response_model=List[dict])
|
|
async def get_recent_companies(
|
|
data_source: str = "Bloomberg",
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""获取最近更新的公司列表"""
|
|
return await data_fetcher_service.get_recent_companies(
|
|
data_source=data_source,
|
|
db=db,
|
|
limit=20
|
|
)
|