""" API router for financial data (Tushare for China market) """ import json import os import time from datetime import datetime, timezone, timedelta from typing import Dict, List from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse import os from app.core.config import settings from app.schemas.financial import ( BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse, AnalysisResponse, AnalysisConfigResponse ) from app.services.tushare_client import TushareClient from app.services.company_profile_client import CompanyProfileClient from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config router = APIRouter() # Load metric config from file (project root is repo root, not backend/) # routers/ -> app/ -> backend/ -> repo root REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json") BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json") ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json") def _load_json(path: str) -> Dict: if not os.path.exists(path): return {} try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {} @router.get("/config", response_model=FinancialConfigResponse) async def get_financial_config(): data = _load_json(FINANCIAL_CONFIG_PATH) api_groups = data.get("api_groups", {}) return FinancialConfigResponse(api_groups=api_groups) @router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse) async def get_china_financials( ts_code: str, years: int = Query(5, ge=1, le=15), ): # Load Tushare token base_cfg = _load_json(BASE_CONFIG_PATH) token = ( os.environ.get("TUSHARE_TOKEN") or settings.TUSHARE_TOKEN or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key") ) if not token: raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key") # Load metric config fin_cfg = _load_json(FINANCIAL_CONFIG_PATH) api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {}) client = TushareClient(token=token) # Meta tracking started_real = datetime.now(timezone.utc) started = time.perf_counter_ns() api_calls_total = 0 api_calls_by_group: Dict[str, int] = {} steps: List[StepRecord] = [] current_action = "初始化" # Get company name from stock_basic API company_name = None try: basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name") api_calls_total += 1 if basic_data and len(basic_data) > 0: company_name = basic_data[0].get("name") except Exception: # If getting company name fails, continue without it pass # Collect series per metric key series: Dict[str, List[Dict]] = {} # Helper to store year-value pairs while keeping most recent per year def _merge_year_value(key: str, year: str, value, month: int = None): arr = series.setdefault(key, []) # upsert by year for item in arr: if item["year"] == year: item["value"] = value if month is not None: item["month"] = month return arr.append({"year": year, "value": value, "month": month}) # Query each API group we care errors: Dict[str, str] = {} for group_name, metrics in api_groups.items(): step = StepRecord( name=f"拉取 {group_name}", start_ts=started_real.isoformat(), status="running", ) steps.append(step) current_action = step.name if not metrics: continue # 按 API 分组 metrics(处理 unknown 组中有多个不同 API 的情况) api_groups_dict: Dict[str, List[Dict]] = {} for metric in metrics: api = metric.get("api") or group_name if api: # 跳过空 API if api not in api_groups_dict: api_groups_dict[api] = [] api_groups_dict[api].append(metric) # 对每个 API 分别处理 for api_name, api_metrics in api_groups_dict.items(): fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")] if not fields: continue date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date" # 构建 API 参数 params = {"ts_code": ts_code, "limit": 5000} # 对于需要日期范围的 API(如 stk_holdernumber),添加日期参数 if api_name == "stk_holdernumber": # 计算日期范围:从 years 年前到现在 end_date = datetime.now().strftime("%Y%m%d") start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d") params["start_date"] = start_date params["end_date"] = end_date # stk_holdernumber 返回的日期字段通常是 end_date date_field = "end_date" # 对于非时间序列 API(如 stock_company),标记为静态数据 is_static_data = api_name == "stock_company" # 构建 fields 字符串:包含日期字段和所有需要的指标字段 # 确保日期字段存在,因为我们需要用它来确定年份 fields_list = list(fields) if date_field not in fields_list: fields_list.insert(0, date_field) # 对于 fina_indicator 等 API,通常还需要 ts_code 和 ann_date if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"): for req_field in ["ts_code", "ann_date"]: if req_field not in fields_list: fields_list.insert(0, req_field) fields_str = ",".join(fields_list) try: data_rows = await client.query(api_name=api_name, params=params, fields=fields_str) api_calls_total += 1 api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1 except Exception as e: # 记录错误但继续处理其他 API error_key = f"{group_name}_{api_name}" errors[error_key] = str(e) continue tmp: Dict[str, Dict] = {} current_year = datetime.now().strftime("%Y") for row in data_rows: if is_static_data: # 对于静态数据(如 stock_company),使用当前年份 # 只处理第一行数据,因为静态数据通常只有一行 if current_year not in tmp: year = current_year month = None tmp[year] = row tmp[year]['_month'] = month # 跳过后续行 continue else: # 对于时间序列数据,按日期字段处理 date_val = row.get(date_field) if not date_val: continue year = str(date_val)[:4] month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None existing = tmp.get(year) if existing is None or str(row.get(date_field)) > str(existing.get(date_field)): tmp[year] = row tmp[year]['_month'] = month for metric in api_metrics: key = metric.get("tushareParam") if not key: continue for year, row in tmp.items(): month = row.get('_month') _merge_year_value(key, year, row.get(key), month) step.status = "done" step.end_ts = datetime.now(timezone.utc).isoformat() step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) finished_real = datetime.now(timezone.utc) elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000) if not series: # If nothing succeeded, expose partial error info raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors}) # Truncate years and sort for key, arr in series.items(): # Deduplicate and sort desc by year, then cut to requested years, and return asc uniq = {item["year"]: item for item in arr} arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True) arr_limited = arr_sorted_desc[:years] arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year series[key] = arr_sorted meta = FinancialMeta( started_at=started_real.isoformat(), finished_at=finished_real.isoformat(), elapsed_ms=elapsed_ms, api_calls_total=api_calls_total, api_calls_by_group=api_calls_by_group, current_action=None, steps=steps, ) return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta) @router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse) async def get_company_profile( ts_code: str, company_name: str = Query(None, description="Company name for better context"), ): """ Get company profile for a company using Gemini AI (non-streaming, single response) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Company profile requested for {ts_code}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) gemini_cfg = base_cfg.get("llm", {}).get("gemini", {}) api_key = gemini_cfg.get("api_key") if not api_key: logger.error("[API] Gemini API key not configured") raise HTTPException( status_code=500, detail="Gemini API key not configured. Set config.json llm.gemini.api_key" ) client = CompanyProfileClient(api_key=api_key) # Get company name from ts_code if not provided if not company_name: logger.info(f"[API] Fetching company name for {ts_code}") # Try to get from stock_basic API try: base_cfg = _load_json(BASE_CONFIG_PATH) token = ( os.environ.get("TUSHARE_TOKEN") or settings.TUSHARE_TOKEN or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key") ) if token: from app.services.tushare_client import TushareClient tushare_client = TushareClient(token=token) basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name") if basic_data and len(basic_data) > 0: company_name = basic_data[0].get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") else: company_name = ts_code else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating profile for {company_name}") # Generate profile using non-streaming API result = await client.generate_profile( company_name=company_name, ts_code=ts_code, financial_data=None ) logger.info(f"[API] Profile generation completed, success={result.get('success')}") return CompanyProfileResponse( ts_code=ts_code, company_name=company_name, content=result.get("content", ""), model=result.get("model", "gemini-2.5-flash"), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) @router.get("/analysis-config", response_model=AnalysisConfigResponse) async def get_analysis_config_endpoint(): """Get analysis configuration""" config = load_analysis_config() return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {})) @router.put("/analysis-config", response_model=AnalysisConfigResponse) async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse): """Update analysis configuration""" import logging logger = logging.getLogger(__name__) try: # 保存到文件 config_data = { "analysis_modules": analysis_config.analysis_modules } with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_data, f, ensure_ascii=False, indent=2) logger.info(f"[API] Analysis config updated successfully") return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules) except Exception as e: logger.error(f"[API] Failed to update analysis config: {e}") raise HTTPException( status_code=500, detail=f"Failed to update analysis config: {str(e)}" ) @router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse) async def generate_analysis( ts_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ Generate analysis for a company using Gemini AI Supported analysis types: - fundamental_analysis (基本面分析) - bull_case (看涨分析) - bear_case (看跌分析) - market_analysis (市场分析) - news_analysis (新闻分析) - trading_analysis (交易分析) - insider_institutional (内部人与机构动向分析) - final_conclusion (最终结论) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) gemini_cfg = base_cfg.get("llm", {}).get("gemini", {}) api_key = gemini_cfg.get("api_key") if not api_key: logger.error("[API] Gemini API key not configured") raise HTTPException( status_code=500, detail="Gemini API key not configured. Set config.json llm.gemini.api_key" ) # Get analysis configuration analysis_cfg = get_analysis_config(analysis_type) if not analysis_cfg: raise HTTPException( status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration" ) model = analysis_cfg.get("model", "gemini-2.5-flash") prompt_template = analysis_cfg.get("prompt_template", "") if not prompt_template: raise HTTPException( status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'" ) # Get company name from ts_code if not provided financial_data = None if not company_name: logger.info(f"[API] Fetching company name and financial data for {ts_code}") try: token = ( os.environ.get("TUSHARE_TOKEN") or settings.TUSHARE_TOKEN or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key") ) if token: tushare_client = TushareClient(token=token) basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name") if basic_data and len(basic_data) > 0: company_name = basic_data[0].get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") # Try to get financial data for context try: fin_cfg = _load_json(FINANCIAL_CONFIG_PATH) api_groups = fin_cfg.get("api_groups", {}) # Get financial data summary for context series: Dict[str, List[Dict]] = {} for group_name, metrics in api_groups.items(): if not metrics: continue api_groups_dict: Dict[str, List[Dict]] = {} for metric in metrics: api = metric.get("api") or group_name if api: if api not in api_groups_dict: api_groups_dict[api] = [] api_groups_dict[api].append(metric) for api_name, api_metrics in api_groups_dict.items(): fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")] if not fields: continue date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date" params = {"ts_code": ts_code, "limit": 500} fields_list = list(fields) if date_field not in fields_list: fields_list.insert(0, date_field) if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"): for req_field in ["ts_code", "ann_date"]: if req_field not in fields_list: fields_list.insert(0, req_field) fields_str = ",".join(fields_list) try: data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str) if data_rows: # Get latest year's data latest_row = data_rows[0] if data_rows else {} for metric in api_metrics: key = metric.get("tushareParam") if key and key in latest_row: if key not in series: series[key] = [] series[key].append({ "year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year), "value": latest_row.get(key) }) except Exception: pass financial_data = {"series": series} except Exception as e: logger.warning(f"[API] Failed to get financial data: {e}") financial_data = None else: company_name = ts_code else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating {analysis_type} for {company_name}") # Initialize analysis client with configured model client = AnalysisClient(api_key=api_key, model=model) # Generate analysis result = await client.generate_analysis( analysis_type=analysis_type, company_name=company_name, ts_code=ts_code, prompt_template=prompt_template, financial_data=financial_data ) logger.info(f"[API] Analysis generation completed, success={result.get('success')}") return AnalysisResponse( ts_code=ts_code, company_name=company_name, analysis_type=analysis_type, content=result.get("content", ""), model=result.get("model", model), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") )