From e0aa61b8c4d676a33bad2f469f91cbb0b275630a Mon Sep 17 00:00:00 2001 From: xucheng Date: Wed, 29 Oct 2025 22:49:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=88=86=E6=9E=90?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E9=85=8D=E7=BD=AE=E5=92=8C=E5=88=86=E6=9E=90?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=9B=B4=E6=96=B0=E8=B4=A2=E5=8A=A1?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/routers/config.py | 19 +- backend/app/routers/financial.py | 333 +++- backend/app/schemas/financial.py | 17 + backend/app/services/analysis_client.py | 136 ++ backend/app/services/config_manager.py | 267 ++- backend/requirements.txt | 1 + config/analysis-config.json | 49 + config/financial-tushare.json | 3 +- frontend/src/app/config/page.tsx | 636 ++++++- frontend/src/app/layout.tsx | 3 + frontend/src/app/report/[symbol]/page.tsx | 1677 +++++++++++++++-- frontend/src/components/TradingViewWidget.tsx | 147 ++ frontend/src/components/ui/label.tsx | 21 + frontend/src/components/ui/separator.tsx | 29 + frontend/src/components/ui/textarea.tsx | 25 + frontend/src/hooks/useApi.ts | 45 +- frontend/src/stores/useConfigStore.ts | 1 + frontend/src/types/index.ts | 38 + scripts/test-api-tax-to-ebt.py | 56 + scripts/test-config.py | 122 ++ scripts/test-employees.py | 82 + scripts/test-holder-number.py | 104 + scripts/test-holder-processing.py | 115 ++ scripts/test-tax-to-ebt.py | 110 ++ 24 files changed, 3735 insertions(+), 301 deletions(-) create mode 100644 backend/app/services/analysis_client.py create mode 100644 config/analysis-config.json create mode 100644 frontend/src/components/TradingViewWidget.tsx create mode 100644 frontend/src/components/ui/label.tsx create mode 100644 frontend/src/components/ui/separator.tsx create mode 100644 frontend/src/components/ui/textarea.tsx create mode 100644 scripts/test-api-tax-to-ebt.py create mode 100644 scripts/test-config.py create mode 100755 scripts/test-employees.py create mode 100755 scripts/test-holder-number.py create mode 100755 scripts/test-holder-processing.py create mode 100644 scripts/test-tax-to-ebt.py diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py index 37d794e..0c3846f 100644 --- a/backend/app/routers/config.py +++ b/backend/app/routers/config.py @@ -28,11 +28,14 @@ async def test_config( config_manager: ConfigManager = Depends(get_config_manager) ): """Test a specific configuration (e.g., database connection).""" - # The test logic will be implemented in a subsequent step inside the ConfigManager - # For now, we return a placeholder response. - # test_result = await config_manager.test_config( - # test_request.config_type, - # test_request.config_data - # ) - # return test_result - raise HTTPException(status_code=501, detail="Not Implemented") + try: + test_result = await config_manager.test_config( + test_request.config_type, + test_request.config_data + ) + return test_result + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"测试失败: {str(e)}" + ) diff --git a/backend/app/routers/financial.py b/backend/app/routers/financial.py index 2bd00b8..d6b872f 100644 --- a/backend/app/routers/financial.py +++ b/backend/app/routers/financial.py @@ -4,7 +4,7 @@ API router for financial data (Tushare for China market) import json import os import time -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from typing import Dict, List from fastapi import APIRouter, HTTPException, Query @@ -12,9 +12,18 @@ from fastapi.responses import StreamingResponse import os from app.core.config import settings -from app.schemas.financial import BatchFinancialDataResponse, FinancialConfigResponse, FinancialMeta, StepRecord, CompanyProfileResponse +from app.schemas.financial import ( + BatchFinancialDataResponse, + FinancialConfigResponse, + FinancialMeta, + StepRecord, + CompanyProfileResponse, + AnalysisResponse, + AnalysisConfigResponse +) from app.services.tushare_client import TushareClient from app.services.company_profile_client import CompanyProfileClient +from app.services.analysis_client import AnalysisClient, load_analysis_config, get_analysis_config router = APIRouter() @@ -23,6 +32,7 @@ router = APIRouter() REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) FINANCIAL_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "financial-tushare.json") BASE_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "config.json") +ANALYSIS_CONFIG_PATH = os.path.join(REPO_ROOT, "config", "analysis-config.json") def _load_json(path: str) -> Dict: @@ -86,14 +96,16 @@ async def get_china_financials( series: Dict[str, List[Dict]] = {} # Helper to store year-value pairs while keeping most recent per year - def _merge_year_value(key: str, year: str, value): + def _merge_year_value(key: str, year: str, value, month: int = None): arr = series.setdefault(key, []) # upsert by year for item in arr: if item["year"] == year: item["value"] = value + if month is not None: + item["month"] = month return - arr.append({"year": year, "value": value}) + arr.append({"year": year, "value": value, "month": month}) # Query each API group we care errors: Dict[str, str] = {} @@ -107,39 +119,96 @@ async def get_china_financials( current_action = step.name if not metrics: continue - api_name = metrics[0].get("api") or group_name - fields = list({m.get("tushareParam") for m in metrics if m.get("tushareParam")}) - if not fields: - continue - - date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date" - try: - data_rows = await client.query(api_name=api_name, params={"ts_code": ts_code, "limit": 5000}, fields=None) - api_calls_total += 1 - api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1 - except Exception as e: - step.status = "error" - step.error = str(e) - step.end_ts = datetime.now(timezone.utc).isoformat() - step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) - errors[group_name] = str(e) - continue - - tmp: Dict[str, Dict] = {} - for row in data_rows: - date_val = row.get(date_field) - if not date_val: - continue - year = str(date_val)[:4] - existing = tmp.get(year) - if existing is None or str(row.get(date_field)) > str(existing.get(date_field)): - tmp[year] = row + + # 按 API 分组 metrics(处理 unknown 组中有多个不同 API 的情况) + api_groups_dict: Dict[str, List[Dict]] = {} for metric in metrics: - key = metric.get("tushareParam") - if not key: + api = metric.get("api") or group_name + if api: # 跳过空 API + if api not in api_groups_dict: + api_groups_dict[api] = [] + api_groups_dict[api].append(metric) + + # 对每个 API 分别处理 + for api_name, api_metrics in api_groups_dict.items(): + fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")] + if not fields: continue - for year, row in tmp.items(): - _merge_year_value(key, year, row.get(key)) + + date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date" + + # 构建 API 参数 + params = {"ts_code": ts_code, "limit": 5000} + + # 对于需要日期范围的 API(如 stk_holdernumber),添加日期参数 + if api_name == "stk_holdernumber": + # 计算日期范围:从 years 年前到现在 + end_date = datetime.now().strftime("%Y%m%d") + start_date = (datetime.now() - timedelta(days=years * 365)).strftime("%Y%m%d") + params["start_date"] = start_date + params["end_date"] = end_date + # stk_holdernumber 返回的日期字段通常是 end_date + date_field = "end_date" + + # 对于非时间序列 API(如 stock_company),标记为静态数据 + is_static_data = api_name == "stock_company" + + # 构建 fields 字符串:包含日期字段和所有需要的指标字段 + # 确保日期字段存在,因为我们需要用它来确定年份 + fields_list = list(fields) + if date_field not in fields_list: + fields_list.insert(0, date_field) + # 对于 fina_indicator 等 API,通常还需要 ts_code 和 ann_date + if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"): + for req_field in ["ts_code", "ann_date"]: + if req_field not in fields_list: + fields_list.insert(0, req_field) + fields_str = ",".join(fields_list) + + try: + data_rows = await client.query(api_name=api_name, params=params, fields=fields_str) + api_calls_total += 1 + api_calls_by_group[group_name] = api_calls_by_group.get(group_name, 0) + 1 + except Exception as e: + # 记录错误但继续处理其他 API + error_key = f"{group_name}_{api_name}" + errors[error_key] = str(e) + continue + + tmp: Dict[str, Dict] = {} + current_year = datetime.now().strftime("%Y") + + for row in data_rows: + if is_static_data: + # 对于静态数据(如 stock_company),使用当前年份 + # 只处理第一行数据,因为静态数据通常只有一行 + if current_year not in tmp: + year = current_year + month = None + tmp[year] = row + tmp[year]['_month'] = month + # 跳过后续行 + continue + else: + # 对于时间序列数据,按日期字段处理 + date_val = row.get(date_field) + if not date_val: + continue + year = str(date_val)[:4] + month = int(str(date_val)[4:6]) if len(str(date_val)) >= 6 else None + existing = tmp.get(year) + if existing is None or str(row.get(date_field)) > str(existing.get(date_field)): + tmp[year] = row + tmp[year]['_month'] = month + + for metric in api_metrics: + key = metric.get("tushareParam") + if not key: + continue + for year, row in tmp.items(): + month = row.get('_month') + _merge_year_value(key, year, row.get(key), month) + step.status = "done" step.end_ts = datetime.now(timezone.utc).isoformat() step.duration_ms = int((time.perf_counter_ns() - started) / 1_000_000) @@ -247,3 +316,197 @@ async def get_company_profile( success=result.get("success", False), error=result.get("error") ) + + +@router.get("/analysis-config", response_model=AnalysisConfigResponse) +async def get_analysis_config_endpoint(): + """Get analysis configuration""" + config = load_analysis_config() + return AnalysisConfigResponse(analysis_modules=config.get("analysis_modules", {})) + + +@router.put("/analysis-config", response_model=AnalysisConfigResponse) +async def update_analysis_config_endpoint(analysis_config: AnalysisConfigResponse): + """Update analysis configuration""" + import logging + logger = logging.getLogger(__name__) + + try: + # 保存到文件 + config_data = { + "analysis_modules": analysis_config.analysis_modules + } + + with open(ANALYSIS_CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(config_data, f, ensure_ascii=False, indent=2) + + logger.info(f"[API] Analysis config updated successfully") + return AnalysisConfigResponse(analysis_modules=analysis_config.analysis_modules) + except Exception as e: + logger.error(f"[API] Failed to update analysis config: {e}") + raise HTTPException( + status_code=500, + detail=f"Failed to update analysis config: {str(e)}" + ) + + +@router.get("/china/{ts_code}/analysis/{analysis_type}", response_model=AnalysisResponse) +async def generate_analysis( + ts_code: str, + analysis_type: str, + company_name: str = Query(None, description="Company name for better context"), +): + """ + Generate analysis for a company using Gemini AI + Supported analysis types: + - fundamental_analysis (基本面分析) + - bull_case (看涨分析) + - bear_case (看跌分析) + - market_analysis (市场分析) + - news_analysis (新闻分析) + - trading_analysis (交易分析) + - insider_institutional (内部人与机构动向分析) + - final_conclusion (最终结论) + """ + import logging + logger = logging.getLogger(__name__) + + logger.info(f"[API] Analysis requested for {ts_code}, type: {analysis_type}") + + # Load config + base_cfg = _load_json(BASE_CONFIG_PATH) + gemini_cfg = base_cfg.get("llm", {}).get("gemini", {}) + api_key = gemini_cfg.get("api_key") + + if not api_key: + logger.error("[API] Gemini API key not configured") + raise HTTPException( + status_code=500, + detail="Gemini API key not configured. Set config.json llm.gemini.api_key" + ) + + # Get analysis configuration + analysis_cfg = get_analysis_config(analysis_type) + if not analysis_cfg: + raise HTTPException( + status_code=404, + detail=f"Analysis type '{analysis_type}' not found in configuration" + ) + + model = analysis_cfg.get("model", "gemini-2.5-flash") + prompt_template = analysis_cfg.get("prompt_template", "") + + if not prompt_template: + raise HTTPException( + status_code=500, + detail=f"Prompt template not found for analysis type '{analysis_type}'" + ) + + # Get company name from ts_code if not provided + financial_data = None + if not company_name: + logger.info(f"[API] Fetching company name and financial data for {ts_code}") + try: + token = ( + os.environ.get("TUSHARE_TOKEN") + or settings.TUSHARE_TOKEN + or base_cfg.get("data_sources", {}).get("tushare", {}).get("api_key") + ) + if token: + tushare_client = TushareClient(token=token) + basic_data = await tushare_client.query(api_name="stock_basic", params={"ts_code": ts_code}, fields="ts_code,name") + if basic_data and len(basic_data) > 0: + company_name = basic_data[0].get("name", ts_code) + logger.info(f"[API] Got company name: {company_name}") + + # Try to get financial data for context + try: + fin_cfg = _load_json(FINANCIAL_CONFIG_PATH) + api_groups = fin_cfg.get("api_groups", {}) + + # Get financial data summary for context + series: Dict[str, List[Dict]] = {} + for group_name, metrics in api_groups.items(): + if not metrics: + continue + api_groups_dict: Dict[str, List[Dict]] = {} + for metric in metrics: + api = metric.get("api") or group_name + if api: + if api not in api_groups_dict: + api_groups_dict[api] = [] + api_groups_dict[api].append(metric) + + for api_name, api_metrics in api_groups_dict.items(): + fields = [m.get("tushareParam") for m in api_metrics if m.get("tushareParam")] + if not fields: + continue + + date_field = "end_date" if group_name in ("fina_indicator", "income", "balancesheet", "cashflow") else "trade_date" + + params = {"ts_code": ts_code, "limit": 500} + fields_list = list(fields) + if date_field not in fields_list: + fields_list.insert(0, date_field) + if api_name in ("fina_indicator", "income", "balancesheet", "cashflow"): + for req_field in ["ts_code", "ann_date"]: + if req_field not in fields_list: + fields_list.insert(0, req_field) + fields_str = ",".join(fields_list) + + try: + data_rows = await tushare_client.query(api_name=api_name, params=params, fields=fields_str) + if data_rows: + # Get latest year's data + latest_row = data_rows[0] if data_rows else {} + for metric in api_metrics: + key = metric.get("tushareParam") + if key and key in latest_row: + if key not in series: + series[key] = [] + series[key].append({ + "year": latest_row.get(date_field, "")[:4] if latest_row.get(date_field) else str(datetime.now().year), + "value": latest_row.get(key) + }) + except Exception: + pass + + financial_data = {"series": series} + except Exception as e: + logger.warning(f"[API] Failed to get financial data: {e}") + financial_data = None + else: + company_name = ts_code + else: + company_name = ts_code + except Exception as e: + logger.warning(f"[API] Failed to get company name: {e}") + company_name = ts_code + + logger.info(f"[API] Generating {analysis_type} for {company_name}") + + # Initialize analysis client with configured model + client = AnalysisClient(api_key=api_key, model=model) + + # Generate analysis + result = await client.generate_analysis( + analysis_type=analysis_type, + company_name=company_name, + ts_code=ts_code, + prompt_template=prompt_template, + financial_data=financial_data + ) + + logger.info(f"[API] Analysis generation completed, success={result.get('success')}") + + return AnalysisResponse( + ts_code=ts_code, + company_name=company_name, + analysis_type=analysis_type, + content=result.get("content", ""), + model=result.get("model", model), + tokens=result.get("tokens", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}), + elapsed_ms=result.get("elapsed_ms", 0), + success=result.get("success", False), + error=result.get("error") + ) diff --git a/backend/app/schemas/financial.py b/backend/app/schemas/financial.py index 88ae52b..2aa6c02 100644 --- a/backend/app/schemas/financial.py +++ b/backend/app/schemas/financial.py @@ -8,6 +8,7 @@ from pydantic import BaseModel class YearDataPoint(BaseModel): year: str value: Optional[float] + month: Optional[int] = None # 月份信息,用于确定季度 class StepRecord(BaseModel): @@ -55,3 +56,19 @@ class CompanyProfileResponse(BaseModel): elapsed_ms: int success: bool = True error: Optional[str] = None + + +class AnalysisResponse(BaseModel): + ts_code: str + company_name: Optional[str] = None + analysis_type: str + content: str + model: str + tokens: TokenUsage + elapsed_ms: int + success: bool = True + error: Optional[str] = None + + +class AnalysisConfigResponse(BaseModel): + analysis_modules: Dict[str, Dict] diff --git a/backend/app/services/analysis_client.py b/backend/app/services/analysis_client.py new file mode 100644 index 0000000..0eda089 --- /dev/null +++ b/backend/app/services/analysis_client.py @@ -0,0 +1,136 @@ +""" +Generic Analysis Client for various analysis types using Gemini API +""" +import time +import json +import os +from typing import Dict, Optional +import google.generativeai as genai + + +class AnalysisClient: + """Generic client for generating various types of analysis using Gemini API""" + + def __init__(self, api_key: str, model: str = "gemini-2.5-flash"): + """Initialize Gemini client with API key and model""" + genai.configure(api_key=api_key) + self.model_name = model + self.model = genai.GenerativeModel(model) + + async def generate_analysis( + self, + analysis_type: str, + company_name: str, + ts_code: str, + prompt_template: str, + financial_data: Optional[Dict] = None + ) -> Dict: + """ + Generate analysis using Gemini API (non-streaming) + + Args: + analysis_type: Type of analysis (e.g., "fundamental_analysis") + company_name: Company name + ts_code: Stock code + prompt_template: Prompt template with placeholders {company_name}, {ts_code}, {financial_data} + financial_data: Optional financial data for context + + Returns: + Dict with analysis content and metadata + """ + start_time = time.perf_counter_ns() + + # Build prompt from template + prompt = self._build_prompt( + prompt_template, + company_name, + ts_code, + financial_data + ) + + # Call Gemini API (using sync API in async context) + try: + import asyncio + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: self.model.generate_content(prompt) + ) + + # Get token usage + usage_metadata = response.usage_metadata if hasattr(response, 'usage_metadata') else None + + elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000) + + return { + "content": response.text, + "model": self.model_name, + "tokens": { + "prompt_tokens": usage_metadata.prompt_token_count if usage_metadata else 0, + "completion_tokens": usage_metadata.candidates_token_count if usage_metadata else 0, + "total_tokens": usage_metadata.total_token_count if usage_metadata else 0, + } if usage_metadata else {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "elapsed_ms": elapsed_ms, + "success": True, + "analysis_type": analysis_type, + } + except Exception as e: + elapsed_ms = int((time.perf_counter_ns() - start_time) / 1_000_000) + return { + "content": "", + "model": self.model_name, + "tokens": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "elapsed_ms": elapsed_ms, + "success": False, + "error": str(e), + "analysis_type": analysis_type, + } + + def _build_prompt( + self, + prompt_template: str, + company_name: str, + ts_code: str, + financial_data: Optional[Dict] = None + ) -> str: + """Build prompt from template by replacing placeholders""" + # Format financial data as string if provided + financial_data_str = "" + if financial_data: + try: + financial_data_str = json.dumps(financial_data, ensure_ascii=False, indent=2) + except Exception: + financial_data_str = str(financial_data) + + # Replace placeholders in template + prompt = prompt_template.format( + company_name=company_name, + ts_code=ts_code, + financial_data=financial_data_str + ) + + return prompt + + +def load_analysis_config() -> Dict: + """Load analysis configuration from JSON file""" + # Get project root: backend/app/services -> project_root/config/analysis-config.json + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + config_path = os.path.join(project_root, "config", "analysis-config.json") + + if not os.path.exists(config_path): + return {} + + try: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + +def get_analysis_config(analysis_type: str) -> Optional[Dict]: + """Get configuration for a specific analysis type""" + config = load_analysis_config() + modules = config.get("analysis_modules", {}) + return modules.get(analysis_type) + diff --git a/backend/app/services/config_manager.py b/backend/app/services/config_manager.py index 283baed..b7fb35e 100644 --- a/backend/app/services/config_manager.py +++ b/backend/app/services/config_manager.py @@ -3,13 +3,16 @@ Configuration Management Service """ import json import os +import asyncio from typing import Any, Dict +import asyncpg +import httpx from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.models.system_config import SystemConfig -from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, GeminiConfig, DataSourceConfig +from app.schemas.config import ConfigResponse, ConfigUpdateRequest, DatabaseConfig, GeminiConfig, DataSourceConfig, ConfigTestResponse class ConfigManager: """Manages system configuration by merging a static JSON file with dynamic settings from the database.""" @@ -17,8 +20,10 @@ class ConfigManager: def __init__(self, db_session: AsyncSession, config_path: str = None): self.db = db_session if config_path is None: - # Default path: backend/ -> project_root/ -> config/config.json - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + # Default path: backend/app/services -> project_root/config/config.json + # __file__ = backend/app/services/config_manager.py + # go up three levels to project root + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) self.config_path = os.path.join(project_root, "config", "config.json") else: self.config_path = config_path @@ -34,12 +39,19 @@ class ConfigManager: return {} async def _load_dynamic_config_from_db(self) -> Dict[str, Any]: - """Loads dynamic configuration overrides from the database.""" - db_configs = {} - result = await self.db.execute(select(SystemConfig)) - for record in result.scalars().all(): - db_configs[record.config_key] = record.config_value - return db_configs + """Loads dynamic configuration overrides from the database. + + 当数据库表尚未创建(如开发环境未运行迁移)时,优雅降级为返回空覆盖配置,避免接口 500。 + """ + try: + db_configs: Dict[str, Any] = {} + result = await self.db.execute(select(SystemConfig)) + for record in result.scalars().all(): + db_configs[record.config_key] = record.config_value + return db_configs + except Exception: + # 表不存在或其他数据库错误时,忽略动态配置覆盖 + return {} def _merge_configs(self, base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: """Deeply merges the override config into the base config.""" @@ -57,9 +69,12 @@ class ConfigManager: merged_config = self._merge_configs(base_config, db_config) + # 兼容两种位置:优先使用 gemini_api,其次回退到 llm.gemini + gemini_src = merged_config.get("gemini_api") or merged_config.get("llm", {}).get("gemini", {}) + return ConfigResponse( database=DatabaseConfig(**merged_config.get("database", {})), - gemini_api=GeminiConfig(**merged_config.get("llm", {}).get("gemini", {})), + gemini_api=GeminiConfig(**(gemini_src or {})), data_sources={ k: DataSourceConfig(**v) for k, v in merged_config.get("data_sources", {}).items() @@ -68,20 +83,222 @@ class ConfigManager: async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse: """Updates configuration in the database and returns the new merged config.""" - update_dict = config_update.dict(exclude_unset=True) - - for key, value in update_dict.items(): - existing_config = await self.db.get(SystemConfig, key) - if existing_config: - # Merge with existing DB value before updating - if isinstance(existing_config.config_value, dict) and isinstance(value, dict): - merged_value = self._merge_configs(existing_config.config_value, value) - existing_config.config_value = merged_value + try: + update_dict = config_update.dict(exclude_unset=True) + + # 验证配置数据 + self._validate_config_data(update_dict) + + for key, value in update_dict.items(): + existing_config = await self.db.get(SystemConfig, key) + if existing_config: + # Merge with existing DB value before updating + if isinstance(existing_config.config_value, dict) and isinstance(value, dict): + merged_value = self._merge_configs(existing_config.config_value, value) + existing_config.config_value = merged_value + else: + existing_config.config_value = value else: - existing_config.config_value = value - else: - new_config = SystemConfig(config_key=key, config_value=value) - self.db.add(new_config) + new_config = SystemConfig(config_key=key, config_value=value) + self.db.add(new_config) + + await self.db.commit() + return await self.get_config() + except Exception as e: + await self.db.rollback() + raise e + + def _validate_config_data(self, config_data: Dict[str, Any]) -> None: + """Validate configuration data before saving.""" + if "database" in config_data: + db_config = config_data["database"] + if "url" in db_config: + url = db_config["url"] + if not url.startswith(("postgresql://", "postgresql+asyncpg://")): + raise ValueError("数据库URL必须以 postgresql:// 或 postgresql+asyncpg:// 开头") - await self.db.commit() - return await self.get_config() + if "gemini_api" in config_data: + gemini_config = config_data["gemini_api"] + if "api_key" in gemini_config and len(gemini_config["api_key"]) < 10: + raise ValueError("Gemini API Key长度不能少于10个字符") + if "base_url" in gemini_config and gemini_config["base_url"]: + base_url = gemini_config["base_url"] + if not base_url.startswith(("http://", "https://")): + raise ValueError("Gemini Base URL必须以 http:// 或 https:// 开头") + + if "data_sources" in config_data: + for source_name, source_config in config_data["data_sources"].items(): + if "api_key" in source_config and len(source_config["api_key"]) < 10: + raise ValueError(f"{source_name} API Key长度不能少于10个字符") + + async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse: + """Test a specific configuration.""" + try: + if config_type == "database": + return await self._test_database(config_data) + elif config_type == "gemini": + return await self._test_gemini(config_data) + elif config_type == "tushare": + return await self._test_tushare(config_data) + elif config_type == "finnhub": + return await self._test_finnhub(config_data) + else: + return ConfigTestResponse( + success=False, + message=f"不支持的配置类型: {config_type}" + ) + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"测试失败: {str(e)}" + ) + + async def _test_database(self, config_data: Dict[str, Any]) -> ConfigTestResponse: + """Test database connection.""" + db_url = config_data.get("url") + if not db_url: + return ConfigTestResponse( + success=False, + message="数据库URL不能为空" + ) + + try: + # 解析数据库URL + if db_url.startswith("postgresql+asyncpg://"): + db_url = db_url.replace("postgresql+asyncpg://", "postgresql://") + + # 测试连接 + conn = await asyncpg.connect(db_url) + await conn.close() + + return ConfigTestResponse( + success=True, + message="数据库连接成功" + ) + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"数据库连接失败: {str(e)}" + ) + + async def _test_gemini(self, config_data: Dict[str, Any]) -> ConfigTestResponse: + """Test Gemini API connection.""" + api_key = config_data.get("api_key") + base_url = config_data.get("base_url", "https://generativelanguage.googleapis.com/v1beta") + + if not api_key: + return ConfigTestResponse( + success=False, + message="Gemini API Key不能为空" + ) + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # 测试API可用性 + response = await client.get( + f"{base_url}/models", + headers={"x-goog-api-key": api_key} + ) + + if response.status_code == 200: + return ConfigTestResponse( + success=True, + message="Gemini API连接成功" + ) + else: + return ConfigTestResponse( + success=False, + message=f"Gemini API测试失败: HTTP {response.status_code}" + ) + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"Gemini API连接失败: {str(e)}" + ) + + async def _test_tushare(self, config_data: Dict[str, Any]) -> ConfigTestResponse: + """Test Tushare API connection.""" + api_key = config_data.get("api_key") + + if not api_key: + return ConfigTestResponse( + success=False, + message="Tushare API Key不能为空" + ) + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # 测试API可用性 + response = await client.post( + "http://api.tushare.pro", + json={ + "api_name": "stock_basic", + "token": api_key, + "params": {"list_status": "L"}, + "fields": "ts_code" + } + ) + + if response.status_code == 200: + data = response.json() + if data.get("code") == 0: + return ConfigTestResponse( + success=True, + message="Tushare API连接成功" + ) + else: + return ConfigTestResponse( + success=False, + message=f"Tushare API错误: {data.get('msg', '未知错误')}" + ) + else: + return ConfigTestResponse( + success=False, + message=f"Tushare API测试失败: HTTP {response.status_code}" + ) + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"Tushare API连接失败: {str(e)}" + ) + + async def _test_finnhub(self, config_data: Dict[str, Any]) -> ConfigTestResponse: + """Test Finnhub API connection.""" + api_key = config_data.get("api_key") + + if not api_key: + return ConfigTestResponse( + success=False, + message="Finnhub API Key不能为空" + ) + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # 测试API可用性 + response = await client.get( + f"https://finnhub.io/api/v1/quote", + params={"symbol": "AAPL", "token": api_key} + ) + + if response.status_code == 200: + data = response.json() + if "c" in data: # 检查是否有价格数据 + return ConfigTestResponse( + success=True, + message="Finnhub API连接成功" + ) + else: + return ConfigTestResponse( + success=False, + message="Finnhub API响应格式错误" + ) + else: + return ConfigTestResponse( + success=False, + message=f"Finnhub API测试失败: HTTP {response.status_code}" + ) + except Exception as e: + return ConfigTestResponse( + success=False, + message=f"Finnhub API连接失败: {str(e)}" + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 6b1eeaa..ce16c92 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,3 +6,4 @@ SQLAlchemy==2.0.36 aiosqlite==0.20.0 alembic==1.13.3 google-generativeai==0.8.3 +asyncpg==0.29.0 diff --git a/config/analysis-config.json b/config/analysis-config.json new file mode 100644 index 0000000..b646ddc --- /dev/null +++ b/config/analysis-config.json @@ -0,0 +1,49 @@ +{ + "analysis_modules": { + "company_profile": { + "name": "公司简介", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券市场分析师。请为公司 {company_name} (股票代码: {ts_code}) 生成一份详细且专业的公司介绍。开头不要自我介绍,直接开始正文。正文用MarkDown输出,尽量说明信息来源,用斜体显示信息来源。在生成内容时,请严格遵循以下要求并采用清晰、结构化的格式:\n\n1. **公司概览**:\n * 简要介绍公司的性质、核心业务领域及其在行业中的定位。\n * 提炼并阐述公司的核心价值理念。\n\n2. **主营业务**:\n * 详细描述公司主要的**产品或服务**。\n * **重要提示**:如果能获取到公司最新的官方**年报**或**财务报告**,请从中提取各主要产品/服务线的**收入金额**和其占公司总收入的**百分比**。请**明确标注数据来源**(例如:\"数据来源于XX年年度报告\")。\n * **严格禁止**编造或估算任何财务数据。若无法找到公开、准确的财务数据,请**不要**在这一点中提及具体金额或比例,仅描述业务内容。\n\n3. **发展历程**:\n * 以时间线或关键事件的形式,概述公司自成立以来的主要**里程碑事件**、重大发展阶段、战略转型或重要成就。\n\n4. **核心团队**:\n * 介绍公司**主要管理层和核心技术团队成员**。\n * 对于每位核心成员,提供其**职务、主要工作履历、教育背景**。\n * 如果公开可查,可补充其**出生年份**。\n\n5. **供应链**:\n * 描述公司的**主要原材料、部件或服务来源**。\n * 如果公开信息中包含,请列出**主要供应商名称**,并**明确其在总采购金额中的大致占比**。若无此数据,则仅描述采购模式。\n\n6. **主要客户及销售模式**:\n * 阐明公司的**销售模式**(例如:直销、经销、线上销售、代理等)。\n * 列出公司的**主要客户群体**或**代表性大客户**。\n * 如果公开信息中包含,请标明**主要客户(或前五大客户)的销售额占公司总销售额的比例**。若无此数据,则仅描述客户类型。\n\n7. **未来展望**:\n * 基于公司**公开的官方声明、管理层访谈或战略规划**,总结公司未来的发展方向、战略目标、重点项目或市场预期。请确保此部分内容有可靠的信息来源支持。" + }, + "fundamental_analysis": { + "name": "基本面分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "bull_case": { + "name": "看涨分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "bear_case": { + "name": "看跌分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "market_analysis": { + "name": "市场分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "news_analysis": { + "name": "新闻分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "trading_analysis": { + "name": "交易分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "insider_institutional": { + "name": "内部人与机构动向分析", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + }, + "final_conclusion": { + "name": "最终结论", + "model": "gemini-2.5-flash", + "prompt_template": "您是一位专业的证券分析师。我还没想好,先随便输出100字" + } + } +} \ No newline at end of file diff --git a/config/financial-tushare.json b/config/financial-tushare.json index e6c3d1d..fea9625 100644 --- a/config/financial-tushare.json +++ b/config/financial-tushare.json @@ -41,7 +41,8 @@ "cashflow": [ { "displayText": "经营净现金流", "tushareParam": "n_cashflow_act", "api": "cashflow" }, { "displayText": "资本开支", "tushareParam": "c_pay_acq_const_fiolta", "api": "cashflow" }, - { "displayText": "折旧费用", "tushareParam": "depr_fa_coga_dpba", "api": "cashflow" } + { "displayText": "折旧费用", "tushareParam": "depr_fa_coga_dpba", "api": "cashflow" }, + { "displayText": "支付给职工以及为职工支付的现金", "tushareParam": "c_paid_to_for_empl", "api": "cashflow" } ], "daily_basic": [ { "displayText": "PB", "tushareParam": "pb", "api": "daily_basic" }, diff --git a/frontend/src/app/config/page.tsx b/frontend/src/app/config/page.tsx index 98268b3..ac81959 100644 --- a/frontend/src/app/config/page.tsx +++ b/frontend/src/app/config/page.tsx @@ -1,161 +1,589 @@ 'use client'; import { useState, useEffect } from 'react'; -import { useConfig, updateConfig, testConfig } from '@/hooks/useApi'; +import { useConfig, updateConfig, testConfig, useAnalysisConfig, updateAnalysisConfig } from '@/hooks/useApi'; import { useConfigStore, SystemConfig } from '@/stores/useConfigStore'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Input } from "@/components/ui/input"; +import { Textarea } from "@/components/ui/textarea"; import { Button } from "@/components/ui/button"; import { Badge } from "@/components/ui/badge"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Label } from "@/components/ui/label"; +import { Separator } from "@/components/ui/separator"; +import type { AnalysisConfigResponse } from '@/types'; export default function ConfigPage() { // 从 Zustand store 获取全局状态 const { config, loading, error, setConfig } = useConfigStore(); // 使用 SWR hook 加载初始配置 useConfig(); + + // 加载分析配置 + const { data: analysisConfig, mutate: mutateAnalysisConfig } = useAnalysisConfig(); // 本地表单状态 const [dbUrl, setDbUrl] = useState(''); const [geminiApiKey, setGeminiApiKey] = useState(''); + const [geminiBaseUrl, setGeminiBaseUrl] = useState(''); const [tushareApiKey, setTushareApiKey] = useState(''); - + const [finnhubApiKey, setFinnhubApiKey] = useState(''); + + // 分析配置的本地状态 + const [localAnalysisConfig, setLocalAnalysisConfig] = useState>({}); + + // 分析配置保存状态 + const [savingAnalysis, setSavingAnalysis] = useState(false); + const [analysisSaveMessage, setAnalysisSaveMessage] = useState(''); + // 测试结果状态 - const [dbTestResult, setDbTestResult] = useState<{ success: boolean; message: string } | null>(null); - const [geminiTestResult, setGeminiTestResult] = useState<{ success: boolean; message: string } | null>(null); - + const [testResults, setTestResults] = useState>({}); + // 保存状态 const [saving, setSaving] = useState(false); const [saveMessage, setSaveMessage] = useState(''); - + + // 初始化分析配置的本地状态 useEffect(() => { - if (config) { - setDbUrl(config.database?.url || ''); - // API Keys 不回显 + if (analysisConfig?.analysis_modules) { + setLocalAnalysisConfig(analysisConfig.analysis_modules); } - }, [config]); + }, [analysisConfig]); + + // 更新分析配置中的某个字段 + const updateAnalysisField = (type: string, field: 'name' | 'model' | 'prompt_template', value: string) => { + setLocalAnalysisConfig(prev => ({ + ...prev, + [type]: { + ...prev[type], + [field]: value + } + })); + }; + + // 保存分析配置 + const handleSaveAnalysisConfig = async () => { + setSavingAnalysis(true); + setAnalysisSaveMessage('保存中...'); + + try { + const updated = await updateAnalysisConfig({ + analysis_modules: localAnalysisConfig + }); + await mutateAnalysisConfig(updated); + setAnalysisSaveMessage('保存成功!'); + } catch (e: any) { + setAnalysisSaveMessage(`保存失败: ${e.message}`); + } finally { + setSavingAnalysis(false); + setTimeout(() => setAnalysisSaveMessage(''), 5000); + } + }; + + const validateConfig = () => { + const errors: string[] = []; + + // 验证数据库URL格式 + if (dbUrl && !dbUrl.match(/^postgresql(\+asyncpg)?:\/\/.+/)) { + errors.push('数据库URL格式不正确,应为 postgresql://user:pass@host:port/dbname'); + } + + // 验证Gemini Base URL格式 + if (geminiBaseUrl && !geminiBaseUrl.match(/^https?:\/\/.+/)) { + errors.push('Gemini Base URL格式不正确,应为 http:// 或 https:// 开头'); + } + + // 验证API Key长度(基本检查) + if (geminiApiKey && geminiApiKey.length < 10) { + errors.push('Gemini API Key长度过短'); + } + + if (tushareApiKey && tushareApiKey.length < 10) { + errors.push('Tushare API Key长度过短'); + } + + if (finnhubApiKey && finnhubApiKey.length < 10) { + errors.push('Finnhub API Key长度过短'); + } + + return errors; + }; const handleSave = async () => { + // 验证配置 + const validationErrors = validateConfig(); + if (validationErrors.length > 0) { + setSaveMessage(`配置验证失败: ${validationErrors.join(', ')}`); + return; + } + setSaving(true); setSaveMessage('保存中...'); - const newConfig: Partial = { - database: { url: dbUrl }, - gemini_api: { api_key: geminiApiKey }, - data_sources: { - tushare: { api_key: tushareApiKey }, - }, - }; + const newConfig: Partial = {}; + + // 只更新有值的字段 + if (dbUrl) { + newConfig.database = { url: dbUrl }; + } + + if (geminiApiKey || geminiBaseUrl) { + newConfig.gemini_api = { + api_key: geminiApiKey || config?.gemini_api?.api_key || '', + base_url: geminiBaseUrl || config?.gemini_api?.base_url || undefined, + }; + } + + if (tushareApiKey || finnhubApiKey) { + newConfig.data_sources = { + ...config?.data_sources, + ...(tushareApiKey && { tushare: { api_key: tushareApiKey } }), + ...(finnhubApiKey && { finnhub: { api_key: finnhubApiKey } }), + }; + } try { const updated = await updateConfig(newConfig); setConfig(updated); // 更新全局状态 setSaveMessage('保存成功!'); - setGeminiApiKey(''); // 清空敏感字段输入 + // 清空敏感字段输入 + setGeminiApiKey(''); setTushareApiKey(''); + setFinnhubApiKey(''); } catch (e: any) { setSaveMessage(`保存失败: ${e.message}`); } finally { setSaving(false); - setTimeout(() => setSaveMessage(''), 3000); + setTimeout(() => setSaveMessage(''), 5000); } }; - const handleTestDb = async () => { - const result = await testConfig('database', { url: dbUrl }); - setDbTestResult(result); + const handleTest = async (type: string, data: any) => { + try { + const result = await testConfig(type, data); + setTestResults(prev => ({ ...prev, [type]: result })); + } catch (e: any) { + setTestResults(prev => ({ + ...prev, + [type]: { success: false, message: e.message } + })); + } }; - const handleTestGemini = async () => { - const result = await testConfig('gemini', { api_key: geminiApiKey || config?.gemini_api.api_key }); - setGeminiTestResult(result); + const handleTestDb = () => { + handleTest('database', { url: dbUrl }); }; - if (loading) return
Loading...
; - if (error) return
Error loading config: {error}
; + const handleTestGemini = () => { + handleTest('gemini', { + api_key: geminiApiKey || config?.gemini_api?.api_key, + base_url: geminiBaseUrl || config?.gemini_api?.base_url + }); + }; + + const handleTestTushare = () => { + handleTest('tushare', { api_key: tushareApiKey || config?.data_sources?.tushare?.api_key }); + }; + + const handleTestFinnhub = () => { + handleTest('finnhub', { api_key: finnhubApiKey || config?.data_sources?.finnhub?.api_key }); + }; + + const handleReset = () => { + setDbUrl(''); + setGeminiApiKey(''); + setGeminiBaseUrl(''); + setTushareApiKey(''); + setFinnhubApiKey(''); + setTestResults({}); + setSaveMessage(''); + }; + + const handleExportConfig = () => { + if (!config) return; + + const configToExport = { + database: config.database, + gemini_api: config.gemini_api, + data_sources: config.data_sources, + export_time: new Date().toISOString(), + version: "1.0" + }; + + const blob = new Blob([JSON.stringify(configToExport, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `config-backup-${new Date().toISOString().split('T')[0]}.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }; + + const handleImportConfig = (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (!file) return; + + const reader = new FileReader(); + reader.onload = (e) => { + try { + const importedConfig = JSON.parse(e.target?.result as string); + + // 验证导入的配置格式 + if (importedConfig.database?.url) { + setDbUrl(importedConfig.database.url); + } + if (importedConfig.gemini_api?.base_url) { + setGeminiBaseUrl(importedConfig.gemini_api.base_url); + } + + setSaveMessage('配置导入成功,请检查并保存'); + } catch (error) { + setSaveMessage('配置文件格式错误,导入失败'); + } + }; + reader.readAsText(file); + }; + + if (loading) return ( +
+
+
+

加载配置中...

+
+
+ ); + + if (error) return ( +
+
+
⚠️
+

加载配置失败: {error}

+
+
+ ); return ( -
+
-

配置中心

-

- 管理系统配置,包括数据库、API密钥等。敏感密钥不回显,留空表示保持现值。 +

配置中心

+

+ 管理系统配置,包括数据库连接、API密钥等。敏感信息不回显,留空表示保持当前值。

- - - 数据库配置 - PostgreSQL 连接设置 - - -
- - setDbUrl(e.target.value)} - placeholder="postgresql+asyncpg://user:pass@host:port/dbname" - className="flex-1" - /> - -
- {dbTestResult && ( - - {dbTestResult.message} - + + + 数据库 + AI服务 + 数据源 + 分析配置 + 系统 + + + + + + 数据库配置 + PostgreSQL 数据库连接设置 + + +
+ +
+ setDbUrl(e.target.value)} + placeholder="postgresql+asyncpg://user:password@host:port/database" + className="flex-1" + /> + +
+ {testResults.database && ( + + {testResults.database.message} + + )} +
+
+
+
+ + + + + AI 服务配置 + Google Gemini API 设置 + + +
+ +
+ setGeminiApiKey(e.target.value)} + placeholder="留空表示保持当前值" + className="flex-1" + /> + +
+ {testResults.gemini && ( + + {testResults.gemini.message} + + )} +
+ +
+ + setGeminiBaseUrl(e.target.value)} + placeholder="https://generativelanguage.googleapis.com/v1beta" + className="flex-1" + /> +
+
+
+
+ + + + + 数据源配置 + 外部数据源 API 设置 + + +
+
+ +

中国股票数据源

+
+ setTushareApiKey(e.target.value)} + placeholder="留空表示保持当前值" + className="flex-1" + /> + +
+ {testResults.tushare && ( + + {testResults.tushare.message} + + )} +
+ + + +
+ +

全球金融市场数据源

+
+ setFinnhubApiKey(e.target.value)} + placeholder="留空表示保持当前值" + className="flex-1" + /> + +
+ {testResults.finnhub && ( + + {testResults.finnhub.message} + + )} +
+
+
+
+
+ + + + + 分析模块配置 + 配置各个分析模块的模型和提示词 + + + {Object.entries(localAnalysisConfig).map(([type, config]) => ( +
+
+

{config.name || type}

+ {type} +
+ +
+ + updateAnalysisField(type, 'name', e.target.value)} + placeholder="分析模块显示名称" + /> +
+ +
+ + updateAnalysisField(type, 'model', e.target.value)} + placeholder="例如: gemini-2.5-flash" + /> +

+ 使用的Gemini模型名称 +

+
+ +
+ +