本次提交引入了一系列重要功能,核心是实现了财务分析模块的动态配置,并对配置和报告页面的用户界面进行了改进。 主要变更: - **动态配置:** - 后端实现了 `ConfigManager` 服务,用于动态管理 `analysis-config.json` 和 `config.json`。 - 添加了用于读取和更新配置的 API 端点。 - 开发了前端 `/config` 页面,允许用户实时查看和修改分析配置。 - **后端增强:** - 更新了 `AnalysisClient` 和 `CompanyProfileClient` 以使用新的配置系统。 - 重构了财务数据相关的路由。 - **前端改进:** - 新增了可复用的 `Checkbox` UI 组件。 - 使用更直观和用户友好的界面重新设计了配置页面。 - 改进了财务报告页面的布局和数据展示。 - **文档与杂务:** - 更新了设计和需求文档以反映新功能。 - 更新了前后端依赖。 - 修改了开发脚本 `dev.sh`。
694 lines
28 KiB
Python
694 lines
28 KiB
Python
"""
|
||
API router for financial data (Tushare for China market)
|
||
"""
|
||
import json
|
||
import os
|
||
import time
|
||
from datetime import datetime, timezone, timedelta
|
||
from typing import Dict, List
|
||
|
||
from fastapi import APIRouter, HTTPException, Query
|
||
from fastapi.responses import StreamingResponse
|
||
import os
|
||
|
||
from app.core.config import settings
|
||
from app.schemas.financial import (
|
||
BatchFinancialDataResponse,
|
||
FinancialConfigResponse,
|
||
FinancialMeta,
|
||
StepRecord,
|
||
CompanyProfileResponse,
|
||
AnalysisResponse,
|
||
AnalysisConfigResponse
|
||
)
|
||
from app.services.tushare_client import TushareClient
|
||
from app.services.company_profile_client import CompanyProfileClient
|
||
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
||
|
||
router = APIRouter()
|
||
|
||
# Load metric config from file (project root is repo root, not backend/)
|
||
# routers/ -> app/ -> backend/ -> repo root
|
||
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
||
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
||
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
|
||
|
||
|
||
def _load_json(path: str) -> Dict:
|
||
if not os.path.exists(path):
|
||
return {}
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
@router.post("/china/{ts_code}/analysis", response_model=List[AnalysisResponse])
|
||
async def generate_full_analysis(
|
||
ts_code: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Generate a full analysis report by orchestrating multiple analysis modules
|
||
based on dependencies defined in the configuration.
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Full analysis requested for {ts_code}")
|
||
|
||
# Load base and analysis configurations
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url")
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
analysis_config_full = load_analysis_config()
|
||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||
if not modules_config:
|
||
raise HTTPException(status_code=404, detail="Analysis modules configuration not found.")
|
||
|
||
# --- Dependency Resolution (Topological Sort) ---
|
||
def topological_sort(graph):
|
||
in_degree = {u: 0 for u in graph}
|
||
for u in graph:
|
||
for v in graph[u]:
|
||
in_degree[v] += 1
|
||
|
||
queue = [u for u in graph if in_degree[u] == 0]
|
||
sorted_order = []
|
||
|
||
while queue:
|
||
u = queue.pop(0)
|
||
sorted_order.append(u)
|
||
for v in graph.get(u, []):
|
||
in_degree[v] -= 1
|
||
if in_degree[v] == 0:
|
||
queue.append(v)
|
||
|
||
if len(sorted_order) == len(graph):
|
||
return sorted_order
|
||
else:
|
||
# Detect cycles and provide a meaningful error
|
||
cycles = []
|
||
visited = set()
|
||
path = []
|
||
|
||
def find_cycle_util(node):
|
||
visited.add(node)
|
||
path.append(node)
|
||
for neighbor in graph.get(node, []):
|
||
if neighbor in path:
|
||
cycle_start_index = path.index(neighbor)
|
||
cycles.append(path[cycle_start_index:] + [neighbor])
|
||
return
|
||
if neighbor not in visited:
|
||
find_cycle_util(neighbor)
|
||
path.pop()
|
||
|
||
for node in graph:
|
||
if node not in visited:
|
||
find_cycle_util(node)
|
||
|
||
return None, cycles
|
||
|
||
|
||
# Build dependency graph
|
||
dependency_graph = {
|
||
name: config.get("dependencies", [])
|
||
for name, config in modules_config.items()
|
||
}
|
||
|
||
# Invert graph for topological sort (from dependency to dependent)
|
||
adj_list = {u: [] for u in dependency_graph}
|
||
for u, dependencies in dependency_graph.items():
|
||
for dep in dependencies:
|
||
if dep in adj_list:
|
||
adj_list[dep].append(u)
|
||
|
||
sorted_modules, cycle = topological_sort(adj_list)
|
||
if not sorted_modules:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Circular dependency detected in analysis modules configuration. Cycle: {cycle}"
|
||
)
|
||
|
||
# --- Fetch common data (company name, financial data) ---
|
||
# This logic is duplicated, could be refactored into a helper
|
||
financial_data = None
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name for {ts_code}")
|
||
try:
|
||
token = base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
if token:
|
||
tushare_client = TushareClient(token=token)
|
||
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
if basic_data:
|
||
company_name = basic_data[0].get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get company name, proceeding with ts_code. Error: {e}")
|
||
company_name = ts_code
|
||
|
||
# --- Execute modules in order ---
|
||
analysis_results = []
|
||
completed_modules_content = {}
|
||
|
||
for module_type in sorted_modules:
|
||
module_config = modules_config[module_type]
|
||
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
|
||
|
||
client = AnalysisClient(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model=module_config.get("model", "gemini-1.5-flash")
|
||
)
|
||
|
||
# Gather context from completed dependencies
|
||
context = {
|
||
dep: completed_modules_content.get(dep, "")
|
||
for dep in module_config.get("dependencies", [])
|
||
}
|
||
|
||
result = await client.generate_analysis(
|
||
analysis_type=module_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=module_config.get("prompt_template", ""),
|
||
financial_data=financial_data,
|
||
context=context,
|
||
)
|
||
|
||
response = AnalysisResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
analysis_type=module_type,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", module_config.get("model")),
|
||
tokens=result.get("tokens", {}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
analysis_results.append(response)
|
||
|
||
if response.success:
|
||
completed_modules_content[module_type] = response.content
|
||
else:
|
||
# If a module fails, subsequent dependent modules will get an empty string for its context.
|
||
# This prevents total failure but may affect quality.
|
||
completed_modules_content[module_type] = f"Error: Analysis for {module_type} failed."
|
||
logger.error(f"[Orchestrator] Module {module_type} failed: {response.error}")
|
||
|
||
logger.info(f"[API] Full analysis for {ts_code} completed.")
|
||
return analysis_results
|
||
|
||
|
||
@router.get("/config", response_model=FinancialConfigResponse)
|
||
async def get_financial_config():
|
||
data = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups = data.get("api_groups", {})
|
||
return FinancialConfigResponse(api_groups=api_groups)
|
||
|
||
|
||
@router.get("/china/{ts_code}", response_model=BatchFinancialDataResponse)
|
||
async def get_china_financials(
|
||
ts_code: str,
|
||
years: int = Query(5, ge=1, le=15),
|
||
):
|
||
# Load Tushare token
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if not token:
|
||
raise HTTPException(status_code=500, detail="Tushare API token not configured. Set TUSHARE_TOKEN env or config/config.json data_sources.tushare.api_key")
|
||
|
||
# Load metric config
|
||
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups: Dict[str, List[Dict]] = fin_cfg.get("api_groups", {})
|
||
|
||
client = TushareClient(token=token)
|
||
|
||
# Meta tracking
|
||
started_real = datetime.now(timezone.utc)
|
||
started = time.perf_counter_ns()
|
||
api_calls_total = 0
|
||
api_calls_by_group: Dict[str, int] = {}
|
||
steps: List[StepRecord] = []
|
||
current_action = "初始化"
|
||
|
||
# Get company name from stock_basic API
|
||
company_name = None
|
||
try:
|
||
basic_data = await client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
api_calls_total += 1
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name")
|
||
except Exception:
|
||
# If getting company name fails, continue without it
|
||
pass
|
||
|
||
# Collect series per metric key
|
||
series: Dict[str, List[Dict]] = {}
|
||
|
||
# Helper to store year-value pairs while keeping most recent per year
|
||
def _merge_year_value(key: str, year: str, value, month: int = None):
|
||
arr = series.setdefault(key, [])
|
||
# upsert by year
|
||
for item in arr:
|
||
if item["year"] == year:
|
||
item["value"] = value
|
||
if month is not None:
|
||
item["month"] = month
|
||
return
|
||
arr.append({"year": year, "value": value, "month": month})
|
||
|
||
# Query each API group we care
|
||
errors: Dict[str, str] = {}
|
||
for group_name, metrics in api_groups.items():
|
||
step = StepRecord(
|
||
name=f"拉取 {group_name}",
|
||
start_ts=started_real.isoformat(),
|
||
status="running",
|
||
)
|
||
steps.append(step)
|
||
current_action = step.name
|
||
if not metrics:
|
||
continue
|
||
|
||
# 按 API 分组 metrics(处理 unknown 组中有多个不同 API 的情况)
|
||
api_groups_dict: Dict[str, List[Dict]] = {}
|
||
for metric in metrics:
|
||
api = metric.get("api") or group_name
|
||
if api: # 跳过空 API
|
||
if api not in api_groups_dict:
|
||
api_groups_dict[api] = []
|
||
api_groups_dict[api].append(metric)
|
||
|
||
# 对每个 API 分别处理
|
||
for api_name, api_metrics in api_groups_dict.items():
|
||
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||
if not fields:
|
||
continue
|
||
|
||
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||
|
||
# 构建 API 参数
|
||
params = {"ts_code": ts_code, "limit": 5000}
|
||
|
||
# 对于需要日期范围的 API(如 stk_holdernumber),添加日期参数
|
||
if api_name == "stk_holdernumber":
|
||
# 计算日期范围:从 years 年前到现在
|
||
end_date = datetime.now().strftime("%Y%m%d")
|
||
start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d")
|
||
params["start_date"] = start_date
|
||
params["end_date"] = end_date
|
||
# stk_holdernumber 返回的日期字段通常是 end_date
|
||
date_field = "end_date"
|
||
|
||
# 对于非时间序列 API(如 stock_company),标记为静态数据
|
||
is_static_data = api_name == "stock_company"
|
||
|
||
# 构建 fields 字符串:包含日期字段和所有需要的指标字段
|
||
# 确保日期字段存在,因为我们需要用它来确定年份
|
||
fields_list = list(fields)
|
||
if date_field not in fields_list:
|
||
fields_list.insert(0, date_field)
|
||
# 对于 fina_indicator 等 API,通常还需要 ts_code 和 ann_date
|
||
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||
for req_field in ["ts_code", "ann_date"]:
|
||
if req_field not in fields_list:
|
||
fields_list.insert(0, req_field)
|
||
fields_str = ",".join(fields_list)
|
||
|
||
try:
|
||
data_rows = await client.query(api_name=api_name, params=params, fields=fields_str)
|
||
api_calls_total += 1
|
||
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
|
||
except Exception as e:
|
||
# 记录错误但继续处理其他 API
|
||
error_key = f"{group_name}_{api_name}"
|
||
errors[error_key] = str(e)
|
||
continue
|
||
|
||
tmp: Dict[str, Dict] = {}
|
||
current_year = datetime.now().strftime("%Y")
|
||
|
||
for row in data_rows:
|
||
if is_static_data:
|
||
# 对于静态数据(如 stock_company),使用当前年份
|
||
# 只处理第一行数据,因为静态数据通常只有一行
|
||
if current_year not in tmp:
|
||
year = current_year
|
||
month = None
|
||
tmp[year] = row
|
||
tmp[year]['_month'] = month
|
||
# 跳过后续行
|
||
continue
|
||
else:
|
||
# 对于时间序列数据,按日期字段处理
|
||
date_val = row.get(date_field)
|
||
if not date_val:
|
||
continue
|
||
year = str(date_val)[:4]
|
||
month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None
|
||
existing = tmp.get(year)
|
||
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
||
tmp[year] = row
|
||
tmp[year]['_month'] = month
|
||
|
||
for metric in api_metrics:
|
||
key = metric.get("tushareParam")
|
||
if not key:
|
||
continue
|
||
for year, row in tmp.items():
|
||
month = row.get('_month')
|
||
_merge_year_value(key, year, row.get(key), month)
|
||
|
||
step.status = "done"
|
||
step.end_ts = datetime.now(timezone.utc).isoformat()
|
||
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
finished_real = datetime.now(timezone.utc)
|
||
elapsed_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||
|
||
if not series:
|
||
# If nothing succeeded, expose partial error info
|
||
raise HTTPException(status_code=502, detail={"message": "No data returned from Tushare", "errors": errors})
|
||
|
||
# Truncate years and sort
|
||
for key, arr in series.items():
|
||
# Deduplicate and sort desc by year, then cut to requested years, and return asc
|
||
uniq = {item["year"]: item for item in arr}
|
||
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x["year"], reverse=True)
|
||
arr_limited = arr_sorted_desc[:years]
|
||
arr_sorted = sorted(arr_limited, key=lambda x: x["year"]) # ascending by year
|
||
series[key] = arr_sorted
|
||
|
||
meta = FinancialMeta(
|
||
started_at=started_real.isoformat(),
|
||
finished_at=finished_real.isoformat(),
|
||
elapsed_ms=elapsed_ms,
|
||
api_calls_total=api_calls_total,
|
||
api_calls_by_group=api_calls_by_group,
|
||
current_action=None,
|
||
steps=steps,
|
||
)
|
||
|
||
return BatchFinancialDataResponse(ts_code=ts_code, name=company_name, series=series, meta=meta)
|
||
|
||
|
||
@router.get("/china/{ts_code}/company-profile", response_model=CompanyProfileResponse)
|
||
async def get_company_profile(
|
||
ts_code: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Get company profile for a company using Gemini AI (non-streaming, single response)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Company profile requested for {ts_code}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url") # Will be None if not set, handled by client
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
client = CompanyProfileClient(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
model="gemini-1.5-flash"
|
||
)
|
||
|
||
# Get company name from ts_code if not provided
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name for {ts_code}")
|
||
# Try to get from stock_basic API
|
||
try:
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if token:
|
||
from app.services.tushare_client import TushareClient
|
||
tushare_client = TushareClient(token=token)
|
||
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
else:
|
||
company_name = ts_code
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating profile for {company_name}")
|
||
|
||
# Generate profile using non-streaming API
|
||
result = await client.generate_profile(
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
financial_data=None
|
||
)
|
||
|
||
logger.info(f"[API] Profile generation completed, success={result.get('success')}")
|
||
|
||
return CompanyProfileResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", "gemini-2.5-flash"),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|
||
|
||
|
||
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def get_analysis_config_endpoint():
|
||
"""Get analysis configuration"""
|
||
config = load_analysis_config()
|
||
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
|
||
|
||
|
||
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
|
||
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
|
||
"""Update analysis configuration"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
# 保存到文件
|
||
config_data = {
|
||
"analysis_modules": analysis_config.analysis_modules
|
||
}
|
||
|
||
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"[API] Analysis config updated successfully")
|
||
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
|
||
except Exception as e:
|
||
logger.error(f"[API] Failed to update analysis config: {e}")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to update analysis config: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
|
||
async def generate_analysis(
|
||
ts_code: str,
|
||
analysis_type: str,
|
||
company_name: str = Query(None, description="Company name for better context"),
|
||
):
|
||
"""
|
||
Generate analysis for a company using Gemini AI
|
||
Supported analysis types:
|
||
- fundamental_analysis (基本面分析)
|
||
- bull_case (看涨分析)
|
||
- bear_case (看跌分析)
|
||
- market_analysis (市场分析)
|
||
- news_analysis (新闻分析)
|
||
- trading_analysis (交易分析)
|
||
- insider_institutional (内部人与机构动向分析)
|
||
- final_conclusion (最终结论)
|
||
"""
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
||
|
||
# Load config
|
||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||
|
||
api_key = llm_config.get("api_key")
|
||
base_url = llm_config.get("base_url")
|
||
|
||
if not api_key:
|
||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"API key for {llm_provider} not configured."
|
||
)
|
||
|
||
# Get analysis configuration
|
||
analysis_cfg = get_analysis_config(analysis_type)
|
||
if not analysis_cfg:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||
)
|
||
|
||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||
|
||
if not prompt_template:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||
)
|
||
|
||
# Get company name from ts_code if not provided
|
||
financial_data = None
|
||
if not company_name:
|
||
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
|
||
try:
|
||
token = (
|
||
os.environ.get("TUSHARE_TOKEN")
|
||
or settings.TUSHARE_TOKEN
|
||
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||
)
|
||
if token:
|
||
tushare_client = TushareClient(token=token)
|
||
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||
if basic_data and len(basic_data) > 0:
|
||
company_name = basic_data[0].get("name", ts_code)
|
||
logger.info(f"[API] Got company name: {company_name}")
|
||
|
||
# Try to get financial data for context
|
||
try:
|
||
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||
api_groups = fin_cfg.get("api_groups", {})
|
||
|
||
# Get financial data summary for context
|
||
series: Dict[str, List[Dict]] = {}
|
||
for group_name, metrics in api_groups.items():
|
||
if not metrics:
|
||
continue
|
||
api_groups_dict: Dict[str, List[Dict]] = {}
|
||
for metric in metrics:
|
||
api = metric.get("api") or group_name
|
||
if api:
|
||
if api not in api_groups_dict:
|
||
api_groups_dict[api] = []
|
||
api_groups_dict[api].append(metric)
|
||
|
||
for api_name, api_metrics in api_groups_dict.items():
|
||
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||
if not fields:
|
||
continue
|
||
|
||
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||
|
||
params = {"ts_code": ts_code, "limit": 500}
|
||
fields_list = list(fields)
|
||
if date_field not in fields_list:
|
||
fields_list.insert(0, date_field)
|
||
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||
for req_field in ["ts_code", "ann_date"]:
|
||
if req_field not in fields_list:
|
||
fields_list.insert(0, req_field)
|
||
fields_str = ",".join(fields_list)
|
||
|
||
try:
|
||
data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str)
|
||
if data_rows:
|
||
# Get latest year's data
|
||
latest_row = data_rows[0] if data_rows else {}
|
||
for metric in api_metrics:
|
||
key = metric.get("tushareParam")
|
||
if key and key in latest_row:
|
||
if key not in series:
|
||
series[key] = []
|
||
series[key].append({
|
||
"year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year),
|
||
"value": latest_row.get(key)
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
financial_data = {"series": series}
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get financial data: {e}")
|
||
financial_data = None
|
||
else:
|
||
company_name = ts_code
|
||
else:
|
||
company_name = ts_code
|
||
except Exception as e:
|
||
logger.warning(f"[API] Failed to get company name: {e}")
|
||
company_name = ts_code
|
||
|
||
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
||
|
||
# Initialize analysis client with configured model
|
||
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
||
|
||
# Generate analysis
|
||
result = await client.generate_analysis(
|
||
analysis_type=analysis_type,
|
||
company_name=company_name,
|
||
ts_code=ts_code,
|
||
prompt_template=prompt_template,
|
||
financial_data=financial_data
|
||
)
|
||
|
||
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
|
||
|
||
return AnalysisResponse(
|
||
ts_code=ts_code,
|
||
company_name=company_name,
|
||
analysis_type=analysis_type,
|
||
content=result.get("content", ""),
|
||
model=result.get("model", model),
|
||
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||
elapsed_ms=result.get("elapsed_ms", 0),
|
||
success=result.get("success", False),
|
||
error=result.get("error")
|
||
)
|