183 lines
5.1 KiB
Python
183 lines
5.1 KiB
Python
"""
|
|
应用配置管理
|
|
处理环境变量和系统配置
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
from pydantic import validator
|
|
from pydantic_settings import BaseSettings
|
|
import os
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""应用设置类"""
|
|
|
|
# 应用基础配置
|
|
APP_NAME: str = "基本面选股系统"
|
|
APP_VERSION: str = "1.0.0"
|
|
DEBUG: bool = False
|
|
|
|
# 数据库配置
|
|
DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost:5432/stock_analysis"
|
|
DATABASE_ECHO: bool = False
|
|
|
|
# API配置
|
|
API_V1_STR: str = "/api"
|
|
ALLOWED_ORIGINS: List[str] = ["http://localhost:3000", "http://127.0.0.1:3000"]
|
|
|
|
# 外部服务配置
|
|
GEMINI_API_KEY: Optional[str] = None
|
|
TUSHARE_TOKEN: Optional[str] = None
|
|
|
|
# 数据源配置
|
|
CHINA_DATA_SOURCE: str = "tushare"
|
|
HK_DATA_SOURCE: str = "yahoo"
|
|
US_DATA_SOURCE: str = "yahoo"
|
|
JP_DATA_SOURCE: str = "yahoo"
|
|
|
|
# 报告生成配置
|
|
MAX_CONCURRENT_REPORTS: int = 5
|
|
REPORT_TIMEOUT_MINUTES: int = 30
|
|
|
|
# 缓存配置
|
|
CACHE_TTL_SECONDS: int = 3600 # 1小时
|
|
|
|
@validator("ALLOWED_ORIGINS", pre=True)
|
|
def assemble_cors_origins(cls, v):
|
|
"""处理CORS origins配置"""
|
|
if isinstance(v, str):
|
|
return [i.strip() for i in v.split(",")]
|
|
return v
|
|
|
|
@validator("DATABASE_URL", pre=True)
|
|
def assemble_db_connection(cls, v):
|
|
"""处理数据库连接字符串"""
|
|
if v and not v.startswith("postgresql"):
|
|
raise ValueError("数据库URL必须使用PostgreSQL")
|
|
return v
|
|
|
|
class Config:
|
|
env_file = ".env"
|
|
case_sensitive = True
|
|
|
|
|
|
# 创建全局设置实例
|
|
settings = Settings()
|
|
|
|
|
|
class DatabaseConfig:
|
|
"""数据库配置类"""
|
|
|
|
def __init__(self):
|
|
self.url = settings.DATABASE_URL
|
|
self.echo = settings.DATABASE_ECHO
|
|
self.pool_size = 10
|
|
self.max_overflow = 20
|
|
self.pool_timeout = 30
|
|
self.pool_recycle = 3600
|
|
|
|
|
|
class ExternalAPIConfig:
|
|
"""外部API配置类"""
|
|
|
|
def __init__(self):
|
|
self.gemini_api_key = settings.GEMINI_API_KEY
|
|
self.tushare_token = settings.TUSHARE_TOKEN
|
|
|
|
# 数据源配置
|
|
self.data_sources_config = {
|
|
"tushare": {
|
|
"enabled": bool(self.tushare_token),
|
|
"api_key": self.tushare_token,
|
|
"token": self.tushare_token,
|
|
"base_url": "http://api.tushare.pro",
|
|
"timeout": 30,
|
|
"max_retries": 3,
|
|
"retry_delay": 1,
|
|
"name": "tushare"
|
|
},
|
|
"yahoo": {
|
|
"enabled": True,
|
|
"base_url": "https://query1.finance.yahoo.com",
|
|
"timeout": 30,
|
|
"max_retries": 3,
|
|
"retry_delay": 1,
|
|
"name": "yahoo"
|
|
}
|
|
}
|
|
|
|
# AI服务配置
|
|
self.ai_services_config = {
|
|
"gemini": {
|
|
"enabled": bool(self.gemini_api_key),
|
|
"api_key": self.gemini_api_key,
|
|
"model": "gemini-pro",
|
|
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
|
"timeout": 60,
|
|
"max_retries": 3,
|
|
"retry_delay": 2,
|
|
"temperature": 0.7,
|
|
"top_p": 0.8,
|
|
"top_k": 40,
|
|
"max_output_tokens": 8192
|
|
}
|
|
}
|
|
|
|
def validate_gemini_config(self) -> bool:
|
|
"""验证Gemini API配置"""
|
|
return bool(self.gemini_api_key)
|
|
|
|
def validate_tushare_config(self) -> bool:
|
|
"""验证Tushare API配置"""
|
|
return bool(self.tushare_token)
|
|
|
|
def get_data_source_manager_config(self) -> dict:
|
|
"""获取数据源管理器配置"""
|
|
return {
|
|
"data_sources": self.data_sources_config,
|
|
"ai_services": self.ai_services_config,
|
|
"market_mapping": {
|
|
"china": "tushare",
|
|
"中国": "tushare",
|
|
"hongkong": "yahoo",
|
|
"香港": "yahoo",
|
|
"usa": "yahoo",
|
|
"美国": "yahoo",
|
|
"japan": "yahoo",
|
|
"日本": "yahoo"
|
|
},
|
|
"fallback_sources": {
|
|
"tushare": ["yahoo"],
|
|
"yahoo": ["tushare"]
|
|
}
|
|
}
|
|
|
|
|
|
class DataSourceConfig:
|
|
"""数据源配置类"""
|
|
|
|
def __init__(self):
|
|
self.sources = {
|
|
"china": settings.CHINA_DATA_SOURCE,
|
|
"hongkong": settings.HK_DATA_SOURCE,
|
|
"usa": settings.US_DATA_SOURCE,
|
|
"japan": settings.JP_DATA_SOURCE
|
|
}
|
|
|
|
def get_source_for_market(self, market: str) -> str:
|
|
"""根据市场获取数据源"""
|
|
market_mapping = {
|
|
"中国": "china",
|
|
"香港": "hongkong",
|
|
"美国": "usa",
|
|
"日本": "japan"
|
|
}
|
|
|
|
market_key = market_mapping.get(market, "china")
|
|
return self.sources.get(market_key, "tushare")
|
|
|
|
|
|
# 创建配置实例
|
|
db_config = DatabaseConfig()
|
|
api_config = ExternalAPIConfig()
|
|
data_source_config = DataSourceConfig() |