513 lines
21 KiB
Python
513 lines
21 KiB
Python
"""
|
||
API router for financial data (Tushare for China market)
|
||
"""
|
||
import json
|
||
import os
|
||
import time
|
||
from datetime import datetime, timezone, timedelta
|
||
from typing import Dict, List
|
||
|
||
from fastapi import APIRouter, HTTPException, Query
|
||
from fastapi.responses import StreamingResponse
|
||
import os
|
||
|
||
from app.core.config import settings
|
||
from app.schemas.financial import (
|
||
BatchFinancialDataResponse,
|
||
FinancialConfigResponse,
|
||
FinancialMeta,
|
||
StepRecord,
|
||
CompanyProfileResponse,
|
||
AnalysisResponse,
|
||
AnalysisConfigResponse
|
||
)
|
||
from app.services.tushare_client import TushareClient
|
||
from app.services.company_profile_client import CompanyProfileClient
|
||
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
||
|
||
router = APIRouter()
|
||
|
||
# Load metric config from file (project root is repo root, not backend/)
|
||
# routers/ -> app/ -> backend/ -> repo root
|
||
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
||
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
||
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
|
||
|
||
|
||
def _load_json(path: str) -> Dict:
|
||
if not os.path.exists(path):
|
||
return {}
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
@router.get("/config", response_model=FinancialConfigResponse)
|
||
async def get_financial_config():
|
||
data = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups = data.get("api_groups", {})
|
||
return FinancialConfigResponse(api_groups=api_groups)
|
||
|
||
|
||
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
|
||
async def get_china_financials(
|
||
ts_code: str,
|
||
years: int = Query(5, ge=1, le=15),
|
||
):
|
||
# Load Tushare token
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if not token:
|
||
raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key")
|
||
|
||
# Load metric config
|
||
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
|
||
|
||
client = TushareClient(token=token)
|
||
|
||
# Meta tracking
|
||
started_real = datetime.now(timezone.utc)
|
||
started = time.perf_counter_ns()
|
||
api_calls_total = 0
|
||
api_calls_by_group: Dict[str, int] = {}
|
||
steps: List[StepRecord] = []
|
||
current_action = "初始化"
|
||
|
||
# Get company name from stock_basic API
|
||
company_name = None
|
||
try:
|
||
basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
api_calls_total += 1
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name")
|
||
except Exception:
|
||
# If getting company name fails, continue without it
|
||
pass
|
||
|
||
# Collect series per metric key
|
||
series: Dict[str, List[Dict]] = {}
|
||
|
||
# Helper to store year-value pairs while keeping most recent per year
|
||
def _merge_year_value(key: str, year: str, value, month: int = None):
|
||
arr = series.setdefault(key, [])
|
||
# upsert by year
|
||
for item in arr:
|
||
if item["year"] == year:
|
||
item["value"] = value
|
||
if month is not None:
|
||
item["month"] = month
|
||
return
|
||
arr.append({"year": year, "value": value, "month": month})
|
||
|
||
# Query each API group we care
|
||
errors: Dict[str, str] = {}
|
||
for group_name, metrics in api_groups.items():
|
||
step = StepRecord(
|
||
name=f"拉取 {group_name}",
|
||
start_ts=started_real.isoformat(),
|
||
status="running",
|
||
)
|
||
steps.append(step)
|
||
current_action = step.name
|
||
if not metrics:
|
||
continue
|
||
|
||
# 按 API 分组 metrics(处理 unknown 组中有多个不同 API 的情况)
|
||
api_groups_dict: Dict[str, List[Dict]] = {}
|
||
for metric in metrics:
|
||
api = metric.get("api") or group_name
|
||
if api: # 跳过空 API
|
||
if api not in api_groups_dict:
|
||
api_groups_dict[api] = []
|
||
api_groups_dict[api].append(metric)
|
||
|
||
# 对每个 API 分别处理
|
||
for api_name, api_metrics in api_groups_dict.items():
|
||
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||
if not fields:
|
||
continue
|
||
|
||
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||
|
||
# 构建 API 参数
|
||
params = {"ts_code": ts_code, "limit": 5000}
|
||
|
||
# 对于需要日期范围的 API(如 stk_holdernumber),添加日期参数
|
||
if api_name == "stk_holdernumber":
|
||
# 计算日期范围:从 years 年前到现在
|
||
end_date = datetime.now().strftime("%Y%m%d")
|
||
start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d")
|
||
params["start_date"] = start_date
|
||
params["end_date"] = end_date
|
||
# stk_holdernumber 返回的日期字段通常是 end_date
|
||
date_field = "end_date"
|
||
|
||
# 对于非时间序列 API(如 stock_company),标记为静态数据
|
||
is_static_data = api_name == "stock_company"
|
||
|
||
# 构建 fields 字符串:包含日期字段和所有需要的指标字段
|
||
# 确保日期字段存在,因为我们需要用它来确定年份
|
||
fields_list = list(fields)
|
||
if date_field not in fields_list:
|
||
fields_list.insert(0, date_field)
|
||
# 对于 fina_indicator 等 API,通常还需要 ts_code 和 ann_date
|
||
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||
for req_field in ["ts_code", "ann_date"]:
|
||
if req_field not in fields_list:
|
||
fields_list.insert(0, req_field)
|
||
fields_str = ",".join(fields_list)
|
||
|
||
try:
|
||
data_rows = await client.query(api_name=api_name, params=params, fields=fields_str)
|
||
api_calls_total += 1
|
||
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
|
||
except Exception as e:
|
||
# 记录错误但继续处理其他 API
|
||
error_key = f"{group_name}_{api_name}"
|
||
errors[error_key] = str(e)
|
||
continue
|
||
|
||
tmp: Dict[str, Dict] = {}
|
||
current_year = datetime.now().strftime("%Y")
|
||
|
||
for row in data_rows:
|
||
if is_static_data:
|
||
# 对于静态数据(如 stock_company),使用当前年份
|
||
# 只处理第一行数据,因为静态数据通常只有一行
|
||
if current_year not in tmp:
|
||
year = current_year
|
||
month = None
|
||
tmp[year] = row
|
||
tmp[year]['_month'] = month
|
||
# 跳过后续行
|
||
continue
|
||
else:
|
||
# 对于时间序列数据,按日期字段处理
|
||
date_val = row.get(date_field)
|
||
if not date_val:
|
||
continue
|
||
year = str(date_val)[:4]
|
||
month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None
|
||
existing = tmp.get(year)
|
||
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
||
tmp[year] = row
|
||
tmp[year]['_month'] = month
|
||
|
||
for metric in api_metrics:
|
||
key = metric.get("tushareParam")
|
||
if not key:
|
||
continue
|
||
for year, row in tmp.items():
|
||
month = row.get('_month')
|
||
_merge_year_value(key, year, row.get(key), month)
|
||
|
||
step.status = "done"
|
||
step.end_ts = datetime.now(timezone.utc).isoformat()
|
||
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
finished_real = datetime.now(timezone.utc)
|
||
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
if not series:
|
||
# If nothing succeeded, expose partial error info
|
||
raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors})
|
||
|
||
# Truncate years and sort
|
||
for key, arr in series.items():
|
||
# Deduplicate and sort desc by year, then cut to requested years, and return asc
|
||
uniq = {item["year"]: item for item in arr}
|
||
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
|
||
arr_limited = arr_sorted_desc[:years]
|
||
arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year
|
||
series[key] = arr_sorted
|
||
|
||
meta = FinancialMeta(
|
||
started_at=started_real.isoformat(),
|
||
finished_at=finished_real.isoformat(),
|
||
elapsed_ms=elapsed_ms,
|
||
api_calls_total=api_calls_total,
|
||
api_calls_by_group=api_calls_by_group,
|
||
current_action=None,
|
||
steps=steps,
|
||
)
|
||
|
||
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
|
||
|
||
|
||
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
|
||
async def get_company_profile(
|
||
ts_code: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Get company profile for a company using Gemini AI (non-streaming, single response)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Company profile requested for {ts_code}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
|
||
api_key = gemini_cfg.get("api_key")
|
||
|
||
if not api_key:
|
||
logger.error("[API] Gemini API key not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
|
||
)
|
||
|
||
client = CompanyProfileClient(api_key=api_key)
|
||
|
||
# Get company name from ts_code if not provided
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name for {ts_code}")
|
||
# Try to get from stock_basic API
|
||
try:
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if token:
|
||
from app.services.tushare_client import TushareClient
|
||
tushare_client = TushareClient(token=token)
|
||
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
else:
|
||
company_name = ts_code
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating profile for {company_name}")
|
||
|
||
# Generate profile using non-streaming API
|
||
result = await client.generate_profile(
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
financial_data=None
|
||
)
|
||
|
||
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
|
||
|
||
return CompanyProfileResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", "gemini-2.5-flash"),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
|
||
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def get_analysis_config_endpoint():
|
||
"""Get analysis configuration"""
|
||
config = load_analysis_config()
|
||
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
|
||
|
||
|
||
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
|
||
"""Update analysis configuration"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
# 保存到文件
|
||
config_data = {
|
||
"analysis_modules": analysis_config.analysis_modules
|
||
}
|
||
|
||
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"[API] Analysis config updated successfully")
|
||
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
|
||
except Exception as e:
|
||
logger.error(f"[API] Failed to update analysis config: {e}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to update analysis config: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
|
||
async def generate_analysis(
|
||
ts_code: str,
|
||
analysis_type: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Generate analysis for a company using Gemini AI
|
||
Supported analysis types:
|
||
- fundamental_analysis (基本面分析)
|
||
- bull_case (看涨分析)
|
||
- bear_case (看跌分析)
|
||
- market_analysis (市场分析)
|
||
- news_analysis (新闻分析)
|
||
- trading_analysis (交易分析)
|
||
- insider_institutional (内部人与机构动向分析)
|
||
- final_conclusion (最终结论)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
|
||
api_key = gemini_cfg.get("api_key")
|
||
|
||
if not api_key:
|
||
logger.error("[API] Gemini API key not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
|
||
)
|
||
|
||
# Get analysis configuration
|
||
analysis_cfg = get_analysis_config(analysis_type)
|
||
if not analysis_cfg:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||
)
|
||
|
||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||
|
||
if not prompt_template:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||
)
|
||
|
||
# Get company name from ts_code if not provided
|
||
financial_data = None
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
|
||
try:
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if token:
|
||
tushare_client = TushareClient(token=token)
|
||
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
|
||
# Try to get financial data for context
|
||
try:
|
||
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups = fin_cfg.get("api_groups", {})
|
||
|
||
# Get financial data summary for context
|
||
series: Dict[str, List[Dict]] = {}
|
||
for group_name, metrics in api_groups.items():
|
||
if not metrics:
|
||
continue
|
||
api_groups_dict: Dict[str, List[Dict]] = {}
|
||
for metric in metrics:
|
||
api = metric.get("api") or group_name
|
||
if api:
|
||
if api not in api_groups_dict:
|
||
api_groups_dict[api] = []
|
||
api_groups_dict[api].append(metric)
|
||
|
||
for api_name, api_metrics in api_groups_dict.items():
|
||
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||
if not fields:
|
||
continue
|
||
|
||
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||
|
||
params = {"ts_code": ts_code, "limit": 500}
|
||
fields_list = list(fields)
|
||
if date_field not in fields_list:
|
||
fields_list.insert(0, date_field)
|
||
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||
for req_field in ["ts_code", "ann_date"]:
|
||
if req_field not in fields_list:
|
||
fields_list.insert(0, req_field)
|
||
fields_str = ",".join(fields_list)
|
||
|
||
try:
|
||
data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str)
|
||
if data_rows:
|
||
# Get latest year's data
|
||
latest_row = data_rows[0] if data_rows else {}
|
||
for metric in api_metrics:
|
||
key = metric.get("tushareParam")
|
||
if key and key in latest_row:
|
||
if key not in series:
|
||
series[key] = []
|
||
series[key].append({
|
||
"year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year),
|
||
"value": latest_row.get(key)
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
financial_data = {"series": series}
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get financial data: {e}")
|
||
financial_data = None
|
||
else:
|
||
company_name = ts_code
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
||
|
||
# Initialize analysis client with configured model
|
||
client = AnalysisClient(api_key=api_key, model=model)
|
||
|
||
# Generate analysis
|
||
result = await client.generate_analysis(
|
||
analysis_type=analysis_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=prompt_template,
|
||
financial_data=financial_data
|
||
)
|
||
|
||
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
|
||
|
||
return AnalysisResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
analysis_type=analysis_type,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", model),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|