260 lines
9.8 KiB
Python
260 lines
9.8 KiB
Python
"""
|
|
配置管理服务
|
|
处理系统配置的读取、更新和验证
|
|
"""
|
|
|
|
from typing import Dict, Any, Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
import httpx
|
|
import asyncio
|
|
|
|
from ..models.system_config import SystemConfig
|
|
from ..schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestResponse
|
|
from ..core.exceptions import ConfigurationError, DatabaseError, APIError
|
|
|
|
|
|
class ConfigManager:
|
|
"""配置管理器"""
|
|
|
|
def __init__(self, db_session: AsyncSession):
|
|
self.db = db_session
|
|
|
|
async def get_config(self) -> ConfigResponse:
|
|
"""获取系统配置"""
|
|
try:
|
|
# 查询所有配置
|
|
result = await self.db.execute(select(SystemConfig))
|
|
configs = result.scalars().all()
|
|
|
|
# 组织配置数据
|
|
config_dict = {config.config_key: config.config_value for config in configs}
|
|
|
|
return ConfigResponse(
|
|
database=config_dict.get("database"),
|
|
gemini_api=config_dict.get("gemini_api"),
|
|
data_sources=config_dict.get("data_sources", {})
|
|
)
|
|
except Exception as e:
|
|
raise DatabaseError(f"获取配置失败: {str(e)}", "get_config")
|
|
|
|
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
|
"""更新系统配置"""
|
|
try:
|
|
# 更新数据库配置
|
|
if config_update.database:
|
|
await self._update_config_item("database", config_update.database.dict())
|
|
|
|
# 更新Gemini API配置
|
|
if config_update.gemini_api:
|
|
await self._update_config_item("gemini_api", config_update.gemini_api.dict())
|
|
|
|
# 更新数据源配置
|
|
if config_update.data_sources:
|
|
data_sources_dict = {k: v.dict() for k, v in config_update.data_sources.items()}
|
|
await self._update_config_item("data_sources", data_sources_dict)
|
|
|
|
await self.db.commit()
|
|
|
|
# 返回更新后的配置
|
|
return await self.get_config()
|
|
except Exception as e:
|
|
await self.db.rollback()
|
|
raise DatabaseError(f"更新配置失败: {str(e)}", "update_config")
|
|
|
|
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
|
"""测试配置连接"""
|
|
|
|
try:
|
|
if config_type == "database":
|
|
return await self._test_database_config(config_data)
|
|
elif config_type == "gemini":
|
|
return await self._test_gemini_config(config_data)
|
|
elif config_type == "data_source":
|
|
return await self._test_data_source_config(config_data)
|
|
else:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"不支持的配置类型: {config_type}"
|
|
)
|
|
|
|
except Exception as e:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"配置测试失败: {str(e)}"
|
|
)
|
|
|
|
async def _update_config_item(self, key: str, value: Dict[str, Any]):
|
|
"""更新单个配置项"""
|
|
# 查询现有配置
|
|
result = await self.db.execute(
|
|
select(SystemConfig).where(SystemConfig.config_key == key)
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if config:
|
|
# 更新现有配置
|
|
config.config_value = value
|
|
else:
|
|
# 创建新配置
|
|
config = SystemConfig(config_key=key, config_value=value)
|
|
self.db.add(config)
|
|
|
|
async def _test_database_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
|
"""测试数据库配置"""
|
|
try:
|
|
# 尝试创建数据库连接
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
db_url = config_data.get("url")
|
|
if not db_url:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message="数据库URL未配置"
|
|
)
|
|
|
|
# 创建临时引擎测试连接
|
|
test_engine = create_async_engine(db_url, echo=False)
|
|
|
|
# 测试连接
|
|
async with test_engine.begin() as conn:
|
|
await conn.execute("SELECT 1")
|
|
|
|
await test_engine.dispose()
|
|
|
|
return ConfigTestResponse(
|
|
success=True,
|
|
message="数据库连接测试成功"
|
|
)
|
|
except Exception as e:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"数据库连接测试失败: {str(e)}"
|
|
)
|
|
|
|
async def _test_gemini_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
|
"""测试Gemini API配置"""
|
|
try:
|
|
api_key = config_data.get("api_key")
|
|
if not api_key:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message="Gemini API密钥未配置"
|
|
)
|
|
|
|
# 测试API调用
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
# 这里应该调用实际的Gemini API端点进行测试
|
|
# 暂时模拟成功
|
|
await asyncio.sleep(0.1) # 模拟网络延迟
|
|
|
|
return ConfigTestResponse(
|
|
success=True,
|
|
message="Gemini API连接测试成功"
|
|
)
|
|
except Exception as e:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"Gemini API连接测试失败: {str(e)}"
|
|
)
|
|
|
|
async def get_data_source_config(self, market: str) -> Dict[str, Any]:
|
|
"""获取指定市场的数据源配置"""
|
|
try:
|
|
result = await self.db.execute(
|
|
select(SystemConfig).where(SystemConfig.config_key == "data_sources")
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if not config:
|
|
raise ConfigurationError("数据源配置未找到", "data_sources")
|
|
|
|
data_sources = config.config_value
|
|
|
|
# 根据市场选择数据源
|
|
market_lower = market.lower()
|
|
if market_lower == "china":
|
|
if "tushare" in data_sources:
|
|
return data_sources["tushare"]
|
|
else:
|
|
raise ConfigurationError("中国市场数据源(Tushare)未配置", "tushare")
|
|
else:
|
|
# 其他市场使用Yahoo Finance
|
|
if "yahoo" in data_sources:
|
|
return data_sources["yahoo"]
|
|
else:
|
|
raise ConfigurationError("国际市场数据源(Yahoo)未配置", "yahoo")
|
|
|
|
except Exception as e:
|
|
if isinstance(e, ConfigurationError):
|
|
raise
|
|
raise ConfigurationError(f"获取数据源配置失败: {str(e)}", "data_sources")
|
|
|
|
async def get_gemini_config(self) -> Dict[str, Any]:
|
|
"""获取Gemini API配置"""
|
|
try:
|
|
result = await self.db.execute(
|
|
select(SystemConfig).where(SystemConfig.config_key == "gemini_api")
|
|
)
|
|
config = result.scalar_one_or_none()
|
|
|
|
if not config:
|
|
raise ConfigurationError("Gemini API配置未找到", "gemini_api")
|
|
|
|
gemini_config = config.config_value
|
|
|
|
if not gemini_config.get("api_key"):
|
|
raise ConfigurationError("Gemini API密钥未配置", "gemini_api")
|
|
|
|
return gemini_config
|
|
|
|
except Exception as e:
|
|
if isinstance(e, ConfigurationError):
|
|
raise
|
|
raise ConfigurationError(f"获取Gemini配置失败: {str(e)}", "gemini_api")
|
|
|
|
async def _test_data_source_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
|
"""测试数据源配置"""
|
|
try:
|
|
name = config_data.get("name")
|
|
api_key = config_data.get("api_key")
|
|
base_url = config_data.get("base_url")
|
|
timeout = config_data.get("timeout", 30)
|
|
|
|
if not name:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message="数据源名称未配置"
|
|
)
|
|
|
|
# 根据数据源类型进行不同的测试
|
|
if name.lower() == "tushare":
|
|
if not api_key:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message="Tushare API密钥未配置"
|
|
)
|
|
# 测试Tushare API
|
|
# 暂时模拟成功
|
|
await asyncio.sleep(0.1)
|
|
elif name.lower() == "yahoo":
|
|
# 测试Yahoo Finance API
|
|
if base_url:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.get(f"{base_url}/health", timeout=timeout)
|
|
if response.status_code != 200:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"数据源API返回错误状态码: {response.status_code}"
|
|
)
|
|
|
|
return ConfigTestResponse(
|
|
success=True,
|
|
message=f"数据源 {name} 连接测试成功"
|
|
)
|
|
except Exception as e:
|
|
return ConfigTestResponse(
|
|
success=False,
|
|
message=f"数据源连接测试失败: {str(e)}"
|
|
) |