324 lines
9.5 KiB
Python
324 lines
9.5 KiB
Python
"""
|
||
数据管理 API 路由
|
||
处理财务数据的检查、获取和读取
|
||
"""
|
||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from typing import List
|
||
|
||
from app.database import get_db
|
||
from app.schemas import (
|
||
DataCheckRequest,
|
||
DataCheckResponse,
|
||
FetchDataRequest,
|
||
FetchDataResponse,
|
||
DataUpdateResponse,
|
||
FinancialDataResponse,
|
||
DataSourceListResponse,
|
||
DataSourceInfo
|
||
)
|
||
from app.services import data_fetcher_service
|
||
from app.models import DataUpdate
|
||
from sqlalchemy import select
|
||
|
||
router = APIRouter(prefix="/data", tags=["data"])
|
||
|
||
|
||
@router.get("/check", response_model=DataCheckResponse)
|
||
async def check_data_status(
|
||
market: str,
|
||
symbol: str,
|
||
data_source: str,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
检查公司数据状态
|
||
|
||
查询数据库中是否存在指定公司和数据源的数据,
|
||
如果存在则返回最近一次更新的时间和数据范围
|
||
"""
|
||
return await data_fetcher_service.check_data_status(
|
||
market=market,
|
||
symbol=symbol,
|
||
data_source=data_source,
|
||
db=db
|
||
)
|
||
|
||
|
||
@router.post("/fetch", response_model=FetchDataResponse)
|
||
async def fetch_data(
|
||
request: FetchDataRequest,
|
||
background_tasks: BackgroundTasks,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
获取/更新财务数据
|
||
|
||
从指定数据源获取公司的财务数据并存储到数据库
|
||
此操作在后台异步执行
|
||
"""
|
||
# 1. 创建或获取公司记录
|
||
company = await data_fetcher_service.create_or_get_company(
|
||
market=request.market,
|
||
symbol=request.symbol,
|
||
company_name=request.company_name,
|
||
db=db
|
||
)
|
||
|
||
# 2. 创建数据更新记录
|
||
data_update = await data_fetcher_service.create_data_update_record(
|
||
company_id=company.id,
|
||
data_source=request.data_source,
|
||
update_type='full',
|
||
db=db
|
||
)
|
||
|
||
# 3. 启动后台任务
|
||
background_tasks.add_task(
|
||
fetch_data_background,
|
||
company_id=company.id,
|
||
market=request.market,
|
||
symbol=request.symbol,
|
||
data_source=request.data_source,
|
||
update_id=data_update.id,
|
||
currency=request.currency
|
||
)
|
||
|
||
return FetchDataResponse(
|
||
update_id=data_update.id,
|
||
data_source=request.data_source,
|
||
status='in_progress',
|
||
message=f'正在从 {request.data_source} 获取数据,请稍候...'
|
||
)
|
||
|
||
|
||
def fetch_data_background(
|
||
company_id: int,
|
||
market: str,
|
||
symbol: str,
|
||
data_source: str,
|
||
update_id: int,
|
||
currency: str = None
|
||
):
|
||
"""后台数据获取任务 - 完全同步执行,避免event loop冲突"""
|
||
import sys
|
||
import os
|
||
from pathlib import Path
|
||
import psycopg2
|
||
from datetime import datetime
|
||
|
||
# 确保项目根目录在 Python 路径中
|
||
project_root = Path(__file__).parent.parent.parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
try:
|
||
# 执行数据获取(同步)
|
||
result = data_fetcher_service.fetch_financial_data_sync(
|
||
company_id=company_id,
|
||
market=market,
|
||
symbol=symbol,
|
||
data_source=data_source,
|
||
update_id=update_id,
|
||
currency=currency
|
||
)
|
||
|
||
# 更新数据更新记录 - 使用psycopg2同步连接
|
||
conn = psycopg2.connect(
|
||
host=os.getenv('DB_HOST', '192.168.3.195'),
|
||
user=os.getenv('DB_USER', 'value'),
|
||
password=os.getenv('DB_PASSWORD', 'Value609!'),
|
||
dbname=os.getenv('DB_NAME', 'fa3'),
|
||
port=os.getenv('DB_PORT', '5432')
|
||
)
|
||
|
||
cur = conn.cursor()
|
||
|
||
# 更新记录
|
||
update_sql = """
|
||
UPDATE data_updates
|
||
SET status = %s,
|
||
completed_at = %s,
|
||
error_message = %s,
|
||
data_start_date = %s,
|
||
data_end_date = %s,
|
||
fetched_tables = %s,
|
||
row_counts = %s
|
||
WHERE id = %s
|
||
"""
|
||
|
||
import json
|
||
cur.execute(
|
||
update_sql,
|
||
(
|
||
result['status'],
|
||
datetime.now() if result['status'] == 'completed' else None,
|
||
result.get('error_message'),
|
||
result.get('data_start_date'),
|
||
result.get('data_end_date'),
|
||
json.dumps(result.get('fetched_tables')),
|
||
json.dumps(result.get('row_counts')),
|
||
update_id
|
||
)
|
||
)
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
|
||
except Exception as e:
|
||
print(f"❌ 后台任务执行错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 尝试更新错误状态
|
||
try:
|
||
conn = psycopg2.connect(
|
||
host=os.getenv('DB_HOST', '192.168.3.195'),
|
||
user=os.getenv('DB_USER', 'value'),
|
||
password=os.getenv('DB_PASSWORD', 'Value609!'),
|
||
dbname=os.getenv('DB_NAME', 'fa3'),
|
||
port=os.getenv('DB_PORT', '5432')
|
||
)
|
||
cur = conn.cursor()
|
||
cur.execute(
|
||
"UPDATE data_updates SET status = %s, error_message = %s, completed_at = %s WHERE id = %s",
|
||
('failed', str(e), datetime.now(), update_id)
|
||
)
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
except:
|
||
pass
|
||
|
||
|
||
|
||
@router.get("/status/{update_id}", response_model=DataUpdateResponse)
|
||
async def get_fetch_status(
|
||
update_id: int,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
查询数据获取状态
|
||
|
||
通过 update_id 查询数据获取任务的当前状态
|
||
"""
|
||
result = await db.execute(
|
||
select(DataUpdate).where(DataUpdate.id == update_id)
|
||
)
|
||
data_update = result.scalar_one_or_none()
|
||
|
||
if not data_update:
|
||
raise HTTPException(status_code=404, detail="数据更新记录不存在")
|
||
|
||
# ⚠️ 特殊处理:由于 fetched_tables 和 row_counts 现在是 TEXT 类型存储 JSON 字符串
|
||
# 我们需要手动解析它们,否则 Pydantic 会校验失败
|
||
import json
|
||
|
||
fetched_tables = []
|
||
if data_update.fetched_tables:
|
||
if isinstance(data_update.fetched_tables, str):
|
||
try:
|
||
fetched_tables = json.loads(data_update.fetched_tables)
|
||
except:
|
||
fetched_tables = []
|
||
else:
|
||
fetched_tables = data_update.fetched_tables
|
||
|
||
row_counts = {}
|
||
if data_update.row_counts:
|
||
if isinstance(data_update.row_counts, str):
|
||
try:
|
||
row_counts = json.loads(data_update.row_counts)
|
||
except:
|
||
row_counts = {}
|
||
else:
|
||
row_counts = data_update.row_counts
|
||
|
||
# 手动构建响应对象
|
||
return DataUpdateResponse(
|
||
id=data_update.id,
|
||
company_id=data_update.company_id,
|
||
data_source=data_update.data_source,
|
||
update_type=data_update.update_type,
|
||
status=data_update.status,
|
||
progress_message=data_update.progress_message,
|
||
progress_percentage=data_update.progress_percentage,
|
||
started_at=data_update.started_at,
|
||
completed_at=data_update.completed_at,
|
||
error_message=data_update.error_message,
|
||
data_start_date=data_update.data_start_date,
|
||
data_end_date=data_update.data_end_date,
|
||
fetched_tables=fetched_tables, # 传入解析后的 List
|
||
row_counts=row_counts # 传入解析后的 Dict
|
||
)
|
||
|
||
|
||
@router.get("/financial", response_model=FinancialDataResponse)
|
||
async def get_financial_data(
|
||
company_id: int,
|
||
data_source: str,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
读取财务数据
|
||
|
||
从数据库中读取指定公司和数据源的所有财务数据
|
||
"""
|
||
try:
|
||
data = await data_fetcher_service.get_financial_data_from_db(
|
||
company_id=company_id,
|
||
data_source=data_source,
|
||
db=db
|
||
)
|
||
|
||
# 递归清理 NaN/Infinity 值,这会导致 JSON 序列化失败
|
||
import math
|
||
def clean_nan(obj):
|
||
if isinstance(obj, float):
|
||
if math.isnan(obj) or math.isinf(obj):
|
||
return None
|
||
return obj
|
||
elif isinstance(obj, dict):
|
||
return {k: clean_nan(v) for k, v in obj.items()}
|
||
elif isinstance(obj, list):
|
||
return [clean_nan(v) for v in obj]
|
||
return obj
|
||
|
||
return clean_nan(data)
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
return data
|
||
|
||
|
||
@router.get("/sources", response_model=DataSourceListResponse)
|
||
async def get_available_sources(market: str):
|
||
"""
|
||
获取可用数据源列表
|
||
|
||
返回指定市场支持的所有数据源
|
||
"""
|
||
sources = data_fetcher_service.get_available_data_sources(market)
|
||
|
||
return DataSourceListResponse(
|
||
market=market,
|
||
sources=[DataSourceInfo(**s) for s in sources]
|
||
)
|
||
|
||
@router.get("/recent", response_model=List[dict])
|
||
async def get_recent_companies(
|
||
data_source: str = None,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取最近更新的公司列表
|
||
|
||
如果不指定 data_source,将返回所有数据源的最近更新。
|
||
"""
|
||
return await data_fetcher_service.get_recent_companies(
|
||
data_source=data_source,
|
||
db=db,
|
||
limit=20
|
||
)
|