Fundamental_Analysis/backend/app/services/data_source_manager.py

357 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据源管理服务
处理数据源配置和切换逻辑
"""
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime
from .data_fetcher import DataFetcher, DataFetcherFactory
from .ai_analyzer import GeminiAnalyzer, AIAnalyzerFactory
from ..core.exceptions import (
DataSourceError,
ConfigurationError,
AIAnalysisError
)
from ..schemas.data import (
FinancialDataResponse,
MarketDataResponse,
SymbolValidationResponse,
DataSourceStatus,
DataSourcesStatusResponse
)
from ..schemas.report import AIAnalysisResponse
class DataSourceManager:
"""数据源管理器"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self._data_fetchers: Dict[str, DataFetcher] = {}
self._ai_analyzer: Optional[GeminiAnalyzer] = None
self._market_source_mapping = {
"china": "tushare",
"中国": "tushare",
"hongkong": "yahoo",
"香港": "yahoo",
"usa": "yahoo",
"美国": "yahoo",
"japan": "yahoo",
"日本": "yahoo"
}
# 初始化数据获取器
self._initialize_data_fetchers()
# 初始化AI分析器
self._initialize_ai_analyzer()
def _initialize_data_fetchers(self):
"""初始化数据获取器"""
data_sources_config = self.config.get("data_sources", {})
for source_name, source_config in data_sources_config.items():
try:
if source_config.get("enabled", True):
fetcher = DataFetcherFactory.create_fetcher(source_name, source_config)
self._data_fetchers[source_name] = fetcher
except Exception as e:
print(f"警告: 初始化数据源 {source_name} 失败: {str(e)}")
def _initialize_ai_analyzer(self):
"""初始化AI分析器"""
ai_config = self.config.get("ai_services", {})
gemini_config = ai_config.get("gemini", {})
if gemini_config.get("enabled", True) and gemini_config.get("api_key"):
try:
self._ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
gemini_config["api_key"],
gemini_config
)
except Exception as e:
print(f"警告: 初始化Gemini分析器失败: {str(e)}")
def get_data_source_for_market(self, market: str) -> str:
"""根据市场获取数据源"""
market_lower = market.lower()
# 首先检查配置中的映射
market_mapping = self.config.get("market_mapping", {})
if market_lower in market_mapping:
return market_mapping[market_lower]
# 使用默认映射
return self._market_source_mapping.get(market_lower, "tushare")
def get_data_fetcher(self, data_source: str) -> DataFetcher:
"""获取数据获取器"""
if data_source not in self._data_fetchers:
raise DataSourceError(f"数据源 {data_source} 未配置或不可用", data_source)
return self._data_fetchers[data_source]
def get_ai_analyzer(self) -> GeminiAnalyzer:
"""获取AI分析器"""
if not self._ai_analyzer:
raise AIAnalysisError("AI分析器未配置或不可用", "gemini")
return self._ai_analyzer
async def fetch_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
"""获取财务数据(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.fetch_financial_data(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.fetch_financial_data(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def fetch_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
"""获取市场数据(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.fetch_market_data(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.fetch_market_data(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def validate_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
"""验证证券代码(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.validate_symbol(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.validate_symbol(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def analyze_with_ai(self, analysis_type: str, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""使用AI进行分析"""
analyzer = self.get_ai_analyzer()
if analysis_type == "business_info":
return await analyzer.analyze_business_info(symbol, market, context_data)
elif analysis_type == "fundamental_analysis":
business_info = context_data.get("business_info", {})
financial_data = context_data.get("financial_data", {})
return await analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
elif analysis_type == "bullish_analysis":
return await analyzer.analyze_bullish_case(symbol, market, context_data)
elif analysis_type == "bearish_analysis":
return await analyzer.analyze_bearish_case(symbol, market, context_data)
elif analysis_type == "market_analysis":
return await analyzer.analyze_market_sentiment(symbol, market, context_data)
elif analysis_type == "news_analysis":
return await analyzer.analyze_news_catalysts(symbol, market, context_data)
elif analysis_type == "trading_analysis":
return await analyzer.analyze_trading_dynamics(symbol, market, context_data)
elif analysis_type == "insider_analysis":
return await analyzer.analyze_insider_institutional(symbol, market, context_data)
elif analysis_type == "final_conclusion":
all_analyses = context_data.get("all_analyses", [])
return await analyzer.generate_final_conclusion(symbol, market, all_analyses)
else:
raise AIAnalysisError(f"不支持的分析类型: {analysis_type}", "gemini")
async def check_all_sources_status(self) -> DataSourcesStatusResponse:
"""检查所有数据源状态"""
status_tasks = []
# 检查数据获取器状态
for source_name, fetcher in self._data_fetchers.items():
status_tasks.append(fetcher.check_status())
# 检查AI分析器状态
if self._ai_analyzer:
status_tasks.append(self._check_ai_analyzer_status())
# 并发执行状态检查
statuses = await asyncio.gather(*status_tasks, return_exceptions=True)
source_statuses = []
healthy_count = 0
for i, status in enumerate(statuses):
if isinstance(status, Exception):
# 处理异常情况
if i < len(self._data_fetchers):
source_name = list(self._data_fetchers.keys())[i]
else:
source_name = "gemini"
source_statuses.append(DataSourceStatus(
name=source_name,
is_available=False,
last_check=datetime.now(),
error_message=str(status)
))
else:
source_statuses.append(status)
if status.is_available:
healthy_count += 1
# 确定整体状态
total_sources = len(source_statuses)
if healthy_count == total_sources:
overall_status = "healthy"
elif healthy_count > 0:
overall_status = "degraded"
else:
overall_status = "down"
return DataSourcesStatusResponse(
sources=source_statuses,
overall_status=overall_status
)
async def _check_ai_analyzer_status(self) -> DataSourceStatus:
"""检查AI分析器状态"""
start_time = datetime.now()
try:
# 简单的健康检查 - 尝试生成一个很短的测试内容
test_prompt = "请回答1+1等于几"
await self._ai_analyzer._call_gemini_api(test_prompt)
end_time = datetime.now()
response_time = int((end_time - start_time).total_seconds() * 1000)
return DataSourceStatus(
name="gemini",
is_available=True,
last_check=end_time,
response_time_ms=response_time
)
except Exception as e:
end_time = datetime.now()
return DataSourceStatus(
name="gemini",
is_available=False,
last_check=end_time,
error_message=str(e)
)
def _get_fallback_sources(self, primary_source: str) -> List[str]:
"""获取备用数据源列表"""
fallback_config = self.config.get("fallback_sources", {})
if primary_source in fallback_config:
return fallback_config[primary_source]
# 默认备用策略
all_sources = list(self._data_fetchers.keys())
return [source for source in all_sources if source != primary_source]
def update_config(self, new_config: Dict[str, Any]):
"""更新配置"""
self.config.update(new_config)
# 重新初始化数据获取器
self._data_fetchers.clear()
self._initialize_data_fetchers()
# 重新初始化AI分析器
self._ai_analyzer = None
self._initialize_ai_analyzer()
def get_supported_sources(self) -> List[str]:
"""获取支持的数据源列表"""
return DataFetcherFactory.get_supported_sources()
def get_available_sources(self) -> List[str]:
"""获取当前可用的数据源列表"""
return list(self._data_fetchers.keys())
def is_ai_analyzer_available(self) -> bool:
"""检查AI分析器是否可用"""
return self._ai_analyzer is not None
def create_data_source_manager(config: Dict[str, Any]) -> DataSourceManager:
"""创建数据源管理器"""
return DataSourceManager(config)
# 默认配置示例
DEFAULT_CONFIG = {
"data_sources": {
"tushare": {
"enabled": True,
"api_key": "", # 需要从环境变量或配置文件获取
"base_url": "http://api.tushare.pro",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
},
"yahoo": {
"enabled": True,
"base_url": "https://query1.finance.yahoo.com",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
}
},
"ai_services": {
"gemini": {
"enabled": True,
"api_key": "", # 需要从环境变量或配置文件获取
"model": "gemini-pro",
"timeout": 60,
"max_retries": 3,
"retry_delay": 2,
"temperature": 0.7,
"max_output_tokens": 8192
}
},
"market_mapping": {
"china": "tushare",
"中国": "tushare",
"hongkong": "yahoo",
"香港": "yahoo",
"usa": "yahoo",
"美国": "yahoo",
"japan": "yahoo",
"日本": "yahoo"
},
"fallback_sources": {
"tushare": ["yahoo"],
"yahoo": ["tushare"]
}
}