357 lines
13 KiB
Python
357 lines
13 KiB
Python
"""
|
||
数据源管理服务
|
||
处理数据源配置和切换逻辑
|
||
"""
|
||
|
||
from typing import Dict, Any, Optional, List
|
||
import asyncio
|
||
from datetime import datetime
|
||
|
||
from .data_fetcher import DataFetcher, DataFetcherFactory
|
||
from .ai_analyzer import GeminiAnalyzer, AIAnalyzerFactory
|
||
from ..core.exceptions import (
|
||
DataSourceError,
|
||
ConfigurationError,
|
||
AIAnalysisError
|
||
)
|
||
from ..schemas.data import (
|
||
FinancialDataResponse,
|
||
MarketDataResponse,
|
||
SymbolValidationResponse,
|
||
DataSourceStatus,
|
||
DataSourcesStatusResponse
|
||
)
|
||
from ..schemas.report import AIAnalysisResponse
|
||
|
||
|
||
class DataSourceManager:
|
||
"""数据源管理器"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
self.config = config
|
||
self._data_fetchers: Dict[str, DataFetcher] = {}
|
||
self._ai_analyzer: Optional[GeminiAnalyzer] = None
|
||
self._market_source_mapping = {
|
||
"china": "tushare",
|
||
"中国": "tushare",
|
||
"hongkong": "yahoo",
|
||
"香港": "yahoo",
|
||
"usa": "yahoo",
|
||
"美国": "yahoo",
|
||
"japan": "yahoo",
|
||
"日本": "yahoo"
|
||
}
|
||
|
||
# 初始化数据获取器
|
||
self._initialize_data_fetchers()
|
||
|
||
# 初始化AI分析器
|
||
self._initialize_ai_analyzer()
|
||
|
||
def _initialize_data_fetchers(self):
|
||
"""初始化数据获取器"""
|
||
data_sources_config = self.config.get("data_sources", {})
|
||
|
||
for source_name, source_config in data_sources_config.items():
|
||
try:
|
||
if source_config.get("enabled", True):
|
||
fetcher = DataFetcherFactory.create_fetcher(source_name, source_config)
|
||
self._data_fetchers[source_name] = fetcher
|
||
except Exception as e:
|
||
print(f"警告: 初始化数据源 {source_name} 失败: {str(e)}")
|
||
|
||
def _initialize_ai_analyzer(self):
|
||
"""初始化AI分析器"""
|
||
ai_config = self.config.get("ai_services", {})
|
||
gemini_config = ai_config.get("gemini", {})
|
||
|
||
if gemini_config.get("enabled", True) and gemini_config.get("api_key"):
|
||
try:
|
||
self._ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
|
||
gemini_config["api_key"],
|
||
gemini_config
|
||
)
|
||
except Exception as e:
|
||
print(f"警告: 初始化Gemini分析器失败: {str(e)}")
|
||
|
||
def get_data_source_for_market(self, market: str) -> str:
|
||
"""根据市场获取数据源"""
|
||
market_lower = market.lower()
|
||
|
||
# 首先检查配置中的映射
|
||
market_mapping = self.config.get("market_mapping", {})
|
||
if market_lower in market_mapping:
|
||
return market_mapping[market_lower]
|
||
|
||
# 使用默认映射
|
||
return self._market_source_mapping.get(market_lower, "tushare")
|
||
|
||
def get_data_fetcher(self, data_source: str) -> DataFetcher:
|
||
"""获取数据获取器"""
|
||
if data_source not in self._data_fetchers:
|
||
raise DataSourceError(f"数据源 {data_source} 未配置或不可用", data_source)
|
||
|
||
return self._data_fetchers[data_source]
|
||
|
||
def get_ai_analyzer(self) -> GeminiAnalyzer:
|
||
"""获取AI分析器"""
|
||
if not self._ai_analyzer:
|
||
raise AIAnalysisError("AI分析器未配置或不可用", "gemini")
|
||
|
||
return self._ai_analyzer
|
||
|
||
async def fetch_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
|
||
"""获取财务数据(支持数据源切换)"""
|
||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||
|
||
try:
|
||
fetcher = self.get_data_fetcher(data_source)
|
||
return await fetcher.fetch_financial_data(symbol, market)
|
||
except DataSourceError as e:
|
||
# 尝试备用数据源
|
||
fallback_sources = self._get_fallback_sources(data_source)
|
||
|
||
for fallback_source in fallback_sources:
|
||
try:
|
||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||
return await fallback_fetcher.fetch_financial_data(symbol, market)
|
||
except Exception:
|
||
continue
|
||
|
||
# 所有数据源都失败了
|
||
raise e
|
||
|
||
async def fetch_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
|
||
"""获取市场数据(支持数据源切换)"""
|
||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||
|
||
try:
|
||
fetcher = self.get_data_fetcher(data_source)
|
||
return await fetcher.fetch_market_data(symbol, market)
|
||
except DataSourceError as e:
|
||
# 尝试备用数据源
|
||
fallback_sources = self._get_fallback_sources(data_source)
|
||
|
||
for fallback_source in fallback_sources:
|
||
try:
|
||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||
return await fallback_fetcher.fetch_market_data(symbol, market)
|
||
except Exception:
|
||
continue
|
||
|
||
# 所有数据源都失败了
|
||
raise e
|
||
|
||
async def validate_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
|
||
"""验证证券代码(支持数据源切换)"""
|
||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||
|
||
try:
|
||
fetcher = self.get_data_fetcher(data_source)
|
||
return await fetcher.validate_symbol(symbol, market)
|
||
except DataSourceError as e:
|
||
# 尝试备用数据源
|
||
fallback_sources = self._get_fallback_sources(data_source)
|
||
|
||
for fallback_source in fallback_sources:
|
||
try:
|
||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||
return await fallback_fetcher.validate_symbol(symbol, market)
|
||
except Exception:
|
||
continue
|
||
|
||
# 所有数据源都失败了
|
||
raise e
|
||
|
||
async def analyze_with_ai(self, analysis_type: str, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||
"""使用AI进行分析"""
|
||
analyzer = self.get_ai_analyzer()
|
||
|
||
if analysis_type == "business_info":
|
||
return await analyzer.analyze_business_info(symbol, market, context_data)
|
||
elif analysis_type == "fundamental_analysis":
|
||
business_info = context_data.get("business_info", {})
|
||
financial_data = context_data.get("financial_data", {})
|
||
return await analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
|
||
elif analysis_type == "bullish_analysis":
|
||
return await analyzer.analyze_bullish_case(symbol, market, context_data)
|
||
elif analysis_type == "bearish_analysis":
|
||
return await analyzer.analyze_bearish_case(symbol, market, context_data)
|
||
elif analysis_type == "market_analysis":
|
||
return await analyzer.analyze_market_sentiment(symbol, market, context_data)
|
||
elif analysis_type == "news_analysis":
|
||
return await analyzer.analyze_news_catalysts(symbol, market, context_data)
|
||
elif analysis_type == "trading_analysis":
|
||
return await analyzer.analyze_trading_dynamics(symbol, market, context_data)
|
||
elif analysis_type == "insider_analysis":
|
||
return await analyzer.analyze_insider_institutional(symbol, market, context_data)
|
||
elif analysis_type == "final_conclusion":
|
||
all_analyses = context_data.get("all_analyses", [])
|
||
return await analyzer.generate_final_conclusion(symbol, market, all_analyses)
|
||
else:
|
||
raise AIAnalysisError(f"不支持的分析类型: {analysis_type}", "gemini")
|
||
|
||
async def check_all_sources_status(self) -> DataSourcesStatusResponse:
|
||
"""检查所有数据源状态"""
|
||
status_tasks = []
|
||
|
||
# 检查数据获取器状态
|
||
for source_name, fetcher in self._data_fetchers.items():
|
||
status_tasks.append(fetcher.check_status())
|
||
|
||
# 检查AI分析器状态
|
||
if self._ai_analyzer:
|
||
status_tasks.append(self._check_ai_analyzer_status())
|
||
|
||
# 并发执行状态检查
|
||
statuses = await asyncio.gather(*status_tasks, return_exceptions=True)
|
||
|
||
source_statuses = []
|
||
healthy_count = 0
|
||
|
||
for i, status in enumerate(statuses):
|
||
if isinstance(status, Exception):
|
||
# 处理异常情况
|
||
if i < len(self._data_fetchers):
|
||
source_name = list(self._data_fetchers.keys())[i]
|
||
else:
|
||
source_name = "gemini"
|
||
|
||
source_statuses.append(DataSourceStatus(
|
||
name=source_name,
|
||
is_available=False,
|
||
last_check=datetime.now(),
|
||
error_message=str(status)
|
||
))
|
||
else:
|
||
source_statuses.append(status)
|
||
if status.is_available:
|
||
healthy_count += 1
|
||
|
||
# 确定整体状态
|
||
total_sources = len(source_statuses)
|
||
if healthy_count == total_sources:
|
||
overall_status = "healthy"
|
||
elif healthy_count > 0:
|
||
overall_status = "degraded"
|
||
else:
|
||
overall_status = "down"
|
||
|
||
return DataSourcesStatusResponse(
|
||
sources=source_statuses,
|
||
overall_status=overall_status
|
||
)
|
||
|
||
async def _check_ai_analyzer_status(self) -> DataSourceStatus:
|
||
"""检查AI分析器状态"""
|
||
start_time = datetime.now()
|
||
try:
|
||
# 简单的健康检查 - 尝试生成一个很短的测试内容
|
||
test_prompt = "请回答:1+1等于几?"
|
||
await self._ai_analyzer._call_gemini_api(test_prompt)
|
||
|
||
end_time = datetime.now()
|
||
response_time = int((end_time - start_time).total_seconds() * 1000)
|
||
|
||
return DataSourceStatus(
|
||
name="gemini",
|
||
is_available=True,
|
||
last_check=end_time,
|
||
response_time_ms=response_time
|
||
)
|
||
except Exception as e:
|
||
end_time = datetime.now()
|
||
return DataSourceStatus(
|
||
name="gemini",
|
||
is_available=False,
|
||
last_check=end_time,
|
||
error_message=str(e)
|
||
)
|
||
|
||
def _get_fallback_sources(self, primary_source: str) -> List[str]:
|
||
"""获取备用数据源列表"""
|
||
fallback_config = self.config.get("fallback_sources", {})
|
||
|
||
if primary_source in fallback_config:
|
||
return fallback_config[primary_source]
|
||
|
||
# 默认备用策略
|
||
all_sources = list(self._data_fetchers.keys())
|
||
return [source for source in all_sources if source != primary_source]
|
||
|
||
def update_config(self, new_config: Dict[str, Any]):
|
||
"""更新配置"""
|
||
self.config.update(new_config)
|
||
|
||
# 重新初始化数据获取器
|
||
self._data_fetchers.clear()
|
||
self._initialize_data_fetchers()
|
||
|
||
# 重新初始化AI分析器
|
||
self._ai_analyzer = None
|
||
self._initialize_ai_analyzer()
|
||
|
||
def get_supported_sources(self) -> List[str]:
|
||
"""获取支持的数据源列表"""
|
||
return DataFetcherFactory.get_supported_sources()
|
||
|
||
def get_available_sources(self) -> List[str]:
|
||
"""获取当前可用的数据源列表"""
|
||
return list(self._data_fetchers.keys())
|
||
|
||
def is_ai_analyzer_available(self) -> bool:
|
||
"""检查AI分析器是否可用"""
|
||
return self._ai_analyzer is not None
|
||
|
||
|
||
def create_data_source_manager(config: Dict[str, Any]) -> DataSourceManager:
|
||
"""创建数据源管理器"""
|
||
return DataSourceManager(config)
|
||
|
||
|
||
# 默认配置示例
|
||
DEFAULT_CONFIG = {
|
||
"data_sources": {
|
||
"tushare": {
|
||
"enabled": True,
|
||
"api_key": "", # 需要从环境变量或配置文件获取
|
||
"base_url": "http://api.tushare.pro",
|
||
"timeout": 30,
|
||
"max_retries": 3,
|
||
"retry_delay": 1
|
||
},
|
||
"yahoo": {
|
||
"enabled": True,
|
||
"base_url": "https://query1.finance.yahoo.com",
|
||
"timeout": 30,
|
||
"max_retries": 3,
|
||
"retry_delay": 1
|
||
}
|
||
},
|
||
"ai_services": {
|
||
"gemini": {
|
||
"enabled": True,
|
||
"api_key": "", # 需要从环境变量或配置文件获取
|
||
"model": "gemini-pro",
|
||
"timeout": 60,
|
||
"max_retries": 3,
|
||
"retry_delay": 2,
|
||
"temperature": 0.7,
|
||
"max_output_tokens": 8192
|
||
}
|
||
},
|
||
"market_mapping": {
|
||
"china": "tushare",
|
||
"中国": "tushare",
|
||
"hongkong": "yahoo",
|
||
"香港": "yahoo",
|
||
"usa": "yahoo",
|
||
"美国": "yahoo",
|
||
"japan": "yahoo",
|
||
"日本": "yahoo"
|
||
},
|
||
"fallback_sources": {
|
||
"tushare": ["yahoo"],
|
||
"yahoo": ["tushare"]
|
||
}
|
||
} |