feat(frontend): integrate Prisma and reports API/pages chore(config): add data_sources.yaml; update analysis-config.json docs: add 2025-11-03 dev log; update user guide scripts: enhance dev.sh; add tushare_legacy_client deps: update backend and frontend dependencies
762 lines
29 KiB
Python
762 lines
29 KiB
Python
"""
|
|
API router for financial data (Tushare for China market)
|
|
"""
|
|
import json
|
|
import os
|
|
import time
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Dict, List
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.config import settings
|
|
from app.schemas.financial import (
|
|
BatchFinancialDataResponse,
|
|
FinancialConfigResponse,
|
|
FinancialMeta,
|
|
StepRecord,
|
|
CompanyProfileResponse,
|
|
AnalysisResponse,
|
|
AnalysisConfigResponse
|
|
)
|
|
from app.services.company_profile_client import CompanyProfileClient
|
|
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
|
|
|
# Lazy DataManager loader to avoid import-time failures when optional providers/config are missing
|
|
_dm = None
|
|
def get_dm():
|
|
global _dm
|
|
if _dm is not None:
|
|
return _dm
|
|
try:
|
|
from app.data_manager import data_manager as real_dm
|
|
_dm = real_dm
|
|
return _dm
|
|
except Exception:
|
|
class _StubDM:
|
|
config = {}
|
|
async def get_stock_basic(self, stock_code: str):
|
|
return None
|
|
async def get_financial_statements(self, stock_code: str, report_dates):
|
|
return []
|
|
_dm = _StubDM()
|
|
return _dm
|
|
|
|
router = APIRouter()
|
|
|
|
# Load metric config from file (project root is repo root, not backend/)
|
|
# routers/ -> app/ -> backend/ -> repo root
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
|
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
|
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
|
|
|
|
|
|
def _load_json(path: str) -> Dict:
|
|
if not os.path.exists(path):
|
|
return {}
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
return {}
|
|
|
|
|
|
@router.get("/data-sources", response_model=Dict[str, List[str]])
|
|
async def get_data_sources():
|
|
"""
|
|
Get the list of data sources that require an API key from the config.
|
|
"""
|
|
try:
|
|
data_sources_config = get_dm().config.get("data_sources", {})
|
|
sources_requiring_keys = [
|
|
source for source, config in data_sources_config.items()
|
|
if config.get("api_key_env")
|
|
]
|
|
return {"sources": sources_requiring_keys}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to load data sources configuration: {e}")
|
|
|
|
|
|
@router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse])
|
|
async def generate_full_analysis(
|
|
ts_code: str,
|
|
company_name: str = Query(None, description="Company name for better context"),
|
|
):
|
|
"""
|
|
Generate a full analysis report by orchestrating multiple analysis modules
|
|
based on dependencies defined in the configuration.
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"[API] Full analysis requested for {ts_code}")
|
|
|
|
# Load base and analysis configurations
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
|
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
|
|
|
api_key = llm_config.get("api_key")
|
|
base_url = llm_config.get("base_url")
|
|
|
|
if not api_key:
|
|
logger.error(f"[API] API key for {llm_provider} not configured")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"API key for {llm_provider} not configured."
|
|
)
|
|
|
|
analysis_config_full = load_analysis_config()
|
|
modules_config = analysis_config_full.get("analysis_modules", {})
|
|
if not modules_config:
|
|
raise HTTPException(status_code=404, detail="Analysis modules configuration not found.")
|
|
|
|
# --- Dependency Resolution (Topological Sort) ---
|
|
def topological_sort(graph):
|
|
in_degree = {u: 0 for u in graph}
|
|
for u in graph:
|
|
for v in graph[u]:
|
|
in_degree[v] += 1
|
|
|
|
queue = [u for u in graph if in_degree[u] == 0]
|
|
sorted_order = []
|
|
|
|
while queue:
|
|
u = queue.pop(0)
|
|
sorted_order.append(u)
|
|
for v in graph.get(u, []):
|
|
in_degree[v] -= 1
|
|
if in_degree[v] == 0:
|
|
queue.append(v)
|
|
|
|
if len(sorted_order) == len(graph):
|
|
return sorted_order
|
|
else:
|
|
# Detect cycles and provide a meaningful error
|
|
cycles = []
|
|
visited = set()
|
|
path = []
|
|
|
|
def find_cycle_util(node):
|
|
visited.add(node)
|
|
path.append(node)
|
|
for neighbor in graph.get(node, []):
|
|
if neighbor in path:
|
|
cycle_start_index = path.index(neighbor)
|
|
cycles.append(path[cycle_start_index:] + [neighbor])
|
|
return
|
|
if neighbor not in visited:
|
|
find_cycle_util(neighbor)
|
|
path.pop()
|
|
|
|
for node in graph:
|
|
if node not in visited:
|
|
find_cycle_util(node)
|
|
|
|
return None, cycles
|
|
|
|
|
|
# Build dependency graph
|
|
dependency_graph = {
|
|
name: config.get("dependencies", [])
|
|
for name, config in modules_config.items()
|
|
}
|
|
|
|
# Invert graph for topological sort (from dependency to dependent)
|
|
adj_list = {u: [] for u in dependency_graph}
|
|
for u, dependencies in dependency_graph.items():
|
|
for dep in dependencies:
|
|
if dep in adj_list:
|
|
adj_list[dep].append(u)
|
|
|
|
sorted_modules, cycle = topological_sort(adj_list)
|
|
if not sorted_modules:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}"
|
|
)
|
|
|
|
# --- Fetch common data (company name, financial data) ---
|
|
# This logic is duplicated, could be refactored into a helper
|
|
financial_data = None
|
|
if not company_name:
|
|
logger.info(f"[API] Fetching company name for {ts_code}")
|
|
try:
|
|
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
|
if basic_data:
|
|
company_name = basic_data.get("name", ts_code)
|
|
logger.info(f"[API] Got company name: {company_name}")
|
|
else:
|
|
company_name = ts_code
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}")
|
|
company_name = ts_code
|
|
|
|
# --- Execute modules in order ---
|
|
analysis_results = []
|
|
completed_modules_content = {}
|
|
|
|
for module_type in sorted_modules:
|
|
module_config = modules_config[module_type]
|
|
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
|
|
|
|
client = AnalysisClient(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model=module_config.get("model", "gemini-1.5-flash")
|
|
)
|
|
|
|
# Gather context from completed dependencies
|
|
context = {
|
|
dep: completed_modules_content.get(dep, "")
|
|
for dep in module_config.get("dependencies", [])
|
|
}
|
|
|
|
result = await client.generate_analysis(
|
|
analysis_type=module_type,
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
prompt_template=module_config.get("prompt_template", ""),
|
|
financial_data=financial_data,
|
|
context=context,
|
|
)
|
|
|
|
response = AnalysisResponse(
|
|
ts_code=ts_code,
|
|
company_name=company_name,
|
|
analysis_type=module_type,
|
|
content=result.get("content", ""),
|
|
model=result.get("model", module_config.get("model")),
|
|
tokens=result.get("tokens", {}),
|
|
elapsed_ms=result.get("elapsed_ms", 0),
|
|
success=result.get("success", False),
|
|
error=result.get("error")
|
|
)
|
|
|
|
analysis_results.append(response)
|
|
|
|
if response.success:
|
|
completed_modules_content[module_type] = response.content
|
|
else:
|
|
# If a module fails, subsequent dependent modules will get an empty string for its context.
|
|
# This prevents total failure but may affect quality.
|
|
completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed."
|
|
logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}")
|
|
|
|
logger.info(f"[API] Full analysis for {ts_code} completed.")
|
|
return analysis_results
|
|
|
|
|
|
@router.get("/config", response_model=FinancialConfigResponse)
|
|
async def get_financial_config():
|
|
data = _load_json(FINANCIAL_CONFIG_PATH)
|
|
api_groups = data.get("api_groups", {})
|
|
return FinancialConfigResponse(api_groups=api_groups)
|
|
|
|
|
|
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
|
|
async def get_china_financials(
|
|
ts_code: str,
|
|
years: int = Query(5, ge=1, le=15),
|
|
):
|
|
# Load metric config
|
|
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
|
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
|
|
|
|
# Meta tracking
|
|
started_real = datetime.now(timezone.utc)
|
|
started = time.perf_counter_ns()
|
|
api_calls_total = 0 # This will be harder to track now, maybe DataManager should provide it
|
|
api_calls_by_group: Dict[str, int] = {}
|
|
steps: List[StepRecord] = []
|
|
|
|
# Get company name
|
|
company_name = ts_code
|
|
try:
|
|
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
|
if basic_data:
|
|
company_name = basic_data.get("name", ts_code)
|
|
except Exception:
|
|
pass # Continue without it
|
|
|
|
# Collect series per metric key
|
|
series: Dict[str, List[Dict]] = {}
|
|
errors: Dict[str, str] = {}
|
|
|
|
# Generate date range for financial statements
|
|
current_year = datetime.now().year
|
|
report_dates = [f"{year}1231" for year in range(current_year - years, current_year + 1)]
|
|
|
|
# Fetch all financial statements at once
|
|
step_financials = StepRecord(name="拉取财务报表", start_ts=started_real.isoformat(), status="running")
|
|
steps.append(step_financials)
|
|
|
|
all_financial_data = await get_dm().get_financial_statements(stock_code=ts_code, report_dates=report_dates)
|
|
|
|
if all_financial_data:
|
|
# Process financial data into the 'series' format
|
|
for report in all_financial_data:
|
|
year = report.get("end_date", "")[:4]
|
|
for key, value in report.items():
|
|
# Skip non-numeric fields like ts_code, end_date, ann_date, etc.
|
|
if key in ['ts_code', 'end_date', 'ann_date', 'f_ann_date', 'report_type', 'comp_type', 'end_type', 'update_flag']:
|
|
continue
|
|
|
|
# Only include numeric values
|
|
if isinstance(value, (int, float)) and value is not None:
|
|
if key not in series:
|
|
series[key] = []
|
|
|
|
# Avoid duplicates for the same year
|
|
if not any(d['year'] == year for d in series[key]):
|
|
series[key].append({"year": year, "value": value})
|
|
else:
|
|
errors["financial_statements"] = "Failed to fetch from all providers."
|
|
|
|
step_financials.status = "done"
|
|
step_financials.end_ts = datetime.now(timezone.utc).isoformat()
|
|
step_financials.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
|
|
# --- Potentially fetch other data types like daily prices if needed by config ---
|
|
# This part is simplified. The original code had complex logic for different api_groups.
|
|
# We will assume for now that the main data comes from financial_statements.
|
|
# The logic can be extended here to call other data_manager methods based on `fin_cfg`.
|
|
|
|
finished_real = datetime.now(timezone.utc)
|
|
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
|
|
if not series:
|
|
raise HTTPException(status_code=502, detail={"message": "No data returned from any data source", "errors": errors})
|
|
|
|
# Truncate years and sort (the data should already be mostly correct, but we ensure)
|
|
for key, arr in series.items():
|
|
# Deduplicate and sort desc by year, then cut to requested years, and return asc
|
|
uniq = {item["year"]: item for item in arr}
|
|
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
|
|
arr_limited = arr_sorted_desc[:years]
|
|
arr_sorted = sorted(arr_limited, key=lambda x: x["year"])
|
|
series[key] = arr_sorted
|
|
|
|
meta = FinancialMeta(
|
|
started_at=started_real.isoformat(),
|
|
finished_at=finished_real.isoformat(),
|
|
elapsed_ms=elapsed_ms,
|
|
api_calls_total=api_calls_total,
|
|
api_calls_by_group=api_calls_by_group,
|
|
current_action=None,
|
|
steps=steps,
|
|
)
|
|
|
|
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
|
|
|
|
|
|
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
|
|
async def get_company_profile(
|
|
ts_code: str,
|
|
company_name: str = Query(None, description="Company name for better context"),
|
|
):
|
|
"""
|
|
Get company profile for a company using Gemini AI (non-streaming, single response)
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"[API] Company profile requested for {ts_code}")
|
|
|
|
# Load config
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
|
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
|
|
|
api_key = llm_config.get("api_key")
|
|
base_url = llm_config.get("base_url") # Will be None if not set, handled by client
|
|
|
|
if not api_key:
|
|
logger.error(f"[API] API key for {llm_provider} not configured")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"API key for {llm_provider} not configured."
|
|
)
|
|
|
|
client = CompanyProfileClient(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model="gemini-1.5-flash"
|
|
)
|
|
|
|
# Get company name from ts_code if not provided
|
|
if not company_name:
|
|
logger.info(f"[API] Fetching company name for {ts_code}")
|
|
# Try to get from stock_basic API
|
|
try:
|
|
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
|
if basic_data:
|
|
company_name = basic_data.get("name", ts_code)
|
|
logger.info(f"[API] Got company name: {company_name}")
|
|
else:
|
|
company_name = ts_code
|
|
except Exception as e:
|
|
logger.warning(f"[API] Failed to get company name: {e}")
|
|
company_name = ts_code
|
|
|
|
logger.info(f"[API] Generating profile for {company_name}")
|
|
|
|
# Generate profile using non-streaming API
|
|
result = await client.generate_profile(
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
financial_data=None
|
|
)
|
|
|
|
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
|
|
|
|
return CompanyProfileResponse(
|
|
ts_code=ts_code,
|
|
company_name=company_name,
|
|
content=result.get("content", ""),
|
|
model=result.get("model", "gemini-2.5-flash"),
|
|
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
|
elapsed_ms=result.get("elapsed_ms", 0),
|
|
success=result.get("success", False),
|
|
error=result.get("error")
|
|
)
|
|
|
|
|
|
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
|
|
async def get_analysis_config_endpoint():
|
|
"""Get analysis configuration"""
|
|
config = load_analysis_config()
|
|
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
|
|
|
|
|
|
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
|
|
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
|
|
"""Update analysis configuration"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
# 保存到文件
|
|
config_data = {
|
|
"analysis_modules": analysis_config.analysis_modules
|
|
}
|
|
|
|
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
|
|
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
|
|
|
logger.info(f"[API] Analysis config updated successfully")
|
|
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
|
|
except Exception as e:
|
|
logger.error(f"[API] Failed to update analysis config: {e}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to update analysis config: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
|
|
async def generate_analysis(
|
|
ts_code: str,
|
|
analysis_type: str,
|
|
company_name: str = Query(None, description="Company name for better context"),
|
|
):
|
|
"""
|
|
Generate analysis for a company using Gemini AI
|
|
Supported analysis types:
|
|
- fundamental_analysis (基本面分析)
|
|
- bull_case (看涨分析)
|
|
- bear_case (看跌分析)
|
|
- market_analysis (市场分析)
|
|
- news_analysis (新闻分析)
|
|
- trading_analysis (交易分析)
|
|
- insider_institutional (内部人与机构动向分析)
|
|
- final_conclusion (最终结论)
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
|
|
|
# Load config
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
|
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
|
|
|
api_key = llm_config.get("api_key")
|
|
base_url = llm_config.get("base_url")
|
|
|
|
if not api_key:
|
|
logger.error(f"[API] API key for {llm_provider} not configured")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"API key for {llm_provider} not configured."
|
|
)
|
|
|
|
# Get analysis configuration
|
|
analysis_cfg = get_analysis_config(analysis_type)
|
|
if not analysis_cfg:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
|
)
|
|
|
|
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
|
prompt_template = analysis_cfg.get("prompt_template", "")
|
|
|
|
if not prompt_template:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
|
)
|
|
|
|
# Get company name from ts_code if not provided
|
|
financial_data = None
|
|
if not company_name:
|
|
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
|
|
try:
|
|
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
|
if basic_data:
|
|
company_name = basic_data.get("name", ts_code)
|
|
logger.info(f"[API] Got company name: {company_name}")
|
|
|
|
# Try to get financial data for context
|
|
try:
|
|
# A simplified approach to get the latest year's financial data
|
|
current_year = datetime.now().year
|
|
report_dates = [f"{current_year-1}1231"] # Get last year's report
|
|
latest_financials = await get_dm().get_financial_statements(
|
|
stock_code=ts_code,
|
|
report_dates=report_dates
|
|
)
|
|
if latest_financials:
|
|
financial_data = {"series": latest_financials[0]}
|
|
except Exception as e:
|
|
logger.warning(f"[API] Failed to get financial data: {e}")
|
|
financial_data = None
|
|
else:
|
|
company_name = ts_code
|
|
except Exception as e:
|
|
logger.warning(f"[API] Failed to get company name: {e}")
|
|
company_name = ts_code
|
|
|
|
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
|
|
|
# Initialize analysis client with configured model
|
|
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
|
|
|
# Prepare dependency context for single-module generation
|
|
# If the requested module declares dependencies, generate them first and inject their outputs
|
|
context = {}
|
|
try:
|
|
dependencies = analysis_cfg.get("dependencies", []) or []
|
|
if dependencies:
|
|
# Load full modules config to resolve dependency graph
|
|
analysis_config_full = load_analysis_config()
|
|
modules_config = analysis_config_full.get("analysis_modules", {})
|
|
|
|
# Collect all transitive dependencies
|
|
all_required = set()
|
|
|
|
def collect_all_deps(mod_name: str):
|
|
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
|
|
if dep not in all_required:
|
|
all_required.add(dep)
|
|
collect_all_deps(dep)
|
|
|
|
for dep in dependencies:
|
|
all_required.add(dep)
|
|
collect_all_deps(dep)
|
|
|
|
# Build subgraph and topologically sort
|
|
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
|
|
in_degree = {u: 0 for u in graph}
|
|
for u, deps in graph.items():
|
|
for v in deps:
|
|
in_degree[v] += 1
|
|
queue = [u for u, deg in in_degree.items() if deg == 0]
|
|
order = []
|
|
while queue:
|
|
u = queue.pop(0)
|
|
order.append(u)
|
|
for v in graph.get(u, []):
|
|
in_degree[v] -= 1
|
|
if in_degree[v] == 0:
|
|
queue.append(v)
|
|
if len(order) != len(graph):
|
|
# Fallback: if cycle detected, just use any order
|
|
order = list(all_required)
|
|
|
|
# Generate dependencies in order
|
|
completed = {}
|
|
for mod in order:
|
|
cfg = modules_config.get(mod, {})
|
|
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
|
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
|
dep_result = await dep_client.generate_analysis(
|
|
analysis_type=mod,
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
prompt_template=cfg.get("prompt_template", ""),
|
|
financial_data=financial_data,
|
|
context=dep_ctx,
|
|
)
|
|
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
|
|
|
|
context = {dep: completed.get(dep, "") for dep in dependencies}
|
|
except Exception:
|
|
# Best-effort context; if anything goes wrong, continue without it
|
|
context = {}
|
|
|
|
# Generate analysis
|
|
result = await client.generate_analysis(
|
|
analysis_type=analysis_type,
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
prompt_template=prompt_template,
|
|
financial_data=financial_data,
|
|
context=context,
|
|
)
|
|
|
|
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
|
|
|
|
return AnalysisResponse(
|
|
ts_code=ts_code,
|
|
company_name=company_name,
|
|
analysis_type=analysis_type,
|
|
content=result.get("content", ""),
|
|
model=result.get("model", model),
|
|
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
|
elapsed_ms=result.get("elapsed_ms", 0),
|
|
success=result.get("success", False),
|
|
error=result.get("error")
|
|
)
|
|
|
|
|
|
@router.get("/china/{ts_code}/analysis/{analysis_type}/stream")
|
|
async def stream_analysis(
|
|
ts_code: str,
|
|
analysis_type: str,
|
|
company_name: str = Query(None, description="Company name for better context"),
|
|
):
|
|
"""
|
|
Stream analysis content chunks for a given module using OpenAI-compatible streaming.
|
|
Plain text streaming (text/plain; utf-8). Dependencies are resolved first (non-stream),
|
|
then the target module content is streamed.
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}")
|
|
|
|
# Load config
|
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
|
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
|
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
|
|
|
api_key = llm_config.get("api_key")
|
|
base_url = llm_config.get("base_url")
|
|
|
|
if not api_key:
|
|
logger.error(f"[API] API key for {llm_provider} not configured")
|
|
raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.")
|
|
|
|
# Get analysis configuration
|
|
analysis_cfg = get_analysis_config(analysis_type)
|
|
if not analysis_cfg:
|
|
raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration")
|
|
|
|
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
|
prompt_template = analysis_cfg.get("prompt_template", "")
|
|
if not prompt_template:
|
|
raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'")
|
|
|
|
# Get company name from ts_code if not provided; we don't need full financials here
|
|
financial_data = None
|
|
if not company_name:
|
|
try:
|
|
basic_data = await get_dm().get_stock_basic(stock_code=ts_code)
|
|
if basic_data:
|
|
company_name = basic_data.get("name", ts_code)
|
|
else:
|
|
company_name = ts_code
|
|
except Exception:
|
|
company_name = ts_code
|
|
|
|
# Resolve dependency context (non-streaming)
|
|
context = {}
|
|
try:
|
|
dependencies = analysis_cfg.get("dependencies", []) or []
|
|
if dependencies:
|
|
analysis_config_full = load_analysis_config()
|
|
modules_config = analysis_config_full.get("analysis_modules", {})
|
|
|
|
all_required = set()
|
|
def collect_all(mod_name: str):
|
|
for dep in modules_config.get(mod_name, {}).get("dependencies", []) or []:
|
|
if dep not in all_required:
|
|
all_required.add(dep)
|
|
collect_all(dep)
|
|
for dep in dependencies:
|
|
all_required.add(dep)
|
|
collect_all(dep)
|
|
|
|
graph = {name: [d for d in (modules_config.get(name, {}).get("dependencies", []) or []) if d in all_required] for name in all_required}
|
|
in_degree = {u: 0 for u in graph}
|
|
for u, deps in graph.items():
|
|
for v in deps:
|
|
in_degree[v] += 1
|
|
queue = [u for u, deg in in_degree.items() if deg == 0]
|
|
order = []
|
|
while queue:
|
|
u = queue.pop(0)
|
|
order.append(u)
|
|
for v in graph.get(u, []):
|
|
in_degree[v] -= 1
|
|
if in_degree[v] == 0:
|
|
queue.append(v)
|
|
if len(order) != len(graph):
|
|
order = list(all_required)
|
|
|
|
completed = {}
|
|
for mod in order:
|
|
cfg = modules_config.get(mod, {})
|
|
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
|
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
|
dep_result = await dep_client.generate_analysis(
|
|
analysis_type=mod,
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
prompt_template=cfg.get("prompt_template", ""),
|
|
financial_data=financial_data,
|
|
context=dep_ctx,
|
|
)
|
|
completed[mod] = dep_result.get("content", "") if dep_result.get("success") else ""
|
|
context = {dep: completed.get(dep, "") for dep in dependencies}
|
|
except Exception:
|
|
context = {}
|
|
|
|
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
|
|
|
async def streamer():
|
|
# Optional header line to help client-side UI
|
|
header = f"# {analysis_cfg.get('name', analysis_type)}\n\n"
|
|
yield header
|
|
async for chunk in client.generate_analysis_stream(
|
|
analysis_type=analysis_type,
|
|
company_name=company_name,
|
|
ts_code=ts_code,
|
|
prompt_template=prompt_template,
|
|
financial_data=financial_data,
|
|
context=context,
|
|
):
|
|
yield chunk
|
|
|
|
headers = {
|
|
# 禁止中间层缓冲,确保尽快把分块推送给客户端
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"X-Accel-Buffering": "no",
|
|
}
|
|
return StreamingResponse(streamer(), media_type="text/plain; charset=utf-8", headers=headers)
|