Fundamental_Analysis/backend/app/routers/financial.py
xucheng edfd51b0a7 feat: 昨日快照API与前端卡片;注册orgs路由;多项优化
- backend(financial): 新增 /china/{ts_code}/snapshot API,返回昨日交易日的收盘价/市值/PE/PB/股息率等

- backend(schemas): 新增 TodaySnapshotResponse

- backend(main): 注册 orgs 路由 /api/v1/orgs

- backend(providers:finnhub): 归一化财报字段并计算 gross_margin/net_margin/ROA/ROE

- backend(providers:tushare): 股东户数报告期与财报期对齐

- backend(routers/financial): years 默认改为 10(最大 10)

- config: analysis-config.json 切换到 qwen-flash-2025-07-28

- frontend(report/[symbol]): 新增“昨日快照”卡片、限制展示期数为10、优化增长与阈值高亮、修正类名与标题处理

- frontend(reports/[id]): 统一 period 变量与计算,修正表格 key

- frontend(hooks): 新增 useChinaSnapshot 钩子与类型

- scripts: dev.sh 增加调试输出
2025-11-05 17:00:32 +08:00

928 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API router for financial data (Tushare for China market)
"""
import json
import os
import time
from datetime import datetime, timezone, timedelta
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.schemas.financial import (
BatchFinancialDataResponse,
FinancialConfigResponse,
FinancialMeta,
StepRecord,
CompanyProfileResponse,
AnalysisResponse,
AnalysisConfigResponse,
TodaySnapshotResponse,
)
from app.services.company_profile_client import CompanyProfileClient
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
# Lazy DataManager loader to avoid import-time failures when optional providers/config are missing
_dm = None
def get_dm():
global _dm
if _dm is not None:
return _dm
try:
from app.data_manager import data_manager as real_dm
_dm = real_dm
return _dm
except Exception:
class _StubDM:
config = {}
async def get_stock_basic(self, stock_code: str):
return None
async def get_financial_statements(self, stock_code: str, report_dates):
return []
_dm = _StubDM()
return _dm
router = APIRouter()
# Load metric config from file (project root is repo root, not backend/)
# routers/ -> app/ -> backend/ -> repo root
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
def _load_json(path: str) -> Dict:
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
@router.get("/data-sources", response_model=Dict[str, List[str]])
async def get_data_sources():
"""
Get the list of data sources that require an API key from the config.
"""
try:
data_sources_config = get_dm().config.get("data_sources", {})
sources_requiring_keys = [
source for source, config in data_sources_config.items()
if config.get("api_key_env")
]
return {"sources": sources_requiring_keys}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}")
@router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse])
async def generate_full_analysis(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Generate a full analysis report by orchestrating multiple analysis modules
based on dependencies defined in the configuration.
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Full analysis requested for {ts_code}")
# Load base and analysis configurations
base_cfg = _load_json(BASE_CONFIG_PATH)
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
api_key = llm_config.get("api_key")
base_url = llm_config.get("base_url")
if not api_key:
logger.error(f"[API] API key for {llm_provider} not configured")
raise HTTPException(
status_code=500,
detail=f"API key for {llm_provider} not configured."
)
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
if not modules_config:
raise HTTPException(status_code=404, detail="Analysis modules configuration not found.")
# --- Dependency Resolution (Topological Sort) ---
def topological_sort(graph):
in_degree = {u: 0 for u in graph}
for u in graph:
for v in graph[u]:
in_degree[v] += 1
queue = [u for u in graph if in_degree[u] == 0]
sorted_order = []
while queue:
u = queue.pop(0)
sorted_order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(sorted_order) == len(graph):
return sorted_order
else:
# Detect cycles and provide a meaningful error
cycles = []
visited = set()
path = []
def find_cycle_util(node):
visited.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor in path:
cycle_start_index = path.index(neighbor)
cycles.append(path[cycle_start_index:] + [neighbor])
return
if neighbor not in visited:
find_cycle_util(neighbor)
path.pop()
for node in graph:
if node not in visited:
find_cycle_util(node)
return None, cycles
# Build dependency graph
dependency_graph = {
name: config.get("dependencies", [])
for name, config in modules_config.items()
}
# Invert graph for topological sort (from dependency to dependent)
adj_list = {u: [] for u in dependency_graph}
for u, dependencies in dependency_graph.items():
for dep in dependencies:
if dep in adj_list:
adj_list[dep].append(u)
sorted_modules, cycle = topological_sort(adj_list)
if not sorted_modules:
raise HTTPException(
status_code=400,
detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}"
)
# --- Fetch common data (company name, financial data) ---
# This logic is duplicated, could be refactored into a helper
financial_data = None
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
except Exception as e:
logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}")
company_name = ts_code
# --- Execute modules in order ---
analysis_results = []
completed_modules_content = {}
for module_type in sorted_modules:
module_config = modules_config[module_type]
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
client = AnalysisClient(
api_key=api_key,
base_url=base_url,
model=module_config.get("model", "gemini-1.5-flash")
)
# Gather context from completed dependencies
context = {
dep: completed_modules_content.get(dep, "")
for dep in module_config.get("dependencies", [])
}
result = await client.generate_analysis(
analysis_type=module_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=module_config.get("prompt_template", ""),
financial_data=financial_data,
context=context,
)
response = AnalysisResponse(
ts_code=ts_code,
company_name=company_name,
analysis_type=module_type,
content=result.get("content", ""),
model=result.get("model", module_config.get("model")),
tokens=result.get("tokens", {}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
analysis_results.append(response)
if response.success:
completed_modules_content[module_type] = response.content
else:
# If a module fails, subsequent dependent modules will get an empty string for its context.
# This prevents total failure but may affect quality.
completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed."
logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}")
logger.info(f"[API] Full analysis for {ts_code} completed.")
return analysis_results
@router.get("/config", response_model=FinancialConfigResponse)
async def get_financial_config():
data = _load_json(FINANCIAL_CONFIG_PATH)
api_groups = data.get("api_groups", {})
return FinancialConfigResponse(api_groups=api_groups)
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
async def get_china_financials(
ts_code: str,
years: int = Query(10, ge=1, le=10),
):
# Load metric config
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
# Meta tracking
started_real = datetime.now(timezone.utc)
started = time.perf_counter_ns()
api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it
api_calls_by_group: Dict[str, int] = {}
steps: List[StepRecord] = []
# Get company name
company_name = ts_code
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
except Exception:
pass # Continue without it
# Collect series per metric key
series: Dict[str, List[Dict]] = {}
errors: Dict[str, str] = {}
# Generate date range for financial statements
current_year = datetime.now().year
report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)]
# Fetch all financial statements at once
step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running")
steps.append(step_financials)
# Fetch all financial statements at once (already in series format from provider)
series = await get_dm().get_financial_statements(stock_code=ts_code, report_dates=report_dates)
# Get the latest current year report period for market data
latest_current_year_report = None
if series:
current_year_str = str(current_year)
for key in series:
if series[key]:
for item in series[key]:
period = item.get('period', '')
if period.startswith(current_year_str) and not period.endswith('1231'):
if latest_current_year_report is None or period > latest_current_year_report:
latest_current_year_report = period
if not series:
errors["financial_statements"] = "Failed to fetch from all providers."
step_financials.status = "done"
step_financials.end_ts = datetime.now(timezone.utc).isoformat()
step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
# --- 拉取市值/估值daily_basic与股价daily按年度末日期 ---
try:
# 仅当配置包含相应分组时再尝试拉取
has_daily_basic = bool(api_groups.get("daily_basic"))
has_daily = bool(api_groups.get("daily"))
if has_daily_basic or has_daily:
step_market = StepRecord(name="拉取市值与股价", start_ts=datetime.now(timezone.utc).isoformat(), status="running")
steps.append(step_market)
# 构建市场数据查询日期:年度日期 + 当前年最新报告期
market_dates = report_dates.copy()
if latest_current_year_report:
# 查找当前年最新报告期对应的交易日(通常是报告期当月最后一天或最近交易日)
try:
report_date_obj = datetime.strptime(latest_current_year_report, '%Y%m%d')
# 使用报告期日期作为查询日期API会自动找到最近的交易日
market_dates.append(latest_current_year_report)
except ValueError:
pass
try:
if has_daily_basic:
db_rows = await get_dm().get_data('get_daily_basic_points', stock_code=ts_code, trade_dates=market_dates)
if isinstance(db_rows, list):
for row in db_rows:
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
if not trade_date:
continue
# 判断是否为当前年最新报告期的数据
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
year = str(trade_date)[:4]
if is_current_year_report:
# 当前年最新报告期的数据使用报告期period显示
period = latest_current_year_report
else:
# 其他年度数据使用年度period显示YYYY1231
period = f"{year}1231"
for key, value in row.items():
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
# 检查是否已存在该period的数据如果存在则替换为最新的数据
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
if existing_index >= 0:
series[key][existing_index] = {"period": period, "value": value}
else:
series[key].append({"period": period, "value": value})
if has_daily:
d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=market_dates)
if isinstance(d_rows, list):
for row in d_rows:
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
if not trade_date:
continue
# 判断是否为当前年最新报告期的数据
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
year = str(trade_date)[:4]
if is_current_year_report:
# 当前年最新报告期的数据使用报告期period显示
period = latest_current_year_report
else:
# 其他年度数据使用年度period显示YYYY1231
period = f"{year}1231"
for key, value in row.items():
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
# 检查是否已存在该period的数据如果存在则替换为最新的数据
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
if existing_index >= 0:
series[key][existing_index] = {"period": period, "value": value}
else:
series[key].append({"period": period, "value": value})
except Exception as e:
errors["market_data"] = f"Failed to fetch market data: {e}"
finally:
step_market.status = "done"
step_market.end_ts = datetime.now(timezone.utc).isoformat()
step_market.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
except Exception as e:
errors["market_data_init"] = f"Market data init failed: {e}"
finished_real = datetime.now(timezone.utc)
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
if not series:
raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors})
# Truncate periods and sort (the data should already be mostly correct, but we ensure)
for key, arr in series.items():
# Deduplicate and sort desc by period, then cut to requested periods, and return asc
uniq = {item["period"]: item for item in arr}
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["period"], reverse=True)
arr_limited = arr_sorted_desc[:years]
arr_sorted = sorted(arr_limited, key=lambda x: x["period"])
series[key] = arr_sorted
# Create periods_list for derived metrics calculation
periods_list = sorted(list(set(item["period"] for arr in series.values() for item in arr)))
# Note: Derived financial metrics calculation has been moved to individual data providers
# The data_manager.get_financial_statements() should handle this
meta = FinancialMeta(
started_at=started_real.isoformat(),
finished_at=finished_real.isoformat(),
elapsed_ms=elapsed_ms,
api_calls_total=api_calls_total,
api_calls_by_group=api_calls_by_group,
current_action=None,
steps=steps,
)
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
async def get_company_profile(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Get company profile for a company using Gemini AI (non-streaming, single response)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Company profile requested for {ts_code}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
api_key = llm_config.get("api_key")
base_url = llm_config.get("base_url") # Will be None if not set, handled by client
if not api_key:
logger.error(f"[API] API key for {llm_provider} not configured")
raise HTTPException(
status_code=500,
detail=f"API key for {llm_provider} not configured."
)
client = CompanyProfileClient(
api_key=api_key,
base_url=base_url,
model="gemini-1.5-flash"
)
# Get company name from ts_code if not provided
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
# Try to get from stock_basic API
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating profile for {company_name}")
# Generate profile using non-streaming API
result = await client.generate_profile(
company_name=company_name,
ts_code=ts_code,
financial_data=None
)
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
return CompanyProfileResponse(
ts_code=ts_code,
company_name=company_name,
content=result.get("content", ""),
model=result.get("model", "gemini-2.5-flash"),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
async def get_analysis_config_endpoint():
"""Get analysis configuration"""
config = load_analysis_config()
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
"""Update analysis configuration"""
import logging
logger = logging.getLogger(__name__)
try:
# 保存到文件
config_data = {
"analysis_modules": analysis_config.analysis_modules
}
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(config_data, f, ensure_ascii=False, indent=2)
logger.info(f"[API] Analysis config updated successfully")
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
except Exception as e:
logger.error(f"[API] Failed to update analysis config: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to update analysis config: {str(e)}"
)
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
async def generate_analysis(
ts_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Generate analysis for a company using Gemini AI
Supported analysis types:
- fundamental_analysis (基本面分析)
- bull_case (看涨分析)
- bear_case (看跌分析)
- market_analysis (市场分析)
- news_analysis (新闻分析)
- trading_analysis (交易分析)
- insider_institutional (内部人与机构动向分析)
- final_conclusion (最终结论)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
api_key = llm_config.get("api_key")
base_url = llm_config.get("base_url")
if not api_key:
logger.error(f"[API] API key for {llm_provider} not configured")
raise HTTPException(
status_code=500,
detail=f"API key for {llm_provider} not configured."
)
# Get analysis configuration
analysis_cfg = get_analysis_config(analysis_type)
if not analysis_cfg:
raise HTTPException(
status_code=404,
detail=f"Analysis type '{analysis_type}' not found in configuration"
)
model = analysis_cfg.get("model", "gemini-2.5-flash")
prompt_template = analysis_cfg.get("prompt_template", "")
if not prompt_template:
raise HTTPException(
status_code=500,
detail=f"Prompt template not found for analysis type '{analysis_type}'"
)
# Get company name from ts_code if not provided
financial_data = None
if not company_name:
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
# Try to get financial data for context
try:
# A simplified approach to get a single financial report
current_year = datetime.now().year
report_dates = [f"{current_year-1}1231"] # Get last year's report
# Use get_financial_statement which is designed to return a single flat report
latest_financials_report = await get_dm().get_financial_statement(
stock_code=ts_code,
report_dates=report_dates[0]
)
if latest_financials_report:
financial_data = {"series": latest_financials_report}
except Exception as e:
logger.warning(f"[API] Failed to get financial data: {e}")
financial_data = None
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating {analysis_type} for {company_name}")
# Initialize analysis client with configured model
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
# Prepare dependency context for single-module generation
# If the requested module declares dependencies, generate them first and inject their outputs
context = {}
try:
dependencies = analysis_cfg.get("dependencies", []) or []
if dependencies:
# Load full modules config to resolve dependency graph
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
# Collect all transitive dependencies
all_required = set()
def collect_all_deps(mod_name: str):
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
if dep not in all_required:
all_required.add(dep)
collect_all_deps(dep)
for dep in dependencies:
all_required.add(dep)
collect_all_deps(dep)
# Build subgraph and topologically sort
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
in_degree = {u: 0 for u in graph}
for u, deps in graph.items():
for v in deps:
in_degree[v] += 1
queue = [u for u, deg in in_degree.items() if deg == 0]
order = []
while queue:
u = queue.pop(0)
order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(order) != len(graph):
# Fallback: if cycle detected, just use any order
order = list(all_required)
# Generate dependencies in order
completed = {}
for mod in order:
cfg = modules_config.get(mod, {})
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
dep_result = await dep_client.generate_analysis(
analysis_type=mod,
company_name=company_name,
ts_code=ts_code,
prompt_template=cfg.get("prompt_template", ""),
financial_data=financial_data,
context=dep_ctx,
)
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
context = {dep: completed.get(dep, "") for dep in dependencies}
except Exception:
# Best-effort context; if anything goes wrong, continue without it
context = {}
# Generate analysis
result = await client.generate_analysis(
analysis_type=analysis_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=prompt_template,
financial_data=financial_data,
context=context,
)
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
return AnalysisResponse(
ts_code=ts_code,
company_name=company_name,
analysis_type=analysis_type,
content=result.get("content", ""),
model=result.get("model", model),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
@router.get("/china/{ts_code}/snapshot", response_model=TodaySnapshotResponse)
async def get_today_snapshot(ts_code: str):
"""
获取“昨日快照”(以上一个自然日为基准,映射为不晚于该日的最近交易日)的市场数据:
- 日期trade_date
- 收盘价close
- 市值total_mv返回原始万元单位值
- 估值pe、pb
- 股息率dv_ratio单位%
"""
try:
# 优先取公司名称(可选)
company_name = None
try:
basic = await get_dm().get_stock_basic(stock_code=ts_code)
if basic:
company_name = basic.get("name")
except Exception:
company_name = None
# 以“昨天”为查询日期provider 内部会解析为“不晚于该日的最近交易日”
base_dt = (datetime.now() - timedelta(days=1)).date()
base_str = base_dt.strftime("%Y%m%d")
# 从 daily_basic 取主要字段,包含 close、pe、pb、dv_ratio、total_mv
rows = await get_dm().get_data(
'get_daily_basic_points',
stock_code=ts_code,
trade_dates=[base_str]
)
row = None
if isinstance(rows, list) and rows:
# get_daily_basic_points 返回每个交易日一条记录
row = rows[0]
trade_date = None
close = None
pe = None
pb = None
dv_ratio = None
total_mv = None
if isinstance(row, dict):
trade_date = str(row.get('trade_date') or row.get('trade_dt') or row.get('date') or base_str)
close = row.get('close')
pe = row.get('pe')
pb = row.get('pb')
dv_ratio = row.get('dv_ratio')
total_mv = row.get('total_mv')
# 若 close 缺失,兜底从 daily 取收盘价
if close is None:
d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=[base_str])
if isinstance(d_rows, list) and d_rows:
d = d_rows[0]
close = d.get('close')
if trade_date is None:
trade_date = str(d.get('trade_date') or d.get('trade_dt') or d.get('date') or base_str)
if trade_date is None:
trade_date = base_str
return TodaySnapshotResponse(
ts_code=ts_code,
trade_date=trade_date,
name=company_name,
close=close,
pe=pe,
pb=pb,
dv_ratio=dv_ratio,
total_mv=total_mv,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}")
@router.get("/china/{ts_code}/analysis/{analysis_type}/stream")
async def stream_analysis(
ts_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
Stream analysis content chunks for a given module using OpenAI-compatible streaming.
Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream),
then the target module content is streamed.
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}")
# Load config
base_cfg = _load_json(BASE_CONFIG_PATH)
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
api_key = llm_config.get("api_key")
base_url = llm_config.get("base_url")
if not api_key:
logger.error(f"[API] API key for {llm_provider} not configured")
raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.")
# Get analysis configuration
analysis_cfg = get_analysis_config(analysis_type)
if not analysis_cfg:
raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration")
model = analysis_cfg.get("model", "gemini-2.5-flash")
prompt_template = analysis_cfg.get("prompt_template", "")
if not prompt_template:
raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'")
# Get company name from ts_code if not provided; we don't need full financials here
financial_data = None
if not company_name:
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
else:
company_name = ts_code
except Exception:
company_name = ts_code
# Resolve dependency context (non-streaming)
context = {}
try:
dependencies = analysis_cfg.get("dependencies", []) or []
if dependencies:
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
all_required = set()
def collect_all(mod_name: str):
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
if dep not in all_required:
all_required.add(dep)
collect_all(dep)
for dep in dependencies:
all_required.add(dep)
collect_all(dep)
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
in_degree = {u: 0 for u in graph}
for u, deps in graph.items():
for v in deps:
in_degree[v] += 1
queue = [u for u, deg in in_degree.items() if deg == 0]
order = []
while queue:
u = queue.pop(0)
order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(order) != len(graph):
order = list(all_required)
completed = {}
for mod in order:
cfg = modules_config.get(mod, {})
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
dep_result = await dep_client.generate_analysis(
analysis_type=mod,
company_name=company_name,
ts_code=ts_code,
prompt_template=cfg.get("prompt_template", ""),
financial_data=financial_data,
context=dep_ctx,
)
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
context = {dep: completed.get(dep, "") for dep in dependencies}
except Exception:
context = {}
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
async def streamer():
# Optional header line to help client-side UI
header = f"# {analysis_cfg.get('name', analysis_type)}\n\n"
yield header
async for chunk in client.generate_analysis_stream(
analysis_type=analysis_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=prompt_template,
financial_data=financial_data,
context=context,
):
yield chunk
headers = {
# 禁止中间层缓冲,确保尽快把分块推送给客户端
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers)