Fundamental_Analysis/backend/app/routers/financial.py

250 lines
9.2 KiB
Python

"""
API router for financial data (Tushare for China market)
"""
import json
import os
import time
from datetime import datetime, timezone
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
import os
from app.core.config import settings
from app.schemas.financial import BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse
from app.services.tushare_client import TushareClient
from app.services.company_profile_client import CompanyProfileClient
router = APIRouter()
# Load metric config from file (project root is repo root, not backend/)
# routers/ -> app/ -> backend/ -> repo root
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
def _load_json(path: str) -> Dict:
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
@router.get("/config", response_model=FinancialConfigResponse)
async def get_financial_config():
data = _load_json(FINANCIAL_CONFIG_PATH)
api_groups = data.get("api_groups", {})
return FinancialConfigResponse(api_groups=api_groups)
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
async def get_china_financials(
ts_code: str,
years: int = Query(5, ge=1, le=15),
):
# Load Tushare token
base_cfg = _load_json(BASE_CONFIG_PATH)
token = (
os.environ.get("TUSHARE_TOKEN")
or settings.TUSHARE_TOKEN
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
)
if not token:
raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key")
# Load metric config
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
client = TushareClient(token=token)
# Meta tracking
started_real = datetime.now(timezone.utc)
started = time.perf_counter_ns()
api_calls_total = 0
api_calls_by_group: Dict[str, int] = {}
steps: List[StepRecord] = []
current_action = "初始化"
# Get company name from stock_basic API
company_name = None
try:
basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
api_calls_total += 1
if basic_data and len(basic_data) > 0:
company_name = basic_data[0].get("name")
except Exception:
# If getting company name fails, continue without it
pass
# Collect series per metric key
series: Dict[str, List[Dict]] = {}
# Helper to store year-value pairs while keeping most recent per year
def _merge_year_value(key: str, year: str, value):
arr = series.setdefault(key, [])
# upsert by year
for item in arr:
if item["year"] == year:
item["value"] = value
return
arr.append({"year": year, "value": value})
# Query each API group we care
errors: Dict[str, str] = {}
for group_name, metrics in api_groups.items():
step = StepRecord(
name=f"拉取 {group_name}",
start_ts=started_real.isoformat(),
status="running",
)
steps.append(step)
current_action = step.name
if not metrics:
continue
api_name = metrics[0].get("api") or group_name
fields = list({m.get("tushareParam") for m in metrics if m.get("tushareParam")})
if not fields:
continue
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
try:
data_rows = await client.query(api_name=api_name, params={"ts_code": ts_code, "limit": 5000}, fields=None)
api_calls_total += 1
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
except Exception as e:
step.status = "error"
step.error = str(e)
step.end_ts = datetime.now(timezone.utc).isoformat()
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
errors[group_name] = str(e)
continue
tmp: Dict[str, Dict] = {}
for row in data_rows:
date_val = row.get(date_field)
if not date_val:
continue
year = str(date_val)[:4]
existing = tmp.get(year)
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
tmp[year] = row
for metric in metrics:
key = metric.get("tushareParam")
if not key:
continue
for year, row in tmp.items():
_merge_year_value(key, year, row.get(key))
step.status = "done"
step.end_ts = datetime.now(timezone.utc).isoformat()
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
finished_real = datetime.now(timezone.utc)
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
if not series:
# If nothing succeeded, expose partial error info
raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors})
# Truncate years and sort
for key, arr in series.items():
# Deduplicate and sort desc by year, then cut to requested years, and return asc
uniq = {item["year"]: item for item in arr}
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
arr_limited = arr_sorted_desc[:years]
arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year
series[key] = arr_sorted
meta = FinancialMeta(
started_at=started_real.isoformat(),
finished_at=finished_real.isoformat(),
elapsed_ms=elapsed_ms,
api_calls_total=api_calls_total,
api_calls_by_group=api_calls_by_group,
current_action=None,
steps=steps,
)
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
async def get_company_profile(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Get company profile for a company using Gemini AI (non-streaming, single response)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Company profile requested for {ts_code}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
api_key = gemini_cfg.get("api_key")
if not api_key:
logger.error("[API] Gemini API key not configured")
raise HTTPException(
status_code=500,
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
)
client = CompanyProfileClient(api_key=api_key)
# Get company name from ts_code if not provided
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
# Try to get from stock_basic API
try:
base_cfg = _load_json(BASE_CONFIG_PATH)
token = (
os.environ.get("TUSHARE_TOKEN")
or settings.TUSHARE_TOKEN
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
)
if token:
from app.services.tushare_client import TushareClient
tushare_client = TushareClient(token=token)
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
if basic_data and len(basic_data) > 0:
company_name = basic_data[0].get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating profile for {company_name}")
# Generate profile using non-streaming API
result = await client.generate_profile(
company_name=company_name,
ts_code=ts_code,
financial_data=None
)
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
return CompanyProfileResponse(
ts_code=ts_code,
company_name=company_name,
content=result.get("content", ""),
model=result.get("model", "gemini-2.5-flash"),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)