- backend(financial): 新增 /china/{ts_code}/snapshot API,返回昨日交易日的收盘价/市值/PE/PB/股息率等
- backend(schemas): 新增 TodaySnapshotResponse
- backend(main): 注册 orgs 路由 /api/v1/orgs
- backend(providers:finnhub): 归一化财报字段并计算 gross_margin/net_margin/ROA/ROE
- backend(providers:tushare): 股东户数报告期与财报期对齐
- backend(routers/financial): years 默认改为 10(最大 10)
- config: analysis-config.json 切换到 qwen-flash-2025-07-28
- frontend(report/[symbol]): 新增“昨日快照”卡片、限制展示期数为10、优化增长与阈值高亮、修正类名与标题处理
- frontend(reports/[id]): 统一 period 变量与计算,修正表格 key
- frontend(hooks): 新增 useChinaSnapshot 钩子与类型
- scripts: dev.sh 增加调试输出
928 lines
37 KiB
Python
928 lines
37 KiB
Python
"""
|
||
API router for financial data (Tushare for China market)
|
||
"""
|
||
import json
|
||
import os
|
||
import time
|
||
from datetime import datetime, timezone, timedelta
|
||
from typing import Dict, List
|
||
|
||
from fastapi import APIRouter, HTTPException, Query
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from app.core.config import settings
|
||
from app.schemas.financial import (
|
||
BatchFinancialDataResponse,
|
||
FinancialConfigResponse,
|
||
FinancialMeta,
|
||
StepRecord,
|
||
CompanyProfileResponse,
|
||
AnalysisResponse,
|
||
AnalysisConfigResponse,
|
||
TodaySnapshotResponse,
|
||
)
|
||
from app.services.company_profile_client import CompanyProfileClient
|
||
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
||
|
||
# Lazy DataManager loader to avoid import-time failures when optional providers/config are missing
|
||
_dm = None
|
||
def get_dm():
|
||
global _dm
|
||
if _dm is not None:
|
||
return _dm
|
||
try:
|
||
from app.data_manager import data_manager as real_dm
|
||
_dm = real_dm
|
||
return _dm
|
||
except Exception:
|
||
class _StubDM:
|
||
config = {}
|
||
async def get_stock_basic(self, stock_code: str):
|
||
return None
|
||
async def get_financial_statements(self, stock_code: str, report_dates):
|
||
return []
|
||
_dm = _StubDM()
|
||
return _dm
|
||
|
||
router = APIRouter()
|
||
|
||
# Load metric config from file (project root is repo root, not backend/)
|
||
# routers/ -> app/ -> backend/ -> repo root
|
||
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
||
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
||
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
|
||
|
||
|
||
def _load_json(path: str) -> Dict:
|
||
if not os.path.exists(path):
|
||
return {}
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
@router.get("/data-sources", response_model=Dict[str, List[str]])
|
||
async def get_data_sources():
|
||
"""
|
||
Get the list of data sources that require an API key from the config.
|
||
"""
|
||
try:
|
||
data_sources_config = get_dm().config.get("data_sources", {})
|
||
sources_requiring_keys = [
|
||
source for source, config in data_sources_config.items()
|
||
if config.get("api_key_env")
|
||
]
|
||
return {"sources": sources_requiring_keys}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}")
|
||
|
||
|
||
@router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse])
|
||
async def generate_full_analysis(
|
||
ts_code: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Generate a full analysis report by orchestrating multiple analysis modules
|
||
based on dependencies defined in the configuration.
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Full analysis requested for {ts_code}")
|
||
|
||
# Load base and analysis configurations
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url")
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
analysis_config_full = load_analysis_config()
|
||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||
if not modules_config:
|
||
raise HTTPException(status_code=404, detail="Analysis modules configuration not found.")
|
||
|
||
# --- Dependency Resolution (Topological Sort) ---
|
||
def topological_sort(graph):
|
||
in_degree = {u: 0 for u in graph}
|
||
for u in graph:
|
||
for v in graph[u]:
|
||
in_degree[v] += 1
|
||
|
||
queue = [u for u in graph if in_degree[u] == 0]
|
||
sorted_order = []
|
||
|
||
while queue:
|
||
u = queue.pop(0)
|
||
sorted_order.append(u)
|
||
for v in graph.get(u, []):
|
||
in_degree[v] -= 1
|
||
if in_degree[v] == 0:
|
||
queue.append(v)
|
||
|
||
if len(sorted_order) == len(graph):
|
||
return sorted_order
|
||
else:
|
||
# Detect cycles and provide a meaningful error
|
||
cycles = []
|
||
visited = set()
|
||
path = []
|
||
|
||
def find_cycle_util(node):
|
||
visited.add(node)
|
||
path.append(node)
|
||
for neighbor in graph.get(node, []):
|
||
if neighbor in path:
|
||
cycle_start_index = path.index(neighbor)
|
||
cycles.append(path[cycle_start_index:] + [neighbor])
|
||
return
|
||
if neighbor not in visited:
|
||
find_cycle_util(neighbor)
|
||
path.pop()
|
||
|
||
for node in graph:
|
||
if node not in visited:
|
||
find_cycle_util(node)
|
||
|
||
return None, cycles
|
||
|
||
|
||
# Build dependency graph
|
||
dependency_graph = {
|
||
name: config.get("dependencies", [])
|
||
for name, config in modules_config.items()
|
||
}
|
||
|
||
# Invert graph for topological sort (from dependency to dependent)
|
||
adj_list = {u: [] for u in dependency_graph}
|
||
for u, dependencies in dependency_graph.items():
|
||
for dep in dependencies:
|
||
if dep in adj_list:
|
||
adj_list[dep].append(u)
|
||
|
||
sorted_modules, cycle = topological_sort(adj_list)
|
||
if not sorted_modules:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}"
|
||
)
|
||
|
||
# --- Fetch common data (company name, financial data) ---
|
||
# This logic is duplicated, could be refactored into a helper
|
||
financial_data = None
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name for {ts_code}")
|
||
try:
|
||
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic_data:
|
||
company_name = basic_data.get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}")
|
||
company_name = ts_code
|
||
|
||
# --- Execute modules in order ---
|
||
analysis_results = []
|
||
completed_modules_content = {}
|
||
|
||
for module_type in sorted_modules:
|
||
module_config = modules_config[module_type]
|
||
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
|
||
|
||
client = AnalysisClient(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=module_config.get("model", "gemini-1.5-flash")
|
||
)
|
||
|
||
# Gather context from completed dependencies
|
||
context = {
|
||
dep: completed_modules_content.get(dep, "")
|
||
for dep in module_config.get("dependencies", [])
|
||
}
|
||
|
||
result = await client.generate_analysis(
|
||
analysis_type=module_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=module_config.get("prompt_template", ""),
|
||
financial_data=financial_data,
|
||
context=context,
|
||
)
|
||
|
||
response = AnalysisResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
analysis_type=module_type,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", module_config.get("model")),
|
||
tokens=result.get("tokens", {}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
analysis_results.append(response)
|
||
|
||
if response.success:
|
||
completed_modules_content[module_type] = response.content
|
||
else:
|
||
# If a module fails, subsequent dependent modules will get an empty string for its context.
|
||
# This prevents total failure but may affect quality.
|
||
completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed."
|
||
logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}")
|
||
|
||
logger.info(f"[API] Full analysis for {ts_code} completed.")
|
||
return analysis_results
|
||
|
||
|
||
@router.get("/config", response_model=FinancialConfigResponse)
|
||
async def get_financial_config():
|
||
data = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups = data.get("api_groups", {})
|
||
return FinancialConfigResponse(api_groups=api_groups)
|
||
|
||
|
||
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
|
||
async def get_china_financials(
|
||
ts_code: str,
|
||
years: int = Query(10, ge=1, le=10),
|
||
):
|
||
# Load metric config
|
||
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
|
||
|
||
# Meta tracking
|
||
started_real = datetime.now(timezone.utc)
|
||
started = time.perf_counter_ns()
|
||
api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it
|
||
api_calls_by_group: Dict[str, int] = {}
|
||
steps: List[StepRecord] = []
|
||
|
||
# Get company name
|
||
company_name = ts_code
|
||
try:
|
||
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic_data:
|
||
company_name = basic_data.get("name", ts_code)
|
||
except Exception:
|
||
pass # Continue without it
|
||
|
||
# Collect series per metric key
|
||
series: Dict[str, List[Dict]] = {}
|
||
errors: Dict[str, str] = {}
|
||
|
||
# Generate date range for financial statements
|
||
current_year = datetime.now().year
|
||
report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)]
|
||
|
||
# Fetch all financial statements at once
|
||
step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running")
|
||
steps.append(step_financials)
|
||
|
||
# Fetch all financial statements at once (already in series format from provider)
|
||
series = await get_dm().get_financial_statements(stock_code=ts_code, report_dates=report_dates)
|
||
|
||
# Get the latest current year report period for market data
|
||
latest_current_year_report = None
|
||
if series:
|
||
current_year_str = str(current_year)
|
||
for key in series:
|
||
if series[key]:
|
||
for item in series[key]:
|
||
period = item.get('period', '')
|
||
if period.startswith(current_year_str) and not period.endswith('1231'):
|
||
if latest_current_year_report is None or period > latest_current_year_report:
|
||
latest_current_year_report = period
|
||
|
||
if not series:
|
||
errors["financial_statements"] = "Failed to fetch from all providers."
|
||
|
||
step_financials.status = "done"
|
||
step_financials.end_ts = datetime.now(timezone.utc).isoformat()
|
||
step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
# --- 拉取市值/估值(daily_basic)与股价(daily)按年度末日期 ---
|
||
try:
|
||
# 仅当配置包含相应分组时再尝试拉取
|
||
has_daily_basic = bool(api_groups.get("daily_basic"))
|
||
has_daily = bool(api_groups.get("daily"))
|
||
|
||
if has_daily_basic or has_daily:
|
||
step_market = StepRecord(name="拉取市值与股价", start_ts=datetime.now(timezone.utc).isoformat(), status="running")
|
||
steps.append(step_market)
|
||
|
||
# 构建市场数据查询日期:年度日期 + 当前年最新报告期
|
||
market_dates = report_dates.copy()
|
||
if latest_current_year_report:
|
||
# 查找当前年最新报告期对应的交易日(通常是报告期当月最后一天或最近交易日)
|
||
try:
|
||
report_date_obj = datetime.strptime(latest_current_year_report, '%Y%m%d')
|
||
# 使用报告期日期作为查询日期(API会自动找到最近的交易日)
|
||
market_dates.append(latest_current_year_report)
|
||
except ValueError:
|
||
pass
|
||
|
||
try:
|
||
if has_daily_basic:
|
||
db_rows = await get_dm().get_data('get_daily_basic_points', stock_code=ts_code, trade_dates=market_dates)
|
||
if isinstance(db_rows, list):
|
||
for row in db_rows:
|
||
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
|
||
if not trade_date:
|
||
continue
|
||
# 判断是否为当前年最新报告期的数据
|
||
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
|
||
year = str(trade_date)[:4]
|
||
|
||
if is_current_year_report:
|
||
# 当前年最新报告期的数据,使用报告期period显示
|
||
period = latest_current_year_report
|
||
else:
|
||
# 其他年度数据,使用年度period显示(YYYY1231)
|
||
period = f"{year}1231"
|
||
|
||
for key, value in row.items():
|
||
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
|
||
continue
|
||
if isinstance(value, (int, float)) and value is not None:
|
||
if key not in series:
|
||
series[key] = []
|
||
# 检查是否已存在该period的数据,如果存在则替换为最新的数据
|
||
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
|
||
if existing_index >= 0:
|
||
series[key][existing_index] = {"period": period, "value": value}
|
||
else:
|
||
series[key].append({"period": period, "value": value})
|
||
if has_daily:
|
||
d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=market_dates)
|
||
if isinstance(d_rows, list):
|
||
for row in d_rows:
|
||
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
|
||
if not trade_date:
|
||
continue
|
||
# 判断是否为当前年最新报告期的数据
|
||
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
|
||
year = str(trade_date)[:4]
|
||
|
||
if is_current_year_report:
|
||
# 当前年最新报告期的数据,使用报告期period显示
|
||
period = latest_current_year_report
|
||
else:
|
||
# 其他年度数据,使用年度period显示(YYYY1231)
|
||
period = f"{year}1231"
|
||
|
||
for key, value in row.items():
|
||
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
|
||
continue
|
||
if isinstance(value, (int, float)) and value is not None:
|
||
if key not in series:
|
||
series[key] = []
|
||
# 检查是否已存在该period的数据,如果存在则替换为最新的数据
|
||
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
|
||
if existing_index >= 0:
|
||
series[key][existing_index] = {"period": period, "value": value}
|
||
else:
|
||
series[key].append({"period": period, "value": value})
|
||
except Exception as e:
|
||
errors["market_data"] = f"Failed to fetch market data: {e}"
|
||
finally:
|
||
step_market.status = "done"
|
||
step_market.end_ts = datetime.now(timezone.utc).isoformat()
|
||
step_market.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
except Exception as e:
|
||
errors["market_data_init"] = f"Market data init failed: {e}"
|
||
|
||
finished_real = datetime.now(timezone.utc)
|
||
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
if not series:
|
||
raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors})
|
||
|
||
# Truncate periods and sort (the data should already be mostly correct, but we ensure)
|
||
for key, arr in series.items():
|
||
# Deduplicate and sort desc by period, then cut to requested periods, and return asc
|
||
uniq = {item["period"]: item for item in arr}
|
||
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["period"], reverse=True)
|
||
arr_limited = arr_sorted_desc[:years]
|
||
arr_sorted = sorted(arr_limited, key=lambda x: x["period"])
|
||
series[key] = arr_sorted
|
||
|
||
# Create periods_list for derived metrics calculation
|
||
periods_list = sorted(list(set(item["period"] for arr in series.values() for item in arr)))
|
||
|
||
# Note: Derived financial metrics calculation has been moved to individual data providers
|
||
# The data_manager.get_financial_statements() should handle this
|
||
|
||
meta = FinancialMeta(
|
||
started_at=started_real.isoformat(),
|
||
finished_at=finished_real.isoformat(),
|
||
elapsed_ms=elapsed_ms,
|
||
api_calls_total=api_calls_total,
|
||
api_calls_by_group=api_calls_by_group,
|
||
current_action=None,
|
||
steps=steps,
|
||
)
|
||
|
||
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
|
||
|
||
|
||
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
|
||
async def get_company_profile(
|
||
ts_code: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Get company profile for a company using Gemini AI (non-streaming, single response)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Company profile requested for {ts_code}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url") # Will be None if not set, handled by client
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
client = CompanyProfileClient(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model="gemini-1.5-flash"
|
||
)
|
||
|
||
# Get company name from ts_code if not provided
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name for {ts_code}")
|
||
# Try to get from stock_basic API
|
||
try:
|
||
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic_data:
|
||
company_name = basic_data.get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating profile for {company_name}")
|
||
|
||
# Generate profile using non-streaming API
|
||
result = await client.generate_profile(
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
financial_data=None
|
||
)
|
||
|
||
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
|
||
|
||
return CompanyProfileResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", "gemini-2.5-flash"),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
|
||
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def get_analysis_config_endpoint():
|
||
"""Get analysis configuration"""
|
||
config = load_analysis_config()
|
||
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
|
||
|
||
|
||
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
|
||
"""Update analysis configuration"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
# 保存到文件
|
||
config_data = {
|
||
"analysis_modules": analysis_config.analysis_modules
|
||
}
|
||
|
||
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"[API] Analysis config updated successfully")
|
||
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
|
||
except Exception as e:
|
||
logger.error(f"[API] Failed to update analysis config: {e}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to update analysis config: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
|
||
async def generate_analysis(
|
||
ts_code: str,
|
||
analysis_type: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Generate analysis for a company using Gemini AI
|
||
Supported analysis types:
|
||
- fundamental_analysis (基本面分析)
|
||
- bull_case (看涨分析)
|
||
- bear_case (看跌分析)
|
||
- market_analysis (市场分析)
|
||
- news_analysis (新闻分析)
|
||
- trading_analysis (交易分析)
|
||
- insider_institutional (内部人与机构动向分析)
|
||
- final_conclusion (最终结论)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url")
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
# Get analysis configuration
|
||
analysis_cfg = get_analysis_config(analysis_type)
|
||
if not analysis_cfg:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||
)
|
||
|
||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||
|
||
if not prompt_template:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||
)
|
||
|
||
# Get company name from ts_code if not provided
|
||
financial_data = None
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
|
||
try:
|
||
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic_data:
|
||
company_name = basic_data.get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
|
||
# Try to get financial data for context
|
||
try:
|
||
# A simplified approach to get a single financial report
|
||
current_year = datetime.now().year
|
||
report_dates = [f"{current_year-1}1231"] # Get last year's report
|
||
# Use get_financial_statement which is designed to return a single flat report
|
||
latest_financials_report = await get_dm().get_financial_statement(
|
||
stock_code=ts_code,
|
||
report_dates=report_dates[0]
|
||
)
|
||
if latest_financials_report:
|
||
financial_data = {"series": latest_financials_report}
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get financial data: {e}")
|
||
financial_data = None
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
||
|
||
# Initialize analysis client with configured model
|
||
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
||
|
||
# Prepare dependency context for single-module generation
|
||
# If the requested module declares dependencies, generate them first and inject their outputs
|
||
context = {}
|
||
try:
|
||
dependencies = analysis_cfg.get("dependencies", []) or []
|
||
if dependencies:
|
||
# Load full modules config to resolve dependency graph
|
||
analysis_config_full = load_analysis_config()
|
||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||
|
||
# Collect all transitive dependencies
|
||
all_required = set()
|
||
|
||
def collect_all_deps(mod_name: str):
|
||
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
|
||
if dep not in all_required:
|
||
all_required.add(dep)
|
||
collect_all_deps(dep)
|
||
|
||
for dep in dependencies:
|
||
all_required.add(dep)
|
||
collect_all_deps(dep)
|
||
|
||
# Build subgraph and topologically sort
|
||
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
|
||
in_degree = {u: 0 for u in graph}
|
||
for u, deps in graph.items():
|
||
for v in deps:
|
||
in_degree[v] += 1
|
||
queue = [u for u, deg in in_degree.items() if deg == 0]
|
||
order = []
|
||
while queue:
|
||
u = queue.pop(0)
|
||
order.append(u)
|
||
for v in graph.get(u, []):
|
||
in_degree[v] -= 1
|
||
if in_degree[v] == 0:
|
||
queue.append(v)
|
||
if len(order) != len(graph):
|
||
# Fallback: if cycle detected, just use any order
|
||
order = list(all_required)
|
||
|
||
# Generate dependencies in order
|
||
completed = {}
|
||
for mod in order:
|
||
cfg = modules_config.get(mod, {})
|
||
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
||
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
||
dep_result = await dep_client.generate_analysis(
|
||
analysis_type=mod,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=cfg.get("prompt_template", ""),
|
||
financial_data=financial_data,
|
||
context=dep_ctx,
|
||
)
|
||
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
|
||
|
||
context = {dep: completed.get(dep, "") for dep in dependencies}
|
||
except Exception:
|
||
# Best-effort context; if anything goes wrong, continue without it
|
||
context = {}
|
||
|
||
# Generate analysis
|
||
result = await client.generate_analysis(
|
||
analysis_type=analysis_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=prompt_template,
|
||
financial_data=financial_data,
|
||
context=context,
|
||
)
|
||
|
||
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
|
||
|
||
return AnalysisResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
analysis_type=analysis_type,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", model),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
|
||
@router.get("/china/{ts_code}/snapshot", response_model=TodaySnapshotResponse)
|
||
async def get_today_snapshot(ts_code: str):
|
||
"""
|
||
获取“昨日快照”(以上一个自然日为基准,映射为不晚于该日的最近交易日)的市场数据:
|
||
- 日期(trade_date)
|
||
- 收盘价(close)
|
||
- 市值(total_mv,返回原始万元单位值)
|
||
- 估值(pe、pb)
|
||
- 股息率(dv_ratio,单位%)
|
||
"""
|
||
try:
|
||
# 优先取公司名称(可选)
|
||
company_name = None
|
||
try:
|
||
basic = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic:
|
||
company_name = basic.get("name")
|
||
except Exception:
|
||
company_name = None
|
||
|
||
# 以“昨天”为查询日期,provider 内部会解析为“不晚于该日的最近交易日”
|
||
base_dt = (datetime.now() - timedelta(days=1)).date()
|
||
base_str = base_dt.strftime("%Y%m%d")
|
||
|
||
# 从 daily_basic 取主要字段,包含 close、pe、pb、dv_ratio、total_mv
|
||
rows = await get_dm().get_data(
|
||
'get_daily_basic_points',
|
||
stock_code=ts_code,
|
||
trade_dates=[base_str]
|
||
)
|
||
row = None
|
||
if isinstance(rows, list) and rows:
|
||
# get_daily_basic_points 返回每个交易日一条记录
|
||
row = rows[0]
|
||
|
||
trade_date = None
|
||
close = None
|
||
pe = None
|
||
pb = None
|
||
dv_ratio = None
|
||
total_mv = None
|
||
|
||
if isinstance(row, dict):
|
||
trade_date = str(row.get('trade_date') or row.get('trade_dt') or row.get('date') or base_str)
|
||
close = row.get('close')
|
||
pe = row.get('pe')
|
||
pb = row.get('pb')
|
||
dv_ratio = row.get('dv_ratio')
|
||
total_mv = row.get('total_mv')
|
||
|
||
# 若 close 缺失,兜底从 daily 取收盘价
|
||
if close is None:
|
||
d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=[base_str])
|
||
if isinstance(d_rows, list) and d_rows:
|
||
d = d_rows[0]
|
||
close = d.get('close')
|
||
if trade_date is None:
|
||
trade_date = str(d.get('trade_date') or d.get('trade_dt') or d.get('date') or base_str)
|
||
|
||
if trade_date is None:
|
||
trade_date = base_str
|
||
|
||
return TodaySnapshotResponse(
|
||
ts_code=ts_code,
|
||
trade_date=trade_date,
|
||
name=company_name,
|
||
close=close,
|
||
pe=pe,
|
||
pb=pb,
|
||
dv_ratio=dv_ratio,
|
||
total_mv=total_mv,
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}")
|
||
|
||
|
||
@router.get("/china/{ts_code}/analysis/{analysis_type}/stream")
|
||
async def stream_analysis(
|
||
ts_code: str,
|
||
analysis_type: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Stream analysis content chunks for a given module using OpenAI-compatible streaming.
|
||
Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream),
|
||
then the target module content is streamed.
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url")
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.")
|
||
|
||
# Get analysis configuration
|
||
analysis_cfg = get_analysis_config(analysis_type)
|
||
if not analysis_cfg:
|
||
raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration")
|
||
|
||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||
if not prompt_template:
|
||
raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'")
|
||
|
||
# Get company name from ts_code if not provided; we don't need full financials here
|
||
financial_data = None
|
||
if not company_name:
|
||
try:
|
||
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
||
if basic_data:
|
||
company_name = basic_data.get("name", ts_code)
|
||
else:
|
||
company_name = ts_code
|
||
except Exception:
|
||
company_name = ts_code
|
||
|
||
# Resolve dependency context (non-streaming)
|
||
context = {}
|
||
try:
|
||
dependencies = analysis_cfg.get("dependencies", []) or []
|
||
if dependencies:
|
||
analysis_config_full = load_analysis_config()
|
||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||
|
||
all_required = set()
|
||
def collect_all(mod_name: str):
|
||
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
|
||
if dep not in all_required:
|
||
all_required.add(dep)
|
||
collect_all(dep)
|
||
for dep in dependencies:
|
||
all_required.add(dep)
|
||
collect_all(dep)
|
||
|
||
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
|
||
in_degree = {u: 0 for u in graph}
|
||
for u, deps in graph.items():
|
||
for v in deps:
|
||
in_degree[v] += 1
|
||
queue = [u for u, deg in in_degree.items() if deg == 0]
|
||
order = []
|
||
while queue:
|
||
u = queue.pop(0)
|
||
order.append(u)
|
||
for v in graph.get(u, []):
|
||
in_degree[v] -= 1
|
||
if in_degree[v] == 0:
|
||
queue.append(v)
|
||
if len(order) != len(graph):
|
||
order = list(all_required)
|
||
|
||
completed = {}
|
||
for mod in order:
|
||
cfg = modules_config.get(mod, {})
|
||
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
||
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
||
dep_result = await dep_client.generate_analysis(
|
||
analysis_type=mod,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=cfg.get("prompt_template", ""),
|
||
financial_data=financial_data,
|
||
context=dep_ctx,
|
||
)
|
||
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
|
||
context = {dep: completed.get(dep, "") for dep in dependencies}
|
||
except Exception:
|
||
context = {}
|
||
|
||
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
||
|
||
async def streamer():
|
||
# Optional header line to help client-side UI
|
||
header = f"# {analysis_cfg.get('name', analysis_type)}\n\n"
|
||
yield header
|
||
async for chunk in client.generate_analysis_stream(
|
||
analysis_type=analysis_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=prompt_template,
|
||
financial_data=financial_data,
|
||
context=context,
|
||
):
|
||
yield chunk
|
||
|
||
headers = {
|
||
# 禁止中间层缓冲,确保尽快把分块推送给客户端
|
||
"Cache-Control": "no-cache, no-transform",
|
||
"X-Accel-Buffering": "no",
|
||
}
|
||
return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers)
|