Fundamental_Analysis/backend/app/routers/financial.py
xucheng a79efd8150 feat: Enhance configuration management with new LLM provider support and API integration
- Backend: Introduced new endpoints for LLM configuration retrieval and updates in `config.py`, allowing dynamic management of LLM provider settings.
- Updated schemas to include `AlphaEngineConfig` for better integration with the new provider.
- Frontend: Added state management for AlphaEngine API credentials in the configuration page, ensuring seamless user experience.
- Configuration files updated to reflect changes in LLM provider settings and API keys.

BREAKING CHANGE: The default LLM provider has been changed from `new_api` to `alpha_engine`, requiring updates to existing configurations.
2025-11-11 20:49:27 +08:00

1042 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API router for financial data (Tushare for China market)
"""
import json
import os
import time
from datetime import datetime, timezone, timedelta
from enum import Enum
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Query, Depends
from fastapi.responses import StreamingResponse
from app.core.config import settings
from app.schemas.financial import (
BatchFinancialDataResponse,
FinancialConfigResponse,
FinancialMeta,
StepRecord,
CompanyProfileResponse,
AnalysisResponse,
AnalysisConfigResponse,
TodaySnapshotResponse,
)
from app.services.company_profile_client import CompanyProfileClient
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
from app.core.dependencies import get_config_manager
from app.services.config_manager import ConfigManager
from app.services.client_factory import create_analysis_client
# Lazy DataManager loader to avoid import-time failures when optional providers/config are missing
_dm = None
def get_dm():
global _dm
if _dm is not None:
return _dm
try:
from app.data_manager import data_manager as real_dm
_dm = real_dm
return _dm
except Exception:
class _StubDM:
config = {}
async def get_stock_basic(self, stock_code: str):
return None
async def get_financial_statements(self, stock_code: str, report_dates):
return []
_dm = _StubDM()
return _dm
router = APIRouter()
class MarketEnum(str, Enum):
cn = "cn"
us = "us"
hk = "hk"
jp = "jp"
# Load metric config from file (project root is repo root, not backend/)
# routers/ -> app/ -> backend/ -> repo root
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
def _load_json(path: str) -> Dict:
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
@router.get("/data-sources", response_model=Dict[str, List[str]])
async def get_data_sources():
"""
Get the list of data sources that require an API key from the config.
"""
try:
data_sources_config = get_dm().config.get("data_sources", {})
sources_requiring_keys = [
source for source, config in data_sources_config.items()
if config.get("api_key_env")
]
return {"sources": sources_requiring_keys}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}")
@router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse])
async def generate_full_analysis(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
config_manager: ConfigManager = Depends(get_config_manager),
):
"""
Generate a full analysis report by orchestrating multiple analysis modules
based on dependencies defined in the configuration.
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Full analysis requested for {ts_code}")
# Load LLM configuration using ConfigManager
llm_config_result = await config_manager.get_llm_config()
default_provider = llm_config_result["provider"]
default_config = llm_config_result["config"]
global_model = llm_config_result.get("model") # 全局模型配置
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
if not modules_config:
raise HTTPException(status_code=404, detail="Analysis modules configuration not found.")
# --- Dependency Resolution (Topological Sort) ---
def topological_sort(graph):
in_degree = {u: 0 for u in graph}
for u in graph:
for v in graph[u]:
in_degree[v] += 1
queue = [u for u in graph if in_degree[u] == 0]
sorted_order = []
while queue:
u = queue.pop(0)
sorted_order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(sorted_order) == len(graph):
return sorted_order
else:
# Detect cycles and provide a meaningful error
cycles = []
visited = set()
path = []
def find_cycle_util(node):
visited.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor in path:
cycle_start_index = path.index(neighbor)
cycles.append(path[cycle_start_index:] + [neighbor])
return
if neighbor not in visited:
find_cycle_util(neighbor)
path.pop()
for node in graph:
if node not in visited:
find_cycle_util(node)
return None, cycles
# Build dependency graph
dependency_graph = {
name: config.get("dependencies", [])
for name, config in modules_config.items()
}
# Invert graph for topological sort (from dependency to dependent)
adj_list = {u: [] for u in dependency_graph}
for u, dependencies in dependency_graph.items():
for dep in dependencies:
if dep in adj_list:
adj_list[dep].append(u)
sorted_modules, cycle = topological_sort(adj_list)
if not sorted_modules:
raise HTTPException(
status_code=400,
detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}"
)
# --- Fetch common data (company name, financial data) ---
# This logic is duplicated, could be refactored into a helper
financial_data = None
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
except Exception as e:
logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}")
company_name = ts_code
# --- Execute modules in order ---
analysis_results = []
completed_modules_content = {}
for module_type in sorted_modules:
module_config = modules_config[module_type]
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
# 统一使用全局配置,不再从模块配置读取 provider 和 model
# 使用全局 provider 和 model
model = global_model or default_config.get("model", "gemini-1.5-flash")
# Create client using factory with global config
client = create_analysis_client(
provider=default_provider,
config=default_config,
model=model
)
# Gather context from completed dependencies
context = {
dep: completed_modules_content.get(dep, "")
for dep in module_config.get("dependencies", [])
}
result = await client.generate_analysis(
analysis_type=module_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=module_config.get("prompt_template", ""),
financial_data=financial_data,
context=context,
)
response = AnalysisResponse(
ts_code=ts_code,
company_name=company_name,
analysis_type=module_type,
content=result.get("content", ""),
model=result.get("model", module_config.get("model")),
tokens=result.get("tokens", {}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
analysis_results.append(response)
if response.success:
completed_modules_content[module_type] = response.content
else:
# If a module fails, subsequent dependent modules will get an empty string for its context.
# This prevents total failure but may affect quality.
completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed."
logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}")
logger.info(f"[API] Full analysis for {ts_code} completed.")
return analysis_results
@router.get("/config", response_model=FinancialConfigResponse)
async def get_financial_config():
data = _load_json(FINANCIAL_CONFIG_PATH)
api_groups = data.get("api_groups", {})
return FinancialConfigResponse(api_groups=api_groups)
@router.get("/{market}/{stock_code}", response_model=BatchFinancialDataResponse)
async def get_financials(
market: MarketEnum,
stock_code: str,
years: int = Query(10, ge=1, le=10),
):
# Load metric config
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
# Meta tracking
started_real = datetime.now(timezone.utc)
started = time.perf_counter_ns()
api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it
api_calls_by_group: Dict[str, int] = {}
steps: List[StepRecord] = []
# Get company name
company_name = stock_code
try:
basic_data = await get_dm().get_stock_basic(stock_code=stock_code)
if basic_data:
company_name = basic_data.get("name", stock_code)
except Exception:
pass # Continue without it
# Collect series per metric key
series: Dict[str, List[Dict]] = {}
errors: Dict[str, str] = {}
# Generate date range for financial statements
current_year = datetime.now().year
report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)]
# Fetch all financial statements at once
step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running")
steps.append(step_financials)
# Fetch all financial statements at once (already in series format from provider)
series = await get_dm().get_financial_statements(stock_code=stock_code, report_dates=report_dates)
# Get the latest current year report period for market data
latest_current_year_report = None
if series:
current_year_str = str(current_year)
for key in series:
if series[key]:
for item in series[key]:
period = item.get('period', '')
if period.startswith(current_year_str) and not period.endswith('1231'):
if latest_current_year_report is None or period > latest_current_year_report:
latest_current_year_report = period
if not series:
errors["financial_statements"] = "Failed to fetch from all providers."
step_financials.status = "done"
step_financials.end_ts = datetime.now(timezone.utc).isoformat()
step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
# --- 拉取市值/估值daily_basic与股价daily按年度末日期 ---
try:
# 仅当配置包含相应分组时再尝试拉取
has_daily_basic = bool(api_groups.get("daily_basic"))
has_daily = bool(api_groups.get("daily"))
# 目前仅对中国市场启用 daily_basic/daily 数据拉取,其他市场由对应 provider 后续实现
if market == MarketEnum.cn and (has_daily_basic or has_daily):
step_market = StepRecord(name="拉取市值与股价", start_ts=datetime.now(timezone.utc).isoformat(), status="running")
steps.append(step_market)
# 构建市场数据查询日期:年度日期 + 当前年最新报告期
market_dates = report_dates.copy()
if latest_current_year_report:
# 查找当前年最新报告期对应的交易日(通常是报告期当月最后一天或最近交易日)
try:
report_date_obj = datetime.strptime(latest_current_year_report, '%Y%m%d')
# 使用报告期日期作为查询日期API会自动找到最近的交易日
market_dates.append(latest_current_year_report)
except ValueError:
pass
try:
if has_daily_basic:
db_rows = await get_dm().get_data('get_daily_basic_points', stock_code=stock_code, trade_dates=market_dates)
if isinstance(db_rows, list):
for row in db_rows:
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
if not trade_date:
continue
# 判断是否为当前年最新报告期的数据
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
year = str(trade_date)[:4]
if is_current_year_report:
# 当前年最新报告期的数据使用报告期period显示
period = latest_current_year_report
else:
# 其他年度数据使用年度period显示YYYY1231
period = f"{year}1231"
for key, value in row.items():
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
# 检查是否已存在该period的数据如果存在则替换为最新的数据
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
if existing_index >= 0:
series[key][existing_index] = {"period": period, "value": value}
else:
series[key].append({"period": period, "value": value})
if has_daily:
d_rows = await get_dm().get_data('get_daily_points', stock_code=stock_code, trade_dates=market_dates)
if isinstance(d_rows, list):
for row in d_rows:
trade_date = row.get('trade_date') or row.get('trade_dt') or row.get('date')
if not trade_date:
continue
# 判断是否为当前年最新报告期的数据
is_current_year_report = latest_current_year_report and str(trade_date) == latest_current_year_report
year = str(trade_date)[:4]
if is_current_year_report:
# 当前年最新报告期的数据使用报告期period显示
period = latest_current_year_report
else:
# 其他年度数据使用年度period显示YYYY1231
period = f"{year}1231"
for key, value in row.items():
if key in ['ts_code', 'trade_date', 'trade_dt', 'date']:
continue
if isinstance(value, (int, float)) and value is not None:
if key not in series:
series[key] = []
# 检查是否已存在该period的数据如果存在则替换为最新的数据
existing_index = next((i for i, d in enumerate(series[key]) if d['period'] == period), -1)
if existing_index >= 0:
series[key][existing_index] = {"period": period, "value": value}
else:
series[key].append({"period": period, "value": value})
except Exception as e:
errors["market_data"] = f"Failed to fetch market data: {e}"
finally:
step_market.status = "done"
step_market.end_ts = datetime.now(timezone.utc).isoformat()
step_market.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
except Exception as e:
errors["market_data_init"] = f"Market data init failed: {e}"
finished_real = datetime.now(timezone.utc)
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
if not series:
raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors})
# 统一 period 字段;若仅有 year 则映射为 YYYY1231然后去重与排序
for key, arr in list(series.items()):
normalized: List[Dict] = []
for item in arr:
period = item.get("period")
if not period:
year = item.get("year")
if year:
period = f"{str(year)}1231"
if not period:
# 跳过无法确定 period 的项
continue
value = item.get("value")
normalized.append({"period": str(period), "value": value})
# Deduplicate by period
uniq = {it["period"]: it for it in normalized}
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["period"], reverse=True)
arr_limited = arr_sorted_desc[:years]
arr_sorted = sorted(arr_limited, key=lambda x: x["period"])
series[key] = arr_sorted
# Create periods_list for derived metrics calculation
periods_list = sorted(list(set(item["period"] for arr in series.values() for item in arr)))
# Note: Derived financial metrics calculation has been moved to individual data providers
# The data_manager.get_financial_statements() should handle this
meta = FinancialMeta(
started_at=started_real.isoformat(),
finished_at=finished_real.isoformat(),
elapsed_ms=elapsed_ms,
api_calls_total=api_calls_total,
api_calls_by_group=api_calls_by_group,
current_action=None,
steps=steps,
)
return BatchFinancialDataResponse(ts_code=stock_code, name=company_name, series=series, meta=meta)
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
async def get_company_profile(
ts_code: str,
company_name: str = Query(None, description="Company name for better context"),
config_manager: ConfigManager = Depends(get_config_manager),
):
"""
Get company profile for a company using Gemini AI (non-streaming, single response)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Company profile requested for {ts_code}")
# Load LLM configuration using ConfigManager
llm_config_result = await config_manager.get_llm_config()
provider = llm_config_result["provider"]
provider_config = llm_config_result["config"]
# CompanyProfileClient only supports OpenAI-compatible APIs
if provider == "alpha_engine":
raise HTTPException(
status_code=400,
detail="Company profile generation does not support AlphaEngine provider. Please use OpenAI-compatible API."
)
api_key = provider_config.get("api_key")
base_url = provider_config.get("base_url")
if not api_key:
logger.error(f"[API] API key for {provider} not configured")
raise HTTPException(
status_code=500,
detail=f"API key for {provider} not configured."
)
client = CompanyProfileClient(
api_key=api_key,
base_url=base_url,
model="gemini-1.5-flash"
)
# Get company name from ts_code if not provided
if not company_name:
logger.info(f"[API] Fetching company name for {ts_code}")
# Try to get from stock_basic API
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating profile for {company_name}")
# Generate profile using non-streaming API
result = await client.generate_profile(
company_name=company_name,
ts_code=ts_code,
financial_data=None
)
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
return CompanyProfileResponse(
ts_code=ts_code,
company_name=company_name,
content=result.get("content", ""),
model=result.get("model", "gemini-2.5-flash"),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
async def get_analysis_config_endpoint():
"""Get analysis configuration"""
config = load_analysis_config()
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
"""Update analysis configuration"""
import logging
logger = logging.getLogger(__name__)
try:
# 保存到文件
config_data = {
"analysis_modules": analysis_config.analysis_modules
}
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(config_data, f, ensure_ascii=False, indent=2)
logger.info(f"[API] Analysis config updated successfully")
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
except Exception as e:
logger.error(f"[API] Failed to update analysis config: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to update analysis config: {str(e)}"
)
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
async def generate_analysis(
ts_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
config_manager: ConfigManager = Depends(get_config_manager),
):
"""
Generate analysis for a company using Gemini AI
Supported analysis types:
- fundamental_analysis (基本面分析)
- bull_case (看涨分析)
- bear_case (看跌分析)
- market_analysis (市场分析)
- news_analysis (新闻分析)
- trading_analysis (交易分析)
- insider_institutional (内部人与机构动向分析)
- final_conclusion (最终结论)
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
# Load LLM configuration using ConfigManager
llm_config_result = await config_manager.get_llm_config()
default_provider = llm_config_result["provider"]
default_config = llm_config_result["config"]
global_model = llm_config_result.get("model") # 全局模型配置
# Get company name from ts_code if not provided
financial_data = None
if not company_name:
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
logger.info(f"[API] Got company name: {company_name}")
# Try to get financial data for context
try:
# A simplified approach to get a single financial report
current_year = datetime.now().year
report_dates = [f"{current_year-1}1231"] # Get last year's report
# Use get_financial_statement which is designed to return a single flat report
latest_financials_report = await get_dm().get_financial_statement(
stock_code=ts_code,
report_dates=report_dates[0]
)
if latest_financials_report:
financial_data = {"series": latest_financials_report}
except Exception as e:
logger.warning(f"[API] Failed to get financial data: {e}")
financial_data = None
else:
company_name = ts_code
except Exception as e:
logger.warning(f"[API] Failed to get company name: {e}")
company_name = ts_code
logger.info(f"[API] Generating {analysis_type} for {company_name}")
# Get analysis configuration for prompt template
analysis_cfg = get_analysis_config(analysis_type)
if not analysis_cfg:
raise HTTPException(
status_code=404,
detail=f"Analysis type '{analysis_type}' not found in configuration"
)
prompt_template = analysis_cfg.get("prompt_template", "")
if not prompt_template:
raise HTTPException(
status_code=500,
detail=f"Prompt template not found for analysis type '{analysis_type}'"
)
# 统一使用全局配置,不再从模块配置读取 provider 和 model
model = global_model or default_config.get("model", "gemini-1.5-flash")
# 统一使用全局配置创建客户端
client = create_analysis_client(
provider=default_provider,
config=default_config,
model=model
)
# Prepare dependency context for single-module generation
# If the requested module declares dependencies, generate them first and inject their outputs
context = {}
try:
dependencies = analysis_cfg.get("dependencies", []) or []
if dependencies:
# Load full modules config to resolve dependency graph
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
# Collect all transitive dependencies
all_required = set()
def collect_all_deps(mod_name: str):
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
if dep not in all_required:
all_required.add(dep)
collect_all_deps(dep)
for dep in dependencies:
all_required.add(dep)
collect_all_deps(dep)
# Build subgraph and topologically sort
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
in_degree = {u: 0 for u in graph}
for u, deps in graph.items():
for v in deps:
in_degree[v] += 1
queue = [u for u, deg in in_degree.items() if deg == 0]
order = []
while queue:
u = queue.pop(0)
order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(order) != len(graph):
# Fallback: if cycle detected, just use any order
order = list(all_required)
# Generate dependencies in order - 统一使用全局配置
completed = {}
for mod in order:
cfg = modules_config.get(mod, {})
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
# 统一使用全局配置,不再从模块配置读取
dep_client = create_analysis_client(
provider=default_provider,
config=default_config,
model=model
)
dep_result = await dep_client.generate_analysis(
analysis_type=mod,
company_name=company_name,
ts_code=ts_code,
prompt_template=cfg.get("prompt_template", ""),
financial_data=financial_data,
context=dep_ctx,
)
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
context = {dep: completed.get(dep, "") for dep in dependencies}
except Exception:
# Best-effort context; if anything goes wrong, continue without it
context = {}
# Generate analysis
result = await client.generate_analysis(
analysis_type=analysis_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=prompt_template,
financial_data=financial_data,
context=context,
)
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
return AnalysisResponse(
ts_code=ts_code,
company_name=company_name,
analysis_type=analysis_type,
content=result.get("content", ""),
model=result.get("model", model),
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
elapsed_ms=result.get("elapsed_ms", 0),
success=result.get("success", False),
error=result.get("error")
)
@router.get("/china/{ts_code}/snapshot", response_model=TodaySnapshotResponse)
async def get_today_snapshot(ts_code: str):
"""
获取“昨日快照”(以上一个自然日为基准,映射为不晚于该日的最近交易日)的市场数据:
- 日期trade_date
- 收盘价close
- 市值total_mv返回原始万元单位值
- 估值pe、pb
- 股息率dv_ratio单位%
"""
try:
# 优先取公司名称(可选)
company_name = None
try:
basic = await get_dm().get_stock_basic(stock_code=ts_code)
if basic:
company_name = basic.get("name")
except Exception:
company_name = None
# 以“昨天”为查询日期provider 内部会解析为“不晚于该日的最近交易日”
base_dt = (datetime.now() - timedelta(days=1)).date()
base_str = base_dt.strftime("%Y%m%d")
# 从 daily_basic 取主要字段,包含 close、pe、pb、dv_ratio、total_mv
rows = await get_dm().get_data(
'get_daily_basic_points',
stock_code=ts_code,
trade_dates=[base_str]
)
row = None
if isinstance(rows, list) and rows:
# get_daily_basic_points 返回每个交易日一条记录
row = rows[0]
trade_date = None
close = None
pe = None
pb = None
dv_ratio = None
total_mv = None
if isinstance(row, dict):
trade_date = str(row.get('trade_date') or row.get('trade_dt') or row.get('date') or base_str)
close = row.get('close')
pe = row.get('pe')
pb = row.get('pb')
dv_ratio = row.get('dv_ratio')
total_mv = row.get('total_mv')
# 若 close 缺失,兜底从 daily 取收盘价
if close is None:
d_rows = await get_dm().get_data('get_daily_points', stock_code=ts_code, trade_dates=[base_str])
if isinstance(d_rows, list) and d_rows:
d = d_rows[0]
close = d.get('close')
if trade_date is None:
trade_date = str(d.get('trade_date') or d.get('trade_dt') or d.get('date') or base_str)
if trade_date is None:
trade_date = base_str
return TodaySnapshotResponse(
ts_code=ts_code,
trade_date=trade_date,
name=company_name,
close=close,
pe=pe,
pb=pb,
dv_ratio=dv_ratio,
total_mv=total_mv,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}")
@router.get("/{market}/{stock_code}/snapshot", response_model=TodaySnapshotResponse)
async def get_market_snapshot(market: MarketEnum, stock_code: str):
"""
市场无关的“昨日快照”接口。
- CN: 复用中国市场的快照逻辑daily_basic/daily
- 其他市场: 兜底使用日行情获取最近交易日收盘价;其余字段暂返回空值。
"""
if market == MarketEnum.cn:
return await get_today_snapshot(stock_code)
try:
# 公司名称(可选)
company_name = None
try:
basic = await get_dm().get_stock_basic(stock_code=stock_code)
if basic:
company_name = basic.get("name")
except Exception:
company_name = None
base_dt = (datetime.now() - timedelta(days=1)).date()
base_str = base_dt.strftime("%Y%m%d")
# 为了稳妥拿到最近交易日,回看近 10 天
start_dt = base_dt - timedelta(days=10)
start_str = start_dt.strftime("%Y%m%d")
end_dt = base_dt + timedelta(days=1)
end_str = end_dt.strftime("%Y%m%d")
rows = await get_dm().get_daily_price(stock_code=stock_code, start_date=start_str, end_date=end_str)
trade_date = None
close = None
if isinstance(rows, list) and rows:
# 选择 <= base_str 的最后一条记录
try:
candidates = [r for r in rows if str(r.get("trade_date") or r.get("date") or "") <= base_str]
if candidates:
last = sorted(candidates, key=lambda r: str(r.get("trade_date") or r.get("date") or ""))[-1]
trade_date = str(last.get("trade_date") or last.get("date") or base_str)
close = last.get("close")
except Exception:
pass
if trade_date is None:
trade_date = base_str
return TodaySnapshotResponse(
ts_code=stock_code,
trade_date=trade_date,
name=company_name,
close=close,
pe=None,
pb=None,
dv_ratio=None,
total_mv=None,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch snapshot: {e}")
@router.get("/china/{ts_code}/analysis/{analysis_type}/stream")
async def stream_analysis(
ts_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
config_manager: ConfigManager = Depends(get_config_manager),
):
"""
Stream analysis content chunks for a given module using OpenAI-compatible streaming.
Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream),
then the target module content is streamed.
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}")
# Load LLM configuration using ConfigManager
llm_config_result = await config_manager.get_llm_config()
default_provider = llm_config_result["provider"]
default_config = llm_config_result["config"]
global_model = llm_config_result.get("model") # 全局模型配置
# Get analysis configuration
analysis_cfg = get_analysis_config(analysis_type)
if not analysis_cfg:
raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration")
# 统一使用全局配置,不再从模块配置读取 provider 和 model
model = global_model or default_config.get("model", "gemini-1.5-flash")
prompt_template = analysis_cfg.get("prompt_template", "")
if not prompt_template:
raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'")
# Get company name from ts_code if not provided; we don't need full financials here
financial_data = None
if not company_name:
try:
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
if basic_data:
company_name = basic_data.get("name", ts_code)
else:
company_name = ts_code
except Exception:
company_name = ts_code
# Resolve dependency context (non-streaming)
context = {}
try:
dependencies = analysis_cfg.get("dependencies", []) or []
if dependencies:
analysis_config_full = load_analysis_config()
modules_config = analysis_config_full.get("analysis_modules", {})
all_required = set()
def collect_all(mod_name: str):
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
if dep not in all_required:
all_required.add(dep)
collect_all(dep)
for dep in dependencies:
all_required.add(dep)
collect_all(dep)
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
in_degree = {u: 0 for u in graph}
for u, deps in graph.items():
for v in deps:
in_degree[v] += 1
queue = [u for u, deg in in_degree.items() if deg == 0]
order = []
while queue:
u = queue.pop(0)
order.append(u)
for v in graph.get(u, []):
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)
if len(order) != len(graph):
order = list(all_required)
completed = {}
for mod in order:
cfg = modules_config.get(mod, {})
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
# 统一使用全局配置,不再从模块配置读取
dep_client = create_analysis_client(
provider=default_provider,
config=default_config,
model=model
)
dep_result = await dep_client.generate_analysis(
analysis_type=mod,
company_name=company_name,
ts_code=ts_code,
prompt_template=cfg.get("prompt_template", ""),
financial_data=financial_data,
context=dep_ctx,
)
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
context = {dep: completed.get(dep, "") for dep in dependencies}
except Exception:
context = {}
# 统一使用全局配置创建客户端
client = create_analysis_client(
provider=default_provider,
config=default_config,
model=model
)
async def streamer():
# Optional header line to help client-side UI
header = f"# {analysis_cfg.get('name', analysis_type)}\n\n"
yield header
async for chunk in client.generate_analysis_stream(
analysis_type=analysis_type,
company_name=company_name,
ts_code=ts_code,
prompt_template=prompt_template,
financial_data=financial_data,
context=context,
):
yield chunk
headers = {
# 禁止中间层缓冲,确保尽快把分块推送给客户端
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers)
@router.get("/{market}/{stock_code}/analysis/{analysis_type}/stream")
async def stream_analysis_market(
market: MarketEnum,
stock_code: str,
analysis_type: str,
company_name: str = Query(None, description="Company name for better context"),
):
"""
市场无关的分析流接口。逻辑与中国市场一致,仅路径不同。
"""
return await stream_analysis(stock_code, analysis_type, company_name)