""" API router for financial data (Tushare for China market) """ import json import os import time from datetime import datetime, timezone, timedelta from enum import Enum from typing import Dict, List from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse from app.core.config import settings from app.schemas.financial import ( BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse, AnalysisResponse, AnalysisConfigResponse, TodaySnapshotResponse, RealTimeQuoteResponse, ) from app.services.company_profile_client import CompanyProfileClient from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config from app.services.data_persistence_client import DataPersistenceClient, NewAnalysisResult, DailyMarketData, DailyMarketDataBatch # Lazy DataManager loader to avoid import-time failures when optional providers/config are missing _dm = None def get_dm(): global _dm if _dm is not None: return _dm try: from app.data_manager import data_manager as real_dm _dm = real_dm return _dm except Exception: class _StubDM: config = {} async def get_stock_basic(self, stock_code: str): return None async def get_financial_statements(self, stock_code: str, report_dates): return [] _dm = _StubDM() return _dm router = APIRouter() class MarketEnum(str, Enum): cn = "cn" us = "us" hk = "hk" jp = "jp" # Load metric config from file (project root is repo root, not backend/) # routers/ -> app/ -> backend/ -> repo root REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json") BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json") ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json") def _load_json(path: str) -> Dict: if not os.path.exists(path): return {} try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except Exception: return {} @router.get("/data-sources", response_model=Dict[str, List[str]]) async def get_data_sources(): """ Get the list of data sources that require an API key from the config. """ try: data_sources_config = get_dm().config.get("data_sources", {}) sources_requiring_keys = [ source for source, config in data_sources_config.items() if config.get("api_key_env") ] return {"sources": sources_requiring_keys} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}") @router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse]) async def generate_full_analysis( ts_code: str, company_name: str = Query(None, description="Company name for better context"), ): """ Generate a full analysis report by orchestrating multiple analysis modules based on dependencies defined in the configuration. """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Full analysis requested for {ts_code}") # Load base and analysis configurations base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) if not modules_config: raise HTTPException(status_code=404, detail="Analysis modules configuration not found.") # --- Dependency Resolution (Topological Sort) --- def topological_sort(graph): in_degree = {u: 0 for u in graph} for u in graph: for v in graph[u]: in_degree[v] += 1 queue = [u for u in graph if in_degree[u] == 0] sorted_order = [] while queue: u = queue.pop(0) sorted_order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(sorted_order) == len(graph): return sorted_order else: # Detect cycles and provide a meaningful error cycles = [] visited = set() path = [] def find_cycle_util(node): visited.add(node) path.append(node) for neighbor in graph.get(node, []): if neighbor in path: cycle_start_index = path.index(neighbor) cycles.append(path[cycle_start_index:] + [neighbor]) return if neighbor not in visited: find_cycle_util(neighbor) path.pop() for node in graph: if node not in visited: find_cycle_util(node) return None, cycles # Build dependency graph dependency_graph = { name: config.get("dependencies", []) for name, config in modules_config.items() } # Invert graph for topological sort (from dependency to dependent) adj_list = {u: [] for u in dependency_graph} for u, dependencies in dependency_graph.items(): for dep in dependencies: if dep in adj_list: adj_list[dep].append(u) sorted_modules, cycle = topological_sort(adj_list) if not sorted_modules: raise HTTPException( status_code=400, detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}" ) # --- Fetch common data (company name, financial data) --- # This logic is duplicated, could be refactored into a helper financial_data = None if not company_name: logger.info(f"[API] Fetching company name for {ts_code}") try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") else: company_name = ts_code except Exception as e: logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}") company_name = ts_code # --- Execute modules in order --- analysis_results = [] completed_modules_content = {} for module_type in sorted_modules: module_config = modules_config[module_type] logger.info(f"[Orchestrator] Starting analysis for module: {module_type}") client = AnalysisClient( api_key=api_key, base_url=base_url, model=module_config.get("model", "gemini-1.5-flash") ) # Gather context from completed dependencies context = { dep: completed_modules_content.get(dep, "") for dep in module_config.get("dependencies", []) } result = await client.generate_analysis( analysis_type=module_type, company_name=company_name, ts_code=ts_code, prompt_template=module_config.get("prompt_template", ""), financial_data=financial_data, context=context, ) response = AnalysisResponse( ts_code=ts_code, company_name=company_name, analysis_type=module_type, content=result.get("content", ""), model=result.get("model", module_config.get("model")), tokens=result.get("tokens", {}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) analysis_results.append(response) if response.success: completed_modules_content[module_type] = response.content # Persist analysis result via Rust data-persistence-service try: dp = DataPersistenceClient() await dp.create_analysis_result(NewAnalysisResult( symbol=ts_code, module_id=module_type, model_name=response.model, content=response.content, meta_data={"tokens": response.tokens, "elapsed_ms": response.elapsed_ms}, )) except Exception as e: logger.error(f"[Persistence] Failed to persist analysis result for {module_type}: {e}") else: # If a module fails, subsequent dependent modules will get an empty string for its context. # This prevents total failure but may affect quality. completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed." logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}") logger.info(f"[API] Full analysis for {ts_code} completed.") return analysis_results @router.get("/config", response_model=FinancialConfigResponse) async def get_financial_config(): data = _load_json(FINANCIAL_CONFIG_PATH) api_groups = data.get("api_groups", {}) return FinancialConfigResponse(api_groups=api_groups) @router.get("/{market}/{stock_code}", response_model=BatchFinancialDataResponse) async def get_financials( market: MarketEnum, stock_code: str, years: int = Query(10, ge=1, le=10), ): # Load metric config fin_cfg = _load_json(FINANCIAL_CONFIG_PATH) api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {}) # Meta tracking started_real = datetime.now(timezone.utc) started = time.perf_counter_ns() api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it api_calls_by_group: Dict[str, int] = {} steps: List[StepRecord] = [] # Get company name company_name = stock_code try: basic_data = await get_dm().get_stock_basic(stock_code=stock_code) if basic_data: company_name = basic_data.get("name", stock_code) except Exception: pass # Continue without it # Collect series per metric key series: Dict[str, List[Dict]] = {} errors: Dict[str, str] = {} # Generate date range for financial statements current_year = datetime.now().year report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)] # Fetch all financial statements at once step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running") steps.append(step_financials) # Fetch all financial statements at once (already in series format from provider) series = await get_dm().get_financial_statements(stock_code=stock_code, report_dates=report_dates) # Get the latest current year report period for market data latest_current_year_report = None if series: current_year_str = str(current_year) for key in series: if series[key]: for item in series[key]: period = item.get('period', '') if period.startswith(current_year_str) and not period.endswith('1231'): if latest_current_year_report is None or period > latest_current_year_report: latest_current_year_report = period if not series: errors["financial_statements"] = "Failed to fetch from all providers." step_financials.status = "done" step_financials.end_ts = datetime.now(timezone.utc).isoformat() step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) # --- 拉取市值/估值(daily_basic)与股价(daily)按年度末日期 --- try: # 仅当配置包含相应分组时再尝试拉取 has_daily_basic = bool(api_groups.get("daily_basic")) has_daily = bool(api_groups.get("daily")) # 目前仅对中国市场启用 daily_basic/daily 数据拉取,其他市场由对应 provider 后续实现 if market == MarketEnum.cn and (has_daily_basic or has_daily): step_market = StepRecord(name="拉取市值与股价", start_ts=datetime.now(timezone.utc).isoformat(), status="running") steps.append(step_market) # 构建市场数据查询日期:年度日期 + 当前年最新报告期 market_dates = report_dates.copy() if latest_current_year_report: # 查找当前年最新报告期对应的交易日(通常是报告期当月最后一天或最近交易日) try: report_date_obj = datetime.strptime(latest_current_year_report, '%Y%m%d') # 使用报告期日期作为查询日期(API会自动找到最近的交易日) market_dates.append(latest_current_year_report) except ValueError: pass try: if has_daily_basic: db_rows = await get_dm().get_data('get_daily_basic_points', stock_code=stock_code, trade_dates=market_dates) if isinstance(db_rows, list): for row in db_rows: trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date') if not trade_date: continue # 判断是否为当前年最新报告期的数据 is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report year = str(trade_date)[:4] if is_current_year_report: # 当前年最新报告期的数据,使用报告期period显示 period = latest_current_year_report else: # 其他年度数据,使用年度period显示(YYYY1231) period = f"{year}1231" for key, value in row.items(): if key in ['ts_code', 'trade_date', 'trade_dt', 'date']: continue if isinstance(value, (int, float)) and value is not None: if key not in series: series[key] = [] # 检查是否已存在该period的数据,如果存在则替换为最新的数据 existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1) if existing_index >= 0: series[key][existing_index] = {"period": period, "value": value} else: series[key].append({"period": period, "value": value}) if has_daily: d_rows = await get_dm().get_data('get_daily_points', stock_code=stock_code, trade_dates=market_dates) if isinstance(d_rows, list): for row in d_rows: trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date') if not trade_date: continue # 判断是否为当前年最新报告期的数据 is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report year = str(trade_date)[:4] if is_current_year_report: # 当前年最新报告期的数据,使用报告期period显示 period = latest_current_year_report else: # 其他年度数据,使用年度period显示(YYYY1231) period = f"{year}1231" for key, value in row.items(): if key in ['ts_code', 'trade_date', 'trade_dt', 'date']: continue if isinstance(value, (int, float)) and value is not None: if key not in series: series[key] = [] # 检查是否已存在该period的数据,如果存在则替换为最新的数据 existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1) if existing_index >= 0: series[key][existing_index] = {"period": period, "value": value} else: series[key].append({"period": period, "value": value}) except Exception as e: errors["market_data"] = f"Failed to fetch market data: {e}" finally: step_market.status = "done" step_market.end_ts = datetime.now(timezone.utc).isoformat() step_market.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) except Exception as e: errors["market_data_init"] = f"Market data init failed: {e}" finished_real = datetime.now(timezone.utc) elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000) if not series: raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors}) # 统一 period 字段;若仅有 year 则映射为 YYYY1231;然后去重与排序 for key, arr in list(series.items()): normalized: List[Dict] = [] for item in arr: period = item.get("period") if not period: year = item.get("year") if year: period = f"{str(year)}1231" if not period: # 跳过无法确定 period 的项 continue value = item.get("value") normalized.append({"period": str(period), "value": value}) # Deduplicate by period uniq = {it["period"]: it for it in normalized} arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["period"], reverse=True) arr_limited = arr_sorted_desc[:years] arr_sorted = sorted(arr_limited, key=lambda x: x["period"]) series[key] = arr_sorted # Create periods_list for derived metrics calculation periods_list = sorted(list(set(item["period"] for arr in series.values() for item in arr))) # Note: Derived financial metrics calculation has been moved to individual data providers # The data_manager.get_financial_statements() should handle this meta = FinancialMeta( started_at=started_real.isoformat(), finished_at=finished_real.isoformat(), elapsed_ms=elapsed_ms, api_calls_total=api_calls_total, api_calls_by_group=api_calls_by_group, current_action=None, steps=steps, ) return BatchFinancialDataResponse(ts_code=stock_code, name=company_name, series=series, meta=meta) @router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse) async def get_company_profile( ts_code: str, company_name: str = Query(None, description="Company name for better context"), ): """ Get company profile for a company using Gemini AI (non-streaming, single response) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Company profile requested for {ts_code}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") # Will be None if not set, handled by client if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) client = CompanyProfileClient( api_key=api_key, base_url=base_url, model="gemini-1.5-flash" ) # Get company name from ts_code if not provided if not company_name: logger.info(f"[API] Fetching company name for {ts_code}") # Try to get from stock_basic API try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating profile for {company_name}") # Generate profile using non-streaming API result = await client.generate_profile( company_name=company_name, ts_code=ts_code, financial_data=None ) logger.info(f"[API] Profile generation completed, success={result.get('success')}") return CompanyProfileResponse( ts_code=ts_code, company_name=company_name, content=result.get("content", ""), model=result.get("model", "gemini-2.5-flash"), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) @router.get("/analysis-config", response_model=AnalysisConfigResponse) async def get_analysis_config_endpoint(): """Get analysis configuration""" config = load_analysis_config() return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {})) @router.put("/analysis-config", response_model=AnalysisConfigResponse) async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse): """Update analysis configuration""" import logging logger = logging.getLogger(__name__) try: # 保存到文件 config_data = { "analysis_modules": analysis_config.analysis_modules } with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_data, f, ensure_ascii=False, indent=2) logger.info(f"[API] Analysis config updated successfully") return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules) except Exception as e: logger.error(f"[API] Failed to update analysis config: {e}") raise HTTPException( status_code=500, detail=f"Failed to update analysis config: {str(e)}" ) @router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse) async def generate_analysis( ts_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ Generate analysis for a company using Gemini AI Supported analysis types: - fundamental_analysis (基本面分析) - bull_case (看涨分析) - bear_case (看跌分析) - market_analysis (市场分析) - news_analysis (新闻分析) - trading_analysis (交易分析) - insider_institutional (内部人与机构动向分析) - final_conclusion (最终结论) """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException( status_code=500, detail=f"API key for {llm_provider} not configured." ) # Get analysis configuration analysis_cfg = get_analysis_config(analysis_type) if not analysis_cfg: raise HTTPException( status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration" ) model = analysis_cfg.get("model", "gemini-2.5-flash") prompt_template = analysis_cfg.get("prompt_template", "") if not prompt_template: raise HTTPException( status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'" ) # Get company name from ts_code if not provided financial_data = None if not company_name: logger.info(f"[API] Fetching company name and financial data for {ts_code}") try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) logger.info(f"[API] Got company name: {company_name}") # Try to get financial data for context try: # A simplified approach to get a single financial report current_year = datetime.now().year report_dates = [f"{current_year-1}1231"] # Get last year's report # Use get_financial_statement which is designed to return a single flat report latest_financials_report = await get_dm().get_financial_statement( stock_code=ts_code, report_dates=report_dates[0] ) if latest_financials_report: financial_data = {"series": latest_financials_report} except Exception as e: logger.warning(f"[API] Failed to get financial data: {e}") financial_data = None else: company_name = ts_code except Exception as e: logger.warning(f"[API] Failed to get company name: {e}") company_name = ts_code logger.info(f"[API] Generating {analysis_type} for {company_name}") # Initialize analysis client with configured model client = AnalysisClient(api_key=api_key, base_url=base_url, model=model) # Prepare dependency context for single-module generation # If the requested module declares dependencies, generate them first and inject their outputs context = {} try: dependencies = analysis_cfg.get("dependencies", []) or [] if dependencies: # Load full modules config to resolve dependency graph analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) # Collect all transitive dependencies all_required = set() def collect_all_deps(mod_name: str): for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []: if dep not in all_required: all_required.add(dep) collect_all_deps(dep) for dep in dependencies: all_required.add(dep) collect_all_deps(dep) # Build subgraph and topologically sort graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required} in_degree = {u: 0 for u in graph} for u, deps in graph.items(): for v in deps: in_degree[v] += 1 queue = [u for u, deg in in_degree.items() if deg == 0] order = [] while queue: u = queue.pop(0) order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(order) != len(graph): # Fallback: if cycle detected, just use any order order = list(all_required) # Generate dependencies in order completed = {} for mod in order: cfg = modules_config.get(mod, {}) dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])} dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model)) dep_result = await dep_client.generate_analysis( analysis_type=mod, company_name=company_name, ts_code=ts_code, prompt_template=cfg.get("prompt_template", ""), financial_data=financial_data, context=dep_ctx, ) completed[mod] = dep_result.get("content", "") if dep_result.get("success") else "" context = {dep: completed.get(dep, "") for dep in dependencies} except Exception: # Best-effort context; if anything goes wrong, continue without it context = {} # Generate analysis result = await client.generate_analysis( analysis_type=analysis_type, company_name=company_name, ts_code=ts_code, prompt_template=prompt_template, financial_data=financial_data, context=context, ) logger.info(f"[API] Analysis generation completed, success={result.get('success')}") response = AnalysisResponse( ts_code=ts_code, company_name=company_name, analysis_type=analysis_type, content=result.get("content", ""), model=result.get("model", model), tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), elapsed_ms=result.get("elapsed_ms", 0), success=result.get("success", False), error=result.get("error") ) # Persist on success if response.success: try: dp = DataPersistenceClient() await dp.create_analysis_result(NewAnalysisResult( symbol=ts_code, module_id=analysis_type, model_name=response.model, content=response.content, meta_data={"tokens": response.tokens, "elapsed_ms": response.elapsed_ms}, )) except Exception as e: # Log and continue logger = __import__("logging").getLogger(__name__) logger.error(f"[Persistence] Failed to persist analysis result: {e}") return response @router.get("/china/{ts_code}/snapshot", response_model=TodaySnapshotResponse) async def get_today_snapshot(ts_code: str): """ 获取“昨日快照”(以上一个自然日为基准,映射为不晚于该日的最近交易日)的市场数据: - 日期(trade_date) - 收盘价(close) - 市值(total_mv,返回原始万元单位值) - 估值(pe、pb) - 股息率(dv_ratio,单位%) """ try: # 优先取公司名称(可选) company_name = None try: basic = await get_dm().get_stock_basic(stock_code=ts_code) if basic: company_name = basic.get("name") except Exception: company_name = None # 以“昨天”为查询日期 base_dt = (datetime.now() - timedelta(days=1)).date() base_str = base_dt.strftime("%Y%m%d") # 通过数据持久化服务获取最近交易日快照(向前看10天) dp = DataPersistenceClient() start_dt = base_dt - timedelta(days=10) daily_list = await dp.get_daily_data_by_symbol( symbol=ts_code, start_date=start_dt, end_date=base_dt + timedelta(days=1) ) # 缓存回填:若无数据,则从数据源抓取后写入 Rust 持久化服务 if not isinstance(daily_list, list) or len(daily_list) == 0: try: # 1) 优先用 daily_basic(估值/市值更全) rows = await get_dm().get_data( 'get_daily_basic_points', stock_code=ts_code, trade_dates=[base_str] ) persisted = False if isinstance(rows, list) and rows: r = rows[0] trade_date = str(r.get('trade_date') or r.get('trade_dt') or r.get('date') or base_str) y, m, d = int(trade_date[:4]), int(trade_date[4:6]), int(trade_date[6:8]) record = DailyMarketData( symbol=ts_code, trade_date=datetime(y, m, d).date(), open_price=None, high_price=None, low_price=None, close_price=r.get('close'), volume=r.get('vol') or r.get('volume'), pe=r.get('pe'), pb=r.get('pb'), total_mv=r.get('total_mv'), ) await dp.batch_insert_daily_data(DailyMarketDataBatch(records=[record])) persisted = True # 2) 如无 close,再兜底用 daily(仅价量) if not persisted: d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=[base_str]) if isinstance(d_rows, list) and d_rows: d0 = d_rows[0] trade_date = str(d0.get('trade_date') or d0.get('trade_dt') or d0.get('date') or base_str) y, m, d = int(trade_date[:4]), int(trade_date[4:6]), int(trade_date[6:8]) record = DailyMarketData( symbol=ts_code, trade_date=datetime(y, m, d).date(), open_price=d0.get('open'), high_price=d0.get('high'), low_price=d0.get('low'), close_price=d0.get('close'), volume=d0.get('vol') or d0.get('volume'), pe=None, pb=None, total_mv=None, ) await dp.batch_insert_daily_data(DailyMarketDataBatch(records=[record])) # 3) 回读确认 daily_list = await dp.get_daily_data_by_symbol( symbol=ts_code, start_date=start_dt, end_date=base_dt + timedelta(days=1) ) except Exception: # 回填失败不阻断流程 pass # 选择 <= base_str 的最后一条记录 trade_date = base_str close = None pe = None pb = None dv_ratio = None total_mv = None if isinstance(daily_list, list) and daily_list: candidates = [d for d in daily_list if d.trade_date.strftime("%Y%m%d") <= base_str] if candidates: last = sorted(candidates, key=lambda r: r.trade_date.strftime("%Y%m%d"))[-1] trade_date = last.trade_date.strftime("%Y%m%d") close = last.close_price pe = last.pe pb = last.pb total_mv = last.total_mv # dv_ratio 可能没有,保持 None return TodaySnapshotResponse( ts_code=ts_code, trade_date=trade_date, name=company_name, close=close, pe=pe, pb=pb, dv_ratio=dv_ratio, total_mv=total_mv, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}") @router.get("/{market}/{stock_code}/snapshot", response_model=TodaySnapshotResponse) async def get_market_snapshot(market: MarketEnum, stock_code: str): """ 市场无关的“昨日快照”接口。 - CN: 复用中国市场的快照逻辑(daily_basic/daily)。 - 其他市场: 兜底使用日行情获取最近交易日收盘价;其余字段暂返回空值。 """ try: # 公司名称(可选) company_name = None try: basic = await get_dm().get_stock_basic(stock_code=stock_code) if basic: company_name = basic.get("name") except Exception: company_name = None base_dt = (datetime.now() - timedelta(days=1)).date() base_str = base_dt.strftime("%Y%m%d") # 通过数据持久化服务获取最近交易日快照(向前看10天) dp = DataPersistenceClient() start_dt = base_dt - timedelta(days=10) daily_list = await dp.get_daily_data_by_symbol( symbol=stock_code, start_date=start_dt, end_date=base_dt + timedelta(days=1) ) # 缓存回填(所有市场通用):若无数据,从数据源抓取并写入,然后回读 if not isinstance(daily_list, list) or len(daily_list) == 0: try: start_str = start_dt.strftime("%Y%m%d") end_str = (base_dt + timedelta(days=1)).strftime("%Y%m%d") rows = await get_dm().get_daily_price(stock_code=stock_code, start_date=start_str, end_date=end_str) last_rec = None if isinstance(rows, list) and rows: try: candidates = [r for r in rows if str(r.get("trade_date") or r.get("date") or "") <= base_str] if candidates: last_rec = sorted(candidates, key=lambda r: str(r.get("trade_date") or r.get("date") or ""))[-1] except Exception: last_rec = None if last_rec: t = str(last_rec.get("trade_date") or last_rec.get("date") or base_str) y, m, d = int(t[:4]), int(t[4:6]), int(t[6:8]) record = DailyMarketData( symbol=stock_code, trade_date=datetime(y, m, d).date(), open_price=last_rec.get('open'), high_price=last_rec.get('high'), low_price=last_rec.get('low'), close_price=last_rec.get('close'), volume=last_rec.get('vol') or last_rec.get('volume'), pe=None, pb=None, total_mv=None, ) await dp.batch_insert_daily_data(DailyMarketDataBatch(records=[record])) daily_list = await dp.get_daily_data_by_symbol( symbol=stock_code, start_date=start_dt, end_date=base_dt + timedelta(days=1) ) except Exception: pass trade_date = base_str close = None pe = None pb = None total_mv = None if isinstance(daily_list, list) and daily_list: candidates = [d for d in daily_list if d.trade_date.strftime("%Y%m%d") <= base_str] if candidates: last = sorted(candidates, key=lambda r: r.trade_date.strftime("%Y%m%d"))[-1] trade_date = last.trade_date.strftime("%Y%m%d") close = last.close_price pe = last.pe pb = last.pb total_mv = last.total_mv return TodaySnapshotResponse( ts_code=stock_code, trade_date=trade_date, name=company_name, close=close, pe=pe if market == MarketEnum.cn else None, pb=pb if market == MarketEnum.cn else None, dv_ratio=None, total_mv=total_mv if market == MarketEnum.cn else None, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}") @router.get("/{market}/{stock_code}/realtime", response_model=RealTimeQuoteResponse) async def get_realtime_quote( market: MarketEnum, stock_code: str, max_age_seconds: int = Query(30, ge=1, le=3600), ): """ 实时报价(严格 TTL,无数据或过期即 404,不做数据源兜底抓取)。 数据需由外部任务预热写入 Rust 持久化服务。 """ try: dp = DataPersistenceClient() quote = await dp.get_latest_realtime_quote(market.value, stock_code, max_age_seconds=max_age_seconds) if not quote: raise HTTPException(status_code=404, detail="quote not found or stale") return RealTimeQuoteResponse( symbol=quote.symbol, market=quote.market, ts=quote.ts.isoformat(), price=quote.price, open_price=quote.open_price, high_price=quote.high_price, low_price=quote.low_price, prev_close=quote.prev_close, change=quote.change, change_percent=quote.change_percent, volume=quote.volume, source=quote.source, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch realtime quote: {e}") @router.get("/china/{ts_code}/analysis/{analysis_type}/stream") async def stream_analysis( ts_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ Stream analysis content chunks for a given module using OpenAI-compatible streaming. Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream), then the target module content is streamed. """ import logging logger = logging.getLogger(__name__) logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}") # Load config base_cfg = _load_json(BASE_CONFIG_PATH) llm_provider = base_cfg.get("llm", {}).get("provider", "gemini") llm_config = base_cfg.get("llm", {}).get(llm_provider, {}) api_key = llm_config.get("api_key") base_url = llm_config.get("base_url") if not api_key: logger.error(f"[API] API key for {llm_provider} not configured") raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.") # Get analysis configuration analysis_cfg = get_analysis_config(analysis_type) if not analysis_cfg: raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration") model = analysis_cfg.get("model", "gemini-2.5-flash") prompt_template = analysis_cfg.get("prompt_template", "") if not prompt_template: raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'") # Get company name from ts_code if not provided; we don't need full financials here financial_data = None if not company_name: try: basic_data = await get_dm().get_stock_basic(stock_code=ts_code) if basic_data: company_name = basic_data.get("name", ts_code) else: company_name = ts_code except Exception: company_name = ts_code # Resolve dependency context (non-streaming) context = {} try: dependencies = analysis_cfg.get("dependencies", []) or [] if dependencies: analysis_config_full = load_analysis_config() modules_config = analysis_config_full.get("analysis_modules", {}) all_required = set() def collect_all(mod_name: str): for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []: if dep not in all_required: all_required.add(dep) collect_all(dep) for dep in dependencies: all_required.add(dep) collect_all(dep) graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required} in_degree = {u: 0 for u in graph} for u, deps in graph.items(): for v in deps: in_degree[v] += 1 queue = [u for u, deg in in_degree.items() if deg == 0] order = [] while queue: u = queue.pop(0) order.append(u) for v in graph.get(u, []): in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) if len(order) != len(graph): order = list(all_required) completed = {} for mod in order: cfg = modules_config.get(mod, {}) dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])} dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model)) dep_result = await dep_client.generate_analysis( analysis_type=mod, company_name=company_name, ts_code=ts_code, prompt_template=cfg.get("prompt_template", ""), financial_data=financial_data, context=dep_ctx, ) completed[mod] = dep_result.get("content", "") if dep_result.get("success") else "" context = {dep: completed.get(dep, "") for dep in dependencies} except Exception: context = {} client = AnalysisClient(api_key=api_key, base_url=base_url, model=model) async def streamer(): # Optional header line to help client-side UI header = f"# {analysis_cfg.get('name', analysis_type)}\n\n" yield header async for chunk in client.generate_analysis_stream( analysis_type=analysis_type, company_name=company_name, ts_code=ts_code, prompt_template=prompt_template, financial_data=financial_data, context=context, ): yield chunk headers = { # 禁止中间层缓冲,确保尽快把分块推送给客户端 "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", } return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers) @router.get("/{market}/{stock_code}/analysis/{analysis_type}/stream") async def stream_analysis_market( market: MarketEnum, stock_code: str, analysis_type: str, company_name: str = Query(None, description="Company name for better context"), ): """ 市场无关的分析流接口。逻辑与中国市场一致,仅路径不同。 """ return await stream_analysis(stock_code, analysis_type, company_name)