250 lines
9.2 KiB
Python
250 lines
9.2 KiB
Python
"""
|
|
API router for financial data (Tushare for China market)
|
|
"""
|
|
import json
|
|
import os
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, List
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
import os
|
|
|
|
from app.core.config import settings
|
|
from app.schemas.financial import BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse
|
|
from app.services.tushare_client import TushareClient
|
|
from app.services.company_profile_client import CompanyProfileClient
|
|
|
|
router = APIRouter()
|
|
|
|
# Load metric config from file (project root is repo root, not backend/)
|
|
# routers/ -> app/ -> backend/ -> repo root
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
|
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
|
|
|
|
|
def _load_json(path: str) -> Dict:
|
|
if not os.path.exists(path):
|
|
return {}
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
@router.get("/config", response_model=FinancialConfigResponse)
|
|
async def get_financial_config():
|
|
data = _load_json(FINANCIAL_CONFIG_PATH)
|
|
api_groups = data.get("api_groups", {})
|
|
return FinancialConfigResponse(api_groups=api_groups)
|
|
|
|
|
|
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
|
|
async def get_china_financials(
|
|
ts_code: str,
|
|
years: int = Query(5, ge=1, le=15),
|
|
):
|
|
# Load Tushare token
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
token = (
|
|
os.environ.get("TUSHARE_TOKEN")
|
|
or settings.TUSHARE_TOKEN
|
|
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
|
)
|
|
if not token:
|
|
raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key")
|
|
|
|
# Load metric config
|
|
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
|
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
|
|
|
|
client = TushareClient(token=token)
|
|
|
|
# Meta tracking
|
|
started_real = datetime.now(timezone.utc)
|
|
started = time.perf_counter_ns()
|
|
api_calls_total = 0
|
|
api_calls_by_group: Dict[str, int] = {}
|
|
steps: List[StepRecord] = []
|
|
current_action = "初始化"
|
|
|
|
# Get company name from stock_basic API
|
|
company_name = None
|
|
try:
|
|
basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
|
api_calls_total += 1
|
|
if basic_data and len(basic_data) > 0:
|
|
company_name = basic_data[0].get("name")
|
|
except Exception:
|
|
# If getting company name fails, continue without it
|
|
pass
|
|
|
|
# Collect series per metric key
|
|
series: Dict[str, List[Dict]] = {}
|
|
|
|
# Helper to store year-value pairs while keeping most recent per year
|
|
def _merge_year_value(key: str, year: str, value):
|
|
arr = series.setdefault(key, [])
|
|
# upsert by year
|
|
for item in arr:
|
|
if item["year"] == year:
|
|
item["value"] = value
|
|
return
|
|
arr.append({"year": year, "value": value})
|
|
|
|
# Query each API group we care
|
|
errors: Dict[str, str] = {}
|
|
for group_name, metrics in api_groups.items():
|
|
step = StepRecord(
|
|
name=f"拉取 {group_name}",
|
|
start_ts=started_real.isoformat(),
|
|
status="running",
|
|
)
|
|
steps.append(step)
|
|
current_action = step.name
|
|
if not metrics:
|
|
continue
|
|
api_name = metrics[0].get("api") or group_name
|
|
fields = list({m.get("tushareParam") for m in metrics if m.get("tushareParam")})
|
|
if not fields:
|
|
continue
|
|
|
|
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
|
try:
|
|
data_rows = await client.query(api_name=api_name, params={"ts_code": ts_code, "limit": 5000}, fields=None)
|
|
api_calls_total += 1
|
|
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
|
|
except Exception as e:
|
|
step.status = "error"
|
|
step.error = str(e)
|
|
step.end_ts = datetime.now(timezone.utc).isoformat()
|
|
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
errors[group_name] = str(e)
|
|
continue
|
|
|
|
tmp: Dict[str, Dict] = {}
|
|
for row in data_rows:
|
|
date_val = row.get(date_field)
|
|
if not date_val:
|
|
continue
|
|
year = str(date_val)[:4]
|
|
existing = tmp.get(year)
|
|
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
|
tmp[year] = row
|
|
for metric in metrics:
|
|
key = metric.get("tushareParam")
|
|
if not key:
|
|
continue
|
|
for year, row in tmp.items():
|
|
_merge_year_value(key, year, row.get(key))
|
|
step.status = "done"
|
|
step.end_ts = datetime.now(timezone.utc).isoformat()
|
|
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
|
|
finished_real = datetime.now(timezone.utc)
|
|
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
|
|
if not series:
|
|
# If nothing succeeded, expose partial error info
|
|
raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors})
|
|
|
|
# Truncate years and sort
|
|
for key, arr in series.items():
|
|
# Deduplicate and sort desc by year, then cut to requested years, and return asc
|
|
uniq = {item["year"]: item for item in arr}
|
|
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
|
|
arr_limited = arr_sorted_desc[:years]
|
|
arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year
|
|
series[key] = arr_sorted
|
|
|
|
meta = FinancialMeta(
|
|
started_at=started_real.isoformat(),
|
|
finished_at=finished_real.isoformat(),
|
|
elapsed_ms=elapsed_ms,
|
|
api_calls_total=api_calls_total,
|
|
api_calls_by_group=api_calls_by_group,
|
|
current_action=None,
|
|
steps=steps,
|
|
)
|
|
|
|
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
|
|
|
|
|
|
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
|
|
async def get_company_profile(
|
|
ts_code: str,
|
|
company_name: str = Query(None, description="Company name for better context"),
|
|
):
|
|
"""
|
|
Get company profile for a company using Gemini AI (non-streaming, single response)
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"[API] Company profile requested for {ts_code}")
|
|
|
|
# Load config
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
|
|
api_key = gemini_cfg.get("api_key")
|
|
|
|
if not api_key:
|
|
logger.error("[API] Gemini API key not configured")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
|
|
)
|
|
|
|
client = CompanyProfileClient(api_key=api_key)
|
|
|
|
# Get company name from ts_code if not provided
|
|
if not company_name:
|
|
logger.info(f"[API] Fetching company name for {ts_code}")
|
|
# Try to get from stock_basic API
|
|
try:
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
token = (
|
|
os.environ.get("TUSHARE_TOKEN")
|
|
or settings.TUSHARE_TOKEN
|
|
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
|
)
|
|
if token:
|
|
from app.services.tushare_client import TushareClient
|
|
tushare_client = TushareClient(token=token)
|
|
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
|
if basic_data and len(basic_data) > 0:
|
|
company_name = basic_data[0].get("name", ts_code)
|
|
logger.info(f"[API] Got company name: {company_name}")
|
|
else:
|
|
company_name = ts_code
|
|
else:
|
|
company_name = ts_code
|
|
except Exception as e:
|
|
logger.warning(f"[API] Failed to get company name: {e}")
|
|
company_name = ts_code
|
|
|
|
logger.info(f"[API] Generating profile for {company_name}")
|
|
|
|
# Generate profile using non-streaming API
|
|
result = await client.generate_profile(
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
financial_data=None
|
|
)
|
|
|
|
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
|
|
|
|
return CompanyProfileResponse(
|
|
ts_code=ts_code,
|
|
company_name=company_name,
|
|
content=result.get("content", ""),
|
|
model=result.get("model", "gemini-2.5-flash"),
|
|
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
|
elapsed_ms=result.get("elapsed_ms", 0),
|
|
success=result.get("success", False),
|
|
error=result.get("error")
|
|
)
|