feat: Enhance configuration management with new LLM provider support and API integration
- Backend: Introduced new endpoints for LLM configuration retrieval and updates in `config.py`, allowing dynamic management of LLM provider settings. - Updated schemas to include `AlphaEngineConfig` for better integration with the new provider. - Frontend: Added state management for AlphaEngine API credentials in the configuration page, ensuring seamless user experience. - Configuration files updated to reflect changes in LLM provider settings and API keys. BREAKING CHANGE: The default LLM provider has been changed from `new_api` to `alpha_engine`, requiring updates to existing configurations.
This commit is contained in:
parent
00a79499d4
commit
a79efd8150
@ -1,7 +1,6 @@
|
||||
"""
|
||||
API router for configuration management
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.dependencies import get_config_manager
|
||||
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestRequest, ConfigTestResponse
|
||||
@ -9,11 +8,112 @@ from app.services.config_manager import ConfigManager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
|
||||
@router.get("/", response_model=ConfigResponse)
|
||||
async def get_config(config_manager: ConfigManager = Depends(get_config_manager)):
|
||||
"""Retrieve the current system configuration."""
|
||||
return await config_manager.get_config()
|
||||
|
||||
@router.get("/llm", response_model=Dict[str, Any])
|
||||
async def get_llm_config(config_manager: ConfigManager = Depends(get_config_manager)):
|
||||
"""Get LLM provider and model configuration."""
|
||||
llm_config = await config_manager.get_llm_config()
|
||||
return llm_config
|
||||
|
||||
@router.put("/llm", response_model=Dict[str, Any])
|
||||
async def update_llm_config(
|
||||
llm_update: LLMConfigUpdate,
|
||||
config_manager: ConfigManager = Depends(get_config_manager)
|
||||
):
|
||||
"""Update LLM provider and model configuration."""
|
||||
import json
|
||||
import os
|
||||
|
||||
provider = llm_update.provider
|
||||
model = llm_update.model
|
||||
|
||||
# Load base config
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
config_path = os.path.join(project_root, "config", "config.json")
|
||||
|
||||
base_config = {}
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
base_config = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update llm config
|
||||
if "llm" not in base_config:
|
||||
base_config["llm"] = {}
|
||||
|
||||
base_config["llm"]["provider"] = provider
|
||||
if model:
|
||||
# Update model in the provider-specific config
|
||||
if provider in base_config["llm"]:
|
||||
if not isinstance(base_config["llm"][provider], dict):
|
||||
base_config["llm"][provider] = {}
|
||||
base_config["llm"][provider]["model"] = model
|
||||
|
||||
# Save to file
|
||||
try:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(base_config, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save config: {str(e)}")
|
||||
|
||||
# Also update in database - use same error handling as _load_dynamic_config_from_db
|
||||
try:
|
||||
from app.models.system_config import SystemConfig
|
||||
from sqlalchemy.future import select
|
||||
|
||||
result = await config_manager.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "llm")
|
||||
)
|
||||
existing_llm_config = result.scalar_one_or_none()
|
||||
|
||||
if existing_llm_config:
|
||||
if isinstance(existing_llm_config.config_value, dict):
|
||||
existing_llm_config.config_value["provider"] = provider
|
||||
if model:
|
||||
if provider not in existing_llm_config.config_value:
|
||||
existing_llm_config.config_value[provider] = {}
|
||||
elif not isinstance(existing_llm_config.config_value[provider], dict):
|
||||
existing_llm_config.config_value[provider] = {}
|
||||
existing_llm_config.config_value[provider]["model"] = model
|
||||
else:
|
||||
existing_llm_config.config_value = {"provider": provider}
|
||||
if model:
|
||||
existing_llm_config.config_value[provider] = {"model": model}
|
||||
else:
|
||||
new_llm_config = SystemConfig(
|
||||
config_key="llm",
|
||||
config_value={"provider": provider}
|
||||
)
|
||||
if model:
|
||||
new_llm_config.config_value[provider] = {"model": model}
|
||||
config_manager.db.add(new_llm_config)
|
||||
|
||||
await config_manager.db.commit()
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
try:
|
||||
await config_manager.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Log the error but don't fail the request since file was already saved
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Failed to update LLM config in database (file saved successfully): {e}")
|
||||
# Continue anyway since file config was saved successfully
|
||||
|
||||
return await config_manager.get_llm_config()
|
||||
|
||||
@router.put("/", response_model=ConfigResponse)
|
||||
async def update_config(
|
||||
config_update: ConfigUpdateRequest,
|
||||
|
||||
@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.config import settings
|
||||
@ -24,6 +24,9 @@ from app.schemas.financial import (
|
||||
)
|
||||
from app.services.company_profile_client import CompanyProfileClient
|
||||
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
||||
from app.core.dependencies import get_config_manager
|
||||
from app.services.config_manager import ConfigManager
|
||||
from app.services.client_factory import create_analysis_client
|
||||
|
||||
# Lazy DataManager loader to avoid import-time failures when optional providers/config are missing
|
||||
_dm = None
|
||||
@ -92,6 +95,7 @@ async def get_data_sources():
|
||||
async def generate_full_analysis(
|
||||
ts_code: str,
|
||||
company_name: str = Query(None, description="Company name for better context"),
|
||||
config_manager: ConfigManager = Depends(get_config_manager),
|
||||
):
|
||||
"""
|
||||
Generate a full analysis report by orchestrating multiple analysis modules
|
||||
@ -102,20 +106,11 @@ async def generate_full_analysis(
|
||||
|
||||
logger.info(f"[API] Full analysis requested for {ts_code}")
|
||||
|
||||
# Load base and analysis configurations
|
||||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||||
|
||||
api_key = llm_config.get("api_key")
|
||||
base_url = llm_config.get("base_url")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"API key for {llm_provider} not configured."
|
||||
)
|
||||
# Load LLM configuration using ConfigManager
|
||||
llm_config_result = await config_manager.get_llm_config()
|
||||
default_provider = llm_config_result["provider"]
|
||||
default_config = llm_config_result["config"]
|
||||
global_model = llm_config_result.get("model") # 全局模型配置
|
||||
|
||||
analysis_config_full = load_analysis_config()
|
||||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||||
@ -211,10 +206,15 @@ async def generate_full_analysis(
|
||||
module_config = modules_config[module_type]
|
||||
logger.info(f"[Orchestrator] Starting analysis for module: {module_type}")
|
||||
|
||||
client = AnalysisClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=module_config.get("model", "gemini-1.5-flash")
|
||||
# 统一使用全局配置,不再从模块配置读取 provider 和 model
|
||||
# 使用全局 provider 和 model
|
||||
model = global_model or default_config.get("model", "gemini-1.5-flash")
|
||||
|
||||
# Create client using factory with global config
|
||||
client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
|
||||
# Gather context from completed dependencies
|
||||
@ -468,6 +468,7 @@ async def get_financials(
|
||||
async def get_company_profile(
|
||||
ts_code: str,
|
||||
company_name: str = Query(None, description="Company name for better context"),
|
||||
config_manager: ConfigManager = Depends(get_config_manager),
|
||||
):
|
||||
"""
|
||||
Get company profile for a company using Gemini AI (non-streaming, single response)
|
||||
@ -477,19 +478,26 @@ async def get_company_profile(
|
||||
|
||||
logger.info(f"[API] Company profile requested for {ts_code}")
|
||||
|
||||
# Load config
|
||||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||||
# Load LLM configuration using ConfigManager
|
||||
llm_config_result = await config_manager.get_llm_config()
|
||||
provider = llm_config_result["provider"]
|
||||
provider_config = llm_config_result["config"]
|
||||
|
||||
api_key = llm_config.get("api_key")
|
||||
base_url = llm_config.get("base_url") # Will be None if not set, handled by client
|
||||
# CompanyProfileClient only supports OpenAI-compatible APIs
|
||||
if provider == "alpha_engine":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Company profile generation does not support AlphaEngine provider. Please use OpenAI-compatible API."
|
||||
)
|
||||
|
||||
api_key = provider_config.get("api_key")
|
||||
base_url = provider_config.get("base_url")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||||
logger.error(f"[API] API key for {provider} not configured")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"API key for {llm_provider} not configured."
|
||||
detail=f"API key for {provider} not configured."
|
||||
)
|
||||
|
||||
client = CompanyProfileClient(
|
||||
@ -573,6 +581,7 @@ async def generate_analysis(
|
||||
ts_code: str,
|
||||
analysis_type: str,
|
||||
company_name: str = Query(None, description="Company name for better context"),
|
||||
config_manager: ConfigManager = Depends(get_config_manager),
|
||||
):
|
||||
"""
|
||||
Generate analysis for a company using Gemini AI
|
||||
@ -591,37 +600,11 @@ async def generate_analysis(
|
||||
|
||||
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
||||
|
||||
# Load config
|
||||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||||
|
||||
api_key = llm_config.get("api_key")
|
||||
base_url = llm_config.get("base_url")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"API key for {llm_provider} not configured."
|
||||
)
|
||||
|
||||
# Get analysis configuration
|
||||
analysis_cfg = get_analysis_config(analysis_type)
|
||||
if not analysis_cfg:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||||
)
|
||||
|
||||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||||
|
||||
if not prompt_template:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||||
)
|
||||
# Load LLM configuration using ConfigManager
|
||||
llm_config_result = await config_manager.get_llm_config()
|
||||
default_provider = llm_config_result["provider"]
|
||||
default_config = llm_config_result["config"]
|
||||
global_model = llm_config_result.get("model") # 全局模型配置
|
||||
|
||||
# Get company name from ts_code if not provided
|
||||
financial_data = None
|
||||
@ -656,8 +639,30 @@ async def generate_analysis(
|
||||
|
||||
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
||||
|
||||
# Initialize analysis client with configured model
|
||||
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
||||
# Get analysis configuration for prompt template
|
||||
analysis_cfg = get_analysis_config(analysis_type)
|
||||
if not analysis_cfg:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||||
)
|
||||
|
||||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||||
if not prompt_template:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||||
)
|
||||
|
||||
# 统一使用全局配置,不再从模块配置读取 provider 和 model
|
||||
model = global_model or default_config.get("model", "gemini-1.5-flash")
|
||||
|
||||
# 统一使用全局配置创建客户端
|
||||
client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
|
||||
# Prepare dependency context for single-module generation
|
||||
# If the requested module declares dependencies, generate them first and inject their outputs
|
||||
@ -701,12 +706,18 @@ async def generate_analysis(
|
||||
# Fallback: if cycle detected, just use any order
|
||||
order = list(all_required)
|
||||
|
||||
# Generate dependencies in order
|
||||
# Generate dependencies in order - 统一使用全局配置
|
||||
completed = {}
|
||||
for mod in order:
|
||||
cfg = modules_config.get(mod, {})
|
||||
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
||||
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
||||
|
||||
# 统一使用全局配置,不再从模块配置读取
|
||||
dep_client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
dep_result = await dep_client.generate_analysis(
|
||||
analysis_type=mod,
|
||||
company_name=company_name,
|
||||
@ -888,6 +899,7 @@ async def stream_analysis(
|
||||
ts_code: str,
|
||||
analysis_type: str,
|
||||
company_name: str = Query(None, description="Company name for better context"),
|
||||
config_manager: ConfigManager = Depends(get_config_manager),
|
||||
):
|
||||
"""
|
||||
Stream analysis content chunks for a given module using OpenAI-compatible streaming.
|
||||
@ -899,24 +911,19 @@ async def stream_analysis(
|
||||
|
||||
logger.info(f"[API] Streaming analysis requested for {ts_code}, type: {analysis_type}")
|
||||
|
||||
# Load config
|
||||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||||
|
||||
api_key = llm_config.get("api_key")
|
||||
base_url = llm_config.get("base_url")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"[API] API key for {llm_provider} not configured")
|
||||
raise HTTPException(status_code=500, detail=f"API key for {llm_provider} not configured.")
|
||||
# Load LLM configuration using ConfigManager
|
||||
llm_config_result = await config_manager.get_llm_config()
|
||||
default_provider = llm_config_result["provider"]
|
||||
default_config = llm_config_result["config"]
|
||||
global_model = llm_config_result.get("model") # 全局模型配置
|
||||
|
||||
# Get analysis configuration
|
||||
analysis_cfg = get_analysis_config(analysis_type)
|
||||
if not analysis_cfg:
|
||||
raise HTTPException(status_code=404, detail=f"Analysis type '{analysis_type}' not found in configuration")
|
||||
|
||||
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||||
# 统一使用全局配置,不再从模块配置读取 provider 和 model
|
||||
model = global_model or default_config.get("model", "gemini-1.5-flash")
|
||||
prompt_template = analysis_cfg.get("prompt_template", "")
|
||||
if not prompt_template:
|
||||
raise HTTPException(status_code=500, detail=f"Prompt template not found for analysis type '{analysis_type}'")
|
||||
@ -972,7 +979,13 @@ async def stream_analysis(
|
||||
for mod in order:
|
||||
cfg = modules_config.get(mod, {})
|
||||
dep_ctx = {d: completed.get(d, "") for d in (cfg.get("dependencies", []) or [])}
|
||||
dep_client = AnalysisClient(api_key=api_key, base_url=base_url, model=cfg.get("model", model))
|
||||
|
||||
# 统一使用全局配置,不再从模块配置读取
|
||||
dep_client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
dep_result = await dep_client.generate_analysis(
|
||||
analysis_type=mod,
|
||||
company_name=company_name,
|
||||
@ -986,7 +999,12 @@ async def stream_analysis(
|
||||
except Exception:
|
||||
context = {}
|
||||
|
||||
client = AnalysisClient(api_key=api_key, base_url=base_url, model=model)
|
||||
# 统一使用全局配置创建客户端
|
||||
client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
|
||||
async def streamer():
|
||||
# Optional header line to help client-side UI
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Dict
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Depends
|
||||
|
||||
# Lazy loader for DataManager
|
||||
_dm = None
|
||||
@ -23,6 +23,9 @@ def get_dm():
|
||||
return _dm
|
||||
|
||||
from app.services.analysis_client import AnalysisClient, load_analysis_config
|
||||
from app.core.dependencies import get_config_manager
|
||||
from app.services.config_manager import ConfigManager
|
||||
from app.services.client_factory import create_analysis_client
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -40,7 +43,7 @@ def _load_json(path: str) -> Dict:
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
async def run_full_analysis(org_id: str):
|
||||
async def run_full_analysis(org_id: str, config_manager: ConfigManager = None):
|
||||
"""
|
||||
Asynchronous task to run a full analysis for a given stock.
|
||||
This function is market-agnostic and relies on DataManager.
|
||||
@ -48,16 +51,23 @@ async def run_full_analysis(org_id: str):
|
||||
logger.info(f"Starting full analysis task for {org_id}")
|
||||
|
||||
# 1. Load configurations
|
||||
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||
llm_provider = base_cfg.get("llm", {}).get("provider", "gemini")
|
||||
llm_config = base_cfg.get("llm", {}).get(llm_provider, {})
|
||||
if config_manager is None:
|
||||
# If called from background task, we need to create a new session
|
||||
from app.core.database import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as session:
|
||||
config_manager = ConfigManager(db_session=session)
|
||||
await _run_analysis_with_config(org_id, config_manager)
|
||||
else:
|
||||
await _run_analysis_with_config(org_id, config_manager)
|
||||
|
||||
api_key = llm_config.get("api_key")
|
||||
base_url = llm_config.get("base_url")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"API key for {llm_provider} not configured. Aborting analysis for {org_id}.")
|
||||
return
|
||||
async def _run_analysis_with_config(org_id: str, config_manager: ConfigManager):
|
||||
"""Internal function to run analysis with a ConfigManager instance"""
|
||||
# Load LLM configuration using ConfigManager
|
||||
llm_config_result = await config_manager.get_llm_config()
|
||||
default_provider = llm_config_result["provider"]
|
||||
default_config = llm_config_result["config"]
|
||||
global_model = llm_config_result.get("model") # 全局模型配置
|
||||
|
||||
analysis_config_full = load_analysis_config()
|
||||
modules_config = analysis_config_full.get("analysis_modules", {})
|
||||
@ -96,10 +106,15 @@ async def run_full_analysis(org_id: str):
|
||||
analysis_results = {}
|
||||
for module_type, module_config in modules_config.items():
|
||||
logger.info(f"Running analysis module: {module_type} for {org_id}")
|
||||
client = AnalysisClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=module_config.get("model", "gemini-1.5-flash")
|
||||
|
||||
# 统一使用全局配置,不再从模块配置读取 provider 和 model
|
||||
model = global_model or default_config.get("model", "gemini-1.5-flash")
|
||||
|
||||
# Create client using factory with global config
|
||||
client = create_analysis_client(
|
||||
provider=default_provider,
|
||||
config=default_config,
|
||||
model=model
|
||||
)
|
||||
|
||||
# Simplified context: use results from all previously completed modules
|
||||
@ -128,7 +143,7 @@ async def run_full_analysis(org_id: str):
|
||||
|
||||
|
||||
@router.post("/{market}/{org_id}/reports/generate")
|
||||
async def trigger_report_generation(market: str, org_id: str, background_tasks: BackgroundTasks):
|
||||
async def trigger_report_generation(market: str, org_id: str, background_tasks: BackgroundTasks, config_manager: ConfigManager = Depends(get_config_manager)):
|
||||
"""
|
||||
Triggers a background task to generate a full financial report.
|
||||
This endpoint is now market-agnostic.
|
||||
@ -137,7 +152,8 @@ async def trigger_report_generation(market: str, org_id: str, background_tasks:
|
||||
|
||||
# TODO: Create a report record in the database with "generating" status here.
|
||||
|
||||
background_tasks.add_task(run_full_analysis, org_id)
|
||||
# Pass config_manager to the background task
|
||||
background_tasks.add_task(run_full_analysis, org_id, config_manager)
|
||||
|
||||
logger.info(f"Queued analysis task for {org_id}.")
|
||||
return {"queued": True, "market": market, "org_id": org_id}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""
|
||||
Configuration-related Pydantic schemas
|
||||
"""
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
@ -11,17 +11,30 @@ class NewApiConfig(BaseModel):
|
||||
api_key: str = Field(..., description="New API Key")
|
||||
base_url: Optional[str] = None
|
||||
|
||||
class AlphaEngineConfig(BaseModel):
|
||||
api_url: str = Field(..., description="AlphaEngine API URL")
|
||||
api_key: str = Field(..., description="AlphaEngine API Key")
|
||||
token: str = Field(..., description="AlphaEngine Token")
|
||||
user_id: int = Field(999041, description="User ID")
|
||||
model: str = Field("deepseek-r1", description="Model name")
|
||||
using_indicator: bool = Field(True, description="Whether to use indicators")
|
||||
start_time: str = Field("2024-01-01", description="Start time for data query")
|
||||
doc_show_type: List[str] = Field(["A001", "A002", "A003", "A004"], description="Document types")
|
||||
simple_tracking: bool = Field(True, description="Whether to enable simple tracking")
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
api_key: str = Field(..., description="数据源API Key")
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
database: DatabaseConfig
|
||||
new_api: NewApiConfig
|
||||
alpha_engine: Optional[AlphaEngineConfig] = None
|
||||
data_sources: Dict[str, DataSourceConfig]
|
||||
|
||||
class ConfigUpdateRequest(BaseModel):
|
||||
database: Optional[DatabaseConfig] = None
|
||||
new_api: Optional[NewApiConfig] = None
|
||||
alpha_engine: Optional[AlphaEngineConfig] = None
|
||||
data_sources: Optional[Dict[str, DataSourceConfig]] = None
|
||||
|
||||
class ConfigTestRequest(BaseModel):
|
||||
|
||||
260
backend/app/services/alpha_engine_client.py
Normal file
260
backend/app/services/alpha_engine_client.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
AlphaEngine Client for investment Q&A API
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
import httpx
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
|
||||
class AlphaEngineClient:
|
||||
"""Client for AlphaEngine investment Q&A API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
token: str,
|
||||
user_id: int = 999041,
|
||||
model: str = "deepseek-r1",
|
||||
using_indicator: bool = True,
|
||||
start_time: str = "2024-01-01",
|
||||
doc_show_type: list = None,
|
||||
simple_tracking: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize AlphaEngine client
|
||||
|
||||
Args:
|
||||
api_url: API endpoint URL
|
||||
api_key: X-API-KEY for authentication
|
||||
token: Token for authentication
|
||||
user_id: User ID
|
||||
model: Model name (default: deepseek-r1)
|
||||
using_indicator: Whether to use indicators
|
||||
start_time: Start time for data query
|
||||
doc_show_type: Document types to show (default: ["A001", "A002", "A003", "A004"])
|
||||
simple_tracking: Whether to enable simple tracking
|
||||
"""
|
||||
self.api_url = api_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.token = token
|
||||
self.user_id = user_id
|
||||
self.model = model
|
||||
self.using_indicator = using_indicator
|
||||
self.start_time = start_time
|
||||
self.doc_show_type = doc_show_type or ["A001", "A002", "A003", "A004"]
|
||||
self.simple_tracking = simple_tracking
|
||||
|
||||
async def generate_analysis(
|
||||
self,
|
||||
analysis_type: str,
|
||||
company_name: str,
|
||||
ts_code: str,
|
||||
prompt_template: str,
|
||||
financial_data: Optional[Dict] = None,
|
||||
context: Optional[Dict] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Generate analysis using AlphaEngine API (non-streaming)
|
||||
|
||||
Args:
|
||||
analysis_type: Type of analysis
|
||||
company_name: Company name
|
||||
ts_code: Stock code
|
||||
prompt_template: Prompt template with placeholders
|
||||
financial_data: Optional financial data for context
|
||||
context: Optional dictionary with results from previous analyses
|
||||
|
||||
Returns:
|
||||
Dict with analysis content and metadata
|
||||
"""
|
||||
start_time = time.perf_counter_ns()
|
||||
|
||||
# Build prompt from template
|
||||
prompt = self._build_prompt(
|
||||
prompt_template,
|
||||
company_name,
|
||||
ts_code,
|
||||
financial_data,
|
||||
context
|
||||
)
|
||||
|
||||
# Call AlphaEngine API
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
headers = {
|
||||
'token': self.token,
|
||||
'X-API-KEY': self.api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
payload = {
|
||||
"msg": prompt,
|
||||
"history": [],
|
||||
"user_id": self.user_id,
|
||||
"model": self.model,
|
||||
"using_indicator": self.using_indicator,
|
||||
"start_time": self.start_time,
|
||||
"doc_show_type": self.doc_show_type,
|
||||
"simple_tracking": self.simple_tracking
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.api_url}/api/v3/finchat",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"AlphaEngine API error: HTTP {response.status_code} - {response.text}")
|
||||
|
||||
result_text = response.text
|
||||
|
||||
# Parse response to extract final answer
|
||||
final_answer_match = re.findall(r'\{"id":"_final","content":"(.*?)"}', result_text)
|
||||
final_answer = final_answer_match[0] if final_answer_match else result_text
|
||||
|
||||
# Extract COT if available
|
||||
cot_match = re.findall(r'\{"id":"_cot","content":"(.*?)"}', result_text)
|
||||
cot = "".join(cot_match) if cot_match else ""
|
||||
|
||||
# Extract tracking documents if available
|
||||
tracking_match = re.findall(r'\{"id":"tracking_documents","content":\s*(\[[^]]*])}', result_text)
|
||||
tracking_docs = json.loads(tracking_match[0]) if tracking_match else []
|
||||
|
||||
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||||
|
||||
return {
|
||||
"content": final_answer,
|
||||
"model": self.model,
|
||||
"tokens": {
|
||||
"prompt_tokens": 0, # AlphaEngine doesn't provide token usage
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"success": True,
|
||||
"analysis_type": analysis_type,
|
||||
"cot": cot,
|
||||
"tracking_documents": tracking_docs,
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||||
return {
|
||||
"content": "",
|
||||
"model": self.model,
|
||||
"tokens": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"analysis_type": analysis_type,
|
||||
}
|
||||
|
||||
async def generate_analysis_stream(
|
||||
self,
|
||||
analysis_type: str,
|
||||
company_name: str,
|
||||
ts_code: str,
|
||||
prompt_template: str,
|
||||
financial_data: Optional[Dict] = None,
|
||||
context: Optional[Dict] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Yield analysis content chunks using AlphaEngine streaming API
|
||||
|
||||
Yields plain text chunks as they arrive.
|
||||
"""
|
||||
# Build prompt
|
||||
prompt = self._build_prompt(
|
||||
prompt_template,
|
||||
company_name,
|
||||
ts_code,
|
||||
financial_data,
|
||||
context,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
headers = {
|
||||
'token': self.token,
|
||||
'X-API-KEY': self.api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
payload = {
|
||||
"msg": prompt,
|
||||
"history": [],
|
||||
"user_id": self.user_id,
|
||||
"model": self.model,
|
||||
"using_indicator": self.using_indicator,
|
||||
"start_time": self.start_time,
|
||||
"doc_show_type": self.doc_show_type,
|
||||
"simple_tracking": self.simple_tracking
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.api_url}/api/v3/finchat",
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
yield f"\n\n[错误] HTTP {response.status_code}: {response.text}\n"
|
||||
return
|
||||
|
||||
async for chunk in response.aiter_bytes(chunk_size=128):
|
||||
try:
|
||||
chunk_text = chunk.decode('utf-8', 'ignore')
|
||||
yield chunk_text
|
||||
except UnicodeDecodeError:
|
||||
chunk_text = chunk.decode('utf-8', 'replace')
|
||||
yield chunk_text
|
||||
except Exception as e:
|
||||
yield f"\n\n[错误] {type(e).__name__}: {str(e)}\n"
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
prompt_template: str,
|
||||
company_name: str,
|
||||
ts_code: str,
|
||||
financial_data: Optional[Dict] = None,
|
||||
context: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""Build prompt from template by replacing placeholders"""
|
||||
import string
|
||||
|
||||
# Start with base placeholders
|
||||
placeholders = {
|
||||
"company_name": company_name,
|
||||
"ts_code": ts_code,
|
||||
}
|
||||
|
||||
# Add financial data if provided
|
||||
financial_data_str = ""
|
||||
if financial_data:
|
||||
try:
|
||||
financial_data_str = json.dumps(financial_data, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
financial_data_str = str(financial_data)
|
||||
placeholders["financial_data"] = financial_data_str
|
||||
|
||||
# Add context from previous analysis steps
|
||||
if context:
|
||||
placeholders.update(context)
|
||||
|
||||
# Replace placeholders in template
|
||||
class SafeFormatter(string.Formatter):
|
||||
def get_value(self, key, args, kwargs):
|
||||
if isinstance(key, str):
|
||||
return kwargs.get(key, f"{{{key}}}")
|
||||
else:
|
||||
return super().get_value(key, args, kwargs)
|
||||
|
||||
formatter = SafeFormatter()
|
||||
prompt = formatter.format(prompt_template, **placeholders)
|
||||
|
||||
return prompt
|
||||
|
||||
60
backend/app/services/client_factory.py
Normal file
60
backend/app/services/client_factory.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Unified Analysis Client Factory
|
||||
Creates appropriate client based on provider type
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
from app.services.analysis_client import AnalysisClient
|
||||
from app.services.alpha_engine_client import AlphaEngineClient
|
||||
|
||||
|
||||
def create_analysis_client(
|
||||
provider: str,
|
||||
config: Dict,
|
||||
model: str = None
|
||||
):
|
||||
"""
|
||||
Create an analysis client based on provider type
|
||||
|
||||
Args:
|
||||
provider: Provider type ("openai", "gemini", "new_api", "alpha_engine")
|
||||
config: Configuration dictionary containing provider-specific settings
|
||||
model: Model name (optional, may be overridden by config)
|
||||
|
||||
Returns:
|
||||
Client instance (AnalysisClient or AlphaEngineClient)
|
||||
"""
|
||||
if provider == "alpha_engine":
|
||||
# AlphaEngine specific configuration
|
||||
api_url = config.get("api_url", "")
|
||||
api_key = config.get("api_key", "")
|
||||
token = config.get("token", "")
|
||||
user_id = config.get("user_id", 999041)
|
||||
model_name = model or config.get("model", "deepseek-r1")
|
||||
using_indicator = config.get("using_indicator", True)
|
||||
start_time = config.get("start_time", "2024-01-01")
|
||||
doc_show_type = config.get("doc_show_type", ["A001", "A002", "A003", "A004"])
|
||||
simple_tracking = config.get("simple_tracking", True)
|
||||
|
||||
return AlphaEngineClient(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
token=token,
|
||||
user_id=user_id,
|
||||
model=model_name,
|
||||
using_indicator=using_indicator,
|
||||
start_time=start_time,
|
||||
doc_show_type=doc_show_type,
|
||||
simple_tracking=simple_tracking
|
||||
)
|
||||
else:
|
||||
# OpenAI-compatible API (openai, gemini, new_api)
|
||||
api_key = config.get("api_key", "")
|
||||
base_url = config.get("base_url", "")
|
||||
model_name = model or config.get("model", "gemini-1.5-flash")
|
||||
|
||||
return AnalysisClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
@ -68,14 +68,14 @@ class CompanyProfileClient:
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def generate_profile_stream(
|
||||
async def generate_profile_stream(
|
||||
self,
|
||||
company_name: str,
|
||||
ts_code: str,
|
||||
financial_data: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Generate company profile using Gemini API with streaming
|
||||
Generate company profile using OpenAI-compatible streaming API
|
||||
|
||||
Args:
|
||||
company_name: Company name
|
||||
@ -85,40 +85,31 @@ class CompanyProfileClient:
|
||||
Yields:
|
||||
Chunks of generated content
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"[CompanyProfile] Starting stream generation for {company_name} ({ts_code})")
|
||||
|
||||
# Build prompt
|
||||
prompt = self._build_prompt(company_name, ts_code, financial_data)
|
||||
logger.info(f"[CompanyProfile] Prompt built, length: {len(prompt)} chars")
|
||||
|
||||
# Call Gemini API with streaming
|
||||
# Call OpenAI-compatible API with streaming
|
||||
try:
|
||||
logger.info("[CompanyProfile] Calling Gemini API with stream=True")
|
||||
# Generate streaming response (sync call, but yields chunks)
|
||||
response_stream = self.model.generate_content(prompt, stream=True)
|
||||
logger.info("[CompanyProfile] Gemini API stream object created")
|
||||
|
||||
chunk_count = 0
|
||||
# Stream chunks
|
||||
logger.info("[CompanyProfile] Starting to iterate response stream")
|
||||
for chunk in response_stream:
|
||||
logger.info(f"[CompanyProfile] Received chunk from Gemini, has text: {hasattr(chunk, 'text')}")
|
||||
if hasattr(chunk, 'text') and chunk.text:
|
||||
chunk_count += 1
|
||||
text_len = len(chunk.text)
|
||||
logger.info(f"[CompanyProfile] Chunk {chunk_count}: {text_len} chars")
|
||||
yield chunk.text
|
||||
else:
|
||||
logger.warning(f"[CompanyProfile] Chunk has no text attribute or empty, chunk: {chunk}")
|
||||
|
||||
logger.info(f"[CompanyProfile] Stream iteration completed. Total chunks: {chunk_count}")
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# The SDK yields events with incremental deltas
|
||||
async for event in stream:
|
||||
try:
|
||||
choice = event.choices[0] if getattr(event, "choices", None) else None
|
||||
delta = getattr(choice, "delta", None) if choice is not None else None
|
||||
content = getattr(delta, "content", None) if delta is not None else None
|
||||
if content:
|
||||
yield content
|
||||
except Exception:
|
||||
# Best-effort: ignore malformed chunks
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"[CompanyProfile] Error during streaming: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
yield f"\n\n---\n\n**错误**: {type(e).__name__}: {str(e)}"
|
||||
# Emit error message to the stream so the client can surface it
|
||||
yield f"\n\n[错误] {type(e).__name__}: {str(e)}\n"
|
||||
|
||||
def _build_prompt(self, company_name: str, ts_code: str, financial_data: Optional[Dict] = None) -> str:
|
||||
"""Build prompt for company profile generation"""
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, NewApiConfig, DataSourceConfig, ConfigTestResponse
|
||||
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, NewApiConfig, AlphaEngineConfig, DataSourceConfig, ConfigTestResponse
|
||||
|
||||
class ConfigManager:
|
||||
"""Manages system configuration by merging a static JSON file with dynamic settings from the database."""
|
||||
@ -72,25 +72,114 @@ class ConfigManager:
|
||||
# 兼容两种位置:优先使用 new_api,其次回退到 llm.new_api
|
||||
new_api_src = merged_config.get("new_api") or merged_config.get("llm", {}).get("new_api", {})
|
||||
|
||||
# 获取 alpha_engine 配置
|
||||
alpha_engine_src = merged_config.get("alpha_engine") or merged_config.get("llm", {}).get("alpha_engine")
|
||||
alpha_engine_config = None
|
||||
if alpha_engine_src:
|
||||
alpha_engine_config = AlphaEngineConfig(**alpha_engine_src)
|
||||
|
||||
return ConfigResponse(
|
||||
database=DatabaseConfig(**merged_config.get("database", {})),
|
||||
new_api=NewApiConfig(**(new_api_src or {})),
|
||||
alpha_engine=alpha_engine_config,
|
||||
data_sources={
|
||||
k: DataSourceConfig(**v)
|
||||
for k, v in merged_config.get("data_sources", {}).items()
|
||||
}
|
||||
)
|
||||
|
||||
async def get_llm_config(self, provider: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get LLM configuration for a specific provider
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "new_api", "gemini", "alpha_engine")
|
||||
If None, uses the configured provider from config
|
||||
|
||||
Returns:
|
||||
Dictionary with provider configuration and provider name
|
||||
"""
|
||||
base_config = self._load_base_config_from_file()
|
||||
db_config = await self._load_dynamic_config_from_db()
|
||||
|
||||
merged_config = self._merge_configs(base_config, db_config)
|
||||
|
||||
llm_config = merged_config.get("llm", {})
|
||||
|
||||
# Determine provider
|
||||
if not provider:
|
||||
provider = llm_config.get("provider", "new_api")
|
||||
|
||||
# Get provider-specific config
|
||||
provider_config = llm_config.get(provider, {})
|
||||
|
||||
# Get global model from provider config if available
|
||||
global_model = provider_config.get("model")
|
||||
|
||||
return {
|
||||
"provider": provider,
|
||||
"config": provider_config,
|
||||
"model": global_model # 返回全局模型配置
|
||||
}
|
||||
|
||||
def _filter_empty_values(self, config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Remove empty strings and None values from config dict, but keep 0 and False."""
|
||||
filtered = {}
|
||||
for key, value in config_dict.items():
|
||||
if isinstance(value, dict):
|
||||
filtered_value = self._filter_empty_values(value)
|
||||
if filtered_value: # Only add if dict is not empty
|
||||
filtered[key] = filtered_value
|
||||
elif value is not None and value != "":
|
||||
filtered[key] = value
|
||||
return filtered
|
||||
|
||||
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
||||
"""Updates configuration in the database and returns the new merged config."""
|
||||
try:
|
||||
update_dict = config_update.dict(exclude_unset=True)
|
||||
|
||||
# 过滤空值
|
||||
update_dict = self._filter_empty_values(update_dict)
|
||||
|
||||
# 验证配置数据
|
||||
self._validate_config_data(update_dict)
|
||||
|
||||
# 处理 LLM 相关配置:需要保存到 llm 配置下
|
||||
llm_updates = {}
|
||||
if "new_api" in update_dict:
|
||||
llm_updates["new_api"] = update_dict.pop("new_api")
|
||||
if "alpha_engine" in update_dict:
|
||||
llm_updates["alpha_engine"] = update_dict.pop("alpha_engine")
|
||||
|
||||
# 保存 LLM 配置
|
||||
if llm_updates:
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "llm")
|
||||
)
|
||||
existing_llm_config = result.scalar_one_or_none()
|
||||
|
||||
if existing_llm_config:
|
||||
if isinstance(existing_llm_config.config_value, dict):
|
||||
merged_llm = self._merge_configs(existing_llm_config.config_value, llm_updates)
|
||||
existing_llm_config.config_value = merged_llm
|
||||
else:
|
||||
existing_llm_config.config_value = llm_updates
|
||||
else:
|
||||
# 从文件加载基础配置,然后合并
|
||||
base_config = self._load_base_config_from_file()
|
||||
base_llm = base_config.get("llm", {})
|
||||
merged_llm = self._merge_configs(base_llm, llm_updates)
|
||||
new_llm_config = SystemConfig(config_key="llm", config_value=merged_llm)
|
||||
self.db.add(new_llm_config)
|
||||
|
||||
# 保存其他配置(database, data_sources 等)
|
||||
for key, value in update_dict.items():
|
||||
existing_config = await self.db.get(SystemConfig, key)
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == key)
|
||||
)
|
||||
existing_config = result.scalar_one_or_none()
|
||||
|
||||
if existing_config:
|
||||
# Merge with existing DB value before updating
|
||||
if isinstance(existing_config.config_value, dict) and isinstance(value, dict):
|
||||
@ -112,23 +201,32 @@ class ConfigManager:
|
||||
"""Validate configuration data before saving."""
|
||||
if "database" in config_data:
|
||||
db_config = config_data["database"]
|
||||
if "url" in db_config:
|
||||
if "url" in db_config and db_config["url"]:
|
||||
url = db_config["url"]
|
||||
if not url.startswith(("postgresql://", "postgresql+asyncpg://")):
|
||||
raise ValueError("数据库URL必须以 postgresql:// 或 postgresql+asyncpg:// 开头")
|
||||
|
||||
if "new_api" in config_data:
|
||||
new_api_config = config_data["new_api"]
|
||||
if "api_key" in new_api_config and len(new_api_config["api_key"]) < 10:
|
||||
if "api_key" in new_api_config and new_api_config["api_key"] and len(new_api_config["api_key"]) < 10:
|
||||
raise ValueError("New API Key长度不能少于10个字符")
|
||||
if "base_url" in new_api_config and new_api_config["base_url"]:
|
||||
base_url = new_api_config["base_url"]
|
||||
if not base_url.startswith(("http://", "https://")):
|
||||
raise ValueError("New API Base URL必须以 http:// 或 https:// 开头")
|
||||
|
||||
if "alpha_engine" in config_data:
|
||||
alpha_engine_config = config_data["alpha_engine"]
|
||||
if "api_key" in alpha_engine_config and alpha_engine_config["api_key"] and len(alpha_engine_config["api_key"]) < 5:
|
||||
raise ValueError("AlphaEngine API Key长度不能少于5个字符")
|
||||
if "api_url" in alpha_engine_config and alpha_engine_config["api_url"]:
|
||||
api_url = alpha_engine_config["api_url"]
|
||||
if not api_url.startswith(("http://", "https://")):
|
||||
raise ValueError("AlphaEngine API URL必须以 http:// 或 https:// 开头")
|
||||
|
||||
if "data_sources" in config_data:
|
||||
for source_name, source_config in config_data["data_sources"].items():
|
||||
if "api_key" in source_config and len(source_config["api_key"]) < 10:
|
||||
if "api_key" in source_config and source_config["api_key"] and len(source_config["api_key"]) < 10:
|
||||
raise ValueError(f"{source_name} API Key长度不能少于10个字符")
|
||||
|
||||
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
@ -142,6 +240,8 @@ class ConfigManager:
|
||||
return await self._test_tushare(config_data)
|
||||
elif config_type == "finnhub":
|
||||
return await self._test_finnhub(config_data)
|
||||
elif config_type == "alpha_engine":
|
||||
return await self._test_alpha_engine(config_data)
|
||||
else:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
@ -302,3 +402,57 @@ class ConfigManager:
|
||||
success=False,
|
||||
message=f"Finnhub API连接失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _test_alpha_engine(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
"""Test AlphaEngine API connection."""
|
||||
api_url = config_data.get("api_url")
|
||||
api_key = config_data.get("api_key")
|
||||
token = config_data.get("token")
|
||||
|
||||
if not api_url or not api_key or not token:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message="AlphaEngine API URL、API Key和Token均不能为空"
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
headers = {
|
||||
'token': token,
|
||||
'X-API-KEY': api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 发送一个简单的测试请求
|
||||
payload = {
|
||||
"msg": "测试连接",
|
||||
"history": [],
|
||||
"user_id": config_data.get("user_id", 999041),
|
||||
"model": config_data.get("model", "deepseek-r1"),
|
||||
"using_indicator": config_data.get("using_indicator", True),
|
||||
"start_time": config_data.get("start_time", "2024-01-01"),
|
||||
"doc_show_type": config_data.get("doc_show_type", ["A001", "A002", "A003", "A004"]),
|
||||
"simple_tracking": config_data.get("simple_tracking", True)
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
f"{api_url.rstrip('/')}/api/v3/finchat",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return ConfigTestResponse(
|
||||
success=True,
|
||||
message="AlphaEngine API连接成功"
|
||||
)
|
||||
else:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"AlphaEngine API测试失败: HTTP {response.status_code} - {response.text[:200]}"
|
||||
)
|
||||
except Exception as e:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"AlphaEngine API连接失败: {str(e)}"
|
||||
)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -1,6 +1,6 @@
|
||||
{
|
||||
"llm": {
|
||||
"provider": "new_api",
|
||||
"provider": "alpha_engine",
|
||||
"gemini": {
|
||||
"base_url": "",
|
||||
"api_key": "YOUR_GEMINI_API_KEY"
|
||||
@ -8,6 +8,22 @@
|
||||
"new_api": {
|
||||
"base_url": "http://192.168.3.214:3000/v1",
|
||||
"api_key": "sk-DdTTQ5fdU1aFW6gnYxSNYDgFsVQg938zUcmY4vaB7oPtcNs7"
|
||||
},
|
||||
"alpha_engine": {
|
||||
"api_url": "http://api-ai-prod.valuesimplex.tech",
|
||||
"api_key": "api@shangjian!",
|
||||
"token": "9b5c0b6a5e1e4e8fioouiouqiuioasaz",
|
||||
"user_id": 999041,
|
||||
"model": "deepseek-r1",
|
||||
"using_indicator": true,
|
||||
"start_time": "2024-01-01",
|
||||
"doc_show_type": [
|
||||
"A001",
|
||||
"A002",
|
||||
"A003",
|
||||
"A004"
|
||||
],
|
||||
"simple_tracking": true
|
||||
}
|
||||
},
|
||||
"data_sources": {
|
||||
|
||||
BIN
docs/AlphaEngine/1 熵简科技-投研问答API技术文档.pdf
Normal file
BIN
docs/AlphaEngine/1 熵简科技-投研问答API技术文档.pdf
Normal file
Binary file not shown.
99
docs/AlphaEngine/2 test_investment_qa_v3(1).py
Normal file
99
docs/AlphaEngine/2 test_investment_qa_v3(1).py
Normal file
@ -0,0 +1,99 @@
|
||||
# coding:utf-8
|
||||
import json
|
||||
import re
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
# 请求地址
|
||||
qa_url = "http://api-ai-prod.valuesimplex.tech/api/v3/finchat"
|
||||
# 熵简提供的x-api-key
|
||||
api_key = "api@shangjian!"
|
||||
token = "9b5c0b6a5e1e4e8fioouiouqiuioasaz"
|
||||
user_id = 999041
|
||||
|
||||
|
||||
def ask(question, user_id):
|
||||
# 设置请求头
|
||||
headers = {
|
||||
'token': token,
|
||||
'X-API-KEY': api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# 构造请求体
|
||||
payload = json.dumps({
|
||||
"msg": question,
|
||||
# 历史问答,没有则为空List
|
||||
"history": [],
|
||||
"user_id": user_id,
|
||||
"model": "deepseek-r1", # 默认值 不用改
|
||||
"using_indicator": True, # 是否用指标
|
||||
"start_time": "2024-01-01", # 开始时间
|
||||
"doc_show_type": ["A001", "A002", "A003", "A004"], # 文档类型
|
||||
"simple_tracking": simple_tracking # 是否简单溯源
|
||||
})
|
||||
print(f"******开始提问:[{question}]")
|
||||
|
||||
# 发送请求
|
||||
response = requests.request("POST", qa_url, data=payload, headers=headers, stream=True)
|
||||
|
||||
qa_result = ''
|
||||
|
||||
# 判断请求是否成功
|
||||
if response.status_code == 200:
|
||||
if stream_enabled:
|
||||
try:
|
||||
for chunk in response.iter_content(chunk_size=128):
|
||||
try:
|
||||
chunk_event = chunk.decode('utf-8', 'ignore')
|
||||
except UnicodeDecodeError as e:
|
||||
# 自定义处理解码错误,例如替换无法解码的部分
|
||||
chunk_event = chunk.decode('utf-8', 'replace')
|
||||
print(f"Decoding error occurred: {e}")
|
||||
qa_result += chunk_event
|
||||
print(f"\033[1;32m" + chunk_event)
|
||||
except ChunkedEncodingError:
|
||||
print("Stream ended prematurely. Handling gracefully.")
|
||||
|
||||
else:
|
||||
# 获取响应内容
|
||||
qa_result = response.content
|
||||
# 将响应内容解码为utf-8格式
|
||||
qa_result = qa_result.decode('utf-8')
|
||||
else:
|
||||
print(f"Failed to get stream data. Status code: {response.status_code}")
|
||||
# 返回结果
|
||||
|
||||
return qa_result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 问题内容
|
||||
question = '科大讯飞业绩怎么样?'
|
||||
# 关闭吐字模式
|
||||
stream_enabled = True
|
||||
# 开启简单溯源
|
||||
simple_tracking = True
|
||||
# 调用函数进行问答
|
||||
result = ask(question, user_id)
|
||||
|
||||
# 仅打印最终问答结果
|
||||
print("**************COT**************")
|
||||
cot_list = re.findall(r'\{"id":"_cot","content":"(.*?)"}', result)
|
||||
cot = "".join(cot_list)
|
||||
print(cot)
|
||||
print("**********************************")
|
||||
|
||||
# 仅打印最终问答结果
|
||||
print("**************最终答案**************")
|
||||
print(re.findall(r'\{"id":"_final","content":"(.*?)"}', result)[0])
|
||||
# print(result['answer'])
|
||||
print("**********************************")
|
||||
|
||||
if simple_tracking:
|
||||
print("**************溯源文件**************")
|
||||
source_file = re.findall(r'\{"id":"tracking_documents","content":\s*(\[[^]]*])}', result)
|
||||
if source_file and source_file.__len__() > 0:
|
||||
print(source_file[0])
|
||||
print("**********************************")
|
||||
@ -275,3 +275,4 @@ A:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -27,13 +27,20 @@ export default function ConfigPage() {
|
||||
const [dbUrl, setDbUrl] = useState('');
|
||||
const [newApiApiKey, setNewApiApiKey] = useState('');
|
||||
const [newApiBaseUrl, setNewApiBaseUrl] = useState('');
|
||||
const [alphaEngineApiUrl, setAlphaEngineApiUrl] = useState('');
|
||||
const [alphaEngineApiKey, setAlphaEngineApiKey] = useState('');
|
||||
const [alphaEngineToken, setAlphaEngineToken] = useState('');
|
||||
const [alphaEngineUserId, setAlphaEngineUserId] = useState('');
|
||||
const [tushareApiKey, setTushareApiKey] = useState('');
|
||||
const [finnhubApiKey, setFinnhubApiKey] = useState('');
|
||||
|
||||
// 分析配置的本地状态
|
||||
// 全局 LLM 配置
|
||||
const [llmProvider, setLlmProvider] = useState('');
|
||||
const [llmModel, setLlmModel] = useState('');
|
||||
|
||||
// 分析配置的本地状态(移除 provider 和 model)
|
||||
const [localAnalysisConfig, setLocalAnalysisConfig] = useState<Record<string, {
|
||||
name: string;
|
||||
model: string;
|
||||
prompt_template: string;
|
||||
dependencies?: string[];
|
||||
}>>({});
|
||||
@ -52,12 +59,40 @@ export default function ConfigPage() {
|
||||
// 初始化分析配置的本地状态
|
||||
useEffect(() => {
|
||||
if (analysisConfig?.analysis_modules) {
|
||||
setLocalAnalysisConfig(analysisConfig.analysis_modules);
|
||||
// 移除每个模块的 provider 和 model 字段
|
||||
const cleanedConfig: typeof localAnalysisConfig = {};
|
||||
Object.entries(analysisConfig.analysis_modules).forEach(([key, value]: [string, any]) => {
|
||||
cleanedConfig[key] = {
|
||||
name: value.name || '',
|
||||
prompt_template: value.prompt_template || '',
|
||||
dependencies: value.dependencies || []
|
||||
};
|
||||
});
|
||||
setLocalAnalysisConfig(cleanedConfig);
|
||||
}
|
||||
}, [analysisConfig]);
|
||||
|
||||
// 更新分析配置中的某个字段
|
||||
const updateAnalysisField = (type: string, field: 'name' | 'model' | 'prompt_template', value: string) => {
|
||||
// 初始化全局 LLM 配置(从后端获取)
|
||||
useEffect(() => {
|
||||
const loadLlmConfig = async () => {
|
||||
try {
|
||||
const response = await fetch('/api/config/llm');
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setLlmProvider(data.provider || '');
|
||||
// 从 provider 配置中获取 model
|
||||
const providerConfig = data.config || {};
|
||||
setLlmModel(providerConfig.model || '');
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to load LLM config:', e);
|
||||
}
|
||||
};
|
||||
loadLlmConfig();
|
||||
}, []);
|
||||
|
||||
// 更新分析配置中的某个字段(移除 provider 和 model)
|
||||
const updateAnalysisField = (type: string, field: 'name' | 'prompt_template', value: string) => {
|
||||
setLocalAnalysisConfig(prev => ({
|
||||
...prev,
|
||||
[type]: {
|
||||
@ -161,6 +196,15 @@ export default function ConfigPage() {
|
||||
};
|
||||
}
|
||||
|
||||
if (alphaEngineApiUrl || alphaEngineApiKey || alphaEngineToken) {
|
||||
newConfig.alpha_engine = {
|
||||
api_url: alphaEngineApiUrl || config?.alpha_engine?.api_url || '',
|
||||
api_key: alphaEngineApiKey || config?.alpha_engine?.api_key || '',
|
||||
token: alphaEngineToken || config?.alpha_engine?.token || '',
|
||||
user_id: alphaEngineUserId ? parseInt(alphaEngineUserId) : (config?.alpha_engine?.user_id || 999041),
|
||||
};
|
||||
}
|
||||
|
||||
if (tushareApiKey || finnhubApiKey) {
|
||||
newConfig.data_sources = {
|
||||
...config?.data_sources,
|
||||
@ -216,10 +260,23 @@ export default function ConfigPage() {
|
||||
handleTest('finnhub', { api_key: finnhubApiKey || config?.data_sources?.finnhub?.api_key });
|
||||
};
|
||||
|
||||
const handleTestAlphaEngine = () => {
|
||||
handleTest('alpha_engine', {
|
||||
api_url: alphaEngineApiUrl || config?.alpha_engine?.api_url,
|
||||
api_key: alphaEngineApiKey || config?.alpha_engine?.api_key,
|
||||
token: alphaEngineToken || config?.alpha_engine?.token,
|
||||
user_id: alphaEngineUserId ? parseInt(alphaEngineUserId) : (config?.alpha_engine?.user_id || 999041)
|
||||
});
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
setDbUrl('');
|
||||
setNewApiApiKey('');
|
||||
setNewApiBaseUrl('');
|
||||
setAlphaEngineApiUrl('');
|
||||
setAlphaEngineApiKey('');
|
||||
setAlphaEngineToken('');
|
||||
setAlphaEngineUserId('');
|
||||
setTushareApiKey('');
|
||||
setFinnhubApiKey('');
|
||||
setTestResults({});
|
||||
@ -345,10 +402,91 @@ export default function ConfigPage() {
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>AI 服务配置</CardTitle>
|
||||
<CardDescription>New API 设置 (兼容 OpenAI 格式)</CardDescription>
|
||||
<CardDescription>配置大模型服务提供商和全局设置</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<CardContent className="space-y-6">
|
||||
{/* 全局 LLM 配置 */}
|
||||
<div className="space-y-4 p-4 bg-muted/50 rounded-lg">
|
||||
<h3 className="font-semibold">全局大模型设置</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
所有分析模块将统一使用以下配置
|
||||
</p>
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="llm-provider">大模型提供商</Label>
|
||||
<select
|
||||
id="llm-provider"
|
||||
value={llmProvider}
|
||||
onChange={(e) => setLlmProvider(e.target.value)}
|
||||
className="flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
<option value="">请选择提供商</option>
|
||||
<option value="new_api">New API (OpenAI 兼容)</option>
|
||||
<option value="gemini">Gemini</option>
|
||||
<option value="alpha_engine">AlphaEngine</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="llm-model">模型名称</Label>
|
||||
<Input
|
||||
id="llm-model"
|
||||
type="text"
|
||||
value={llmModel}
|
||||
onChange={(e) => setLlmModel(e.target.value)}
|
||||
placeholder="例如: gemini-1.5-pro, deepseek-r1"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
根据选择的提供商输入对应的模型名称
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={async () => {
|
||||
if (!llmProvider) {
|
||||
setSaveMessage('请先选择大模型提供商');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const response = await fetch('/api/config/llm', {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
provider: llmProvider,
|
||||
model: llmModel || undefined
|
||||
})
|
||||
});
|
||||
if (response.ok) {
|
||||
setSaveMessage('全局 LLM 配置保存成功!');
|
||||
} else {
|
||||
const data = await response.json();
|
||||
setSaveMessage(`保存失败: ${data.detail || '未知错误'}`);
|
||||
}
|
||||
} catch (e: any) {
|
||||
setSaveMessage(`保存失败: ${e.message}`);
|
||||
}
|
||||
setTimeout(() => setSaveMessage(''), 5000);
|
||||
}}
|
||||
variant="outline"
|
||||
>
|
||||
保存全局 LLM 配置
|
||||
</Button>
|
||||
{saveMessage && saveMessage.includes('LLM') && (
|
||||
<Badge variant={saveMessage.includes('成功') ? 'default' : 'destructive'}>
|
||||
{saveMessage}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Label className="text-base font-medium">New API (OpenAI 兼容)</Label>
|
||||
<p className="text-sm text-muted-foreground mb-2">兼容 OpenAI API 格式的服务</p>
|
||||
<div className="space-y-2">
|
||||
<div>
|
||||
<Label htmlFor="new-api-key">API Key</Label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
@ -364,13 +502,13 @@ export default function ConfigPage() {
|
||||
</Button>
|
||||
</div>
|
||||
{testResults.new_api && (
|
||||
<Badge variant={testResults.new_api.success ? 'default' : 'destructive'}>
|
||||
<Badge variant={testResults.new_api.success ? 'default' : 'destructive'} className="mt-2">
|
||||
{testResults.new_api.message}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div>
|
||||
<Label htmlFor="new-api-base-url">Base URL</Label>
|
||||
<Input
|
||||
id="new-api-base-url"
|
||||
@ -381,6 +519,72 @@ export default function ConfigPage() {
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div>
|
||||
<Label className="text-base font-medium">AlphaEngine</Label>
|
||||
<p className="text-sm text-muted-foreground mb-2">熵简科技投研问答 API</p>
|
||||
<div className="space-y-2">
|
||||
<div>
|
||||
<Label htmlFor="alpha-engine-api-url">API URL</Label>
|
||||
<Input
|
||||
id="alpha-engine-api-url"
|
||||
type="text"
|
||||
value={alphaEngineApiUrl}
|
||||
onChange={(e) => setAlphaEngineApiUrl(e.target.value)}
|
||||
placeholder="例如: http://api-ai-prod.valuesimplex.tech"
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<Label htmlFor="alpha-engine-api-key">API Key</Label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
id="alpha-engine-api-key"
|
||||
type="password"
|
||||
value={alphaEngineApiKey}
|
||||
onChange={(e) => setAlphaEngineApiKey(e.target.value)}
|
||||
placeholder="留空表示保持当前值"
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button onClick={handleTestAlphaEngine} variant="outline">
|
||||
测试
|
||||
</Button>
|
||||
</div>
|
||||
{testResults.alpha_engine && (
|
||||
<Badge variant={testResults.alpha_engine.success ? 'default' : 'destructive'} className="mt-2">
|
||||
{testResults.alpha_engine.message}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<Label htmlFor="alpha-engine-token">Token</Label>
|
||||
<Input
|
||||
id="alpha-engine-token"
|
||||
type="password"
|
||||
value={alphaEngineToken}
|
||||
onChange={(e) => setAlphaEngineToken(e.target.value)}
|
||||
placeholder="留空表示保持当前值"
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<Label htmlFor="alpha-engine-user-id">User ID</Label>
|
||||
<Input
|
||||
id="alpha-engine-user-id"
|
||||
type="number"
|
||||
value={alphaEngineUserId}
|
||||
onChange={(e) => setAlphaEngineUserId(e.target.value)}
|
||||
placeholder="默认: 999041"
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
@ -470,19 +674,6 @@ export default function ConfigPage() {
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor={`${type}-model`}>模型名称</Label>
|
||||
<Input
|
||||
id={`${type}-model`}
|
||||
value={config.model || ''}
|
||||
onChange={(e) => updateAnalysisField(type, 'model', e.target.value)}
|
||||
placeholder="例如: gemini-1.5-pro"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
在 AI 服务中配置的模型名称
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>模块依赖</Label>
|
||||
<div className="grid grid-cols-2 sm:grid-cols-3 md:grid-cols-4 gap-2 rounded-lg border p-4">
|
||||
|
||||
@ -21,3 +21,4 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -110,6 +110,9 @@ export default function ReportPage() {
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [saveMsg, setSaveMsg] = useState<string | null>(null)
|
||||
|
||||
// TradingView 显示控制
|
||||
const [showTradingView, setShowTradingView] = useState(false)
|
||||
|
||||
const saveReport = async () => {
|
||||
try {
|
||||
setSaving(true)
|
||||
@ -155,6 +158,9 @@ export default function ReportPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
// 标记已触发分析
|
||||
fullAnalysisTriggeredRef.current = true;
|
||||
|
||||
// 初始化/重置状态,准备顺序执行
|
||||
stopRequestedRef.current = false;
|
||||
abortControllerRef.current?.abort();
|
||||
@ -182,12 +188,7 @@ export default function ReportPage() {
|
||||
setManualRunKey((k) => k + 1);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (financials && !fullAnalysisTriggeredRef.current) {
|
||||
fullAnalysisTriggeredRef.current = true;
|
||||
runFullAnalysis();
|
||||
}
|
||||
}, [financials]);
|
||||
// 移除自动开始分析的逻辑,改为手动触发
|
||||
|
||||
// 计算完成比例
|
||||
const completionProgress = useMemo(() => {
|
||||
@ -796,23 +797,53 @@ export default function ReportPage() {
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="chart" className="space-y-4">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h2 className="text-lg font-medium">股价图表(来自 TradingView)</h2>
|
||||
<Button
|
||||
onClick={() => setShowTradingView(!showTradingView)}
|
||||
variant={showTradingView ? "outline" : "default"}
|
||||
>
|
||||
{showTradingView ? '隐藏图表' : '显示图表'}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{showTradingView ? (
|
||||
<>
|
||||
<div className="flex items-center gap-3 text-sm mb-4">
|
||||
<CheckCircle className="size-4 text-green-600" />
|
||||
<div className="text-muted-foreground">
|
||||
实时股价图表 - {unifiedSymbol}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<TradingViewWidget
|
||||
symbol={unifiedSymbol}
|
||||
market={marketParam}
|
||||
height={500}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<div className="flex items-center justify-center p-8 border border-dashed rounded-lg">
|
||||
<div className="text-center space-y-2">
|
||||
<p className="text-muted-foreground">点击"显示图表"按钮加载 TradingView 股价图表</p>
|
||||
<p className="text-xs text-muted-foreground">图表数据来自 TradingView</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="financial" className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h2 className="text-lg font-medium">财务数据</h2>
|
||||
{financials && !fullAnalysisTriggeredRef.current && analysisConfig?.analysis_modules && (
|
||||
<Button
|
||||
onClick={runFullAnalysis}
|
||||
disabled={isAnalysisRunningRef.current}
|
||||
className="ml-auto"
|
||||
>
|
||||
开始分析
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3 text-sm">
|
||||
{isLoading ? (
|
||||
<Spinner className="size-4" />
|
||||
@ -834,6 +865,25 @@ export default function ReportPage() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{financials && !fullAnalysisTriggeredRef.current && analysisConfig?.analysis_modules && (
|
||||
<div className="bg-blue-50 border border-blue-200 rounded-lg p-4">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-1">
|
||||
<h3 className="font-medium text-blue-900 mb-1">准备开始分析</h3>
|
||||
<p className="text-sm text-blue-700">
|
||||
财务数据已加载完成。点击"开始分析"按钮启动大模型分析流程。
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
onClick={runFullAnalysis}
|
||||
disabled={isAnalysisRunningRef.current}
|
||||
>
|
||||
开始分析
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
|
||||
{financials && (
|
||||
|
||||
@ -43,3 +43,4 @@ if (process.env.NODE_ENV !== 'production') globalForPrisma.prisma = prisma
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user