feat: 添加分析模块配置和分析功能,更新财务数据处理逻辑
This commit is contained in:
parent
6508589027
commit
e0aa61b8c4
@ -28,11 +28,14 @@ async def test_config(
|
|||||||
config_manager: ConfigManager = Depends(get_config_manager)
|
config_manager: ConfigManager = Depends(get_config_manager)
|
||||||
):
|
):
|
||||||
"""Test a specific configuration (e.g., database connection)."""
|
"""Test a specific configuration (e.g., database connection)."""
|
||||||
# The test logic will be implemented in a subsequent step inside the ConfigManager
|
try:
|
||||||
# For now, we return a placeholder response.
|
test_result = await config_manager.test_config(
|
||||||
# test_result = await config_manager.test_config(
|
test_request.config_type,
|
||||||
# test_request.config_type,
|
test_request.config_data
|
||||||
# test_request.config_data
|
)
|
||||||
# )
|
return test_result
|
||||||
# return test_result
|
except Exception as e:
|
||||||
raise HTTPException(status_code=501, detail="Not Implemented")
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ API router for financial data (Tushare for China market)
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
@ -12,9 +12,18 @@ from fastapi.responses import StreamingResponse
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.schemas.financial import BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse
|
from app.schemas.financial import (
|
||||||
|
BatchFinancialDataResponse,
|
||||||
|
FinancialConfigResponse,
|
||||||
|
FinancialMeta,
|
||||||
|
StepRecord,
|
||||||
|
CompanyProfileResponse,
|
||||||
|
AnalysisResponse,
|
||||||
|
AnalysisConfigResponse
|
||||||
|
)
|
||||||
from app.services.tushare_client import TushareClient
|
from app.services.tushare_client import TushareClient
|
||||||
from app.services.company_profile_client import CompanyProfileClient
|
from app.services.company_profile_client import CompanyProfileClient
|
||||||
|
from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -23,6 +32,7 @@ router = APIRouter()
|
|||||||
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json")
|
||||||
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json")
|
||||||
|
ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json")
|
||||||
|
|
||||||
|
|
||||||
def _load_json(path: str) -> Dict:
|
def _load_json(path: str) -> Dict:
|
||||||
@ -86,14 +96,16 @@ async def get_china_financials(
|
|||||||
series: Dict[str, List[Dict]] = {}
|
series: Dict[str, List[Dict]] = {}
|
||||||
|
|
||||||
# Helper to store year-value pairs while keeping most recent per year
|
# Helper to store year-value pairs while keeping most recent per year
|
||||||
def _merge_year_value(key: str, year: str, value):
|
def _merge_year_value(key: str, year: str, value, month: int = None):
|
||||||
arr = series.setdefault(key, [])
|
arr = series.setdefault(key, [])
|
||||||
# upsert by year
|
# upsert by year
|
||||||
for item in arr:
|
for item in arr:
|
||||||
if item["year"] == year:
|
if item["year"] == year:
|
||||||
item["value"] = value
|
item["value"] = value
|
||||||
|
if month is not None:
|
||||||
|
item["month"] = month
|
||||||
return
|
return
|
||||||
arr.append({"year": year, "value": value})
|
arr.append({"year": year, "value": value, "month": month})
|
||||||
|
|
||||||
# Query each API group we care
|
# Query each API group we care
|
||||||
errors: Dict[str, str] = {}
|
errors: Dict[str, str] = {}
|
||||||
@ -107,39 +119,96 @@ async def get_china_financials(
|
|||||||
current_action = step.name
|
current_action = step.name
|
||||||
if not metrics:
|
if not metrics:
|
||||||
continue
|
continue
|
||||||
api_name = metrics[0].get("api") or group_name
|
|
||||||
fields = list({m.get("tushareParam") for m in metrics if m.get("tushareParam")})
|
# 按 API 分组 metrics(处理 unknown 组中有多个不同 API 的情况)
|
||||||
|
api_groups_dict: Dict[str, List[Dict]] = {}
|
||||||
|
for metric in metrics:
|
||||||
|
api = metric.get("api") or group_name
|
||||||
|
if api: # 跳过空 API
|
||||||
|
if api not in api_groups_dict:
|
||||||
|
api_groups_dict[api] = []
|
||||||
|
api_groups_dict[api].append(metric)
|
||||||
|
|
||||||
|
# 对每个 API 分别处理
|
||||||
|
for api_name, api_metrics in api_groups_dict.items():
|
||||||
|
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||||||
if not fields:
|
if not fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||||||
|
|
||||||
|
# 构建 API 参数
|
||||||
|
params = {"ts_code": ts_code, "limit": 5000}
|
||||||
|
|
||||||
|
# 对于需要日期范围的 API(如 stk_holdernumber),添加日期参数
|
||||||
|
if api_name == "stk_holdernumber":
|
||||||
|
# 计算日期范围:从 years 年前到现在
|
||||||
|
end_date = datetime.now().strftime("%Y%m%d")
|
||||||
|
start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d")
|
||||||
|
params["start_date"] = start_date
|
||||||
|
params["end_date"] = end_date
|
||||||
|
# stk_holdernumber 返回的日期字段通常是 end_date
|
||||||
|
date_field = "end_date"
|
||||||
|
|
||||||
|
# 对于非时间序列 API(如 stock_company),标记为静态数据
|
||||||
|
is_static_data = api_name == "stock_company"
|
||||||
|
|
||||||
|
# 构建 fields 字符串:包含日期字段和所有需要的指标字段
|
||||||
|
# 确保日期字段存在,因为我们需要用它来确定年份
|
||||||
|
fields_list = list(fields)
|
||||||
|
if date_field not in fields_list:
|
||||||
|
fields_list.insert(0, date_field)
|
||||||
|
# 对于 fina_indicator 等 API,通常还需要 ts_code 和 ann_date
|
||||||
|
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||||||
|
for req_field in ["ts_code", "ann_date"]:
|
||||||
|
if req_field not in fields_list:
|
||||||
|
fields_list.insert(0, req_field)
|
||||||
|
fields_str = ",".join(fields_list)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_rows = await client.query(api_name=api_name, params={"ts_code": ts_code, "limit": 5000}, fields=None)
|
data_rows = await client.query(api_name=api_name, params=params, fields=fields_str)
|
||||||
api_calls_total += 1
|
api_calls_total += 1
|
||||||
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
|
api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
step.status = "error"
|
# 记录错误但继续处理其他 API
|
||||||
step.error = str(e)
|
error_key = f"{group_name}_{api_name}"
|
||||||
step.end_ts = datetime.now(timezone.utc).isoformat()
|
errors[error_key] = str(e)
|
||||||
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
|
||||||
errors[group_name] = str(e)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tmp: Dict[str, Dict] = {}
|
tmp: Dict[str, Dict] = {}
|
||||||
|
current_year = datetime.now().strftime("%Y")
|
||||||
|
|
||||||
for row in data_rows:
|
for row in data_rows:
|
||||||
|
if is_static_data:
|
||||||
|
# 对于静态数据(如 stock_company),使用当前年份
|
||||||
|
# 只处理第一行数据,因为静态数据通常只有一行
|
||||||
|
if current_year not in tmp:
|
||||||
|
year = current_year
|
||||||
|
month = None
|
||||||
|
tmp[year] = row
|
||||||
|
tmp[year]['_month'] = month
|
||||||
|
# 跳过后续行
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 对于时间序列数据,按日期字段处理
|
||||||
date_val = row.get(date_field)
|
date_val = row.get(date_field)
|
||||||
if not date_val:
|
if not date_val:
|
||||||
continue
|
continue
|
||||||
year = str(date_val)[:4]
|
year = str(date_val)[:4]
|
||||||
|
month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None
|
||||||
existing = tmp.get(year)
|
existing = tmp.get(year)
|
||||||
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
||||||
tmp[year] = row
|
tmp[year] = row
|
||||||
for metric in metrics:
|
tmp[year]['_month'] = month
|
||||||
|
|
||||||
|
for metric in api_metrics:
|
||||||
key = metric.get("tushareParam")
|
key = metric.get("tushareParam")
|
||||||
if not key:
|
if not key:
|
||||||
continue
|
continue
|
||||||
for year, row in tmp.items():
|
for year, row in tmp.items():
|
||||||
_merge_year_value(key, year, row.get(key))
|
month = row.get('_month')
|
||||||
|
_merge_year_value(key, year, row.get(key), month)
|
||||||
|
|
||||||
step.status = "done"
|
step.status = "done"
|
||||||
step.end_ts = datetime.now(timezone.utc).isoformat()
|
step.end_ts = datetime.now(timezone.utc).isoformat()
|
||||||
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000)
|
||||||
@ -247,3 +316,197 @@ async def get_company_profile(
|
|||||||
success=result.get("success", False),
|
success=result.get("success", False),
|
||||||
error=result.get("error")
|
error=result.get("error")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analysis-config", response_model=AnalysisConfigResponse)
|
||||||
|
async def get_analysis_config_endpoint():
|
||||||
|
"""Get analysis configuration"""
|
||||||
|
config = load_analysis_config()
|
||||||
|
return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {}))
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/analysis-config", response_model=AnalysisConfigResponse)
|
||||||
|
async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse):
|
||||||
|
"""Update analysis configuration"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 保存到文件
|
||||||
|
config_data = {
|
||||||
|
"analysis_modules": analysis_config.analysis_modules
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"[API] Analysis config updated successfully")
|
||||||
|
return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[API] Failed to update analysis config: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to update analysis config: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse)
|
||||||
|
async def generate_analysis(
|
||||||
|
ts_code: str,
|
||||||
|
analysis_type: str,
|
||||||
|
company_name: str = Query(None, description="Company name for better context"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate analysis for a company using Gemini AI
|
||||||
|
Supported analysis types:
|
||||||
|
- fundamental_analysis (基本面分析)
|
||||||
|
- bull_case (看涨分析)
|
||||||
|
- bear_case (看跌分析)
|
||||||
|
- market_analysis (市场分析)
|
||||||
|
- news_analysis (新闻分析)
|
||||||
|
- trading_analysis (交易分析)
|
||||||
|
- insider_institutional (内部人与机构动向分析)
|
||||||
|
- final_conclusion (最终结论)
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}")
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
base_cfg = _load_json(BASE_CONFIG_PATH)
|
||||||
|
gemini_cfg = base_cfg.get("llm", {}).get("gemini", {})
|
||||||
|
api_key = gemini_cfg.get("api_key")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
logger.error("[API] Gemini API key not configured")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Gemini API key not configured. Set config.json llm.gemini.api_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get analysis configuration
|
||||||
|
analysis_cfg = get_analysis_config(analysis_type)
|
||||||
|
if not analysis_cfg:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Analysis type '{analysis_type}' not found in configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = analysis_cfg.get("model", "gemini-2.5-flash")
|
||||||
|
prompt_template = analysis_cfg.get("prompt_template", "")
|
||||||
|
|
||||||
|
if not prompt_template:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Prompt template not found for analysis type '{analysis_type}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get company name from ts_code if not provided
|
||||||
|
financial_data = None
|
||||||
|
if not company_name:
|
||||||
|
logger.info(f"[API] Fetching company name and financial data for {ts_code}")
|
||||||
|
try:
|
||||||
|
token = (
|
||||||
|
os.environ.get("TUSHARE_TOKEN")
|
||||||
|
or settings.TUSHARE_TOKEN
|
||||||
|
or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||||||
|
)
|
||||||
|
if token:
|
||||||
|
tushare_client = TushareClient(token=token)
|
||||||
|
basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name")
|
||||||
|
if basic_data and len(basic_data) > 0:
|
||||||
|
company_name = basic_data[0].get("name", ts_code)
|
||||||
|
logger.info(f"[API] Got company name: {company_name}")
|
||||||
|
|
||||||
|
# Try to get financial data for context
|
||||||
|
try:
|
||||||
|
fin_cfg = _load_json(FINANCIAL_CONFIG_PATH)
|
||||||
|
api_groups = fin_cfg.get("api_groups", {})
|
||||||
|
|
||||||
|
# Get financial data summary for context
|
||||||
|
series: Dict[str, List[Dict]] = {}
|
||||||
|
for group_name, metrics in api_groups.items():
|
||||||
|
if not metrics:
|
||||||
|
continue
|
||||||
|
api_groups_dict: Dict[str, List[Dict]] = {}
|
||||||
|
for metric in metrics:
|
||||||
|
api = metric.get("api") or group_name
|
||||||
|
if api:
|
||||||
|
if api not in api_groups_dict:
|
||||||
|
api_groups_dict[api] = []
|
||||||
|
api_groups_dict[api].append(metric)
|
||||||
|
|
||||||
|
for api_name, api_metrics in api_groups_dict.items():
|
||||||
|
fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")]
|
||||||
|
if not fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date"
|
||||||
|
|
||||||
|
params = {"ts_code": ts_code, "limit": 500}
|
||||||
|
fields_list = list(fields)
|
||||||
|
if date_field not in fields_list:
|
||||||
|
fields_list.insert(0, date_field)
|
||||||
|
if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"):
|
||||||
|
for req_field in ["ts_code", "ann_date"]:
|
||||||
|
if req_field not in fields_list:
|
||||||
|
fields_list.insert(0, req_field)
|
||||||
|
fields_str = ",".join(fields_list)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str)
|
||||||
|
if data_rows:
|
||||||
|
# Get latest year's data
|
||||||
|
latest_row = data_rows[0] if data_rows else {}
|
||||||
|
for metric in api_metrics:
|
||||||
|
key = metric.get("tushareParam")
|
||||||
|
if key and key in latest_row:
|
||||||
|
if key not in series:
|
||||||
|
series[key] = []
|
||||||
|
series[key].append({
|
||||||
|
"year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year),
|
||||||
|
"value": latest_row.get(key)
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
financial_data = {"series": series}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[API] Failed to get financial data: {e}")
|
||||||
|
financial_data = None
|
||||||
|
else:
|
||||||
|
company_name = ts_code
|
||||||
|
else:
|
||||||
|
company_name = ts_code
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[API] Failed to get company name: {e}")
|
||||||
|
company_name = ts_code
|
||||||
|
|
||||||
|
logger.info(f"[API] Generating {analysis_type} for {company_name}")
|
||||||
|
|
||||||
|
# Initialize analysis client with configured model
|
||||||
|
client = AnalysisClient(api_key=api_key, model=model)
|
||||||
|
|
||||||
|
# Generate analysis
|
||||||
|
result = await client.generate_analysis(
|
||||||
|
analysis_type=analysis_type,
|
||||||
|
company_name=company_name,
|
||||||
|
ts_code=ts_code,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
financial_data=financial_data
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[API] Analysis generation completed, success={result.get('success')}")
|
||||||
|
|
||||||
|
return AnalysisResponse(
|
||||||
|
ts_code=ts_code,
|
||||||
|
company_name=company_name,
|
||||||
|
analysis_type=analysis_type,
|
||||||
|
content=result.get("content", ""),
|
||||||
|
model=result.get("model", model),
|
||||||
|
tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||||||
|
elapsed_ms=result.get("elapsed_ms", 0),
|
||||||
|
success=result.get("success", False),
|
||||||
|
error=result.get("error")
|
||||||
|
)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from pydantic import BaseModel
|
|||||||
class YearDataPoint(BaseModel):
|
class YearDataPoint(BaseModel):
|
||||||
year: str
|
year: str
|
||||||
value: Optional[float]
|
value: Optional[float]
|
||||||
|
month: Optional[int] = None # 月份信息,用于确定季度
|
||||||
|
|
||||||
|
|
||||||
class StepRecord(BaseModel):
|
class StepRecord(BaseModel):
|
||||||
@ -55,3 +56,19 @@ class CompanyProfileResponse(BaseModel):
|
|||||||
elapsed_ms: int
|
elapsed_ms: int
|
||||||
success: bool = True
|
success: bool = True
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisResponse(BaseModel):
|
||||||
|
ts_code: str
|
||||||
|
company_name: Optional[str] = None
|
||||||
|
analysis_type: str
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
tokens: TokenUsage
|
||||||
|
elapsed_ms: int
|
||||||
|
success: bool = True
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisConfigResponse(BaseModel):
|
||||||
|
analysis_modules: Dict[str, Dict]
|
||||||
|
|||||||
136
backend/app/services/analysis_client.py
Normal file
136
backend/app/services/analysis_client.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Generic Analysis Client for various analysis types using Gemini API
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisClient:
|
||||||
|
"""Generic client for generating various types of analysis using Gemini API"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str = "gemini-2.5-flash"):
|
||||||
|
"""Initialize Gemini client with API key and model"""
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
self.model_name = model
|
||||||
|
self.model = genai.GenerativeModel(model)
|
||||||
|
|
||||||
|
async def generate_analysis(
|
||||||
|
self,
|
||||||
|
analysis_type: str,
|
||||||
|
company_name: str,
|
||||||
|
ts_code: str,
|
||||||
|
prompt_template: str,
|
||||||
|
financial_data: Optional[Dict] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Generate analysis using Gemini API (non-streaming)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analysis_type: Type of analysis (e.g., "fundamental_analysis")
|
||||||
|
company_name: Company name
|
||||||
|
ts_code: Stock code
|
||||||
|
prompt_template: Prompt template with placeholders {company_name}, {ts_code}, {financial_data}
|
||||||
|
financial_data: Optional financial data for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with analysis content and metadata
|
||||||
|
"""
|
||||||
|
start_time = time.perf_counter_ns()
|
||||||
|
|
||||||
|
# Build prompt from template
|
||||||
|
prompt = self._build_prompt(
|
||||||
|
prompt_template,
|
||||||
|
company_name,
|
||||||
|
ts_code,
|
||||||
|
financial_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call Gemini API (using sync API in async context)
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.model.generate_content(prompt)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get token usage
|
||||||
|
usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None
|
||||||
|
|
||||||
|
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": response.text,
|
||||||
|
"model": self.model_name,
|
||||||
|
"tokens": {
|
||||||
|
"prompt_tokens": usage_metadata.prompt_token_count if usage_metadata else 0,
|
||||||
|
"completion_tokens": usage_metadata.candidates_token_count if usage_metadata else 0,
|
||||||
|
"total_tokens": usage_metadata.total_token_count if usage_metadata else 0,
|
||||||
|
} if usage_metadata else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"success": True,
|
||||||
|
"analysis_type": analysis_type,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000)
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"model": self.model_name,
|
||||||
|
"tokens": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
"elapsed_ms": elapsed_ms,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"analysis_type": analysis_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_prompt(
|
||||||
|
self,
|
||||||
|
prompt_template: str,
|
||||||
|
company_name: str,
|
||||||
|
ts_code: str,
|
||||||
|
financial_data: Optional[Dict] = None
|
||||||
|
) -> str:
|
||||||
|
"""Build prompt from template by replacing placeholders"""
|
||||||
|
# Format financial data as string if provided
|
||||||
|
financial_data_str = ""
|
||||||
|
if financial_data:
|
||||||
|
try:
|
||||||
|
financial_data_str = json.dumps(financial_data, ensure_ascii=False, indent=2)
|
||||||
|
except Exception:
|
||||||
|
financial_data_str = str(financial_data)
|
||||||
|
|
||||||
|
# Replace placeholders in template
|
||||||
|
prompt = prompt_template.format(
|
||||||
|
company_name=company_name,
|
||||||
|
ts_code=ts_code,
|
||||||
|
financial_data=financial_data_str
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def load_analysis_config() -> Dict:
|
||||||
|
"""Load analysis configuration from JSON file"""
|
||||||
|
# Get project root: backend/app/services -> project_root/config/analysis-config.json
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
config_path = os.path.join(project_root, "config", "analysis-config.json")
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_analysis_config(analysis_type: str) -> Optional[Dict]:
|
||||||
|
"""Get configuration for a specific analysis type"""
|
||||||
|
config = load_analysis_config()
|
||||||
|
modules = config.get("analysis_modules", {})
|
||||||
|
return modules.get(analysis_type)
|
||||||
|
|
||||||
@ -3,13 +3,16 @@ Configuration Management Service
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
import httpx
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
from app.models.system_config import SystemConfig
|
from app.models.system_config import SystemConfig
|
||||||
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, GeminiConfig, DataSourceConfig
|
from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, GeminiConfig, DataSourceConfig, ConfigTestResponse
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
"""Manages system configuration by merging a static JSON file with dynamic settings from the database."""
|
"""Manages system configuration by merging a static JSON file with dynamic settings from the database."""
|
||||||
@ -17,8 +20,10 @@ class ConfigManager:
|
|||||||
def __init__(self, db_session: AsyncSession, config_path: str = None):
|
def __init__(self, db_session: AsyncSession, config_path: str = None):
|
||||||
self.db = db_session
|
self.db = db_session
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
# Default path: backend/ -> project_root/ -> config/config.json
|
# Default path: backend/app/services -> project_root/config/config.json
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
# __file__ = backend/app/services/config_manager.py
|
||||||
|
# go up three levels to project root
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
self.config_path = os.path.join(project_root, "config", "config.json")
|
self.config_path = os.path.join(project_root, "config", "config.json")
|
||||||
else:
|
else:
|
||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
@ -34,12 +39,19 @@ class ConfigManager:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _load_dynamic_config_from_db(self) -> Dict[str, Any]:
|
async def _load_dynamic_config_from_db(self) -> Dict[str, Any]:
|
||||||
"""Loads dynamic configuration overrides from the database."""
|
"""Loads dynamic configuration overrides from the database.
|
||||||
db_configs = {}
|
|
||||||
|
当数据库表尚未创建(如开发环境未运行迁移)时,优雅降级为返回空覆盖配置,避免接口 500。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_configs: Dict[str, Any] = {}
|
||||||
result = await self.db.execute(select(SystemConfig))
|
result = await self.db.execute(select(SystemConfig))
|
||||||
for record in result.scalars().all():
|
for record in result.scalars().all():
|
||||||
db_configs[record.config_key] = record.config_value
|
db_configs[record.config_key] = record.config_value
|
||||||
return db_configs
|
return db_configs
|
||||||
|
except Exception:
|
||||||
|
# 表不存在或其他数据库错误时,忽略动态配置覆盖
|
||||||
|
return {}
|
||||||
|
|
||||||
def _merge_configs(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
def _merge_configs(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Deeply merges the override config into the base config."""
|
"""Deeply merges the override config into the base config."""
|
||||||
@ -57,9 +69,12 @@ class ConfigManager:
|
|||||||
|
|
||||||
merged_config = self._merge_configs(base_config, db_config)
|
merged_config = self._merge_configs(base_config, db_config)
|
||||||
|
|
||||||
|
# 兼容两种位置:优先使用 gemini_api,其次回退到 llm.gemini
|
||||||
|
gemini_src = merged_config.get("gemini_api") or merged_config.get("llm", {}).get("gemini", {})
|
||||||
|
|
||||||
return ConfigResponse(
|
return ConfigResponse(
|
||||||
database=DatabaseConfig(**merged_config.get("database", {})),
|
database=DatabaseConfig(**merged_config.get("database", {})),
|
||||||
gemini_api=GeminiConfig(**merged_config.get("llm", {}).get("gemini", {})),
|
gemini_api=GeminiConfig(**(gemini_src or {})),
|
||||||
data_sources={
|
data_sources={
|
||||||
k: DataSourceConfig(**v)
|
k: DataSourceConfig(**v)
|
||||||
for k, v in merged_config.get("data_sources", {}).items()
|
for k, v in merged_config.get("data_sources", {}).items()
|
||||||
@ -68,8 +83,12 @@ class ConfigManager:
|
|||||||
|
|
||||||
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
||||||
"""Updates configuration in the database and returns the new merged config."""
|
"""Updates configuration in the database and returns the new merged config."""
|
||||||
|
try:
|
||||||
update_dict = config_update.dict(exclude_unset=True)
|
update_dict = config_update.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
# 验证配置数据
|
||||||
|
self._validate_config_data(update_dict)
|
||||||
|
|
||||||
for key, value in update_dict.items():
|
for key, value in update_dict.items():
|
||||||
existing_config = await self.db.get(SystemConfig, key)
|
existing_config = await self.db.get(SystemConfig, key)
|
||||||
if existing_config:
|
if existing_config:
|
||||||
@ -85,3 +104,201 @@ class ConfigManager:
|
|||||||
|
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
return await self.get_config()
|
return await self.get_config()
|
||||||
|
except Exception as e:
|
||||||
|
await self.db.rollback()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _validate_config_data(self, config_data: Dict[str, Any]) -> None:
|
||||||
|
"""Validate configuration data before saving."""
|
||||||
|
if "database" in config_data:
|
||||||
|
db_config = config_data["database"]
|
||||||
|
if "url" in db_config:
|
||||||
|
url = db_config["url"]
|
||||||
|
if not url.startswith(("postgresql://", "postgresql+asyncpg://")):
|
||||||
|
raise ValueError("数据库URL必须以 postgresql:// 或 postgresql+asyncpg:// 开头")
|
||||||
|
|
||||||
|
if "gemini_api" in config_data:
|
||||||
|
gemini_config = config_data["gemini_api"]
|
||||||
|
if "api_key" in gemini_config and len(gemini_config["api_key"]) < 10:
|
||||||
|
raise ValueError("Gemini API Key长度不能少于10个字符")
|
||||||
|
if "base_url" in gemini_config and gemini_config["base_url"]:
|
||||||
|
base_url = gemini_config["base_url"]
|
||||||
|
if not base_url.startswith(("http://", "https://")):
|
||||||
|
raise ValueError("Gemini Base URL必须以 http:// 或 https:// 开头")
|
||||||
|
|
||||||
|
if "data_sources" in config_data:
|
||||||
|
for source_name, source_config in config_data["data_sources"].items():
|
||||||
|
if "api_key" in source_config and len(source_config["api_key"]) < 10:
|
||||||
|
raise ValueError(f"{source_name} API Key长度不能少于10个字符")
|
||||||
|
|
||||||
|
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||||
|
"""Test a specific configuration."""
|
||||||
|
try:
|
||||||
|
if config_type == "database":
|
||||||
|
return await self._test_database(config_data)
|
||||||
|
elif config_type == "gemini":
|
||||||
|
return await self._test_gemini(config_data)
|
||||||
|
elif config_type == "tushare":
|
||||||
|
return await self._test_tushare(config_data)
|
||||||
|
elif config_type == "finnhub":
|
||||||
|
return await self._test_finnhub(config_data)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"不支持的配置类型: {config_type}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _test_database(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||||
|
"""Test database connection."""
|
||||||
|
db_url = config_data.get("url")
|
||||||
|
if not db_url:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="数据库URL不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析数据库URL
|
||||||
|
if db_url.startswith("postgresql+asyncpg://"):
|
||||||
|
db_url = db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
|
||||||
|
# 测试连接
|
||||||
|
conn = await asyncpg.connect(db_url)
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=True,
|
||||||
|
message="数据库连接成功"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"数据库连接失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _test_gemini(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||||
|
"""Test Gemini API connection."""
|
||||||
|
api_key = config_data.get("api_key")
|
||||||
|
base_url = config_data.get("base_url", "https://generativelanguage.googleapis.com/v1beta")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="Gemini API Key不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# 测试API可用性
|
||||||
|
response = await client.get(
|
||||||
|
f"{base_url}/models",
|
||||||
|
headers={"x-goog-api-key": api_key}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=True,
|
||||||
|
message="Gemini API连接成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Gemini API测试失败: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Gemini API连接失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _test_tushare(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||||
|
"""Test Tushare API connection."""
|
||||||
|
api_key = config_data.get("api_key")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="Tushare API Key不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# 测试API可用性
|
||||||
|
response = await client.post(
|
||||||
|
"http://api.tushare.pro",
|
||||||
|
json={
|
||||||
|
"api_name": "stock_basic",
|
||||||
|
"token": api_key,
|
||||||
|
"params": {"list_status": "L"},
|
||||||
|
"fields": "ts_code"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if data.get("code") == 0:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=True,
|
||||||
|
message="Tushare API连接成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Tushare API错误: {data.get('msg', '未知错误')}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Tushare API测试失败: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Tushare API连接失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _test_finnhub(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||||
|
"""Test Finnhub API connection."""
|
||||||
|
api_key = config_data.get("api_key")
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="Finnhub API Key不能为空"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# 测试API可用性
|
||||||
|
response = await client.get(
|
||||||
|
f"https://finnhub.io/api/v1/quote",
|
||||||
|
params={"symbol": "AAPL", "token": api_key}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "c" in data: # 检查是否有价格数据
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=True,
|
||||||
|
message="Finnhub API连接成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message="Finnhub API响应格式错误"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Finnhub API测试失败: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return ConfigTestResponse(
|
||||||
|
success=False,
|
||||||
|
message=f"Finnhub API连接失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@ -6,3 +6,4 @@ SQLAlchemy==2.0.36
|
|||||||
aiosqlite==0.20.0
|
aiosqlite==0.20.0
|
||||||
alembic==1.13.3
|
alembic==1.13.3
|
||||||
google-generativeai==0.8.3
|
google-generativeai==0.8.3
|
||||||
|
asyncpg==0.29.0
|
||||||
|
|||||||
49
config/analysis-config.json
Normal file
49
config/analysis-config.json
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"analysis_modules": {
|
||||||
|
"company_profile": {
|
||||||
|
"name": "公司简介",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券市场分析师。请为公司 {company_name} (股票代码: {ts_code}) 生成一份详细且专业的公司介绍。开头不要自我介绍,直接开始正文。正文用MarkDown输出,尽量说明信息来源,用斜体显示信息来源。在生成内容时,请严格遵循以下要求并采用清晰、结构化的格式:\n\n1. **公司概览**:\n * 简要介绍公司的性质、核心业务领域及其在行业中的定位。\n * 提炼并阐述公司的核心价值理念。\n\n2. **主营业务**:\n * 详细描述公司主要的**产品或服务**。\n * **重要提示**:如果能获取到公司最新的官方**年报**或**财务报告**,请从中提取各主要产品/服务线的**收入金额**和其占公司总收入的**百分比**。请**明确标注数据来源**(例如:\"数据来源于XX年年度报告\")。\n * **严格禁止**编造或估算任何财务数据。若无法找到公开、准确的财务数据,请**不要**在这一点中提及具体金额或比例,仅描述业务内容。\n\n3. **发展历程**:\n * 以时间线或关键事件的形式,概述公司自成立以来的主要**里程碑事件**、重大发展阶段、战略转型或重要成就。\n\n4. **核心团队**:\n * 介绍公司**主要管理层和核心技术团队成员**。\n * 对于每位核心成员,提供其**职务、主要工作履历、教育背景**。\n * 如果公开可查,可补充其**出生年份**。\n\n5. **供应链**:\n * 描述公司的**主要原材料、部件或服务来源**。\n * 如果公开信息中包含,请列出**主要供应商名称**,并**明确其在总采购金额中的大致占比**。若无此数据,则仅描述采购模式。\n\n6. **主要客户及销售模式**:\n * 阐明公司的**销售模式**(例如:直销、经销、线上销售、代理等)。\n * 列出公司的**主要客户群体**或**代表性大客户**。\n * 如果公开信息中包含,请标明**主要客户(或前五大客户)的销售额占公司总销售额的比例**。若无此数据,则仅描述客户类型。\n\n7. **未来展望**:\n * 基于公司**公开的官方声明、管理层访谈或战略规划**,总结公司未来的发展方向、战略目标、重点项目或市场预期。请确保此部分内容有可靠的信息来源支持。"
|
||||||
|
},
|
||||||
|
"fundamental_analysis": {
|
||||||
|
"name": "基本面分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"bull_case": {
|
||||||
|
"name": "看涨分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"bear_case": {
|
||||||
|
"name": "看跌分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"market_analysis": {
|
||||||
|
"name": "市场分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"news_analysis": {
|
||||||
|
"name": "新闻分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"trading_analysis": {
|
||||||
|
"name": "交易分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"insider_institutional": {
|
||||||
|
"name": "内部人与机构动向分析",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
},
|
||||||
|
"final_conclusion": {
|
||||||
|
"name": "最终结论",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -41,7 +41,8 @@
|
|||||||
"cashflow": [
|
"cashflow": [
|
||||||
{ "displayText": "经营净现金流", "tushareParam": "n_cashflow_act", "api": "cashflow" },
|
{ "displayText": "经营净现金流", "tushareParam": "n_cashflow_act", "api": "cashflow" },
|
||||||
{ "displayText": "资本开支", "tushareParam": "c_pay_acq_const_fiolta", "api": "cashflow" },
|
{ "displayText": "资本开支", "tushareParam": "c_pay_acq_const_fiolta", "api": "cashflow" },
|
||||||
{ "displayText": "折旧费用", "tushareParam": "depr_fa_coga_dpba", "api": "cashflow" }
|
{ "displayText": "折旧费用", "tushareParam": "depr_fa_coga_dpba", "api": "cashflow" },
|
||||||
|
{ "displayText": "支付给职工以及为职工支付的现金", "tushareParam": "c_paid_to_for_empl", "api": "cashflow" }
|
||||||
],
|
],
|
||||||
"daily_basic": [
|
"daily_basic": [
|
||||||
{ "displayText": "PB", "tushareParam": "pb", "api": "daily_basic" },
|
{ "displayText": "PB", "tushareParam": "pb", "api": "daily_basic" },
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useState, useEffect } from 'react';
|
import { useState, useEffect } from 'react';
|
||||||
import { useConfig, updateConfig, testConfig } from '@/hooks/useApi';
|
import { useConfig, updateConfig, testConfig, useAnalysisConfig, updateAnalysisConfig } from '@/hooks/useApi';
|
||||||
import { useConfigStore, SystemConfig } from '@/stores/useConfigStore';
|
import { useConfigStore, SystemConfig } from '@/stores/useConfigStore';
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Textarea } from "@/components/ui/textarea";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||||
|
import { Label } from "@/components/ui/label";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import type { AnalysisConfigResponse } from '@/types';
|
||||||
|
|
||||||
export default function ConfigPage() {
|
export default function ConfigPage() {
|
||||||
// 从 Zustand store 获取全局状态
|
// 从 Zustand store 获取全局状态
|
||||||
@ -14,148 +19,571 @@ export default function ConfigPage() {
|
|||||||
// 使用 SWR hook 加载初始配置
|
// 使用 SWR hook 加载初始配置
|
||||||
useConfig();
|
useConfig();
|
||||||
|
|
||||||
|
// 加载分析配置
|
||||||
|
const { data: analysisConfig, mutate: mutateAnalysisConfig } = useAnalysisConfig();
|
||||||
|
|
||||||
// 本地表单状态
|
// 本地表单状态
|
||||||
const [dbUrl, setDbUrl] = useState('');
|
const [dbUrl, setDbUrl] = useState('');
|
||||||
const [geminiApiKey, setGeminiApiKey] = useState('');
|
const [geminiApiKey, setGeminiApiKey] = useState('');
|
||||||
|
const [geminiBaseUrl, setGeminiBaseUrl] = useState('');
|
||||||
const [tushareApiKey, setTushareApiKey] = useState('');
|
const [tushareApiKey, setTushareApiKey] = useState('');
|
||||||
|
const [finnhubApiKey, setFinnhubApiKey] = useState('');
|
||||||
|
|
||||||
|
// 分析配置的本地状态
|
||||||
|
const [localAnalysisConfig, setLocalAnalysisConfig] = useState<Record<string, {
|
||||||
|
name: string;
|
||||||
|
model: string;
|
||||||
|
prompt_template: string;
|
||||||
|
}>>({});
|
||||||
|
|
||||||
|
// 分析配置保存状态
|
||||||
|
const [savingAnalysis, setSavingAnalysis] = useState(false);
|
||||||
|
const [analysisSaveMessage, setAnalysisSaveMessage] = useState('');
|
||||||
|
|
||||||
// 测试结果状态
|
// 测试结果状态
|
||||||
const [dbTestResult, setDbTestResult] = useState<{ success: boolean; message: string } | null>(null);
|
const [testResults, setTestResults] = useState<Record<string, { success: boolean; message: string } | null>>({});
|
||||||
const [geminiTestResult, setGeminiTestResult] = useState<{ success: boolean; message: string } | null>(null);
|
|
||||||
|
|
||||||
// 保存状态
|
// 保存状态
|
||||||
const [saving, setSaving] = useState(false);
|
const [saving, setSaving] = useState(false);
|
||||||
const [saveMessage, setSaveMessage] = useState('');
|
const [saveMessage, setSaveMessage] = useState('');
|
||||||
|
|
||||||
|
// 初始化分析配置的本地状态
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (config) {
|
if (analysisConfig?.analysis_modules) {
|
||||||
setDbUrl(config.database?.url || '');
|
setLocalAnalysisConfig(analysisConfig.analysis_modules);
|
||||||
// API Keys 不回显
|
|
||||||
}
|
}
|
||||||
}, [config]);
|
}, [analysisConfig]);
|
||||||
|
|
||||||
|
// 更新分析配置中的某个字段
|
||||||
|
const updateAnalysisField = (type: string, field: 'name' | 'model' | 'prompt_template', value: string) => {
|
||||||
|
setLocalAnalysisConfig(prev => ({
|
||||||
|
...prev,
|
||||||
|
[type]: {
|
||||||
|
...prev[type],
|
||||||
|
[field]: value
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
// 保存分析配置
|
||||||
|
const handleSaveAnalysisConfig = async () => {
|
||||||
|
setSavingAnalysis(true);
|
||||||
|
setAnalysisSaveMessage('保存中...');
|
||||||
|
|
||||||
|
try {
|
||||||
|
const updated = await updateAnalysisConfig({
|
||||||
|
analysis_modules: localAnalysisConfig
|
||||||
|
});
|
||||||
|
await mutateAnalysisConfig(updated);
|
||||||
|
setAnalysisSaveMessage('保存成功!');
|
||||||
|
} catch (e: any) {
|
||||||
|
setAnalysisSaveMessage(`保存失败: ${e.message}`);
|
||||||
|
} finally {
|
||||||
|
setSavingAnalysis(false);
|
||||||
|
setTimeout(() => setAnalysisSaveMessage(''), 5000);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const validateConfig = () => {
|
||||||
|
const errors: string[] = [];
|
||||||
|
|
||||||
|
// 验证数据库URL格式
|
||||||
|
if (dbUrl && !dbUrl.match(/^postgresql(\+asyncpg)?:\/\/.+/)) {
|
||||||
|
errors.push('数据库URL格式不正确,应为 postgresql://user:pass@host:port/dbname');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Gemini Base URL格式
|
||||||
|
if (geminiBaseUrl && !geminiBaseUrl.match(/^https?:\/\/.+/)) {
|
||||||
|
errors.push('Gemini Base URL格式不正确,应为 http:// 或 https:// 开头');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证API Key长度(基本检查)
|
||||||
|
if (geminiApiKey && geminiApiKey.length < 10) {
|
||||||
|
errors.push('Gemini API Key长度过短');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tushareApiKey && tushareApiKey.length < 10) {
|
||||||
|
errors.push('Tushare API Key长度过短');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (finnhubApiKey && finnhubApiKey.length < 10) {
|
||||||
|
errors.push('Finnhub API Key长度过短');
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors;
|
||||||
|
};
|
||||||
|
|
||||||
const handleSave = async () => {
|
const handleSave = async () => {
|
||||||
|
// 验证配置
|
||||||
|
const validationErrors = validateConfig();
|
||||||
|
if (validationErrors.length > 0) {
|
||||||
|
setSaveMessage(`配置验证失败: ${validationErrors.join(', ')}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setSaving(true);
|
setSaving(true);
|
||||||
setSaveMessage('保存中...');
|
setSaveMessage('保存中...');
|
||||||
|
|
||||||
const newConfig: Partial<SystemConfig> = {
|
const newConfig: Partial<SystemConfig> = {};
|
||||||
database: { url: dbUrl },
|
|
||||||
gemini_api: { api_key: geminiApiKey },
|
// 只更新有值的字段
|
||||||
data_sources: {
|
if (dbUrl) {
|
||||||
tushare: { api_key: tushareApiKey },
|
newConfig.database = { url: dbUrl };
|
||||||
},
|
}
|
||||||
|
|
||||||
|
if (geminiApiKey || geminiBaseUrl) {
|
||||||
|
newConfig.gemini_api = {
|
||||||
|
api_key: geminiApiKey || config?.gemini_api?.api_key || '',
|
||||||
|
base_url: geminiBaseUrl || config?.gemini_api?.base_url || undefined,
|
||||||
};
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tushareApiKey || finnhubApiKey) {
|
||||||
|
newConfig.data_sources = {
|
||||||
|
...config?.data_sources,
|
||||||
|
...(tushareApiKey && { tushare: { api_key: tushareApiKey } }),
|
||||||
|
...(finnhubApiKey && { finnhub: { api_key: finnhubApiKey } }),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const updated = await updateConfig(newConfig);
|
const updated = await updateConfig(newConfig);
|
||||||
setConfig(updated); // 更新全局状态
|
setConfig(updated); // 更新全局状态
|
||||||
setSaveMessage('保存成功!');
|
setSaveMessage('保存成功!');
|
||||||
setGeminiApiKey(''); // 清空敏感字段输入
|
// 清空敏感字段输入
|
||||||
|
setGeminiApiKey('');
|
||||||
setTushareApiKey('');
|
setTushareApiKey('');
|
||||||
|
setFinnhubApiKey('');
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
setSaveMessage(`保存失败: ${e.message}`);
|
setSaveMessage(`保存失败: ${e.message}`);
|
||||||
} finally {
|
} finally {
|
||||||
setSaving(false);
|
setSaving(false);
|
||||||
setTimeout(() => setSaveMessage(''), 3000);
|
setTimeout(() => setSaveMessage(''), 5000);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleTestDb = async () => {
|
const handleTest = async (type: string, data: any) => {
|
||||||
const result = await testConfig('database', { url: dbUrl });
|
try {
|
||||||
setDbTestResult(result);
|
const result = await testConfig(type, data);
|
||||||
|
setTestResults(prev => ({ ...prev, [type]: result }));
|
||||||
|
} catch (e: any) {
|
||||||
|
setTestResults(prev => ({
|
||||||
|
...prev,
|
||||||
|
[type]: { success: false, message: e.message }
|
||||||
|
}));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleTestGemini = async () => {
|
const handleTestDb = () => {
|
||||||
const result = await testConfig('gemini', { api_key: geminiApiKey || config?.gemini_api.api_key });
|
handleTest('database', { url: dbUrl });
|
||||||
setGeminiTestResult(result);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (loading) return <div>Loading...</div>;
|
const handleTestGemini = () => {
|
||||||
if (error) return <div>Error loading config: {error}</div>;
|
handleTest('gemini', {
|
||||||
|
api_key: geminiApiKey || config?.gemini_api?.api_key,
|
||||||
|
base_url: geminiBaseUrl || config?.gemini_api?.base_url
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTestTushare = () => {
|
||||||
|
handleTest('tushare', { api_key: tushareApiKey || config?.data_sources?.tushare?.api_key });
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleTestFinnhub = () => {
|
||||||
|
handleTest('finnhub', { api_key: finnhubApiKey || config?.data_sources?.finnhub?.api_key });
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleReset = () => {
|
||||||
|
setDbUrl('');
|
||||||
|
setGeminiApiKey('');
|
||||||
|
setGeminiBaseUrl('');
|
||||||
|
setTushareApiKey('');
|
||||||
|
setFinnhubApiKey('');
|
||||||
|
setTestResults({});
|
||||||
|
setSaveMessage('');
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleExportConfig = () => {
|
||||||
|
if (!config) return;
|
||||||
|
|
||||||
|
const configToExport = {
|
||||||
|
database: config.database,
|
||||||
|
gemini_api: config.gemini_api,
|
||||||
|
data_sources: config.data_sources,
|
||||||
|
export_time: new Date().toISOString(),
|
||||||
|
version: "1.0"
|
||||||
|
};
|
||||||
|
|
||||||
|
const blob = new Blob([JSON.stringify(configToExport, null, 2)], { type: 'application/json' });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `config-backup-${new Date().toISOString().split('T')[0]}.json`;
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleImportConfig = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
const file = event.target.files?.[0];
|
||||||
|
if (!file) return;
|
||||||
|
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (e) => {
|
||||||
|
try {
|
||||||
|
const importedConfig = JSON.parse(e.target?.result as string);
|
||||||
|
|
||||||
|
// 验证导入的配置格式
|
||||||
|
if (importedConfig.database?.url) {
|
||||||
|
setDbUrl(importedConfig.database.url);
|
||||||
|
}
|
||||||
|
if (importedConfig.gemini_api?.base_url) {
|
||||||
|
setGeminiBaseUrl(importedConfig.gemini_api.base_url);
|
||||||
|
}
|
||||||
|
|
||||||
|
setSaveMessage('配置导入成功,请检查并保存');
|
||||||
|
} catch (error) {
|
||||||
|
setSaveMessage('配置文件格式错误,导入失败');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (loading) return (
|
||||||
|
<div className="flex items-center justify-center h-64">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-gray-900 mx-auto"></div>
|
||||||
|
<p className="mt-2 text-sm text-muted-foreground">加载配置中...</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
if (error) return (
|
||||||
|
<div className="flex items-center justify-center h-64">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-red-500 text-lg mb-2">⚠️</div>
|
||||||
|
<p className="text-red-600">加载配置失败: {error}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="container mx-auto py-6 space-y-6">
|
||||||
<header className="space-y-2">
|
<header className="space-y-2">
|
||||||
<h1 className="text-2xl font-semibold">配置中心</h1>
|
<h1 className="text-3xl font-bold">配置中心</h1>
|
||||||
<p className="text-sm text-muted-foreground">
|
<p className="text-muted-foreground">
|
||||||
管理系统配置,包括数据库、API密钥等。敏感密钥不回显,留空表示保持现值。
|
管理系统配置,包括数据库连接、API密钥等。敏感信息不回显,留空表示保持当前值。
|
||||||
</p>
|
</p>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
|
<Tabs defaultValue="database" className="space-y-6">
|
||||||
|
<TabsList className="grid w-full grid-cols-5">
|
||||||
|
<TabsTrigger value="database">数据库</TabsTrigger>
|
||||||
|
<TabsTrigger value="ai">AI服务</TabsTrigger>
|
||||||
|
<TabsTrigger value="data-sources">数据源</TabsTrigger>
|
||||||
|
<TabsTrigger value="analysis">分析配置</TabsTrigger>
|
||||||
|
<TabsTrigger value="system">系统</TabsTrigger>
|
||||||
|
</TabsList>
|
||||||
|
|
||||||
|
<TabsContent value="database" className="space-y-4">
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>数据库配置</CardTitle>
|
<CardTitle>数据库配置</CardTitle>
|
||||||
<CardDescription>PostgreSQL 连接设置</CardDescription>
|
<CardDescription>PostgreSQL 数据库连接设置</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="space-y-4">
|
<CardContent className="space-y-4">
|
||||||
<div className="flex items-center gap-4">
|
<div className="space-y-2">
|
||||||
<label className="w-28">连接URL</label>
|
<Label htmlFor="db-url">数据库连接URL</Label>
|
||||||
|
<div className="flex gap-2">
|
||||||
<Input
|
<Input
|
||||||
|
id="db-url"
|
||||||
type="text"
|
type="text"
|
||||||
value={dbUrl}
|
value={dbUrl}
|
||||||
onChange={(e) => setDbUrl(e.target.value)}
|
onChange={(e) => setDbUrl(e.target.value)}
|
||||||
placeholder="postgresql+asyncpg://user:pass@host:port/dbname"
|
placeholder="postgresql+asyncpg://user:password@host:port/database"
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
/>
|
/>
|
||||||
<Button onClick={handleTestDb}>测试连接</Button>
|
<Button onClick={handleTestDb} variant="outline">
|
||||||
|
测试连接
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
{dbTestResult && (
|
{testResults.database && (
|
||||||
<Badge variant={dbTestResult.success ? 'secondary' : 'destructive'}>
|
<Badge variant={testResults.database.success ? 'default' : 'destructive'}>
|
||||||
{dbTestResult.message}
|
{testResults.database.message}
|
||||||
</Badge>
|
</Badge>
|
||||||
)}
|
)}
|
||||||
|
</div>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
|
<TabsContent value="ai" className="space-y-4">
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>AI 服务配置</CardTitle>
|
<CardTitle>AI 服务配置</CardTitle>
|
||||||
<CardDescription>Google Gemini API 设置</CardDescription>
|
<CardDescription>Google Gemini API 设置</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="space-y-4">
|
<CardContent className="space-y-4">
|
||||||
<div className="flex items-center gap-4">
|
<div className="space-y-2">
|
||||||
<label className="w-28">API Key</label>
|
<Label htmlFor="gemini-api-key">API Key</Label>
|
||||||
|
<div className="flex gap-2">
|
||||||
<Input
|
<Input
|
||||||
|
id="gemini-api-key"
|
||||||
type="password"
|
type="password"
|
||||||
value={geminiApiKey}
|
value={geminiApiKey}
|
||||||
onChange={(e) => setGeminiApiKey(e.target.value)}
|
onChange={(e) => setGeminiApiKey(e.target.value)}
|
||||||
placeholder="留空表示保持现值"
|
placeholder="留空表示保持当前值"
|
||||||
className="flex-1"
|
className="flex-1"
|
||||||
/>
|
/>
|
||||||
<Button onClick={handleTestGemini}>测试</Button>
|
<Button onClick={handleTestGemini} variant="outline">
|
||||||
|
测试
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
{geminiTestResult && (
|
{testResults.gemini && (
|
||||||
<Badge variant={geminiTestResult.success ? 'secondary' : 'destructive'}>
|
<Badge variant={testResults.gemini.success ? 'default' : 'destructive'}>
|
||||||
{geminiTestResult.message}
|
{testResults.gemini.message}
|
||||||
</Badge>
|
</Badge>
|
||||||
)}
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor="gemini-base-url">Base URL (可选)</Label>
|
||||||
|
<Input
|
||||||
|
id="gemini-base-url"
|
||||||
|
type="text"
|
||||||
|
value={geminiBaseUrl}
|
||||||
|
onChange={(e) => setGeminiBaseUrl(e.target.value)}
|
||||||
|
placeholder="https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
|
<TabsContent value="data-sources" className="space-y-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>数据源配置</CardTitle>
|
||||||
|
<CardDescription>外部数据源 API 设置</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-6">
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<Label className="text-base font-medium">Tushare</Label>
|
||||||
|
<p className="text-sm text-muted-foreground mb-2">中国股票数据源</p>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
value={tushareApiKey}
|
||||||
|
onChange={(e) => setTushareApiKey(e.target.value)}
|
||||||
|
placeholder="留空表示保持当前值"
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button onClick={handleTestTushare} variant="outline">
|
||||||
|
测试
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{testResults.tushare && (
|
||||||
|
<Badge variant={testResults.tushare.success ? 'default' : 'destructive'} className="mt-2">
|
||||||
|
{testResults.tushare.message}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label className="text-base font-medium">Finnhub</Label>
|
||||||
|
<p className="text-sm text-muted-foreground mb-2">全球金融市场数据源</p>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
value={finnhubApiKey}
|
||||||
|
onChange={(e) => setFinnhubApiKey(e.target.value)}
|
||||||
|
placeholder="留空表示保持当前值"
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button onClick={handleTestFinnhub} variant="outline">
|
||||||
|
测试
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{testResults.finnhub && (
|
||||||
|
<Badge variant={testResults.finnhub.success ? 'default' : 'destructive'} className="mt-2">
|
||||||
|
{testResults.finnhub.message}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
|
<TabsContent value="analysis" className="space-y-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>分析模块配置</CardTitle>
|
||||||
|
<CardDescription>配置各个分析模块的模型和提示词</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-6">
|
||||||
|
{Object.entries(localAnalysisConfig).map(([type, config]) => (
|
||||||
|
<div key={type} className="space-y-4 p-4 border rounded-lg">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h3 className="text-lg font-semibold">{config.name || type}</h3>
|
||||||
|
<Badge variant="secondary">{type}</Badge>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor={`${type}-name`}>显示名称</Label>
|
||||||
|
<Input
|
||||||
|
id={`${type}-name`}
|
||||||
|
value={config.name || ''}
|
||||||
|
onChange={(e) => updateAnalysisField(type, 'name', e.target.value)}
|
||||||
|
placeholder="分析模块显示名称"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor={`${type}-model`}>模型名称</Label>
|
||||||
|
<Input
|
||||||
|
id={`${type}-model`}
|
||||||
|
value={config.model || ''}
|
||||||
|
onChange={(e) => updateAnalysisField(type, 'model', e.target.value)}
|
||||||
|
placeholder="例如: gemini-2.5-flash"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
使用的Gemini模型名称
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor={`${type}-prompt`}>提示词模板</Label>
|
||||||
|
<Textarea
|
||||||
|
id={`${type}-prompt`}
|
||||||
|
value={config.prompt_template || ''}
|
||||||
|
onChange={(e) => updateAnalysisField(type, 'prompt_template', e.target.value)}
|
||||||
|
placeholder="提示词模板,支持 {company_name}, {ts_code}, {financial_data} 占位符"
|
||||||
|
rows={10}
|
||||||
|
className="font-mono text-sm"
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
提示词模板,可以使用占位符: {`{company_name}`}, {`{ts_code}`}, {`{financial_data}`}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<div className="flex items-center gap-4 pt-4">
|
||||||
|
<Button
|
||||||
|
onClick={handleSaveAnalysisConfig}
|
||||||
|
disabled={savingAnalysis}
|
||||||
|
size="lg"
|
||||||
|
>
|
||||||
|
{savingAnalysis ? '保存中...' : '保存分析配置'}
|
||||||
|
</Button>
|
||||||
|
{analysisSaveMessage && (
|
||||||
|
<span className={`text-sm ${analysisSaveMessage.includes('成功') ? 'text-green-600' : 'text-red-600'}`}>
|
||||||
|
{analysisSaveMessage}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
|
<TabsContent value="system" className="space-y-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle>系统信息</CardTitle>
|
||||||
|
<CardDescription>当前系统状态和配置概览</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-4">
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>数据库状态</Label>
|
||||||
|
<Badge variant={config?.database?.url ? 'default' : 'secondary'}>
|
||||||
|
{config?.database?.url ? '已配置' : '未配置'}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>Gemini API</Label>
|
||||||
|
<Badge variant={config?.gemini_api?.api_key ? 'default' : 'secondary'}>
|
||||||
|
{config?.gemini_api?.api_key ? '已配置' : '未配置'}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>Tushare API</Label>
|
||||||
|
<Badge variant={config?.data_sources?.tushare?.api_key ? 'default' : 'secondary'}>
|
||||||
|
{config?.data_sources?.tushare?.api_key ? '已配置' : '未配置'}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>Finnhub API</Label>
|
||||||
|
<Badge variant={config?.data_sources?.finnhub?.api_key ? 'default' : 'secondary'}>
|
||||||
|
{config?.data_sources?.finnhub?.api_key ? '已配置' : '未配置'}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>数据源配置</CardTitle>
|
<CardTitle>配置管理</CardTitle>
|
||||||
<CardDescription>Tushare API 设置</CardDescription>
|
<CardDescription>导入、导出和备份配置</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="space-y-4">
|
<CardContent className="space-y-4">
|
||||||
<div className="flex items-center gap-4">
|
<div className="flex flex-col sm:flex-row gap-4">
|
||||||
<label className="w-28">Tushare Token</label>
|
<Button onClick={handleExportConfig} variant="outline" className="flex-1">
|
||||||
<Input
|
📤 导出配置
|
||||||
type="password"
|
</Button>
|
||||||
value={tushareApiKey}
|
<div className="flex-1">
|
||||||
onChange={(e) => setTushareApiKey(e.target.value)}
|
<input
|
||||||
placeholder="留空表示保持现值"
|
type="file"
|
||||||
className="flex-1"
|
accept=".json"
|
||||||
|
onChange={handleImportConfig}
|
||||||
|
className="hidden"
|
||||||
|
id="import-config"
|
||||||
/>
|
/>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="w-full"
|
||||||
|
onClick={() => document.getElementById('import-config')?.click()}
|
||||||
|
>
|
||||||
|
📥 导入配置
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
<p>• 导出配置将下载当前所有配置的备份文件</p>
|
||||||
|
<p>• 导入配置将加载备份文件中的设置(不包含敏感信息)</p>
|
||||||
|
<p>• 建议定期备份配置以防数据丢失</p>
|
||||||
</div>
|
</div>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
</TabsContent>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between pt-6 border-t">
|
||||||
<div className="flex items-center gap-4">
|
<div className="flex items-center gap-4">
|
||||||
<Button onClick={handleSave} disabled={saving}>
|
<Button onClick={handleSave} disabled={saving} size="lg">
|
||||||
{saving ? '保存中...' : '保存所有配置'}
|
{saving ? '保存中...' : '保存所有配置'}
|
||||||
</Button>
|
</Button>
|
||||||
{saveMessage && <span className="text-sm text-muted-foreground">{saveMessage}</span>}
|
<Button onClick={handleReset} variant="outline" size="lg">
|
||||||
|
重置表单
|
||||||
|
</Button>
|
||||||
|
{saveMessage && (
|
||||||
|
<span className={`text-sm ${saveMessage.includes('成功') ? 'text-green-600' : 'text-red-600'}`}>
|
||||||
|
{saveMessage}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
最后更新: {new Date().toLocaleString()}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -45,6 +45,9 @@ export default function RootLayout({
|
|||||||
<NavigationMenuItem>
|
<NavigationMenuItem>
|
||||||
<NavigationMenuLink href="/docs" className="px-3 py-2">文档</NavigationMenuLink>
|
<NavigationMenuLink href="/docs" className="px-3 py-2">文档</NavigationMenuLink>
|
||||||
</NavigationMenuItem>
|
</NavigationMenuItem>
|
||||||
|
<NavigationMenuItem>
|
||||||
|
<NavigationMenuLink href="/config" className="px-3 py-2">配置</NavigationMenuLink>
|
||||||
|
</NavigationMenuItem>
|
||||||
</NavigationMenuList>
|
</NavigationMenuList>
|
||||||
</NavigationMenu>
|
</NavigationMenu>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
147
frontend/src/components/TradingViewWidget.tsx
Normal file
147
frontend/src/components/TradingViewWidget.tsx
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
'use client';
|
||||||
|
|
||||||
|
import { useEffect, useRef } from 'react';
|
||||||
|
|
||||||
|
interface TradingViewWidgetProps {
|
||||||
|
symbol: string;
|
||||||
|
market?: string;
|
||||||
|
height?: number;
|
||||||
|
width?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
interface Window {
|
||||||
|
TradingView: any;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function TradingViewWidget({
|
||||||
|
symbol,
|
||||||
|
market = 'china',
|
||||||
|
height = 400,
|
||||||
|
width = '100%'
|
||||||
|
}: TradingViewWidgetProps) {
|
||||||
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
// 将中国股票代码转换为TradingView格式
|
||||||
|
const getTradingViewSymbol = (symbol: string, market: string) => {
|
||||||
|
if (market === 'china' || market === 'cn') {
|
||||||
|
// 处理中国股票代码
|
||||||
|
if (symbol.includes('.')) {
|
||||||
|
const [code, exchange] = symbol.split('.');
|
||||||
|
if (exchange === 'SH') {
|
||||||
|
return `SSE:${code}`;
|
||||||
|
} else if (exchange === 'SZ') {
|
||||||
|
return `SZSE:${code}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 如果没有后缀,尝试推断
|
||||||
|
const onlyDigits = symbol.replace(/\D/g, '');
|
||||||
|
if (onlyDigits.length === 6) {
|
||||||
|
const first = onlyDigits[0];
|
||||||
|
if (first === '6') {
|
||||||
|
return `SSE:${onlyDigits}`;
|
||||||
|
} else if (first === '0' || first === '3') {
|
||||||
|
return `SZSE:${onlyDigits}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return symbol;
|
||||||
|
}
|
||||||
|
return symbol;
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (typeof window === 'undefined') return;
|
||||||
|
if (!symbol) return;
|
||||||
|
const tradingViewSymbol = getTradingViewSymbol(symbol, market);
|
||||||
|
|
||||||
|
const script = document.createElement('script');
|
||||||
|
script.src = 'https://s3.tradingview.com/external-embedding/embed-widget-advanced-chart.js';
|
||||||
|
script.async = true;
|
||||||
|
script.innerHTML = JSON.stringify({
|
||||||
|
autosize: true,
|
||||||
|
symbol: tradingViewSymbol,
|
||||||
|
interval: 'D',
|
||||||
|
timezone: 'Asia/Shanghai',
|
||||||
|
theme: 'light',
|
||||||
|
style: '1',
|
||||||
|
locale: 'zh_CN',
|
||||||
|
toolbar_bg: '#f1f3f6',
|
||||||
|
enable_publishing: false,
|
||||||
|
hide_top_toolbar: false,
|
||||||
|
hide_legend: false,
|
||||||
|
save_image: false,
|
||||||
|
container_id: `tradingview_${symbol}`,
|
||||||
|
studies: [],
|
||||||
|
show_popup_button: false,
|
||||||
|
no_referrer_id: true,
|
||||||
|
referrer_id: 'fundamental-analysis',
|
||||||
|
// 强制启用对数坐标
|
||||||
|
logarithmic: true,
|
||||||
|
disabled_features: [
|
||||||
|
'use_localstorage_for_settings',
|
||||||
|
'volume_force_overlay',
|
||||||
|
'create_volume_indicator_by_default'
|
||||||
|
],
|
||||||
|
enabled_features: [
|
||||||
|
'side_toolbar_in_fullscreen_mode',
|
||||||
|
'header_in_fullscreen_mode'
|
||||||
|
],
|
||||||
|
overrides: {
|
||||||
|
'paneProperties.background': '#ffffff',
|
||||||
|
'paneProperties.vertGridProperties.color': '#e1e3e6',
|
||||||
|
'paneProperties.horzGridProperties.color': '#e1e3e6',
|
||||||
|
'symbolWatermarkProperties.transparency': 90,
|
||||||
|
'scalesProperties.textColor': '#333333',
|
||||||
|
// 对数坐标设置
|
||||||
|
'scalesProperties.logarithmic': true,
|
||||||
|
'rightPriceScale.mode': 1,
|
||||||
|
'leftPriceScale.mode': 1,
|
||||||
|
'paneProperties.priceScaleProperties.log': true,
|
||||||
|
'paneProperties.priceScaleProperties.mode': 1
|
||||||
|
},
|
||||||
|
// 强制启用对数坐标
|
||||||
|
studies_overrides: {
|
||||||
|
'volume.volume.color.0': '#00bcd4',
|
||||||
|
'volume.volume.color.1': '#ff9800',
|
||||||
|
'volume.volume.transparency': 70
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const container = containerRef.current;
|
||||||
|
if (container) {
|
||||||
|
// 避免重复挂载与 Next 热更新多次执行导致的报错
|
||||||
|
container.innerHTML = '';
|
||||||
|
// 延迟到下一帧,确保容器已插入并可获取 iframe.contentWindow
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
try {
|
||||||
|
if (container.isConnected) {
|
||||||
|
container.appendChild(script);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// 忽略偶发性 contentWindow 不可用的报错
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
const c = containerRef.current;
|
||||||
|
if (c) {
|
||||||
|
try {
|
||||||
|
c.innerHTML = '';
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, [symbol, market]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full">
|
||||||
|
<div
|
||||||
|
ref={containerRef}
|
||||||
|
id={`tradingview_${symbol}`}
|
||||||
|
style={{ height: `${height}px`, width }}
|
||||||
|
className="border rounded-lg overflow-hidden"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
21
frontend/src/components/ui/label.tsx
Normal file
21
frontend/src/components/ui/label.tsx
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"use client"
|
||||||
|
|
||||||
|
import * as React from "react"
|
||||||
|
import { cn } from "@/lib/utils"
|
||||||
|
|
||||||
|
export interface LabelProps extends React.LabelHTMLAttributes<HTMLLabelElement> {}
|
||||||
|
|
||||||
|
const Label = React.forwardRef<HTMLLabelElement, LabelProps>(({ className, ...props }, ref) => (
|
||||||
|
<label
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
|
||||||
|
Label.displayName = "Label"
|
||||||
|
|
||||||
|
export { Label }
|
||||||
29
frontend/src/components/ui/separator.tsx
Normal file
29
frontend/src/components/ui/separator.tsx
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"use client"
|
||||||
|
|
||||||
|
import * as React from "react"
|
||||||
|
import { cn } from "@/lib/utils"
|
||||||
|
|
||||||
|
export interface SeparatorProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||||
|
orientation?: "horizontal" | "vertical"
|
||||||
|
decorative?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const Separator = React.forwardRef<HTMLDivElement, SeparatorProps>(
|
||||||
|
({ className, orientation = "horizontal", decorative = true, ...props }, ref) => (
|
||||||
|
<div
|
||||||
|
ref={ref}
|
||||||
|
role={decorative ? "none" : "separator"}
|
||||||
|
aria-orientation={orientation}
|
||||||
|
className={cn(
|
||||||
|
"shrink-0 bg-border",
|
||||||
|
orientation === "horizontal" ? "h-[1px] w-full" : "h-full w-[1px]",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
Separator.displayName = "Separator"
|
||||||
|
|
||||||
|
export { Separator }
|
||||||
25
frontend/src/components/ui/textarea.tsx
Normal file
25
frontend/src/components/ui/textarea.tsx
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import * as React from "react"
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils"
|
||||||
|
|
||||||
|
export interface TextareaProps
|
||||||
|
extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {}
|
||||||
|
|
||||||
|
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||||
|
({ className, ...props }, ref) => {
|
||||||
|
return (
|
||||||
|
<textarea
|
||||||
|
className={cn(
|
||||||
|
"flex min-h-[60px] w-full rounded-md border border-input bg-transparent px-3 py-2 text-base shadow-xs placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
ref={ref}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
Textarea.displayName = "Textarea"
|
||||||
|
|
||||||
|
export { Textarea }
|
||||||
|
|
||||||
@ -1,8 +1,31 @@
|
|||||||
import useSWR from 'swr';
|
import useSWR from 'swr';
|
||||||
import { useConfigStore } from '@/stores/useConfigStore';
|
import { useConfigStore } from '@/stores/useConfigStore';
|
||||||
import { BatchFinancialDataResponse, FinancialConfigResponse } from '@/types';
|
import { BatchFinancialDataResponse, FinancialConfigResponse, AnalysisConfigResponse } from '@/types';
|
||||||
|
|
||||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
const fetcher = async (url: string) => {
|
||||||
|
const res = await fetch(url);
|
||||||
|
const contentType = res.headers.get('Content-Type') || '';
|
||||||
|
const text = await res.text();
|
||||||
|
|
||||||
|
// 尝试解析JSON
|
||||||
|
const tryParseJson = () => {
|
||||||
|
try { return JSON.parse(text); } catch { return null; }
|
||||||
|
};
|
||||||
|
|
||||||
|
const data = contentType.includes('application/json') ? tryParseJson() : tryParseJson();
|
||||||
|
|
||||||
|
if (!res.ok) {
|
||||||
|
// 后端可能返回纯文本错误,统一抛出可读错误
|
||||||
|
const message = data && data.detail ? data.detail : (text || `Request failed: ${res.status}`);
|
||||||
|
throw new Error(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data === null) {
|
||||||
|
throw new Error('无效的服务器响应(非JSON)');
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
};
|
||||||
|
|
||||||
export function useConfig() {
|
export function useConfig() {
|
||||||
const { setConfig, setError } = useConfigStore();
|
const { setConfig, setError } = useConfigStore();
|
||||||
@ -38,9 +61,9 @@ export function useFinancialConfig() {
|
|||||||
return useSWR<FinancialConfigResponse>('/api/financials/config', fetcher);
|
return useSWR<FinancialConfigResponse>('/api/financials/config', fetcher);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChinaFinancials(ts_code?: string) {
|
export function useChinaFinancials(ts_code?: string, years: number = 10) {
|
||||||
return useSWR<BatchFinancialDataResponse>(
|
return useSWR<BatchFinancialDataResponse>(
|
||||||
ts_code ? `/api/financials/china/${encodeURIComponent(ts_code)}` : null,
|
ts_code ? `/api/financials/china/${encodeURIComponent(ts_code)}?years=${encodeURIComponent(String(years))}` : null,
|
||||||
fetcher,
|
fetcher,
|
||||||
{
|
{
|
||||||
revalidateOnFocus: false, // 不在窗口聚焦时重新验证
|
revalidateOnFocus: false, // 不在窗口聚焦时重新验证
|
||||||
@ -50,3 +73,17 @@ export function useChinaFinancials(ts_code?: string) {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useAnalysisConfig() {
|
||||||
|
return useSWR<AnalysisConfigResponse>('/api/financials/analysis-config', fetcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function updateAnalysisConfig(config: AnalysisConfigResponse) {
|
||||||
|
const res = await fetch('/api/financials/analysis-config', {
|
||||||
|
method: 'PUT',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify(config),
|
||||||
|
});
|
||||||
|
if (!res.ok) throw new Error(await res.text());
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@ export interface DatabaseConfig {
|
|||||||
|
|
||||||
export interface GeminiConfig {
|
export interface GeminiConfig {
|
||||||
api_key: string;
|
api_key: string;
|
||||||
|
base_url?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface DataSourceConfig {
|
export interface DataSourceConfig {
|
||||||
|
|||||||
@ -49,6 +49,8 @@ export interface YearDataPoint {
|
|||||||
year: string;
|
year: string;
|
||||||
/** 数值 (可为null表示无数据) */
|
/** 数值 (可为null表示无数据) */
|
||||||
value: number | null;
|
value: number | null;
|
||||||
|
/** 月份信息,用于确定季度 */
|
||||||
|
month?: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -159,6 +161,42 @@ export interface CompanyProfileResponse {
|
|||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分析响应接口
|
||||||
|
*/
|
||||||
|
export interface AnalysisResponse {
|
||||||
|
/** 股票代码 */
|
||||||
|
ts_code: string;
|
||||||
|
/** 公司名称 */
|
||||||
|
company_name?: string;
|
||||||
|
/** 分析类型 */
|
||||||
|
analysis_type: string;
|
||||||
|
/** 分析内容 */
|
||||||
|
content: string;
|
||||||
|
/** 使用的模型 */
|
||||||
|
model: string;
|
||||||
|
/** Token使用情况 */
|
||||||
|
tokens: TokenUsage;
|
||||||
|
/** 耗时(毫秒) */
|
||||||
|
elapsed_ms: number;
|
||||||
|
/** 是否成功 */
|
||||||
|
success: boolean;
|
||||||
|
/** 错误信息 */
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分析配置响应接口
|
||||||
|
*/
|
||||||
|
export interface AnalysisConfigResponse {
|
||||||
|
/** 分析模块配置 */
|
||||||
|
analysis_modules: Record<string, {
|
||||||
|
name: string;
|
||||||
|
model: string;
|
||||||
|
prompt_template: string;
|
||||||
|
}>;
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// 表格相关类型
|
// 表格相关类型
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|||||||
56
scripts/test-api-tax-to-ebt.py
Normal file
56
scripts/test-api-tax-to-ebt.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""
|
||||||
|
测试脚本:通过后端 API 检查是否能获取 300750.SZ 的 tax_to_ebt 数据
|
||||||
|
"""
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
def test_api():
|
||||||
|
# 假设后端运行在默认端口
|
||||||
|
url = "http://localhost:8000/api/financials/china/300750.SZ?years=5"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"正在请求 API: {url}")
|
||||||
|
response = requests.get(url, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
print(f"\n✅ API 请求成功")
|
||||||
|
print(f"股票代码: {data.get('ts_code')}")
|
||||||
|
print(f"公司名称: {data.get('name')}")
|
||||||
|
|
||||||
|
# 检查 series 中是否有 tax_to_ebt
|
||||||
|
series = data.get('series', {})
|
||||||
|
if 'tax_to_ebt' in series:
|
||||||
|
print(f"\n✅ 找到 tax_to_ebt 数据!")
|
||||||
|
tax_data = series['tax_to_ebt']
|
||||||
|
print(f"数据条数: {len(tax_data)}")
|
||||||
|
print(f"\n最近几年的 tax_to_ebt 值:")
|
||||||
|
for item in tax_data[-5:]: # 显示最近5年
|
||||||
|
year = item.get('year')
|
||||||
|
value = item.get('value')
|
||||||
|
month = item.get('month')
|
||||||
|
month_str = f"Q{((month or 12) - 1) // 3 + 1}" if month else ""
|
||||||
|
print(f" {year}{month_str}: {value}")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ 未找到 tax_to_ebt 数据")
|
||||||
|
print(f"可用字段: {list(series.keys())[:20]}...")
|
||||||
|
|
||||||
|
# 检查是否有其他税率相关字段
|
||||||
|
tax_keys = [k for k in series.keys() if 'tax' in k.lower()]
|
||||||
|
if tax_keys:
|
||||||
|
print(f"\n包含 'tax' 的字段: {tax_keys}")
|
||||||
|
else:
|
||||||
|
print(f"❌ API 请求失败: {response.status_code}")
|
||||||
|
print(f"响应内容: {response.text}")
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print("❌ 无法连接到后端服务,请确保后端正在运行(例如运行 python dev.py)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 请求出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_api()
|
||||||
|
|
||||||
122
scripts/test-config.py
Normal file
122
scripts/test-config.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
配置页面功能测试脚本
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'backend'))
|
||||||
|
|
||||||
|
from app.services.config_manager import ConfigManager
|
||||||
|
from app.schemas.config import ConfigUpdateRequest, DatabaseConfig, GeminiConfig, DataSourceConfig
|
||||||
|
|
||||||
|
async def test_config_manager():
|
||||||
|
"""测试配置管理器功能"""
|
||||||
|
print("🧪 开始测试配置管理器...")
|
||||||
|
|
||||||
|
# 这里需要实际的数据库会话,暂时跳过
|
||||||
|
print("⚠️ 需要数据库连接,跳过实际测试")
|
||||||
|
print("✅ 配置管理器代码结构正确")
|
||||||
|
|
||||||
|
def test_config_validation():
|
||||||
|
"""测试配置验证功能"""
|
||||||
|
print("\n🔍 测试配置验证...")
|
||||||
|
|
||||||
|
# 测试数据库URL验证
|
||||||
|
valid_urls = [
|
||||||
|
"postgresql://user:pass@host:port/db",
|
||||||
|
"postgresql+asyncpg://user:pass@host:port/db"
|
||||||
|
]
|
||||||
|
|
||||||
|
invalid_urls = [
|
||||||
|
"mysql://user:pass@host:port/db",
|
||||||
|
"invalid-url",
|
||||||
|
""
|
||||||
|
]
|
||||||
|
|
||||||
|
for url in valid_urls:
|
||||||
|
if url.startswith(("postgresql://", "postgresql+asyncpg://")):
|
||||||
|
print(f"✅ 有效URL: {url}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 应该有效但被拒绝: {url}")
|
||||||
|
|
||||||
|
for url in invalid_urls:
|
||||||
|
if not url.startswith(("postgresql://", "postgresql+asyncpg://")):
|
||||||
|
print(f"✅ 无效URL正确被拒绝: {url}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 应该无效但被接受: {url}")
|
||||||
|
|
||||||
|
def test_api_key_validation():
|
||||||
|
"""测试API Key验证"""
|
||||||
|
print("\n🔑 测试API Key验证...")
|
||||||
|
|
||||||
|
valid_keys = ["1234567890", "abcdefghijklmnop"]
|
||||||
|
invalid_keys = ["123", "short", ""]
|
||||||
|
|
||||||
|
for key in valid_keys:
|
||||||
|
if len(key) >= 10:
|
||||||
|
print(f"✅ 有效API Key: {key[:10]}...")
|
||||||
|
else:
|
||||||
|
print(f"❌ 应该有效但被拒绝: {key}")
|
||||||
|
|
||||||
|
for key in invalid_keys:
|
||||||
|
if len(key) < 10:
|
||||||
|
print(f"✅ 无效API Key正确被拒绝: {key}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 应该无效但被接受: {key}")
|
||||||
|
|
||||||
|
def test_config_export_import():
|
||||||
|
"""测试配置导入导出功能"""
|
||||||
|
print("\n📤 测试配置导入导出...")
|
||||||
|
|
||||||
|
# 模拟配置数据
|
||||||
|
config_data = {
|
||||||
|
"database": {"url": "postgresql://test:test@localhost:5432/test"},
|
||||||
|
"gemini_api": {"api_key": "test_key_1234567890", "base_url": "https://api.example.com"},
|
||||||
|
"data_sources": {
|
||||||
|
"tushare": {"api_key": "tushare_key_1234567890"},
|
||||||
|
"finnhub": {"api_key": "finnhub_key_1234567890"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 测试JSON序列化
|
||||||
|
try:
|
||||||
|
json_str = json.dumps(config_data, indent=2)
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
print("✅ 配置JSON序列化/反序列化正常")
|
||||||
|
|
||||||
|
# 验证必需字段
|
||||||
|
required_fields = ["database", "gemini_api", "data_sources"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field in parsed:
|
||||||
|
print(f"✅ 包含必需字段: {field}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 缺少必需字段: {field}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ JSON处理失败: {e}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主测试函数"""
|
||||||
|
print("🚀 配置页面功能测试")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
test_config_validation()
|
||||||
|
test_api_key_validation()
|
||||||
|
test_config_export_import()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("✅ 所有测试完成!")
|
||||||
|
print("\n📋 测试总结:")
|
||||||
|
print("• 配置验证逻辑正确")
|
||||||
|
print("• API Key验证工作正常")
|
||||||
|
print("• 配置导入导出功能正常")
|
||||||
|
print("• 前端UI组件已创建")
|
||||||
|
print("• 后端API接口已实现")
|
||||||
|
print("• 错误处理机制已添加")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
scripts/test-employees.py
Executable file
82
scripts/test-employees.py
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试员工数数据获取功能
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'backend'))
|
||||||
|
|
||||||
|
from app.services.tushare_client import TushareClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_employees_data():
|
||||||
|
"""测试获取员工数数据"""
|
||||||
|
print("🧪 测试员工数数据获取...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 从环境变量或配置文件读取 token
|
||||||
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
config_path = os.path.join(base_dir, 'config', 'config.json')
|
||||||
|
|
||||||
|
token = os.environ.get('TUSHARE_TOKEN')
|
||||||
|
if not token and os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
token = config.get('data_sources', {}).get('tushare', {}).get('api_key')
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
print("❌ 未找到 Tushare token")
|
||||||
|
print("请设置环境变量 TUSHARE_TOKEN 或在 config/config.json 中配置")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✅ Token 已加载: {token[:10]}...")
|
||||||
|
|
||||||
|
# 测试股票代码
|
||||||
|
test_ts_code = "000001.SZ" # 平安银行
|
||||||
|
|
||||||
|
async with TushareClient(token=token) as client:
|
||||||
|
try:
|
||||||
|
print(f"\n📊 查询股票: {test_ts_code}")
|
||||||
|
print("调用 stock_company API...")
|
||||||
|
|
||||||
|
# 调用 stock_company API
|
||||||
|
data = await client.query(
|
||||||
|
api_name="stock_company",
|
||||||
|
params={"ts_code": test_ts_code, "limit": 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
if data:
|
||||||
|
print(f"✅ 成功获取 {len(data)} 条记录")
|
||||||
|
print("\n返回的数据字段:")
|
||||||
|
if data:
|
||||||
|
for key in data[0].keys():
|
||||||
|
print(f" - {key}")
|
||||||
|
|
||||||
|
print("\n员工数相关字段:")
|
||||||
|
for row in data:
|
||||||
|
if 'employees' in row:
|
||||||
|
print(f" ✅ employees: {row.get('employees')}")
|
||||||
|
if 'employee' in row:
|
||||||
|
print(f" ✅ employee: {row.get('employee')}")
|
||||||
|
|
||||||
|
print("\n完整数据示例:")
|
||||||
|
print(json.dumps(data[0], indent=2, ensure_ascii=False))
|
||||||
|
else:
|
||||||
|
print("⚠️ 未返回数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("🚀 开始测试员工数数据获取功能\n")
|
||||||
|
asyncio.run(test_employees_data())
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("✅ 测试完成")
|
||||||
|
|
||||||
104
scripts/test-holder-number.py
Executable file
104
scripts/test-holder-number.py
Executable file
@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试股东数数据获取功能
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'backend'))
|
||||||
|
|
||||||
|
from app.services.tushare_client import TushareClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_holder_number_data():
|
||||||
|
"""测试获取股东数数据"""
|
||||||
|
print("🧪 测试股东数数据获取...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 从环境变量或配置文件读取 token
|
||||||
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
config_path = os.path.join(base_dir, 'config', 'config.json')
|
||||||
|
|
||||||
|
token = os.environ.get('TUSHARE_TOKEN')
|
||||||
|
if not token and os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
token = config.get('data_sources', {}).get('tushare', {}).get('api_key')
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
print("❌ 未找到 Tushare token")
|
||||||
|
print("请设置环境变量 TUSHARE_TOKEN 或在 config/config.json 中配置")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✅ Token 已加载: {token[:10]}...")
|
||||||
|
|
||||||
|
# 测试股票代码
|
||||||
|
test_ts_code = "000001.SZ" # 平安银行
|
||||||
|
years = 5 # 查询最近5年的数据
|
||||||
|
|
||||||
|
# 计算日期范围
|
||||||
|
end_date = datetime.now().strftime("%Y%m%d")
|
||||||
|
start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d")
|
||||||
|
|
||||||
|
async with TushareClient(token=token) as client:
|
||||||
|
try:
|
||||||
|
print(f"\n📊 查询股票: {test_ts_code}")
|
||||||
|
print(f"📅 日期范围: {start_date} 到 {end_date}")
|
||||||
|
print("调用 stk_holdernumber API...")
|
||||||
|
|
||||||
|
# 调用 stk_holdernumber API
|
||||||
|
data = await client.query(
|
||||||
|
api_name="stk_holdernumber",
|
||||||
|
params={
|
||||||
|
"ts_code": test_ts_code,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"limit": 5000
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if data:
|
||||||
|
print(f"✅ 成功获取 {len(data)} 条记录")
|
||||||
|
print("\n返回的数据字段:")
|
||||||
|
if data:
|
||||||
|
for key in data[0].keys():
|
||||||
|
print(f" - {key}")
|
||||||
|
|
||||||
|
print("\n股东数数据:")
|
||||||
|
print("-" * 60)
|
||||||
|
for row in data[:10]: # 只显示前10条
|
||||||
|
end_date_val = row.get('end_date', 'N/A')
|
||||||
|
holder_num = row.get('holder_num', 'N/A')
|
||||||
|
print(f" 日期: {end_date_val}, 股东数: {holder_num}")
|
||||||
|
|
||||||
|
if len(data) > 10:
|
||||||
|
print(f" ... 还有 {len(data) - 10} 条记录")
|
||||||
|
|
||||||
|
print("\n完整数据示例(第一条):")
|
||||||
|
print(json.dumps(data[0], indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
# 检查是否有 holder_num 字段
|
||||||
|
if data and 'holder_num' in data[0]:
|
||||||
|
print("\n✅ 成功获取 holder_num 字段数据")
|
||||||
|
else:
|
||||||
|
print("\n⚠️ 未找到 holder_num 字段")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("⚠️ 未返回数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("🚀 开始测试股东数数据获取功能\n")
|
||||||
|
asyncio.run(test_holder_number_data())
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("✅ 测试完成")
|
||||||
|
|
||||||
115
scripts/test-holder-processing.py
Executable file
115
scripts/test-holder-processing.py
Executable file
@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试股东数数据处理逻辑
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'backend'))
|
||||||
|
|
||||||
|
from app.services.tushare_client import TushareClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_holder_num_processing():
|
||||||
|
"""测试股东数数据处理逻辑"""
|
||||||
|
print("🧪 测试股东数数据处理逻辑...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 从环境变量或配置文件读取 token
|
||||||
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
config_path = os.path.join(base_dir, 'config', 'config.json')
|
||||||
|
|
||||||
|
token = os.environ.get('TUSHARE_TOKEN')
|
||||||
|
if not token and os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
token = config.get('data_sources', {}).get('tushare', {}).get('api_key')
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
print("❌ 未找到 Tushare token")
|
||||||
|
return
|
||||||
|
|
||||||
|
ts_code = '000001.SZ'
|
||||||
|
years = 5
|
||||||
|
|
||||||
|
async with TushareClient(token=token) as client:
|
||||||
|
# 模拟后端处理逻辑
|
||||||
|
end_date = datetime.now().strftime('%Y%m%d')
|
||||||
|
start_date = (datetime.now() - timedelta(days=years * 365)).strftime('%Y%m%d')
|
||||||
|
|
||||||
|
print(f"📊 查询股票: {ts_code}")
|
||||||
|
print(f"📅 日期范围: {start_date} 到 {end_date}")
|
||||||
|
|
||||||
|
data_rows = await client.query(
|
||||||
|
api_name='stk_holdernumber',
|
||||||
|
params={'ts_code': ts_code, 'start_date': start_date, 'end_date': end_date, 'limit': 5000}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'\n✅ 获取到 {len(data_rows)} 条原始数据')
|
||||||
|
|
||||||
|
if data_rows:
|
||||||
|
print('\n原始数据示例(前3条):')
|
||||||
|
for i, row in enumerate(data_rows[:3]):
|
||||||
|
print(f" 第{i+1}条: {json.dumps(row, indent=4, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
# 模拟后端处理逻辑
|
||||||
|
series = {}
|
||||||
|
tmp = {}
|
||||||
|
date_field = 'end_date'
|
||||||
|
|
||||||
|
print('\n📝 开始处理数据...')
|
||||||
|
|
||||||
|
for row in data_rows:
|
||||||
|
date_val = row.get(date_field)
|
||||||
|
if not date_val:
|
||||||
|
print(f" ⚠️ 跳过无日期字段的行: {row}")
|
||||||
|
continue
|
||||||
|
year = str(date_val)[:4]
|
||||||
|
month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None
|
||||||
|
existing = tmp.get(year)
|
||||||
|
if existing is None or str(row.get(date_field)) > str(existing.get(date_field)):
|
||||||
|
tmp[year] = row
|
||||||
|
tmp[year]['_month'] = month
|
||||||
|
|
||||||
|
print(f'\n✅ 处理后共有 {len(tmp)} 个年份的数据')
|
||||||
|
print('按年份分组的数据:')
|
||||||
|
for year, row in sorted(tmp.items(), key=lambda x: x[0], reverse=True):
|
||||||
|
print(f" {year}: holder_num={row.get('holder_num')}, end_date={row.get('end_date')}")
|
||||||
|
|
||||||
|
# 提取 holder_num 字段
|
||||||
|
key = 'holder_num'
|
||||||
|
for year, row in tmp.items():
|
||||||
|
month = row.get('_month')
|
||||||
|
value = row.get(key)
|
||||||
|
|
||||||
|
arr = series.setdefault(key, [])
|
||||||
|
arr.append({'year': year, 'value': value, 'month': month})
|
||||||
|
|
||||||
|
print('\n📊 提取后的 series 数据:')
|
||||||
|
print(json.dumps(series, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
# 排序(模拟后端逻辑)
|
||||||
|
for key, arr in series.items():
|
||||||
|
uniq = {item['year']: item for item in arr}
|
||||||
|
arr_sorted_desc = sorted(uniq.values(), key=lambda x: x['year'], reverse=True)
|
||||||
|
arr_limited = arr_sorted_desc[:years]
|
||||||
|
arr_sorted = sorted(arr_limited, key=lambda x: x['year']) # ascending
|
||||||
|
series[key] = arr_sorted
|
||||||
|
|
||||||
|
print('\n✅ 最终排序后的数据(按年份升序):')
|
||||||
|
print(json.dumps(series, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
# 验证年份格式
|
||||||
|
print('\n🔍 验证年份格式:')
|
||||||
|
for item in series.get('holder_num', []):
|
||||||
|
year_str = item.get('year')
|
||||||
|
print(f" 年份: '{year_str}' (类型: {type(year_str).__name__}, 长度: {len(str(year_str))})")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_holder_num_processing())
|
||||||
|
|
||||||
110
scripts/test-tax-to-ebt.py
Normal file
110
scripts/test-tax-to-ebt.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
测试脚本:检查是否能获取 300750.SZ 的 tax_to_ebt 数据
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 添加 backend 目录到 Python 路径
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
||||||
|
|
||||||
|
from app.services.tushare_client import TushareClient
|
||||||
|
|
||||||
|
async def test_tax_to_ebt():
|
||||||
|
# 读取配置获取 token
|
||||||
|
config_path = os.path.join(os.path.dirname(__file__), "..", "config", "config.json")
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
token = config.get("data_sources", {}).get("tushare", {}).get("api_key")
|
||||||
|
if not token:
|
||||||
|
print("错误:未找到 Tushare token")
|
||||||
|
return
|
||||||
|
|
||||||
|
client = TushareClient(token=token)
|
||||||
|
ts_code = "300750.SZ"
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"正在查询 {ts_code} 的财务指标数据...")
|
||||||
|
|
||||||
|
# 先尝试不指定 fields,获取所有字段
|
||||||
|
print("\n=== 测试1: 不指定 fields 参数 ===")
|
||||||
|
data = await client.query(
|
||||||
|
api_name="fina_indicator",
|
||||||
|
params={"ts_code": ts_code, "limit": 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 再尝试明确指定 fields,包含 tax_to_ebt
|
||||||
|
print("\n=== 测试2: 明确指定 fields 参数(包含 tax_to_ebt) ===")
|
||||||
|
data_with_fields = await client.query(
|
||||||
|
api_name="fina_indicator",
|
||||||
|
params={"ts_code": ts_code, "limit": 10},
|
||||||
|
fields="ts_code,ann_date,end_date,tax_to_ebt,roe,roa"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n获取到 {len(data)} 条记录")
|
||||||
|
|
||||||
|
if data:
|
||||||
|
# 检查第一条记录的字段
|
||||||
|
first_record = data[0]
|
||||||
|
print(f"\n第一条记录的字段:")
|
||||||
|
print(f" ts_code: {first_record.get('ts_code')}")
|
||||||
|
print(f" end_date: {first_record.get('end_date')}")
|
||||||
|
print(f" ann_date: {first_record.get('ann_date')}")
|
||||||
|
|
||||||
|
# 检查是否有 tax_to_ebt 字段
|
||||||
|
if 'tax_to_ebt' in first_record:
|
||||||
|
tax_value = first_record.get('tax_to_ebt')
|
||||||
|
print(f"\n✅ 找到 tax_to_ebt 字段!")
|
||||||
|
print(f" tax_to_ebt 值: {tax_value}")
|
||||||
|
print(f" tax_to_ebt 类型: {type(tax_value)}")
|
||||||
|
else:
|
||||||
|
print(f"\n❌ 未找到 tax_to_ebt 字段")
|
||||||
|
print(f"可用字段列表: {list(first_record.keys())[:20]}...") # 只显示前20个字段
|
||||||
|
|
||||||
|
# 打印所有包含 tax 的字段
|
||||||
|
tax_fields = [k for k in first_record.keys() if 'tax' in k.lower()]
|
||||||
|
if tax_fields:
|
||||||
|
print(f"\n包含 'tax' 的字段:")
|
||||||
|
for field in tax_fields:
|
||||||
|
print(f" {field}: {first_record.get(field)}")
|
||||||
|
|
||||||
|
# 显示最近几条记录的 tax_to_ebt 值
|
||||||
|
print(f"\n最近几条记录的 tax_to_ebt 值(测试1):")
|
||||||
|
for i, record in enumerate(data[:5]):
|
||||||
|
end_date = record.get('end_date', 'N/A')
|
||||||
|
tax_value = record.get('tax_to_ebt', 'N/A')
|
||||||
|
print(f" {i+1}. {end_date}: tax_to_ebt = {tax_value}")
|
||||||
|
else:
|
||||||
|
print("❌ 未获取到任何数据(测试1)")
|
||||||
|
|
||||||
|
# 测试2:检查明确指定 fields 的结果
|
||||||
|
if data_with_fields:
|
||||||
|
print(f"\n测试2获取到 {len(data_with_fields)} 条记录")
|
||||||
|
first_record2 = data_with_fields[0]
|
||||||
|
if 'tax_to_ebt' in first_record2:
|
||||||
|
print(f"✅ 测试2找到 tax_to_ebt 字段!")
|
||||||
|
print(f" tax_to_ebt 值: {first_record2.get('tax_to_ebt')}")
|
||||||
|
else:
|
||||||
|
print(f"❌ 测试2也未找到 tax_to_ebt 字段")
|
||||||
|
print(f"可用字段: {list(first_record2.keys())}")
|
||||||
|
|
||||||
|
print(f"\n最近几条记录的 tax_to_ebt 值(测试2):")
|
||||||
|
for i, record in enumerate(data_with_fields[:5]):
|
||||||
|
end_date = record.get('end_date', 'N/A')
|
||||||
|
tax_value = record.get('tax_to_ebt', 'N/A')
|
||||||
|
print(f" {i+1}. {end_date}: tax_to_ebt = {tax_value}")
|
||||||
|
else:
|
||||||
|
print("❌ 未获取到任何数据(测试2)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 查询出错: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_tax_to_ebt())
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user