FA3-Datafetch/backend/app/api/data_routes.py

344 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据管理 API 路由
处理财务数据的检查、获取和读取
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.database import get_db
from app.schemas import (
DataCheckRequest,
DataCheckResponse,
FetchDataRequest,
FetchDataResponse,
DataUpdateResponse,
FinancialDataResponse,
DataSourceListResponse,
DataSourceInfo
)
from app.services import data_fetcher_service
from app.models import DataUpdate
from sqlalchemy import select
router = APIRouter(prefix="/data", tags=["data"])
@router.get("/check", response_model=DataCheckResponse)
async def check_data_status(
market: str,
symbol: str,
data_source: str,
db: AsyncSession = Depends(get_db)
):
"""
检查公司数据状态
查询数据库中是否存在指定公司和数据源的数据,
如果存在则返回最近一次更新的时间和数据范围
"""
return await data_fetcher_service.check_data_status(
market=market,
symbol=symbol,
data_source=data_source,
db=db
)
@router.post("/fetch", response_model=FetchDataResponse)
async def fetch_data(
request: FetchDataRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
获取/更新财务数据
从指定数据源获取公司的财务数据并存储到数据库
此操作在后台异步执行
"""
# 1. 创建或获取公司记录
company = await data_fetcher_service.create_or_get_company(
market=request.market,
symbol=request.symbol,
company_name=request.company_name,
db=db
)
# 2. 创建数据更新记录
data_update = await data_fetcher_service.create_data_update_record(
company_id=company.id,
data_source=request.data_source,
update_type='full',
db=db
)
# 3. 启动后台任务
background_tasks.add_task(
fetch_data_background,
company_id=company.id,
market=request.market,
symbol=request.symbol,
data_source=request.data_source,
update_id=data_update.id,
currency=request.currency,
frequency=request.frequency
)
return FetchDataResponse(
update_id=data_update.id,
data_source=request.data_source,
status='in_progress',
message=f'正在从 {request.data_source} 获取数据,请稍候...'
)
def fetch_data_background(
company_id: int,
market: str,
symbol: str,
data_source: str,
update_id: int,
currency: str = None,
frequency: str = "Annual"
):
"""后台数据获取任务 - 完全同步执行,避免event loop冲突"""
import sys
import os
from pathlib import Path
import psycopg2
from datetime import datetime
# 确保项目根目录在 Python 路径中
project_root = Path(__file__).parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
try:
# 执行数据获取(同步)
result = data_fetcher_service.fetch_financial_data_sync(
company_id=company_id,
market=market,
symbol=symbol,
data_source=data_source,
update_id=update_id,
currency=currency,
frequency=frequency
)
# 更新数据更新记录 - 使用psycopg2同步连接
conn = psycopg2.connect(
host=os.getenv('DB_HOST', '192.168.3.195'),
user=os.getenv('DB_USER', 'value'),
password=os.getenv('DB_PASSWORD', 'Value609!'),
dbname=os.getenv('DB_NAME', 'fa3'),
port=os.getenv('DB_PORT', '5432')
)
cur = conn.cursor()
# 更新记录
update_sql = """
UPDATE data_updates
SET status = %s,
completed_at = %s,
error_message = %s,
data_start_date = %s,
data_end_date = %s,
fetched_tables = %s,
row_counts = %s
WHERE id = %s
"""
import json
cur.execute(
update_sql,
(
result['status'],
datetime.now() if result['status'] == 'completed' else None,
result.get('error_message'),
result.get('data_start_date'),
result.get('data_end_date'),
json.dumps(result.get('fetched_tables')),
json.dumps(result.get('row_counts')),
update_id
)
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"❌ 后台任务执行错误: {e}")
import traceback
traceback.print_exc()
# 尝试更新错误状态
try:
conn = psycopg2.connect(
host=os.getenv('DB_HOST', '192.168.3.195'),
user=os.getenv('DB_USER', 'value'),
password=os.getenv('DB_PASSWORD', 'Value609!'),
dbname=os.getenv('DB_NAME', 'fa3'),
port=os.getenv('DB_PORT', '5432')
)
cur = conn.cursor()
cur.execute(
"UPDATE data_updates SET status = %s, error_message = %s, completed_at = %s WHERE id = %s",
('failed', str(e), datetime.now(), update_id)
)
conn.commit()
cur.close()
conn.close()
except:
pass
@router.get("/status/{update_id}", response_model=DataUpdateResponse)
async def get_fetch_status(
update_id: int,
db: AsyncSession = Depends(get_db)
):
"""
查询数据获取状态
通过 update_id 查询数据获取任务的当前状态
"""
result = await db.execute(
select(DataUpdate).where(DataUpdate.id == update_id)
)
data_update = result.scalar_one_or_none()
if not data_update:
raise HTTPException(status_code=404, detail="数据更新记录不存在")
# ⚠️ 特殊处理:由于 fetched_tables 和 row_counts 现在是 TEXT 类型存储 JSON 字符串
# 我们需要手动解析它们,否则 Pydantic 会校验失败
import json
fetched_tables = []
if data_update.fetched_tables:
if isinstance(data_update.fetched_tables, str):
try:
fetched_tables = json.loads(data_update.fetched_tables)
except:
fetched_tables = []
else:
fetched_tables = data_update.fetched_tables
row_counts = {}
if data_update.row_counts:
if isinstance(data_update.row_counts, str):
try:
row_counts = json.loads(data_update.row_counts)
except:
row_counts = {}
else:
row_counts = data_update.row_counts
# 手动构建响应对象
return DataUpdateResponse(
id=data_update.id,
company_id=data_update.company_id,
data_source=data_update.data_source,
update_type=data_update.update_type,
status=data_update.status,
progress_message=data_update.progress_message,
progress_percentage=data_update.progress_percentage,
started_at=data_update.started_at,
completed_at=data_update.completed_at,
error_message=data_update.error_message,
data_start_date=data_update.data_start_date,
data_end_date=data_update.data_end_date,
fetched_tables=fetched_tables, # 传入解析后的 List
row_counts=row_counts # 传入解析后的 Dict
)
@router.get("/financial", response_model=FinancialDataResponse)
async def get_financial_data(
company_id: int,
data_source: str,
frequency: str = "Annual",
db: AsyncSession = Depends(get_db)
):
"""
读取财务数据
从数据库中读取指定公司和数据源的所有财务数据
"""
try:
data = await data_fetcher_service.get_financial_data_from_db(
company_id=company_id,
data_source=data_source,
frequency=frequency,
db=db
)
# 递归清理 NaN/Infinity 值,这会导致 JSON 序列化失败
import math
def clean_nan(obj):
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return None
return obj
elif isinstance(obj, dict):
return {k: clean_nan(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [clean_nan(v) for v in obj]
return obj
return clean_nan(data)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
return data
@router.get("/sources", response_model=DataSourceListResponse)
async def get_available_sources(market: str):
"""
获取可用数据源列表
返回指定市场支持的所有数据源
"""
sources = data_fetcher_service.get_available_data_sources(market)
return DataSourceListResponse(
market=market,
sources=[DataSourceInfo(**s) for s in sources]
)
@router.get("/recent", response_model=List[dict])
async def get_recent_companies(
data_source: str = None,
db: AsyncSession = Depends(get_db)
):
"""获取最近更新的公司列表
如果不指定 data_source将返回所有数据源的最近更新。
"""
return await data_fetcher_service.get_recent_companies(
data_source=data_source,
db=db,
limit=20
)
@router.get("/search_local", response_model=List[dict])
async def search_local_companies(
q: str,
db: AsyncSession = Depends(get_db)
):
"""
搜索本地数据库中的公司
"""
return await data_fetcher_service.search_companies(
query=q,
db=db,
limit=10
)