Fundamental_Analysis/backend/app/routers/financial.py

513 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API router for financial data (Tushare for China market)
"""
import json
import os
import time
from datetime import datetime, timezone, timedelta
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
import os
from app.core.config import settings
from app.schemas.financial import (
BatchFinancialDataResponse,
FinancialConfigResponse,
FinancialMeta,
StepRecord,
CompanyProfileResponse,
AnalysisResponse,
AnalysisConfigResponse
)
from app.services.tushare_client import TushareClient
from app.services.company_profile_client import CompanyProfileClient
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
router = APIRouter()
# Load metric config from file (project root is repo root, not backend/)
# routers/ -> app/ -> backend/ -> repo root
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
def _load_json(path: str) -> Dict:
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
@router.get("/config", response_model=FinancialConfigResponse)
async def get_financial_config():
data = _load_json(FINANCIAL_CONFIG_PATH)
api_groups = data.get("api_groups", {})
return FinancialConfigResponse(api_groups=api_groups)
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
async def get_china_financials(
ts_code: str,
years: int = Query(5, ge=1, le=15),
):
# Load Tushare token
base_cfg = _load_json(BASE_CONFIG_PATH)
token = (
os.environ.get("TUSHARE_TOKEN")
or settings.TUSHARE_TOKEN
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
)
if not token:
raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key")
# Load metric config
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
client = TushareClient(token=token)
# Meta tracking
started_real = datetime.now(timezone.utc)
started = time.perf_counter_ns()
api_calls_total = 0
api_calls_by_group: Dict[str, int] = {}
steps: List[StepRecord] = []
current_action = "初始化"
# Get company name from stock_basic API
company_name = None
try:
basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
api_calls_total += 1
if basic_data and len(basic_data) > 0:
company_name = basic_data[0].get("name")
except Exception:
# If getting company name fails, continue without it
pass
# Collect series per metric key
series: Dict[str, List[Dict]] = {}
# Helper to store year-value pairs while keeping most recent per year
def _merge_year_value(key: str, year: str, value, month: int = None):
arr = series.setdefault(key, [])
# upsert by year
for item in arr:
if item["year"] == year:
item["value"] = value
if month is not None:
item["month"] = month
return
arr.append({"year": year, "value": value, "month": month})
# Query each API group we care
errors: Dict[str, str] = {}
for group_name, metrics in api_groups.items():
step = StepRecord(
name=f"拉取 {group_name}",
start_ts=started_real.isoformat(),
status="running",
)
steps.append(step)
current_action = step.name
if not metrics:
continue
# 按 API 分组 metrics处理 unknown 组中有多个不同 API 的情况)
api_groups_dict: Dict[str, List[Dict]] = {}
for metric in metrics:
api = metric.get("api") or group_name
if api: # 跳过空 API
if api not in api_groups_dict:
api_groups_dict[api] = []
api_groups_dict[api].append(metric)
# 对每个 API 分别处理
for api_name, api_metrics in api_groups_dict.items():
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
if not fields:
continue
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
# 构建 API 参数
params = {"ts_code": ts_code, "limit": 5000}
# 对于需要日期范围的 API如 stk_holdernumber添加日期参数
if api_name == "stk_holdernumber":
# 计算日期范围:从 years 年前到现在
end_date = datetime.now().strftime("%Y%m%d")
start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d")
params["start_date"] = start_date
params["end_date"] = end_date
# stk_holdernumber 返回的日期字段通常是 end_date
date_field = "end_date"
# 对于非时间序列 API如 stock_company标记为静态数据
is_static_data = api_name == "stock_company"
# 构建 fields 字符串:包含日期字段和所有需要的指标字段
# 确保日期字段存在,因为我们需要用它来确定年份
fields_list = list(fields)
if date_field not in fields_list:
fields_list.insert(0, date_field)
# 对于 fina_indicator 等 API通常还需要 ts_code 和 ann_date
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
for req_field in ["ts_code", "ann_date"]:
if req_field not in fields_list:
fields_list.insert(0, req_field)
fields_str = ",".join(fields_list)
try:
data_rows = await client.query(api_name=api_name, params=params, fields=fields_str)
api_calls_total += 1
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
except Exception as e:
# 记录错误但继续处理其他 API
error_key = f"{group_name}_{api_name}"
errors[error_key] = str(e)
continue
tmp: Dict[str, Dict] = {}
current_year = datetime.now().strftime("%Y")
for row in data_rows:
if is_static_data:
# 对于静态数据(如 stock_company使用当前年份
# 只处理第一行数据,因为静态数据通常只有一行
if current_year not in tmp:
year = current_year
month = None
tmp[year] = row
tmp[year]['_month'] = month
# 跳过后续行
continue
else:
# 对于时间序列数据,按日期字段处理
date_val = row.get(date_field)
if not date_val:
continue
year = str(date_val)[:4]
month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None
existing = tmp.get(year)
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
tmp[year] = row
tmp[year]['_month'] = month
for metric in api_metrics:
key = metric.get("tushareParam")
if not key:
continue
for year, row in tmp.items():
month = row.get('_month')
_merge_year_value(key, year, row.get(key), month)
step.status = "done"
step.end_ts = datetime.now(timezone.utc).isoformat()
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
finished_real = datetime.now(timezone.utc)
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
if not series:
# If nothing succeeded, expose partial error info
raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors})
# Truncate years and sort
for key, arr in series.items():
# Deduplicate and sort desc by year, then cut to requested years, and return asc
uniq = {item["year"]: item for item in arr}
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
arr_limited = arr_sorted_desc[:years]
arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year
series[key] = arr_sorted
meta = FinancialMeta(
started_at=started_real.isoformat(),
finished_at=finished_real.isoformat(),
elapsed_ms=elapsed_ms,
api_calls_total=api_calls_total,
api_calls_by_group=api_calls_by_group,
current_action=None,
steps=steps,
)
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
async def get_company_profile(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Get company profile for a company using Gemini AI (non-streaming, single response)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Company profile requested for {ts_code}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
api_key = gemini_cfg.get("api_key")
if not api_key:
logger.error("[API] Gemini API key not configured")
raise HTTPException(
status_code=500,
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
)
client = CompanyProfileClient(api_key=api_key)
# Get company name from ts_code if not provided
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
# Try to get from stock_basic API
try:
base_cfg = _load_json(BASE_CONFIG_PATH)
token = (
os.environ.get("TUSHARE_TOKEN")
or settings.TUSHARE_TOKEN
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
)
if token:
from app.services.tushare_client import TushareClient
tushare_client = TushareClient(token=token)
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
if basic_data and len(basic_data) > 0:
company_name = basic_data[0].get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating profile for {company_name}")
# Generate profile using non-streaming API
result = await client.generate_profile(
company_name=company_name,
ts_code=ts_code,
financial_data=None
)
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
return CompanyProfileResponse(
ts_code=ts_code,
company_name=company_name,
content=result.get("content", ""),
model=result.get("model", "gemini-2.5-flash"),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
async def get_analysis_config_endpoint():
"""Get analysis configuration"""
config = load_analysis_config()
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
"""Update analysis configuration"""
import logging
logger = logging.getLogger(__name__)
try:
# 保存到文件
config_data = {
"analysis_modules": analysis_config.analysis_modules
}
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(config_data, f, ensure_ascii=False, indent=2)
logger.info(f"[API] Analysis config updated successfully")
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
except Exception as e:
logger.error(f"[API] Failed to update analysis config: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to update analysis config: {str(e)}"
)
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
async def generate_analysis(
ts_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Generate analysis for a company using Gemini AI
Supported analysis types:
- fundamental_analysis (基本面分析)
- bull_case (看涨分析)
- bear_case (看跌分析)
- market_analysis (市场分析)
- news_analysis (新闻分析)
- trading_analysis (交易分析)
- insider_institutional (内部人与机构动向分析)
- final_conclusion (最终结论)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
api_key = gemini_cfg.get("api_key")
if not api_key:
logger.error("[API] Gemini API key not configured")
raise HTTPException(
status_code=500,
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
)
# Get analysis configuration
analysis_cfg = get_analysis_config(analysis_type)
if not analysis_cfg:
raise HTTPException(
status_code=404,
detail=f"Analysis type '{analysis_type}' not found in configuration"
)
model = analysis_cfg.get("model", "gemini-2.5-flash")
prompt_template = analysis_cfg.get("prompt_template", "")
if not prompt_template:
raise HTTPException(
status_code=500,
detail=f"Prompt template not found for analysis type '{analysis_type}'"
)
# Get company name from ts_code if not provided
financial_data = None
if not company_name:
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
try:
token = (
os.environ.get("TUSHARE_TOKEN")
or settings.TUSHARE_TOKEN
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
)
if token:
tushare_client = TushareClient(token=token)
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
if basic_data and len(basic_data) > 0:
company_name = basic_data[0].get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
# Try to get financial data for context
try:
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
api_groups = fin_cfg.get("api_groups", {})
# Get financial data summary for context
series: Dict[str, List[Dict]] = {}
for group_name, metrics in api_groups.items():
if not metrics:
continue
api_groups_dict: Dict[str, List[Dict]] = {}
for metric in metrics:
api = metric.get("api") or group_name
if api:
if api not in api_groups_dict:
api_groups_dict[api] = []
api_groups_dict[api].append(metric)
for api_name, api_metrics in api_groups_dict.items():
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
if not fields:
continue
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
params = {"ts_code": ts_code, "limit": 500}
fields_list = list(fields)
if date_field not in fields_list:
fields_list.insert(0, date_field)
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
for req_field in ["ts_code", "ann_date"]:
if req_field not in fields_list:
fields_list.insert(0, req_field)
fields_str = ",".join(fields_list)
try:
data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str)
if data_rows:
# Get latest year's data
latest_row = data_rows[0] if data_rows else {}
for metric in api_metrics:
key = metric.get("tushareParam")
if key and key in latest_row:
if key not in series:
series[key] = []
series[key].append({
"year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year),
"value": latest_row.get(key)
})
except Exception:
pass
financial_data = {"series": series}
except Exception as e:
logger.warning(f"[API] Failed to get financial data: {e}")
financial_data = None
else:
company_name = ts_code
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating {analysis_type} for {company_name}")
# Initialize analysis client with configured model
client = AnalysisClient(api_key=api_key, model=model)
# Generate analysis
result = await client.generate_analysis(
analysis_type=analysis_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=prompt_template,
financial_data=financial_data
)
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
return AnalysisResponse(
ts_code=ts_code,
company_name=company_name,
analysis_type=analysis_type,
content=result.get("content", ""),
model=result.get("model", model),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)