- Covered by data-persistence-service tests (db/api). - No references or compose entries.
189 lines
9.0 KiB
Python
189 lines
9.0 KiB
Python
"""
|
||
Configuration Management Service (file + service based; no direct DB)
|
||
"""
|
||
import json
|
||
import os
|
||
from typing import Any, Dict
|
||
|
||
import asyncpg
|
||
import httpx
|
||
|
||
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, NewApiConfig, DataSourceConfig, ConfigTestResponse
|
||
from app.core.config import settings
|
||
|
||
class ConfigManager:
|
||
"""Manages system configuration by fetching from config-service and updating local config."""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
if config_path is None:
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||
self.config_path = os.path.join(project_root, "config", "config.json")
|
||
else:
|
||
self.config_path = config_path
|
||
|
||
async def _fetch_base_config_from_service(self) -> Dict[str, Any]:
|
||
base_url = settings.CONFIG_SERVICE_BASE_URL.rstrip("/")
|
||
url = f"{base_url}/system"
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
resp = await client.get(url)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
if not isinstance(data, dict):
|
||
raise ValueError("Config service 返回的系统配置格式错误")
|
||
return data
|
||
|
||
def _merge_configs(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
||
for key, value in overrides.items():
|
||
if isinstance(value, dict) and isinstance(base.get(key), dict):
|
||
base[key] = self._merge_configs(base[key], value)
|
||
else:
|
||
base[key] = value
|
||
return base
|
||
|
||
async def get_config(self) -> ConfigResponse:
|
||
base_config = await self._fetch_base_config_from_service()
|
||
|
||
# 兼容两种位置:优先使用 new_api,其次回退到 llm.new_api
|
||
new_api_src = base_config.get("new_api") or base_config.get("llm", {}).get("new_api", {})
|
||
|
||
return ConfigResponse(
|
||
database=DatabaseConfig(**base_config.get("database", {})),
|
||
new_api=NewApiConfig(**(new_api_src or {})),
|
||
data_sources={
|
||
k: DataSourceConfig(**v)
|
||
for k, v in base_config.get("data_sources", {}).items()
|
||
}
|
||
)
|
||
|
||
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
||
update_dict = config_update.dict(exclude_unset=True)
|
||
self._validate_config_data(update_dict)
|
||
|
||
# 直接写入项目根目录的 config.json
|
||
current = {}
|
||
if os.path.exists(self.config_path):
|
||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||
current = json.load(f) or {}
|
||
|
||
merged = self._merge_configs(current, update_dict)
|
||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||
json.dump(merged, f, ensure_ascii=False, indent=2)
|
||
|
||
# 返回合并后的视图(与 get_config 一致:从服务读取一次,避免多源不一致)
|
||
return await self.get_config()
|
||
|
||
def _validate_config_data(self, config_data: Dict[str, Any]) -> None:
|
||
if "database" in config_data:
|
||
db_config = config_data["database"]
|
||
if "url" in db_config:
|
||
url = db_config["url"]
|
||
if not url.startswith(("postgresql://", "postgresql+asyncpg://")):
|
||
raise ValueError("数据库URL必须以 postgresql:// 或 postgresql+asyncpg:// 开头")
|
||
|
||
if "new_api" in config_data:
|
||
new_api_config = config_data["new_api"]
|
||
if "api_key" in new_api_config and len(new_api_config["api_key"]) < 10:
|
||
raise ValueError("New API Key长度不能少于10个字符")
|
||
if "base_url" in new_api_config and new_api_config["base_url"]:
|
||
base_url = new_api_config["base_url"]
|
||
if not base_url.startswith(("http://", "https://")):
|
||
raise ValueError("New API Base URL必须以 http:// 或 https:// 开头")
|
||
|
||
if "data_sources" in config_data:
|
||
for source_name, source_config in config_data["data_sources"].items():
|
||
if "api_key" in source_config and len(source_config["api_key"]) < 10:
|
||
raise ValueError(f"{source_name} API Key长度不能少于10个字符")
|
||
|
||
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||
try:
|
||
if config_type == "database":
|
||
return await self._test_database(config_data)
|
||
elif config_type == "new_api":
|
||
return await self._test_new_api(config_data)
|
||
elif config_type == "tushare":
|
||
return await self._test_tushare(config_data)
|
||
elif config_type == "finnhub":
|
||
return await self._test_finnhub(config_data)
|
||
else:
|
||
return ConfigTestResponse(success=False, message=f"不支持的配置类型: {config_type}")
|
||
except Exception as e:
|
||
return ConfigTestResponse(success=False, message=f"测试失败: {str(e)}")
|
||
|
||
async def _test_database(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||
db_url = config_data.get("url")
|
||
if not db_url:
|
||
return ConfigTestResponse(success=False, message="数据库URL不能为空")
|
||
try:
|
||
if db_url.startswith("postgresql+asyncpg://"):
|
||
db_url = db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||
conn = await asyncpg.connect(db_url)
|
||
await conn.close()
|
||
return ConfigTestResponse(success=True, message="数据库连接成功")
|
||
except Exception as e:
|
||
return ConfigTestResponse(success=False, message=f"数据库连接失败: {str(e)}")
|
||
|
||
async def _test_new_api(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||
api_key = config_data.get("api_key")
|
||
base_url = config_data.get("base_url")
|
||
if not api_key or not base_url:
|
||
return ConfigTestResponse(success=False, message="New API Key和Base URL均不能为空")
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.get(
|
||
f"{base_url.rstrip('/')}/models",
|
||
headers={"Authorization": f"Bearer {api_key}"}
|
||
)
|
||
if response.status_code == 200:
|
||
return ConfigTestResponse(success=True, message="New API连接成功")
|
||
else:
|
||
return ConfigTestResponse(success=False, message=f"New API测试失败: HTTP {response.status_code} - {response.text}")
|
||
except Exception as e:
|
||
return ConfigTestResponse(success=False, message=f"New API连接失败: {str(e)}")
|
||
|
||
async def _test_tushare(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||
api_key = config_data.get("api_key")
|
||
if not api_key:
|
||
return ConfigTestResponse(success=False, message="Tushare API Key不能为空")
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.post(
|
||
"http://api.tushare.pro",
|
||
json={
|
||
"api_name": "stock_basic",
|
||
"token": api_key,
|
||
"params": {"list_status": "L"},
|
||
"fields": "ts_code"
|
||
}
|
||
)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
if data.get("code") == 0:
|
||
return ConfigTestResponse(success=True, message="Tushare API连接成功")
|
||
else:
|
||
return ConfigTestResponse(success=False, message=f"Tushare API错误: {data.get('msg', '未知错误')}")
|
||
else:
|
||
return ConfigTestResponse(success=False, message=f"Tushare API测试失败: HTTP {response.status_code}")
|
||
except Exception as e:
|
||
return ConfigTestResponse(success=False, message=f"Tushare API连接失败: {str(e)}")
|
||
|
||
async def _test_finnhub(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||
api_key = config_data.get("api_key")
|
||
if not api_key:
|
||
return ConfigTestResponse(success=False, message="Finnhub API Key不能为空")
|
||
try:
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.get(
|
||
"https://finnhub.io/api/v1/quote",
|
||
params={"symbol": "AAPL", "token": api_key}
|
||
)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
if "c" in data:
|
||
return ConfigTestResponse(success=True, message="Finnhub API连接成功")
|
||
else:
|
||
return ConfigTestResponse(success=False, message="Finnhub API响应格式错误")
|
||
else:
|
||
return ConfigTestResponse(success=False, message=f"Finnhub API测试失败: HTTP {response.status_code}")
|
||
except Exception as e:
|
||
return ConfigTestResponse(success=False, message=f"Finnhub API连接失败: {str(e)}")
|