""" API router for financial data (Tushare for China market) """ import json import os import time from datetime import datetime, timezone, timedelta from typing import Dict, List from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse from app.core.config import settings from app.schemas.financial import ( BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse, AnalysisResponse, AnalysisConfigResponse ) from app.services.company_profile_client import CompanyProfileClient from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config # Lazy DataManager loader to avoid import-time failures when optional providers/config are missing _dm = None def get_dm(): global _dm if _dm is not None: return _dm try: from app.data_manager import data_manager as real_dm _dm = real_dm return _dm except Exception: class _StubDM: config = {} async def get_stock_basic(self, stock_code: str): return None async def get_financial_statements(self, stock_code: str, report_dates): return [] _dm = _StubDM() return _dm router = APIRouter() # Load metric config from file (project root is repo root, not backend/) # routers/ -> app/ -> backend/ -> repo root REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json") BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json") ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json") def _load_json(path: str) -> Dict: if not os.path.exists(path): return {} try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {} @router.get("/data-sources", response_model=Dict[str, List[str]]) async def get_data_sources(): """ Get the list of data sources that require an API key from the config. """ try: data_sources_config = get_dm().config.get("data_sources", {}) sources_requiring_keys = [ source for source, config in data_sources_config.items() if config.get("api_key_env") ] return {"sources": sources_requiring_keys} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}") @router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse]) async def generate_full_analysis( ts_code: str, company_name: str = Query(None, description="Company name for better context"), ): """ Generate a full analysis report by orchestrating multiple analysis modules based on dependencies defined in the configuration. """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Full analysis requested for {ts_code}") # Load base and analysis configurations base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) if not modules_config: raise HTTPException(status_code=404, detail="Analysis modules configuration not found.") # --- Dependency Resolution (Topological Sort) --- def topological_sort(graph): in_degree = {u: 0 for u in graph} for u in graph: for v in graph[u]: in_degree[v] += 1 queue = [u for u in graph if in_degree[u] == 0] sorted_order = [] while queue: u = queue.pop(0) sorted_order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(sorted_order) == len(graph): return sorted_order else: # Detect cycles and provide a meaningful error cycles = [] visited = set() path = [] def find_cycle_util(node): visited.add(node) path.append(node) for neighbor in graph.get(node, []): if neighbor in path: cycle_start_index = path.index(neighbor) cycles.append(path[cycle_start_index:] + [neighbor]) return if neighbor not in visited: find_cycle_util(neighbor) path.pop() for node in graph: if node not in visited: find_cycle_util(node) return None, cycles # Build dependency graph dependency_graph = { name: config.get("dependencies", []) for name, config in modules_config.items() } # Invert graph for topological sort (from dependency to dependent) adj_list = {u: [] for u in dependency_graph} for u, dependencies in dependency_graph.items(): for dep in dependencies: if dep in adj_list: adj_list[dep].append(u) sorted_modules, cycle = topological_sort(adj_list) if not sorted_modules: raise HTTPException( status_code=400, detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}" ) # --- Fetch common data (company name, financial data) --- # This logic is duplicated, could be refactored into a helper financial_data = None if not company_name: logger.info(f"[API] Fetching company name for {ts_code}") try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") else: company_name = ts_code except Exception as e: logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}") company_name = ts_code # --- Execute modules in order --- analysis_results = [] completed_modules_content = {} for module_type in sorted_modules: module_config = modules_config[module_type] logger.info(f"[Orchestrator] Starting analysis for module: {module_type}") client = AnalysisClient( api_key=api_key, base_url=base_url, model=module_config.get("model", "gemini-1.5-flash") ) # Gather context from completed dependencies context = { dep: completed_modules_content.get(dep, "") for dep in module_config.get("dependencies", []) } result = await client.generate_analysis( analysis_type=module_type, company_name=company_name, ts_code=ts_code, prompt_template=module_config.get("prompt_template", ""), financial_data=financial_data, context=context, ) response = AnalysisResponse( ts_code=ts_code, company_name=company_name, analysis_type=module_type, content=result.get("content", ""), model=result.get("model", module_config.get("model")), tokens=result.get("tokens", {}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) analysis_results.append(response) if response.success: completed_modules_content[module_type] = response.content else: # If a module fails, subsequent dependent modules will get an empty string for its context. # This prevents total failure but may affect quality. completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed." logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}") logger.info(f"[API] Full analysis for {ts_code} completed.") return analysis_results @router.get("/config", response_model=FinancialConfigResponse) async def get_financial_config(): data = _load_json(FINANCIAL_CONFIG_PATH) api_groups = data.get("api_groups", {}) return FinancialConfigResponse(api_groups=api_groups) @router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse) async def get_china_financials( ts_code: str, years: int = Query(5, ge=1, le=15), ): # Load metric config fin_cfg = _load_json(FINANCIAL_CONFIG_PATH) api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {}) # Meta tracking started_real = datetime.now(timezone.utc) started = time.perf_counter_ns() api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it api_calls_by_group: Dict[str, int] = {} steps: List[StepRecord] = [] # Get company name company_name = ts_code try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) except Exception: pass # Continue without it # Collect series per metric key series: Dict[str, List[Dict]] = {} errors: Dict[str, str] = {} # Generate date range for financial statements current_year = datetime.now().year report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)] # Fetch all financial statements at once step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running") steps.append(step_financials) # Fetch all financial statements at once (already in series format from provider) series = await get_dm().get_financial_statements(stock_code=ts_code, report_dates=report_dates) # Get the latest current year report period for market data latest_current_year_report = None if series: current_year_str = str(current_year) for key in series: if series[key]: for item in series[key]: period = item.get('period', '') if period.startswith(current_year_str) and not period.endswith('1231'): if latest_current_year_report is None or period > latest_current_year_report: latest_current_year_report = period if not series: errors["financial_statements"] = "Failed to fetch from all providers." step_financials.status = "done" step_financials.end_ts = datetime.now(timezone.utc).isoformat() step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) # --- 拉取市值/估值(daily_basic)与股价(daily)按年度末日期 --- try: # 仅当配置包含相应分组时再尝试拉取 has_daily_basic = bool(api_groups.get("daily_basic")) has_daily = bool(api_groups.get("daily")) if has_daily_basic or has_daily: step_market = StepRecord(name="拉取市值与股价", start_ts=datetime.now(timezone.utc).isoformat(), status="running") steps.append(step_market) # 构建市场数据查询日期:年度日期 + 当前年最新报告期 market_dates = report_dates.copy() if latest_current_year_report: # 查找当前年最新报告期对应的交易日(通常是报告期当月最后一天或最近交易日) try: report_date_obj = datetime.strptime(latest_current_year_report, '%Y%m%d') # 使用报告期日期作为查询日期(API会自动找到最近的交易日) market_dates.append(latest_current_year_report) except ValueError: pass try: if has_daily_basic: db_rows = await get_dm().get_data('get_daily_basic_points', stock_code=ts_code, trade_dates=market_dates) if isinstance(db_rows, list): for row in db_rows: trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date') if not trade_date: continue # 判断是否为当前年最新报告期的数据 is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report year = str(trade_date)[:4] if is_current_year_report: # 当前年最新报告期的数据,使用报告期period显示 period = latest_current_year_report else: # 其他年度数据,使用年度period显示(YYYY1231) period = f"{year}1231" for key, value in row.items(): if key in ['ts_code', 'trade_date', 'trade_dt', 'date']: continue if isinstance(value, (int, float)) and value is not None: if key not in series: series[key] = [] # 检查是否已存在该period的数据,如果存在则替换为最新的数据 existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1) if existing_index >= 0: series[key][existing_index] = {"period": period, "value": value} else: series[key].append({"period": period, "value": value}) if has_daily: d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=market_dates) if isinstance(d_rows, list): for row in d_rows: trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date') if not trade_date: continue # 判断是否为当前年最新报告期的数据 is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report year = str(trade_date)[:4] if is_current_year_report: # 当前年最新报告期的数据,使用报告期period显示 period = latest_current_year_report else: # 其他年度数据,使用年度period显示(YYYY1231) period = f"{year}1231" for key, value in row.items(): if key in ['ts_code', 'trade_date', 'trade_dt', 'date']: continue if isinstance(value, (int, float)) and value is not None: if key not in series: series[key] = [] # 检查是否已存在该period的数据,如果存在则替换为最新的数据 existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1) if existing_index >= 0: series[key][existing_index] = {"period": period, "value": value} else: series[key].append({"period": period, "value": value}) except Exception as e: errors["market_data"] = f"Failed to fetch market data: {e}" finally: step_market.status = "done" step_market.end_ts = datetime.now(timezone.utc).isoformat() step_market.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) except Exception as e: errors["market_data_init"] = f"Market data init failed: {e}" finished_real = datetime.now(timezone.utc) elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000) if not series: raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors}) # Truncate periods and sort (the data should already be mostly correct, but we ensure) for key, arr in series.items(): # Deduplicate and sort desc by period, then cut to requested periods, and return asc uniq = {item["period"]: item for item in arr} arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["period"], reverse=True) arr_limited = arr_sorted_desc[:years] arr_sorted = sorted(arr_limited, key=lambda x: x["period"]) series[key] = arr_sorted # Create periods_list for derived metrics calculation periods_list = sorted(list(set(item["period"] for arr in series.values() for item in arr))) # Note: Derived financial metrics calculation has been moved to individual data providers # The data_manager.get_financial_statements() should handle this meta = FinancialMeta( started_at=started_real.isoformat(), finished_at=finished_real.isoformat(), elapsed_ms=elapsed_ms, api_calls_total=api_calls_total, api_calls_by_group=api_calls_by_group, current_action=None, steps=steps, ) return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta) @router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse) async def get_company_profile( ts_code: str, company_name: str = Query(None, description="Company name for better context"), ): """ Get company profile for a company using Gemini AI (non-streaming, single response) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Company profile requested for {ts_code}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") # Will be None if not set, handled by client if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) client = CompanyProfileClient( api_key=api_key, base_url=base_url, model="gemini-1.5-flash" ) # Get company name from ts_code if not provided if not company_name: logger.info(f"[API] Fetching company name for {ts_code}") # Try to get from stock_basic API try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating profile for {company_name}") # Generate profile using non-streaming API result = await client.generate_profile( company_name=company_name, ts_code=ts_code, financial_data=None ) logger.info(f"[API] Profile generation completed, success={result.get('success')}") return CompanyProfileResponse( ts_code=ts_code, company_name=company_name, content=result.get("content", ""), model=result.get("model", "gemini-2.5-flash"), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) @router.get("/analysis-config", response_model=AnalysisConfigResponse) async def get_analysis_config_endpoint(): """Get analysis configuration""" config = load_analysis_config() return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {})) @router.put("/analysis-config", response_model=AnalysisConfigResponse) async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse): """Update analysis configuration""" import logging logger = logging.getLogger(__name__) try: # 保存到文件 config_data = { "analysis_modules": analysis_config.analysis_modules } with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_data, f, ensure_ascii=False, indent=2) logger.info(f"[API] Analysis config updated successfully") return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules) except Exception as e: logger.error(f"[API] Failed to update analysis config: {e}") raise HTTPException( status_code=500, detail=f"Failed to update analysis config: {str(e)}" ) @router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse) async def generate_analysis( ts_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ Generate analysis for a company using Gemini AI Supported analysis types: - fundamental_analysis (基本面分析) - bull_case (看涨分析) - bear_case (看跌分析) - market_analysis (市场分析) - news_analysis (新闻分析) - trading_analysis (交易分析) - insider_institutional (内部人与机构动向分析) - final_conclusion (最终结论) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) # Get analysis configuration analysis_cfg = get_analysis_config(analysis_type) if not analysis_cfg: raise HTTPException( status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration" ) model = analysis_cfg.get("model", "gemini-2.5-flash") prompt_template = analysis_cfg.get("prompt_template", "") if not prompt_template: raise HTTPException( status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'" ) # Get company name from ts_code if not provided financial_data = None if not company_name: logger.info(f"[API] Fetching company name and financial data for {ts_code}") try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") # Try to get financial data for context try: # A simplified approach to get a single financial report current_year = datetime.now().year report_dates = [f"{current_year-1}1231"] # Get last year's report # Use get_financial_statement which is designed to return a single flat report latest_financials_report = await get_dm().get_financial_statement( stock_code=ts_code, report_dates=report_dates[0] ) if latest_financials_report: financial_data = {"series": latest_financials_report} except Exception as e: logger.warning(f"[API] Failed to get financial data: {e}") financial_data = None else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating {analysis_type} for {company_name}") # Initialize analysis client with configured model client = AnalysisClient(api_key=api_key, base_url=base_url, model=model) # Prepare dependency context for single-module generation # If the requested module declares dependencies, generate them first and inject their outputs context = {} try: dependencies = analysis_cfg.get("dependencies", []) or [] if dependencies: # Load full modules config to resolve dependency graph analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) # Collect all transitive dependencies all_required = set() def collect_all_deps(mod_name: str): for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []: if dep not in all_required: all_required.add(dep) collect_all_deps(dep) for dep in dependencies: all_required.add(dep) collect_all_deps(dep) # Build subgraph and topologically sort graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required} in_degree = {u: 0 for u in graph} for u, deps in graph.items(): for v in deps: in_degree[v] += 1 queue = [u for u, deg in in_degree.items() if deg == 0] order = [] while queue: u = queue.pop(0) order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(order) != len(graph): # Fallback: if cycle detected, just use any order order = list(all_required) # Generate dependencies in order completed = {} for mod in order: cfg = modules_config.get(mod, {}) dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])} dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model)) dep_result = await dep_client.generate_analysis( analysis_type=mod, company_name=company_name, ts_code=ts_code, prompt_template=cfg.get("prompt_template", ""), financial_data=financial_data, context=dep_ctx, ) completed[mod] = dep_result.get("content", "") if dep_result.get("success") else "" context = {dep: completed.get(dep, "") for dep in dependencies} except Exception: # Best-effort context; if anything goes wrong, continue without it context = {} # Generate analysis result = await client.generate_analysis( analysis_type=analysis_type, company_name=company_name, ts_code=ts_code, prompt_template=prompt_template, financial_data=financial_data, context=context, ) logger.info(f"[API] Analysis generation completed, success={result.get('success')}") return AnalysisResponse( ts_code=ts_code, company_name=company_name, analysis_type=analysis_type, content=result.get("content", ""), model=result.get("model", model), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) @router.get("/china/{ts_code}/analysis/{analysis_type}/stream") async def stream_analysis( ts_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ Stream analysis content chunks for a given module using OpenAI-compatible streaming. Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream), then the target module content is streamed. """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.") # Get analysis configuration analysis_cfg = get_analysis_config(analysis_type) if not analysis_cfg: raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration") model = analysis_cfg.get("model", "gemini-2.5-flash") prompt_template = analysis_cfg.get("prompt_template", "") if not prompt_template: raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'") # Get company name from ts_code if not provided; we don't need full financials here financial_data = None if not company_name: try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) else: company_name = ts_code except Exception: company_name = ts_code # Resolve dependency context (non-streaming) context = {} try: dependencies = analysis_cfg.get("dependencies", []) or [] if dependencies: analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) all_required = set() def collect_all(mod_name: str): for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []: if dep not in all_required: all_required.add(dep) collect_all(dep) for dep in dependencies: all_required.add(dep) collect_all(dep) graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required} in_degree = {u: 0 for u in graph} for u, deps in graph.items(): for v in deps: in_degree[v] += 1 queue = [u for u, deg in in_degree.items() if deg == 0] order = [] while queue: u = queue.pop(0) order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(order) != len(graph): order = list(all_required) completed = {} for mod in order: cfg = modules_config.get(mod, {}) dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])} dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model)) dep_result = await dep_client.generate_analysis( analysis_type=mod, company_name=company_name, ts_code=ts_code, prompt_template=cfg.get("prompt_template", ""), financial_data=financial_data, context=dep_ctx, ) completed[mod] = dep_result.get("content", "") if dep_result.get("success") else "" context = {dep: completed.get(dep, "") for dep in dependencies} except Exception: context = {} client = AnalysisClient(api_key=api_key, base_url=base_url, model=model) async def streamer(): # Optional header line to help client-side UI header = f"# {analysis_cfg.get('name', analysis_type)}\n\n" yield header async for chunk in client.generate_analysis_stream( analysis_type=analysis_type, company_name=company_name, ts_code=ts_code, prompt_template=prompt_template, financial_data=financial_data, context=context, ): yield chunk headers = { # 禁止中间层缓冲,确保尽快把分块推送给客户端 "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", } return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers)