Fundamental_Analysis/backend/app/services/config_manager.py
xucheng a79efd8150 feat: Enhance configuration management with new LLM provider support and API integration
- Backend: Introduced new endpoints for LLM configuration retrieval and updates in `config.py`, allowing dynamic management of LLM provider settings.
- Updated schemas to include `AlphaEngineConfig` for better integration with the new provider.
- Frontend: Added state management for AlphaEngine API credentials in the configuration page, ensuring seamless user experience.
- Configuration files updated to reflect changes in LLM provider settings and API keys.

BREAKING CHANGE: The default LLM provider has been changed from `new_api` to `alpha_engine`, requiring updates to existing configurations.
2025-11-11 20:49:27 +08:00

459 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Configuration Management Service
"""
import json
import os
import asyncio
from typing import Any, Dict
import asyncpg
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.models.system_config import SystemConfig
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, NewApiConfig, AlphaEngineConfig, DataSourceConfig, ConfigTestResponse
class ConfigManager:
"""Manages system configuration by merging a static JSON file with dynamic settings from the database."""
def __init__(self, db_session: AsyncSession, config_path: str = None):
self.db = db_session
if config_path is None:
# Default path: backend/app/services -> project_root/config/config.json
# __file__ = backend/app/services/config_manager.py
# go up three levels to project root
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
self.config_path = os.path.join(project_root, "config", "config.json")
else:
self.config_path = config_path
def _load_base_config_from_file(self) -> Dict[str, Any]:
"""Loads the base configuration from the JSON file."""
if not os.path.exists(self.config_path):
return {}
try:
with open(self.config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (IOError, json.JSONDecodeError):
return {}
async def _load_dynamic_config_from_db(self) -> Dict[str, Any]:
"""Loads dynamic configuration overrides from the database.
当数据库表尚未创建(如开发环境未运行迁移)时,优雅降级为返回空覆盖配置,避免接口 500。
"""
try:
db_configs: Dict[str, Any] = {}
result = await self.db.execute(select(SystemConfig))
for record in result.scalars().all():
db_configs[record.config_key] = record.config_value
return db_configs
except Exception:
# 表不存在或其他数据库错误时,忽略动态配置覆盖
return {}
def _merge_configs(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""Deeply merges the override config into the base config."""
for key, value in overrides.items():
if isinstance(value, dict) and isinstance(base.get(key), dict):
base[key] = self._merge_configs(base[key], value)
else:
base[key] = value
return base
async def get_config(self) -> ConfigResponse:
"""Gets the final, merged configuration."""
base_config = self._load_base_config_from_file()
db_config = await self._load_dynamic_config_from_db()
merged_config = self._merge_configs(base_config, db_config)
# 兼容两种位置:优先使用 new_api其次回退到 llm.new_api
new_api_src = merged_config.get("new_api") or merged_config.get("llm", {}).get("new_api", {})
# 获取 alpha_engine 配置
alpha_engine_src = merged_config.get("alpha_engine") or merged_config.get("llm", {}).get("alpha_engine")
alpha_engine_config = None
if alpha_engine_src:
alpha_engine_config = AlphaEngineConfig(**alpha_engine_src)
return ConfigResponse(
database=DatabaseConfig(**merged_config.get("database", {})),
new_api=NewApiConfig(**(new_api_src or {})),
alpha_engine=alpha_engine_config,
data_sources={
k: DataSourceConfig(**v)
for k, v in merged_config.get("data_sources", {}).items()
}
)
async def get_llm_config(self, provider: str = None) -> Dict[str, Any]:
"""
Get LLM configuration for a specific provider
Args:
provider: Provider name (e.g., "new_api", "gemini", "alpha_engine")
If None, uses the configured provider from config
Returns:
Dictionary with provider configuration and provider name
"""
base_config = self._load_base_config_from_file()
db_config = await self._load_dynamic_config_from_db()
merged_config = self._merge_configs(base_config, db_config)
llm_config = merged_config.get("llm", {})
# Determine provider
if not provider:
provider = llm_config.get("provider", "new_api")
# Get provider-specific config
provider_config = llm_config.get(provider, {})
# Get global model from provider config if available
global_model = provider_config.get("model")
return {
"provider": provider,
"config": provider_config,
"model": global_model # 返回全局模型配置
}
def _filter_empty_values(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Remove empty strings and None values from config dict, but keep 0 and False."""
filtered = {}
for key, value in config_dict.items():
if isinstance(value, dict):
filtered_value = self._filter_empty_values(value)
if filtered_value: # Only add if dict is not empty
filtered[key] = filtered_value
elif value is not None and value != "":
filtered[key] = value
return filtered
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
"""Updates configuration in the database and returns the new merged config."""
try:
update_dict = config_update.dict(exclude_unset=True)
# 过滤空值
update_dict = self._filter_empty_values(update_dict)
# 验证配置数据
self._validate_config_data(update_dict)
# 处理 LLM 相关配置:需要保存到 llm 配置下
llm_updates = {}
if "new_api" in update_dict:
llm_updates["new_api"] = update_dict.pop("new_api")
if "alpha_engine" in update_dict:
llm_updates["alpha_engine"] = update_dict.pop("alpha_engine")
# 保存 LLM 配置
if llm_updates:
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.config_key == "llm")
)
existing_llm_config = result.scalar_one_or_none()
if existing_llm_config:
if isinstance(existing_llm_config.config_value, dict):
merged_llm = self._merge_configs(existing_llm_config.config_value, llm_updates)
existing_llm_config.config_value = merged_llm
else:
existing_llm_config.config_value = llm_updates
else:
# 从文件加载基础配置,然后合并
base_config = self._load_base_config_from_file()
base_llm = base_config.get("llm", {})
merged_llm = self._merge_configs(base_llm, llm_updates)
new_llm_config = SystemConfig(config_key="llm", config_value=merged_llm)
self.db.add(new_llm_config)
# 保存其他配置database, data_sources 等)
for key, value in update_dict.items():
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.config_key == key)
)
existing_config = result.scalar_one_or_none()
if existing_config:
# Merge with existing DB value before updating
if isinstance(existing_config.config_value, dict) and isinstance(value, dict):
merged_value = self._merge_configs(existing_config.config_value, value)
existing_config.config_value = merged_value
else:
existing_config.config_value = value
else:
new_config = SystemConfig(config_key=key, config_value=value)
self.db.add(new_config)
await self.db.commit()
return await self.get_config()
except Exception as e:
await self.db.rollback()
raise e
def _validate_config_data(self, config_data: Dict[str, Any]) -> None:
"""Validate configuration data before saving."""
if "database" in config_data:
db_config = config_data["database"]
if "url" in db_config and db_config["url"]:
url = db_config["url"]
if not url.startswith(("postgresql://", "postgresql+asyncpg://")):
raise ValueError("数据库URL必须以 postgresql:// 或 postgresql+asyncpg:// 开头")
if "new_api" in config_data:
new_api_config = config_data["new_api"]
if "api_key" in new_api_config and new_api_config["api_key"] and len(new_api_config["api_key"]) < 10:
raise ValueError("New API Key长度不能少于10个字符")
if "base_url" in new_api_config and new_api_config["base_url"]:
base_url = new_api_config["base_url"]
if not base_url.startswith(("http://", "https://")):
raise ValueError("New API Base URL必须以 http:// 或 https:// 开头")
if "alpha_engine" in config_data:
alpha_engine_config = config_data["alpha_engine"]
if "api_key" in alpha_engine_config and alpha_engine_config["api_key"] and len(alpha_engine_config["api_key"]) < 5:
raise ValueError("AlphaEngine API Key长度不能少于5个字符")
if "api_url" in alpha_engine_config and alpha_engine_config["api_url"]:
api_url = alpha_engine_config["api_url"]
if not api_url.startswith(("http://", "https://")):
raise ValueError("AlphaEngine API URL必须以 http:// 或 https:// 开头")
if "data_sources" in config_data:
for source_name, source_config in config_data["data_sources"].items():
if "api_key" in source_config and source_config["api_key"] and len(source_config["api_key"]) < 10:
raise ValueError(f"{source_name} API Key长度不能少于10个字符")
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test a specific configuration."""
try:
if config_type == "database":
return await self._test_database(config_data)
elif config_type == "new_api":
return await self._test_new_api(config_data)
elif config_type == "tushare":
return await self._test_tushare(config_data)
elif config_type == "finnhub":
return await self._test_finnhub(config_data)
elif config_type == "alpha_engine":
return await self._test_alpha_engine(config_data)
else:
return ConfigTestResponse(
success=False,
message=f"不支持的配置类型: {config_type}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"测试失败: {str(e)}"
)
async def _test_database(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test database connection."""
db_url = config_data.get("url")
if not db_url:
return ConfigTestResponse(
success=False,
message="数据库URL不能为空"
)
try:
# 解析数据库URL
if db_url.startswith("postgresql+asyncpg://"):
db_url = db_url.replace("postgresql+asyncpg://", "postgresql://")
# 测试连接
conn = await asyncpg.connect(db_url)
await conn.close()
return ConfigTestResponse(
success=True,
message="数据库连接成功"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"数据库连接失败: {str(e)}"
)
async def _test_new_api(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test New API (OpenAI-compatible) connection."""
api_key = config_data.get("api_key")
base_url = config_data.get("base_url")
if not api_key or not base_url:
return ConfigTestResponse(
success=False,
message="New API Key和Base URL均不能为空"
)
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Test API availability by listing models
response = await client.get(
f"{base_url.rstrip('/')}/models",
headers={"Authorization": f"Bearer {api_key}"}
)
if response.status_code == 200:
return ConfigTestResponse(
success=True,
message="New API连接成功"
)
else:
return ConfigTestResponse(
success=False,
message=f"New API测试失败: HTTP {response.status_code} - {response.text}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"New API连接失败: {str(e)}"
)
async def _test_tushare(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test Tushare API connection."""
api_key = config_data.get("api_key")
if not api_key:
return ConfigTestResponse(
success=False,
message="Tushare API Key不能为空"
)
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# 测试API可用性
response = await client.post(
"http://api.tushare.pro",
json={
"api_name": "stock_basic",
"token": api_key,
"params": {"list_status": "L"},
"fields": "ts_code"
}
)
if response.status_code == 200:
data = response.json()
if data.get("code") == 0:
return ConfigTestResponse(
success=True,
message="Tushare API连接成功"
)
else:
return ConfigTestResponse(
success=False,
message=f"Tushare API错误: {data.get('msg', '未知错误')}"
)
else:
return ConfigTestResponse(
success=False,
message=f"Tushare API测试失败: HTTP {response.status_code}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"Tushare API连接失败: {str(e)}"
)
async def _test_finnhub(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test Finnhub API connection."""
api_key = config_data.get("api_key")
if not api_key:
return ConfigTestResponse(
success=False,
message="Finnhub API Key不能为空"
)
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# 测试API可用性
response = await client.get(
f"https://finnhub.io/api/v1/quote",
params={"symbol": "AAPL", "token": api_key}
)
if response.status_code == 200:
data = response.json()
if "c" in data: # 检查是否有价格数据
return ConfigTestResponse(
success=True,
message="Finnhub API连接成功"
)
else:
return ConfigTestResponse(
success=False,
message="Finnhub API响应格式错误"
)
else:
return ConfigTestResponse(
success=False,
message=f"Finnhub API测试失败: HTTP {response.status_code}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"Finnhub API连接失败: {str(e)}"
)
async def _test_alpha_engine(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""Test AlphaEngine API connection."""
api_url = config_data.get("api_url")
api_key = config_data.get("api_key")
token = config_data.get("token")
if not api_url or not api_key or not token:
return ConfigTestResponse(
success=False,
message="AlphaEngine API URL、API Key和Token均不能为空"
)
try:
async with httpx.AsyncClient(timeout=10.0) as client:
headers = {
'token': token,
'X-API-KEY': api_key,
'Content-Type': 'application/json'
}
# 发送一个简单的测试请求
payload = {
"msg": "测试连接",
"history": [],
"user_id": config_data.get("user_id", 999041),
"model": config_data.get("model", "deepseek-r1"),
"using_indicator": config_data.get("using_indicator", True),
"start_time": config_data.get("start_time", "2024-01-01"),
"doc_show_type": config_data.get("doc_show_type", ["A001", "A002", "A003", "A004"]),
"simple_tracking": config_data.get("simple_tracking", True)
}
response = await client.post(
f"{api_url.rstrip('/')}/api/v3/finchat",
json=payload,
headers=headers
)
if response.status_code == 200:
return ConfigTestResponse(
success=True,
message="AlphaEngine API连接成功"
)
else:
return ConfigTestResponse(
success=False,
message=f"AlphaEngine API测试失败: HTTP {response.status_code} - {response.text[:200]}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"AlphaEngine API连接失败: {str(e)}"
)