Initial commit: Fundamental stock analysis project setup
This commit is contained in:
commit
91f701139f
314
.gitignore
vendored
Normal file
314
.gitignore
vendored
Normal file
@ -0,0 +1,314 @@
|
||||
# ===== 通用文件 =====
|
||||
# 操作系统生成的文件
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# IDE 和编辑器
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# 日志文件
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# 临时文件
|
||||
*.tmp
|
||||
*.temp
|
||||
.cache/
|
||||
|
||||
# ===== Python 后端 =====
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env.*
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# ===== Node.js / Next.js 前端 =====
|
||||
# Dependencies
|
||||
node_modules/
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# Runtime data
|
||||
pids
|
||||
*.pid
|
||||
*.seed
|
||||
*.pid.lock
|
||||
|
||||
# Coverage directory used by tools like istanbul
|
||||
coverage/
|
||||
*.lcov
|
||||
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# Grunt intermediate storage
|
||||
.grunt
|
||||
|
||||
# Bower dependency directory
|
||||
bower_components
|
||||
|
||||
# node-waf configuration
|
||||
.lock-wscript
|
||||
|
||||
# Compiled binary addons
|
||||
build/Release
|
||||
|
||||
# Dependency directories
|
||||
jspm_packages/
|
||||
|
||||
# Snowpack dependency directory
|
||||
web_modules/
|
||||
|
||||
# TypeScript cache
|
||||
*.tsbuildinfo
|
||||
|
||||
# Optional npm cache directory
|
||||
.npm
|
||||
|
||||
# Optional eslint cache
|
||||
.eslintcache
|
||||
|
||||
# Optional stylelint cache
|
||||
.stylelintcache
|
||||
|
||||
# Microbundle cache
|
||||
.rpt2_cache/
|
||||
.rts2_cache_cjs/
|
||||
.rts2_cache_es/
|
||||
.rts2_cache_umd/
|
||||
|
||||
# Optional REPL history
|
||||
.node_repl_history
|
||||
|
||||
# Output of 'npm pack'
|
||||
*.tgz
|
||||
|
||||
# Yarn Integrity file
|
||||
.yarn-integrity
|
||||
|
||||
# parcel-bundler cache
|
||||
.parcel-cache
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
out/
|
||||
|
||||
# Nuxt.js build / generate output
|
||||
.nuxt
|
||||
dist
|
||||
|
||||
# Gatsby files
|
||||
.cache/
|
||||
public
|
||||
|
||||
# Storybook build outputs
|
||||
.out
|
||||
.storybook-out
|
||||
storybook-static
|
||||
|
||||
# Temporary folders
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# Vercel
|
||||
.vercel
|
||||
|
||||
# Turbo
|
||||
.turbo
|
||||
|
||||
# ===== 数据库 =====
|
||||
# SQLite
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
# PostgreSQL
|
||||
*.sql
|
||||
|
||||
# MySQL
|
||||
*.mysql
|
||||
|
||||
# ===== 配置和密钥 =====
|
||||
# 环境变量文件
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
|
||||
# API 密钥和配置
|
||||
config/secrets.json
|
||||
secrets/
|
||||
*.key
|
||||
*.pem
|
||||
*.p12
|
||||
*.pfx
|
||||
|
||||
# ===== 项目特定 =====
|
||||
# 数据文件
|
||||
data/
|
||||
*.csv
|
||||
*.json.bak
|
||||
*.xlsx
|
||||
|
||||
# 报告和输出
|
||||
reports/
|
||||
output/
|
||||
exports/
|
||||
|
||||
# 备份文件
|
||||
*.bak
|
||||
*.backup
|
||||
backup/
|
||||
|
||||
# 测试数据
|
||||
test_data/
|
||||
mock_data/
|
||||
351
.kiro/specs/fundamental-stock-analysis/design.md
Normal file
351
.kiro/specs/fundamental-stock-analysis/design.md
Normal file
@ -0,0 +1,351 @@
|
||||
# 设计文档 - 基本面选股系统
|
||||
|
||||
## 概览
|
||||
|
||||
基本面选股系统是一个全栈Web应用,采用前后端分离架构。前端使用Next.js和shadcn/ui构建响应式中文界面,后端使用Python FastAPI提供API服务。系统通过多个专业分析模块,结合财务数据API、AI大模型和实时数据,为用户提供全面的股票基本面分析报告。
|
||||
|
||||
## 架构
|
||||
|
||||
### 系统架构图
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "前端层"
|
||||
A[Next.js应用] --> B[shadcn/ui UI组件]
|
||||
A --> C[TradingView图表组件]
|
||||
end
|
||||
|
||||
subgraph "后端层"
|
||||
D[FastAPI服务器] --> E[报告生成引擎]
|
||||
D --> F[数据获取服务]
|
||||
D --> G[配置管理服务]
|
||||
end
|
||||
|
||||
subgraph "外部服务"
|
||||
H[Tushare API]
|
||||
I[Gemini API]
|
||||
J[其他数据源APIs]
|
||||
end
|
||||
|
||||
subgraph "数据层"
|
||||
K[PostgreSQL数据库]
|
||||
end
|
||||
|
||||
A --> D
|
||||
F --> H
|
||||
F --> I
|
||||
F --> J
|
||||
E --> K
|
||||
G --> K
|
||||
```
|
||||
|
||||
### 技术栈
|
||||
|
||||
**前端:**
|
||||
- Next.js 14 (App Router)
|
||||
- TypeScript
|
||||
- shadcn/ui组件库 (https://ui.shadcn.com/)
|
||||
- TradingView Charting Library
|
||||
- Tailwind CSS
|
||||
- Radix UI (shadcn/ui的底层组件)
|
||||
|
||||
**后端:**
|
||||
- Python 3.11+
|
||||
- FastAPI
|
||||
- SQLAlchemy (ORM)
|
||||
- Alembic (数据库迁移)
|
||||
- Pydantic (数据验证)
|
||||
- httpx (HTTP客户端)
|
||||
|
||||
**数据库:**
|
||||
- PostgreSQL 15+
|
||||
|
||||
**外部服务:**
|
||||
- Tushare API (中国股票数据)
|
||||
- Google Gemini API (AI分析)
|
||||
- 其他市场数据源APIs
|
||||
|
||||
## 组件和接口
|
||||
|
||||
### shadcn/ui组件使用规划
|
||||
|
||||
系统将使用shadcn/ui (https://ui.shadcn.com/) 的官方组件来构建一致的用户界面:
|
||||
|
||||
**核心组件使用:**
|
||||
- `Button`: 主要操作按钮(生成报告、保存配置等)
|
||||
- `Input`: 证券代码输入框
|
||||
- `Select`: 交易市场选择器
|
||||
- `Card`: 分析模块容器、报告卡片
|
||||
- `Progress`: 报告生成进度条
|
||||
- `Badge`: 状态标识(完成、进行中、失败)
|
||||
- `Tabs`: 分析模块切换
|
||||
- `Form`: 配置表单、搜索表单
|
||||
- `Alert`: 错误提示、成功消息
|
||||
- `Separator`: 内容分隔线
|
||||
- `Skeleton`: 加载占位符
|
||||
- `Toast`: 操作反馈通知
|
||||
- `Table`: 财务数据展示表格(资产负债表、利润表、现金流量表等)
|
||||
|
||||
**主题配置:**
|
||||
- 使用默认主题,支持深色/浅色模式切换
|
||||
- 自定义中文字体配置
|
||||
- 适配中文内容的间距和排版
|
||||
|
||||
### 前端组件结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── app/
|
||||
│ ├── page.tsx # 首页
|
||||
│ ├── report/[symbol]/page.tsx # 报告页面
|
||||
│ ├── config/page.tsx # 配置页面
|
||||
│ └── layout.tsx # 根布局
|
||||
├── components/
|
||||
│ ├── ui/ # shadcn/ui基础组件 (Button, Input, Card, etc.)
|
||||
│ ├── StockSearchForm.tsx # 股票搜索表单 (使用Form, Input, Select)
|
||||
│ ├── ReportProgress.tsx # 报告生成进度 (使用Progress, Badge, Card)
|
||||
│ ├── TradingViewChart.tsx # TradingView图表
|
||||
│ ├── AnalysisModule.tsx # 分析模块容器 (使用Card, Tabs, Separator)
|
||||
│ ├── FinancialDataTable.tsx # 财务数据表格 (使用Table, TableHeader, TableBody, TableRow, TableCell)
|
||||
│ └── ConfigForm.tsx # 配置表单 (使用Form, Input, Button, Alert)
|
||||
├── lib/
|
||||
│ ├── api.ts # API客户端
|
||||
│ ├── types.ts # TypeScript类型定义
|
||||
│ └── utils.ts # 工具函数
|
||||
└── hooks/
|
||||
├── useReport.ts # 报告数据钩子
|
||||
└── useProgress.ts # 进度追踪钩子
|
||||
```
|
||||
|
||||
### 后端API结构
|
||||
|
||||
```
|
||||
app/
|
||||
├── main.py # FastAPI应用入口
|
||||
├── models/
|
||||
│ ├── __init__.py
|
||||
│ ├── report.py # 报告数据模型
|
||||
│ ├── config.py # 配置数据模型
|
||||
│ └── progress.py # 进度追踪模型
|
||||
├── schemas/
|
||||
│ ├── __init__.py
|
||||
│ ├── report.py # 报告Pydantic模式
|
||||
│ ├── config.py # 配置Pydantic模式
|
||||
│ └── progress.py # 进度Pydantic模式
|
||||
├── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── data_fetcher.py # 数据获取服务
|
||||
│ ├── ai_analyzer.py # AI分析服务
|
||||
│ ├── report_generator.py # 报告生成服务
|
||||
│ └── config_manager.py # 配置管理服务
|
||||
├── routers/
|
||||
│ ├── __init__.py
|
||||
│ ├── reports.py # 报告相关API
|
||||
│ ├── config.py # 配置相关API
|
||||
│ └── progress.py # 进度相关API
|
||||
└── core/
|
||||
├── __init__.py
|
||||
├── database.py # 数据库连接
|
||||
├── config.py # 应用配置
|
||||
└── dependencies.py # 依赖注入
|
||||
```
|
||||
|
||||
### 核心API接口
|
||||
|
||||
#### 报告相关API
|
||||
|
||||
```python
|
||||
# GET /api/reports/{symbol}?market={market}
|
||||
# 获取或生成股票报告
|
||||
class ReportResponse:
|
||||
symbol: str
|
||||
market: str
|
||||
report_id: str
|
||||
status: str # "existing" | "generating" | "completed" | "failed"
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
modules: List[AnalysisModule]
|
||||
|
||||
# POST /api/reports/{symbol}/regenerate?market={market}
|
||||
# 重新生成报告
|
||||
class RegenerateRequest:
|
||||
force: bool = False
|
||||
|
||||
# GET /api/reports/{report_id}/progress
|
||||
# 获取报告生成进度
|
||||
class ProgressResponse:
|
||||
report_id: str
|
||||
current_step: int
|
||||
total_steps: int
|
||||
current_step_name: str
|
||||
status: str # "running" | "completed" | "failed"
|
||||
step_timings: List[StepTiming]
|
||||
estimated_remaining: Optional[int]
|
||||
```
|
||||
|
||||
#### 配置相关API
|
||||
|
||||
```python
|
||||
# GET /api/config
|
||||
# 获取系统配置
|
||||
class ConfigResponse:
|
||||
database: DatabaseConfig
|
||||
gemini_api: GeminiConfig
|
||||
data_sources: Dict[str, DataSourceConfig]
|
||||
|
||||
# PUT /api/config
|
||||
# 更新系统配置
|
||||
class ConfigUpdateRequest:
|
||||
database: Optional[DatabaseConfig]
|
||||
gemini_api: Optional[GeminiConfig]
|
||||
data_sources: Optional[Dict[str, DataSourceConfig]]
|
||||
|
||||
# POST /api/config/test
|
||||
# 测试配置连接
|
||||
class ConfigTestRequest:
|
||||
config_type: str # "database" | "gemini" | "data_source"
|
||||
config_data: Dict[str, Any]
|
||||
```
|
||||
|
||||
## 数据模型
|
||||
|
||||
### 数据库表结构
|
||||
|
||||
```sql
|
||||
-- 报告表
|
||||
CREATE TABLE reports (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
symbol VARCHAR(20) NOT NULL,
|
||||
market VARCHAR(20) NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'generating',
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(symbol, market)
|
||||
);
|
||||
|
||||
-- 分析模块表
|
||||
CREATE TABLE analysis_modules (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
|
||||
module_type VARCHAR(50) NOT NULL,
|
||||
module_order INTEGER NOT NULL,
|
||||
title VARCHAR(200) NOT NULL,
|
||||
content JSONB,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
completed_at TIMESTAMP WITH TIME ZONE,
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
-- 进度追踪表
|
||||
CREATE TABLE progress_tracking (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
|
||||
step_name VARCHAR(100) NOT NULL,
|
||||
step_order INTEGER NOT NULL,
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'pending',
|
||||
started_at TIMESTAMP WITH TIME ZONE,
|
||||
completed_at TIMESTAMP WITH TIME ZONE,
|
||||
duration_ms INTEGER,
|
||||
error_message TEXT
|
||||
);
|
||||
|
||||
-- 系统配置表
|
||||
CREATE TABLE system_config (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
config_key VARCHAR(100) UNIQUE NOT NULL,
|
||||
config_value JSONB NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
```
|
||||
|
||||
### 分析模块类型定义
|
||||
|
||||
```python
|
||||
class AnalysisModuleType(Enum):
|
||||
TRADING_VIEW_CHART = "trading_view_chart"
|
||||
FINANCIAL_DATA = "financial_data"
|
||||
BUSINESS_INFO = "business_info"
|
||||
FUNDAMENTAL_ANALYSIS = "fundamental_analysis"
|
||||
BULLISH_ANALYSIS = "bullish_analysis"
|
||||
BEARISH_ANALYSIS = "bearish_analysis"
|
||||
MARKET_ANALYSIS = "market_analysis"
|
||||
NEWS_ANALYSIS = "news_analysis"
|
||||
TRADING_ANALYSIS = "trading_analysis"
|
||||
INSIDER_ANALYSIS = "insider_analysis"
|
||||
FINAL_CONCLUSION = "final_conclusion"
|
||||
|
||||
class AnalysisModule(BaseModel):
|
||||
id: UUID
|
||||
module_type: AnalysisModuleType
|
||||
title: str
|
||||
content: Dict[str, Any]
|
||||
status: str
|
||||
duration_ms: Optional[int]
|
||||
error_message: Optional[str]
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
### 错误类型定义
|
||||
|
||||
```python
|
||||
class StockAnalysisError(Exception):
|
||||
"""基础异常类"""
|
||||
pass
|
||||
|
||||
class DataSourceError(StockAnalysisError):
|
||||
"""数据源错误"""
|
||||
pass
|
||||
|
||||
class AIAnalysisError(StockAnalysisError):
|
||||
"""AI分析错误"""
|
||||
pass
|
||||
|
||||
class ConfigurationError(StockAnalysisError):
|
||||
"""配置错误"""
|
||||
pass
|
||||
|
||||
class DatabaseError(StockAnalysisError):
|
||||
"""数据库错误"""
|
||||
pass
|
||||
```
|
||||
|
||||
### 错误处理策略
|
||||
|
||||
1. **数据获取失败**: 重试机制,最多3次重试,指数退避
|
||||
2. **AI分析失败**: 记录错误,继续其他模块,最后汇总失败信息
|
||||
3. **数据库连接失败**: 使用连接池,自动重连
|
||||
4. **配置错误**: 提供详细错误信息,阻止系统启动
|
||||
5. **前端错误**: Toast通知,错误边界组件
|
||||
|
||||
## 测试策略
|
||||
|
||||
### 测试层级
|
||||
|
||||
1. **单元测试**
|
||||
- 后端服务函数测试
|
||||
- 前端组件测试
|
||||
- 数据模型验证测试
|
||||
|
||||
2. **集成测试**
|
||||
- API端点测试
|
||||
- 数据库操作测试
|
||||
- 外部服务集成测试
|
||||
|
||||
3. **端到端测试**
|
||||
- 完整报告生成流程测试
|
||||
- 用户界面交互测试
|
||||
|
||||
### 测试工具
|
||||
|
||||
- **后端**: pytest, pytest-asyncio, httpx
|
||||
- **前端**: Jest, React Testing Library, Playwright
|
||||
- **数据库**: pytest-postgresql
|
||||
- **API测试**: FastAPI TestClient
|
||||
|
||||
### 测试数据
|
||||
|
||||
- 使用测试数据库和模拟数据
|
||||
- 外部API使用mock响应
|
||||
- 测试用例覆盖各种市场和股票类型
|
||||
112
.kiro/specs/fundamental-stock-analysis/requirements.md
Normal file
112
.kiro/specs/fundamental-stock-analysis/requirements.md
Normal file
@ -0,0 +1,112 @@
|
||||
# 需求文档 - MVP版本
|
||||
|
||||
## 介绍
|
||||
|
||||
基本面选股系统MVP是一个综合的中文网站,允许用户输入证券代码和交易市场,生成包含多维度分析的详细股票基本面报告。系统通过多个专业分析模块,结合财务数据、AI分析和市场信息,为用户提供全面的投资决策支持。
|
||||
|
||||
## 术语表
|
||||
|
||||
- **选股系统 (Stock_Selection_System)**: 提供基本面分析和报告生成的主要系统
|
||||
- **用户 (User)**: 使用系统进行股票分析的终端用户
|
||||
- **证券代码 (Security_Code)**: 股票在特定交易市场的唯一标识符
|
||||
- **交易市场 (Trading_Market)**: 股票交易的地理区域,包括中国、香港、美国、日本
|
||||
- **基本面报告 (Fundamental_Report)**: 包含九个分析模块的综合股票分析报告
|
||||
- **TradingView图表 (TradingView_Chart)**: 使用TradingView高级图表组件显示的股价图表
|
||||
- **Tushare_API**: 用于获取中国股票财务数据的数据源接口
|
||||
- **Gemini_Model**: Google的大语言模型,用于生成业务分析内容
|
||||
- **景林模型 (Jinglin_Model)**: 基本面分析师使用的问题集分析框架
|
||||
- **PostgreSQL数据库 (PostgreSQL_Database)**: 用于存储报告数据的关系型数据库
|
||||
- **分析模块 (Analysis_Module)**: 报告中的独立分析部分,每个模块对应一个显示页面
|
||||
|
||||
## 需求
|
||||
|
||||
### 需求 1
|
||||
|
||||
**用户故事:** 作为投资者,我希望能够输入股票代码和选择交易市场,以便获取该股票的综合基本面分析报告
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当用户访问首页时,选股系统应当显示证券代码输入框和交易市场选择器
|
||||
2. 当用户选择交易市场时,选股系统应当提供中国、香港、美国、日本四个选项
|
||||
3. 当用户提交证券代码和交易市场时,选股系统应当处理用户请求并跳转到报告页面
|
||||
|
||||
### 需求 2
|
||||
|
||||
**用户故事:** 作为投资者,我希望系统能够检查历史报告,以便决定是查看现有报告还是生成新报告
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当用户提交证券代码和交易市场后,选股系统应当在PostgreSQL数据库中查询对应的历史报告
|
||||
2. 如果存在历史报告,选股系统应当显示历史报告内容和"生成最新报告"按钮
|
||||
3. 如果不存在历史报告,选股系统应当自动启动九步报告生成流程
|
||||
|
||||
### 需求 3
|
||||
|
||||
**用户故事:** 作为投资者,我希望系统能够获取准确的财务数据,以便进行可靠的基本面分析
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当生成中国股票报告时,选股系统应当使用Tushare_API获取财务信息
|
||||
2. 当处理其他市场股票时,选股系统应当根据交易市场选择相应的数据源
|
||||
3. 当财务数据获取完成时,选股系统应当将数据作为后续分析的基础
|
||||
|
||||
### 需求 4
|
||||
|
||||
**用户故事:** 作为投资者,我希望系统能够通过AI分析获取公司业务信息,以便了解公司的全面情况
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当需要业务信息时,选股系统应当使用Gemini生成公司概览、主营业务、发展历程、核心团队、供应链、主要客户及销售模式、未来展望
|
||||
2. 当调用Gemini_Model时,选股系统应当使用配置的API密钥进行身份验证
|
||||
3. 当业务信息生成完成时,选股系统应当将内容整合到报告的第二部分
|
||||
|
||||
### 需求 5
|
||||
|
||||
**用户故事:** 作为投资者,我希望系统能够提供多维度的专业分析,以便获得全面的投资决策支持
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当生成报告时,选股系统应当按顺序执行10个分析模块:财务信息、业务信息、基本面分析、看涨分析、看跌分析、市场分析、新闻分析、交易分析、内部人与机构动向分析、最终结论
|
||||
2. 当执行基本面分析时,选股系统应当使用问题集进行分析
|
||||
3. 当执行看涨分析时,选股系统应当研究潜在隐藏资产和护城河竞争优势
|
||||
4. 当执行看跌分析时,选股系统应当分析公司价值底线和最坏情况
|
||||
5. 当执行市场分析时,选股系统应当研究市场情绪分歧点与变化驱动
|
||||
6. 当执行新闻分析时,选股系统应当研究股价催化剂与拐点预判
|
||||
7. 当执行交易分析时,选股系统应当研究市场体量与增长路径
|
||||
8. 当执行内部人分析时,选股系统应当研究内部人与机构动向
|
||||
9. 当生成最终结论时,选股系统应当指出关键矛盾与预期差以及拐点的临近
|
||||
|
||||
### 需求 6
|
||||
|
||||
**用户故事:** 作为投资者,我希望每个分析模块都能独立查看,以便专注于特定的分析维度
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当显示报告时,选股系统应当为每个分析模块提供独立的显示页面
|
||||
2. 当用户在模块间切换时,选股系统应当保持导航的流畅性
|
||||
3. 当所有模块完成时,选股系统应当将完整报告保存到PostgreSQL数据库
|
||||
|
||||
### 需求 7
|
||||
|
||||
**用户故事:** 作为投资者,我希望在报告生成过程中能够看到实时进度,以便了解当前状态和预估完成时间
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 当开始生成报告时,选股系统应当显示进度指示器展示所有分析步骤
|
||||
2. 当执行每个分析步骤时,选股系统应当高亮显示当前正在进行的步骤
|
||||
3. 当每个步骤完成时,选股系统应当更新步骤状态为已完成
|
||||
4. 当执行分析步骤时,选股系统应当记录每个步骤的开始时间和完成时间
|
||||
5. 当显示进度时,选股系统应当展示每个步骤的耗时统计
|
||||
6. 当步骤执行失败时,选股系统应当显示错误状态和错误信息
|
||||
|
||||
### 需求 8
|
||||
|
||||
**用户故事:** 作为系统管理员,我希望能够配置系统参数,以便系统能够正常连接外部服务
|
||||
|
||||
#### 验收标准
|
||||
|
||||
1. 选股系统应当提供配置页面用于设置数据库连接参数
|
||||
2. 选股系统应当提供配置页面用于设置Gemini_API密钥
|
||||
3. 选股系统应当提供配置页面用于设置各市场的数据源配置
|
||||
4. 当配置更新时,选股系统应当验证配置的有效性
|
||||
5. 当配置保存时,选股系统应当将配置持久化存储
|
||||
167
.kiro/specs/fundamental-stock-analysis/tasks.md
Normal file
167
.kiro/specs/fundamental-stock-analysis/tasks.md
Normal file
@ -0,0 +1,167 @@
|
||||
# 实施计划
|
||||
|
||||
- [x] 1. 后端项目初始化和基础架构
|
||||
- 创建Python FastAPI项目结构
|
||||
- 设置虚拟环境和依赖管理(requirements.txt或pyproject.toml)
|
||||
- 配置FastAPI应用入口(main.py)
|
||||
- 创建核心目录结构(models, schemas, services, routers, core)
|
||||
- 设置基础配置管理(core/config.py)
|
||||
- _需求: 8.1, 8.2_
|
||||
|
||||
- [x] 2. 数据库设置和模型定义
|
||||
- 配置PostgreSQL数据库连接(core/database.py)
|
||||
- 创建SQLAlchemy数据模型(reports, analysis_modules, progress_tracking, system_config)
|
||||
- 设置Alembic数据库迁移工具
|
||||
- 创建初始数据库迁移脚本
|
||||
- 实现数据库会话管理和依赖注入
|
||||
- _需求: 6.3, 8.1_
|
||||
|
||||
- [x] 3. Pydantic模式和基础服务
|
||||
- 创建Pydantic数据验证模式(schemas/)
|
||||
- 实现配置管理服务(services/config_manager.py)
|
||||
- 创建数据获取服务基础架构(services/data_fetcher.py)
|
||||
- 实现基础错误处理和异常类
|
||||
- _需求: 8.2, 8.3, 8.4, 8.5_
|
||||
|
||||
- [x] 4. 外部API集成服务
|
||||
- 实现Tushare API集成(中国股票数据获取)
|
||||
- 实现Gemini API集成(AI分析服务)
|
||||
- 创建数据源配置和切换逻辑
|
||||
- 添加API调用错误处理和重试机制
|
||||
- _需求: 3.1, 3.2, 4.1, 4.2_
|
||||
|
||||
- [x] 5. 报告生成引擎核心
|
||||
- 创建报告生成服务(services/report_generator.py)
|
||||
- 实现分析模块执行框架
|
||||
- 创建进度追踪服务(services/progress_tracker.py)
|
||||
- 实现步骤计时和状态管理
|
||||
- _需求: 5.1, 7.1, 7.2, 7.3, 7.4, 7.5_
|
||||
|
||||
- [x] 6. 后端API路由实现
|
||||
- 实现报告相关API端点(routers/reports.py)
|
||||
- 创建配置管理API端点(routers/config.py)
|
||||
- 实现进度追踪API端点(routers/progress.py)
|
||||
- 添加API文档和验证
|
||||
- _需求: 2.1, 2.2, 2.3, 8.1, 8.2, 8.3_
|
||||
|
||||
- [x] 7. 前端项目初始化
|
||||
- 创建Next.js项目并配置TypeScript
|
||||
- 安装和配置shadcn/ui组件库
|
||||
- 设置Tailwind CSS和基础样式
|
||||
- 配置项目文件夹结构(components, lib, hooks, app)
|
||||
- 创建基础布局和主题配置
|
||||
- _需求: 1.1_
|
||||
|
||||
- [ ] 8. 前端核心组件开发
|
||||
- 安装和配置shadcn/ui基础组件
|
||||
- 实现StockSearchForm组件(使用Form, Input, Select, Button)
|
||||
- 创建ReportProgress组件(使用Progress, Badge, Card)
|
||||
- 实现AnalysisModule组件(使用Card, Tabs, Separator)
|
||||
- 创建FinancialDataTable组件(使用Table组件系列)
|
||||
- _需求: 1.1, 1.2, 7.1, 7.2_
|
||||
|
||||
- [ ] 9. 首页和股票搜索功能
|
||||
- 实现首页布局和设计(app/page.tsx)
|
||||
- 创建股票代码输入和市场选择功能
|
||||
- 实现表单验证和提交逻辑
|
||||
- 添加中文界面文本和错误提示
|
||||
- 连接前端表单到后端API
|
||||
- _需求: 1.1, 1.2, 1.3_
|
||||
|
||||
- [ ] 10. 报告页面和历史报告功能
|
||||
- 实现报告页面路由(app/report/[symbol]/page.tsx)
|
||||
- 创建历史报告检查和显示逻辑
|
||||
- 实现"生成最新报告"按钮功能
|
||||
- 添加报告加载状态和错误处理
|
||||
- _需求: 2.1, 2.2, 2.3_
|
||||
|
||||
- [ ] 11. TradingView图表集成
|
||||
- 集成TradingView高级图表组件
|
||||
- 实现图表配置和参数设置
|
||||
- 根据证券代码和市场配置图表
|
||||
- 处理图表加载错误和异常情况
|
||||
- _需求: 5.1, 5.2, 5.3, 5.4_
|
||||
|
||||
- [ ] 12. 财务数据分析模块
|
||||
- 实现财务数据获取和处理逻辑
|
||||
- 创建财务数据格式化和展示
|
||||
- 实现FinancialDataTable的数据绑定
|
||||
- 添加财务数据的错误处理和重试
|
||||
- _需求: 3.1, 3.2, 3.3_
|
||||
|
||||
- [ ] 13. AI业务信息分析模块
|
||||
- 实现Gemini API调用逻辑和提示词模板
|
||||
- 创建业务信息分析内容生成
|
||||
- 实现公司概览、主营业务、发展历程等内容
|
||||
- 添加AI分析结果的格式化和展示
|
||||
- _需求: 4.1, 4.2, 4.3_
|
||||
|
||||
- [ ] 14. 专业分析模块实现
|
||||
- 实现景林模型基本面分析模块
|
||||
- 创建看涨分析师模块(隐藏资产、护城河分析)
|
||||
- 实现看跌分析师模块(价值底线、最坏情况分析)
|
||||
- 创建市场分析师模块(市场情绪分歧点分析)
|
||||
- 实现新闻分析师模块(股价催化剂分析)
|
||||
- 创建交易分析模块(市场体量与增长路径)
|
||||
- 实现内部人与机构动向分析模块
|
||||
- 创建最终结论模块(关键矛盾与拐点分析)
|
||||
- _需求: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9_
|
||||
|
||||
- [ ] 15. 报告生成流程整合
|
||||
- 整合所有分析模块到报告生成引擎
|
||||
- 实现模块间的数据传递和依赖关系
|
||||
- 创建报告生成的错误处理和重试机制
|
||||
- 实现报告完成后的数据库保存
|
||||
- _需求: 5.1, 6.3_
|
||||
|
||||
- [ ] 16. 实时进度显示功能
|
||||
- 实现前端进度追踪钩子(useProgress)
|
||||
- 连接WebSocket或Server-Sent Events到进度显示
|
||||
- 添加步骤高亮和状态更新
|
||||
- 实现计时显示和预估完成时间
|
||||
- 添加错误状态显示
|
||||
- _需求: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6_
|
||||
|
||||
- [ ] 17. 配置管理页面
|
||||
- 创建配置页面布局和表单(app/config/page.tsx)
|
||||
- 实现数据库配置界面
|
||||
- 添加Gemini API配置功能
|
||||
- 创建数据源配置管理
|
||||
- 实现配置验证和测试功能
|
||||
- _需求: 8.1, 8.2, 8.3, 8.4, 8.5_
|
||||
|
||||
- [ ] 18. 报告展示和导航优化
|
||||
- 实现分析模块的独立页面展示
|
||||
- 创建模块间的流畅导航
|
||||
- 添加报告概览和目录功能
|
||||
- 优化移动端响应式显示
|
||||
- _需求: 6.1, 6.2_
|
||||
|
||||
- [ ] 19. 错误处理和用户体验优化
|
||||
- 实现全局错误处理和错误边界
|
||||
- 添加Toast通知系统
|
||||
- 创建加载状态和骨架屏
|
||||
- 优化中文界面和用户反馈
|
||||
- 添加操作确认和提示
|
||||
- _需求: 7.6, 1.1_
|
||||
|
||||
- [ ]* 20. 测试实现
|
||||
- [ ]* 20.1 后端单元测试
|
||||
- 为数据获取服务编写单元测试
|
||||
- 为AI分析服务编写单元测试
|
||||
- 为报告生成引擎编写单元测试
|
||||
- 为配置管理服务编写单元测试
|
||||
|
||||
- [ ]* 20.2 前端组件测试
|
||||
- 为核心组件编写React Testing Library测试
|
||||
- 为表单组件编写交互测试
|
||||
- 为进度组件编写状态测试
|
||||
|
||||
- [ ]* 20.3 API集成测试
|
||||
- 为报告生成API编写集成测试
|
||||
- 为配置管理API编写集成测试
|
||||
- 为进度追踪API编写集成测试
|
||||
|
||||
- [ ]* 20.4 端到端测试
|
||||
- 编写完整报告生成流程的E2E测试
|
||||
- 编写配置管理流程的E2E测试
|
||||
228
backend/DATABASE_SETUP.md
Normal file
228
backend/DATABASE_SETUP.md
Normal file
@ -0,0 +1,228 @@
|
||||
# 数据库设置指南
|
||||
|
||||
## 概述
|
||||
|
||||
本项目使用PostgreSQL作为主数据库,SQLAlchemy作为ORM,Alembic作为数据库迁移工具。
|
||||
|
||||
## 数据库架构
|
||||
|
||||
### 表结构
|
||||
|
||||
1. **reports** - 报告主表
|
||||
- 存储股票分析报告的基本信息
|
||||
- 包含证券代码、市场、状态等字段
|
||||
|
||||
2. **analysis_modules** - 分析模块表
|
||||
- 存储报告中各个分析模块的内容
|
||||
- 与reports表一对多关系
|
||||
|
||||
3. **progress_tracking** - 进度追踪表
|
||||
- 记录报告生成过程中各步骤的执行状态
|
||||
- 与reports表一对多关系
|
||||
|
||||
4. **system_config** - 系统配置表
|
||||
- 存储系统配置信息
|
||||
- 使用JSONB格式存储配置值
|
||||
|
||||
## 环境配置
|
||||
|
||||
### 1. 安装PostgreSQL
|
||||
|
||||
```bash
|
||||
# macOS (使用Homebrew)
|
||||
brew install postgresql
|
||||
brew services start postgresql
|
||||
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install postgresql postgresql-contrib
|
||||
|
||||
# CentOS/RHEL
|
||||
sudo yum install postgresql-server postgresql-contrib
|
||||
```
|
||||
|
||||
### 2. 创建数据库
|
||||
|
||||
```sql
|
||||
-- 连接到PostgreSQL
|
||||
psql -U postgres
|
||||
|
||||
-- 创建数据库
|
||||
CREATE DATABASE stock_analysis;
|
||||
|
||||
-- 创建用户(可选)
|
||||
CREATE USER stock_user WITH PASSWORD 'your_password';
|
||||
GRANT ALL PRIVILEGES ON DATABASE stock_analysis TO stock_user;
|
||||
```
|
||||
|
||||
### 3. 配置环境变量
|
||||
|
||||
创建 `.env` 文件:
|
||||
|
||||
```bash
|
||||
# 数据库配置
|
||||
DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/stock_analysis
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# API配置
|
||||
GEMINI_API_KEY=your_gemini_api_key
|
||||
TUSHARE_TOKEN=your_tushare_token
|
||||
```
|
||||
|
||||
## 数据库管理
|
||||
|
||||
### 使用管理脚本
|
||||
|
||||
```bash
|
||||
# 检查数据库连接
|
||||
python manage_db.py check
|
||||
|
||||
# 初始化数据库表
|
||||
python manage_db.py init
|
||||
|
||||
# 查看数据库状态
|
||||
python manage_db.py status
|
||||
```
|
||||
|
||||
### 使用Alembic迁移
|
||||
|
||||
```bash
|
||||
# 初始化Alembic(已完成)
|
||||
alembic init alembic
|
||||
|
||||
# 创建迁移文件
|
||||
alembic revision --autogenerate -m "描述信息"
|
||||
|
||||
# 应用迁移
|
||||
alembic upgrade head
|
||||
|
||||
# 查看迁移历史
|
||||
alembic history
|
||||
|
||||
# 回滚迁移
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
## 开发工具
|
||||
|
||||
### 1. 数据库连接检查
|
||||
|
||||
```bash
|
||||
python check_db.py
|
||||
```
|
||||
|
||||
### 2. 数据库初始化
|
||||
|
||||
```bash
|
||||
python init_db.py
|
||||
```
|
||||
|
||||
### 3. 综合管理工具
|
||||
|
||||
```bash
|
||||
python manage_db.py [check|init|status]
|
||||
```
|
||||
|
||||
## 模型使用示例
|
||||
|
||||
### 创建报告
|
||||
|
||||
```python
|
||||
from app.models import Report, AnalysisModule
|
||||
from app.core.database import get_db
|
||||
|
||||
async def create_report():
|
||||
async for db in get_db():
|
||||
# 创建报告
|
||||
report = Report(
|
||||
symbol="000001",
|
||||
market="中国",
|
||||
status="generating"
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
|
||||
# 创建分析模块
|
||||
module = AnalysisModule(
|
||||
report_id=report.id,
|
||||
module_type="financial_data",
|
||||
module_order=1,
|
||||
title="财务数据分析",
|
||||
status="pending"
|
||||
)
|
||||
db.add(module)
|
||||
await db.commit()
|
||||
```
|
||||
|
||||
### 查询报告
|
||||
|
||||
```python
|
||||
from sqlalchemy import select
|
||||
from app.models import Report
|
||||
|
||||
async def get_report(symbol: str, market: str):
|
||||
async for db in get_db():
|
||||
stmt = select(Report).where(
|
||||
Report.symbol == symbol,
|
||||
Report.market == market
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
```
|
||||
|
||||
## 性能优化
|
||||
|
||||
### 索引
|
||||
|
||||
所有表都已配置适当的索引:
|
||||
|
||||
- reports: symbol+market, status, created_at
|
||||
- analysis_modules: report_id, module_type, status, module_order
|
||||
- progress_tracking: report_id, status, step_order
|
||||
- system_config: config_key, updated_at
|
||||
|
||||
### 连接池
|
||||
|
||||
数据库连接使用异步连接池,配置参数:
|
||||
|
||||
- pool_size: 10
|
||||
- max_overflow: 20
|
||||
- pool_timeout: 30秒
|
||||
- pool_recycle: 1小时
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **连接失败**
|
||||
- 检查PostgreSQL服务是否运行
|
||||
- 验证数据库URL配置
|
||||
- 确认防火墙设置
|
||||
|
||||
2. **迁移失败**
|
||||
- 检查数据库权限
|
||||
- 验证表结构冲突
|
||||
- 查看Alembic日志
|
||||
|
||||
3. **性能问题**
|
||||
- 检查索引使用情况
|
||||
- 分析慢查询日志
|
||||
- 优化查询语句
|
||||
|
||||
### 日志配置
|
||||
|
||||
在 `config.py` 中设置 `DATABASE_ECHO=True` 可以查看SQL执行日志。
|
||||
|
||||
## 备份与恢复
|
||||
|
||||
### 备份
|
||||
|
||||
```bash
|
||||
pg_dump -U username -h localhost stock_analysis > backup.sql
|
||||
```
|
||||
|
||||
### 恢复
|
||||
|
||||
```bash
|
||||
psql -U username -h localhost stock_analysis < backup.sql
|
||||
```
|
||||
287
backend/EXTERNAL_API_INTEGRATION.md
Normal file
287
backend/EXTERNAL_API_INTEGRATION.md
Normal file
@ -0,0 +1,287 @@
|
||||
# 外部API集成文档
|
||||
|
||||
## 概述
|
||||
|
||||
本系统集成了多个外部API服务,用于获取股票数据和进行AI分析:
|
||||
|
||||
- **Tushare API**: 中国股票财务数据和市场数据
|
||||
- **Yahoo Finance API**: 全球股票数据(美国、香港、日本等)
|
||||
- **Google Gemini API**: AI分析和内容生成
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 数据源管理器 (DataSourceManager)
|
||||
- 统一管理所有数据源
|
||||
- 支持数据源切换和故障转移
|
||||
- 提供健康检查和状态监控
|
||||
|
||||
### 外部API服务 (ExternalAPIService)
|
||||
- 提供统一的API接口
|
||||
- 处理错误和重试机制
|
||||
- 支持配置动态更新
|
||||
|
||||
## 支持的数据源
|
||||
|
||||
### 1. Tushare API
|
||||
- **用途**: 中国股票数据(A股、港股通等)
|
||||
- **数据类型**: 财务数据、市场数据、基本信息
|
||||
- **配置要求**: 需要Tushare API Token
|
||||
- **限制**: 有调用频率限制,需要付费账户获取完整数据
|
||||
|
||||
#### 配置示例
|
||||
```python
|
||||
{
|
||||
"tushare": {
|
||||
"enabled": True,
|
||||
"api_key": "your_tushare_token",
|
||||
"base_url": "http://api.tushare.pro",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Yahoo Finance API
|
||||
- **用途**: 全球股票数据
|
||||
- **数据类型**: 财务数据、市场数据、基本信息
|
||||
- **配置要求**: 无需API密钥
|
||||
- **限制**: 有调用频率限制,可能被反爬虫机制阻止
|
||||
|
||||
#### 配置示例
|
||||
```python
|
||||
{
|
||||
"yahoo": {
|
||||
"enabled": True,
|
||||
"base_url": "https://query1.finance.yahoo.com",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Google Gemini API
|
||||
- **用途**: AI分析和内容生成
|
||||
- **功能**: 业务分析、基本面分析、投资建议等
|
||||
- **配置要求**: 需要Google Cloud API密钥
|
||||
- **限制**: 有调用频率和配额限制
|
||||
|
||||
#### 配置示例
|
||||
```python
|
||||
{
|
||||
"gemini": {
|
||||
"enabled": True,
|
||||
"api_key": "your_gemini_api_key",
|
||||
"model": "gemini-pro",
|
||||
"timeout": 60,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 2,
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 8192
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 数据源切换逻辑
|
||||
|
||||
### 市场映射
|
||||
系统根据股票市场自动选择合适的数据源:
|
||||
|
||||
```python
|
||||
market_mapping = {
|
||||
"china": "tushare", # 中国A股使用Tushare
|
||||
"中国": "tushare",
|
||||
"hongkong": "yahoo", # 香港股票使用Yahoo Finance
|
||||
"香港": "yahoo",
|
||||
"usa": "yahoo", # 美国股票使用Yahoo Finance
|
||||
"美国": "yahoo",
|
||||
"japan": "yahoo", # 日本股票使用Yahoo Finance
|
||||
"日本": "yahoo"
|
||||
}
|
||||
```
|
||||
|
||||
### 故障转移
|
||||
当主要数据源不可用时,系统会自动切换到备用数据源:
|
||||
|
||||
```python
|
||||
fallback_sources = {
|
||||
"tushare": ["yahoo"], # Tushare失败时使用Yahoo
|
||||
"yahoo": ["tushare"] # Yahoo失败时使用Tushare
|
||||
}
|
||||
```
|
||||
|
||||
## 错误处理和重试机制
|
||||
|
||||
### 错误类型
|
||||
- `DataSourceError`: 数据源相关错误
|
||||
- `AIAnalysisError`: AI分析相关错误
|
||||
- `AuthenticationError`: 认证失败
|
||||
- `RateLimitError`: 调用频率超限
|
||||
- `APIError`: 通用API错误
|
||||
|
||||
### 重试策略
|
||||
- **指数退避**: 重试间隔逐渐增加
|
||||
- **最大重试次数**: 默认3次
|
||||
- **超时处理**: 每个请求都有超时限制
|
||||
- **错误分类**: 不同错误类型采用不同的重试策略
|
||||
|
||||
## API使用示例
|
||||
|
||||
### 1. 获取财务数据
|
||||
```python
|
||||
from app.services.external_api_service import get_external_api_service
|
||||
|
||||
service = get_external_api_service()
|
||||
|
||||
# 获取平安银行财务数据
|
||||
financial_data = await service.get_financial_data("000001", "中国")
|
||||
print(f"数据源: {financial_data.data_source}")
|
||||
print(f"总资产: {financial_data.balance_sheet.get('total_assets')}")
|
||||
```
|
||||
|
||||
### 2. 验证股票代码
|
||||
```python
|
||||
# 验证股票代码是否有效
|
||||
validation = await service.validate_stock_symbol("AAPL", "美国")
|
||||
if validation.is_valid:
|
||||
print(f"公司名称: {validation.company_name}")
|
||||
```
|
||||
|
||||
### 3. AI分析
|
||||
```python
|
||||
# 进行业务信息分析
|
||||
analysis = await service.analyze_business_info(
|
||||
"000001", "中国", financial_data.dict()
|
||||
)
|
||||
print(f"分析内容: {analysis.content['company_overview']}")
|
||||
```
|
||||
|
||||
### 4. 检查服务状态
|
||||
```python
|
||||
# 检查所有外部服务状态
|
||||
status = await service.check_all_services_status()
|
||||
print(f"整体状态: {status.overall_status}")
|
||||
|
||||
for source in status.sources:
|
||||
print(f"{source.name}: {'可用' if source.is_available else '不可用'}")
|
||||
```
|
||||
|
||||
## 配置管理
|
||||
|
||||
### 环境变量
|
||||
在`.env`文件中配置API密钥:
|
||||
|
||||
```bash
|
||||
# Tushare API Token
|
||||
TUSHARE_TOKEN=your_tushare_token_here
|
||||
|
||||
# Gemini API Key
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
```
|
||||
|
||||
### 动态配置更新
|
||||
```python
|
||||
# 更新配置
|
||||
new_config = {
|
||||
"data_sources": {
|
||||
"tushare": {
|
||||
"enabled": True,
|
||||
"api_key": "new_token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
service.update_configuration(new_config)
|
||||
```
|
||||
|
||||
## 测试和调试
|
||||
|
||||
### 运行测试脚本
|
||||
```bash
|
||||
cd backend
|
||||
python test_external_apis.py
|
||||
```
|
||||
|
||||
### 测试单个数据源
|
||||
```python
|
||||
# 测试Tushare连接
|
||||
result = await service.test_data_source_connection("tushare", {
|
||||
"api_key": "your_token",
|
||||
"base_url": "http://api.tushare.pro"
|
||||
})
|
||||
print(f"连接成功: {result['success']}")
|
||||
```
|
||||
|
||||
### 测试AI服务
|
||||
```python
|
||||
# 测试Gemini连接
|
||||
result = await service.test_ai_service_connection("gemini", {
|
||||
"api_key": "your_api_key"
|
||||
})
|
||||
print(f"连接成功: {result['success']}")
|
||||
```
|
||||
|
||||
## 性能优化
|
||||
|
||||
### 缓存策略
|
||||
- 财务数据缓存1小时
|
||||
- 市场数据缓存5分钟
|
||||
- AI分析结果缓存24小时
|
||||
|
||||
### 并发控制
|
||||
- 限制同时进行的API调用数量
|
||||
- 使用连接池管理HTTP连接
|
||||
- 实现请求队列避免频率限制
|
||||
|
||||
### 监控和日志
|
||||
- 记录所有API调用和响应时间
|
||||
- 监控错误率和成功率
|
||||
- 设置告警机制
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **Tushare API调用失败**
|
||||
- 检查API Token是否正确
|
||||
- 确认账户是否有足够的积分
|
||||
- 检查网络连接
|
||||
|
||||
2. **Gemini API调用失败**
|
||||
- 检查API Key是否有效
|
||||
- 确认配额是否充足
|
||||
- 检查请求格式是否正确
|
||||
|
||||
3. **Yahoo Finance被限制**
|
||||
- 降低请求频率
|
||||
- 使用代理服务器
|
||||
- 切换到其他数据源
|
||||
|
||||
### 调试技巧
|
||||
- 启用详细日志记录
|
||||
- 使用测试脚本验证配置
|
||||
- 检查网络连接和防火墙设置
|
||||
- 监控API调用的响应时间和错误率
|
||||
|
||||
## 扩展性
|
||||
|
||||
### 添加新的数据源
|
||||
1. 继承`DataFetcher`基类
|
||||
2. 实现必要的方法
|
||||
3. 在`DataFetcherFactory`中注册
|
||||
4. 更新配置文件
|
||||
|
||||
### 添加新的AI服务
|
||||
1. 创建新的分析器类
|
||||
2. 实现统一的接口
|
||||
3. 在`AIAnalyzerFactory`中注册
|
||||
4. 更新配置管理
|
||||
|
||||
## 安全考虑
|
||||
|
||||
- API密钥加密存储
|
||||
- 使用HTTPS进行所有外部调用
|
||||
- 实现访问控制和权限管理
|
||||
- 定期轮换API密钥
|
||||
- 监控异常访问模式
|
||||
100
backend/README.md
Normal file
100
backend/README.md
Normal file
@ -0,0 +1,100 @@
|
||||
# 基本面选股系统 - 后端服务
|
||||
|
||||
基于FastAPI的股票基本面分析后端服务,提供报告生成、配置管理和进度追踪功能。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── main.py # FastAPI应用入口
|
||||
├── requirements.txt # Python依赖
|
||||
├── .env.example # 环境变量示例
|
||||
├── README.md # 项目说明
|
||||
└── app/
|
||||
├── __init__.py
|
||||
├── core/ # 核心模块
|
||||
│ ├── config.py # 配置管理
|
||||
│ ├── database.py # 数据库连接
|
||||
│ └── dependencies.py # 依赖注入
|
||||
├── models/ # 数据模型
|
||||
│ ├── report.py
|
||||
│ ├── analysis_module.py
|
||||
│ ├── progress_tracking.py
|
||||
│ └── system_config.py
|
||||
├── schemas/ # Pydantic模式
|
||||
│ ├── report.py
|
||||
│ ├── config.py
|
||||
│ └── progress.py
|
||||
├── services/ # 业务服务
|
||||
│ ├── config_manager.py
|
||||
│ ├── data_fetcher.py
|
||||
│ ├── report_generator.py
|
||||
│ └── progress_tracker.py
|
||||
└── routers/ # API路由
|
||||
├── reports.py
|
||||
├── config.py
|
||||
└── progress.py
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
1. 安装依赖:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 配置环境变量:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件,设置数据库连接和API密钥
|
||||
```
|
||||
|
||||
3. 启动服务:
|
||||
```bash
|
||||
uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
4. 访问API文档:
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
|
||||
## API端点
|
||||
|
||||
### 报告相关
|
||||
- `GET /api/reports/{symbol}?market={market}` - 获取或创建报告
|
||||
- `POST /api/reports/{symbol}/regenerate?market={market}` - 重新生成报告
|
||||
- `GET /api/reports/{report_id}/details` - 获取报告详情
|
||||
|
||||
### 配置管理
|
||||
- `GET /api/config` - 获取系统配置
|
||||
- `PUT /api/config` - 更新系统配置
|
||||
- `POST /api/config/test` - 测试配置连接
|
||||
|
||||
### 进度追踪
|
||||
- `GET /api/progress/{report_id}` - 获取报告生成进度
|
||||
|
||||
## 数据库
|
||||
|
||||
系统使用PostgreSQL数据库,包含以下主要表:
|
||||
- `reports` - 报告基本信息
|
||||
- `analysis_modules` - 分析模块内容
|
||||
- `progress_tracking` - 进度追踪记录
|
||||
- `system_config` - 系统配置
|
||||
|
||||
## 开发
|
||||
|
||||
### 代码格式化
|
||||
```bash
|
||||
black app/
|
||||
isort app/
|
||||
```
|
||||
|
||||
### 类型检查
|
||||
```bash
|
||||
mypy app/
|
||||
```
|
||||
|
||||
### 运行测试
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
118
backend/alembic.ini
Normal file
118
backend/alembic.ini
Normal file
@ -0,0 +1,118 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
# Database URL will be set programmatically in env.py
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
102
backend/alembic/env.py
Normal file
102
backend/alembic/env.py
Normal file
@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from alembic import context
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# 导入我们的模型和配置
|
||||
from app.core.database import Base
|
||||
from app.core.config import settings
|
||||
from app.models import * # 导入所有模型以确保它们被注册
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# 设置数据库URL
|
||||
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL.replace("+asyncpg", ""))
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""运行迁移的核心逻辑"""
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""异步运行迁移"""
|
||||
connectable = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
backend/alembic/script.py.mako
Normal file
26
backend/alembic/script.py.mako
Normal file
@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
90
backend/alembic/versions/001_initial_migration.py
Normal file
90
backend/alembic/versions/001_initial_migration.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Initial migration: create all tables
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2024-01-01 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Create reports table
|
||||
op.create_table('reports',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('symbol', sa.String(length=20), nullable=False, comment='证券代码'),
|
||||
sa.Column('market', sa.String(length=20), nullable=False, comment='交易市场'),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, comment='报告状态'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create analysis_modules table
|
||||
op.create_table('analysis_modules',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('report_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('module_type', sa.String(length=50), nullable=False, comment='模块类型'),
|
||||
sa.Column('module_order', sa.Integer(), nullable=False, comment='模块顺序'),
|
||||
sa.Column('title', sa.String(length=200), nullable=False, comment='模块标题'),
|
||||
sa.Column('content', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='模块内容'),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, comment='模块状态'),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True, comment='完成时间'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create progress_tracking table
|
||||
op.create_table('progress_tracking',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('report_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('step_name', sa.String(length=100), nullable=False, comment='步骤名称'),
|
||||
sa.Column('step_order', sa.Integer(), nullable=False, comment='步骤顺序'),
|
||||
sa.Column('status', sa.String(length=20), nullable=False, comment='步骤状态'),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True, comment='开始时间'),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True, comment='完成时间'),
|
||||
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='耗时(毫秒)'),
|
||||
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
|
||||
sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create system_config table
|
||||
op.create_table('system_config',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('config_key', sa.String(length=100), nullable=False, comment='配置键'),
|
||||
sa.Column('config_value', postgresql.JSONB(astext_type=sa.Text()), nullable=False, comment='配置值'),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('config_key')
|
||||
)
|
||||
|
||||
# Set default values for status columns
|
||||
op.execute("ALTER TABLE reports ALTER COLUMN status SET DEFAULT 'generating'")
|
||||
op.execute("ALTER TABLE analysis_modules ALTER COLUMN status SET DEFAULT 'pending'")
|
||||
op.execute("ALTER TABLE progress_tracking ALTER COLUMN status SET DEFAULT 'pending'")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('system_config')
|
||||
op.drop_table('progress_tracking')
|
||||
op.drop_table('analysis_modules')
|
||||
op.drop_table('reports')
|
||||
# ### end Alembic commands ###
|
||||
7
backend/app/__init__.py
Normal file
7
backend/app/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
基本面选股系统后端应用
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "基本面选股系统开发团队"
|
||||
__description__ = "提供股票基本面分析和报告生成的后端服务"
|
||||
48
backend/app/core/__init__.py
Normal file
48
backend/app/core/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""
|
||||
核心模块
|
||||
包含配置、数据库连接、依赖注入和异常处理
|
||||
"""
|
||||
|
||||
from .config import settings, db_config, api_config, data_source_config
|
||||
from .database import engine, AsyncSessionLocal, Base, get_db, init_db, close_db
|
||||
from .dependencies import get_database_session, verify_gemini_api_key, verify_tushare_token
|
||||
from .exceptions import (
|
||||
StockAnalysisError,
|
||||
DataSourceError,
|
||||
AIAnalysisError,
|
||||
ConfigurationError,
|
||||
DatabaseError,
|
||||
ValidationError,
|
||||
APIError,
|
||||
ReportGenerationError,
|
||||
SymbolNotFoundError,
|
||||
RateLimitError,
|
||||
AuthenticationError
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"settings",
|
||||
"db_config",
|
||||
"api_config",
|
||||
"data_source_config",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"Base",
|
||||
"get_db",
|
||||
"init_db",
|
||||
"close_db",
|
||||
"get_database_session",
|
||||
"verify_gemini_api_key",
|
||||
"verify_tushare_token",
|
||||
"StockAnalysisError",
|
||||
"DataSourceError",
|
||||
"AIAnalysisError",
|
||||
"ConfigurationError",
|
||||
"DatabaseError",
|
||||
"ValidationError",
|
||||
"APIError",
|
||||
"ReportGenerationError",
|
||||
"SymbolNotFoundError",
|
||||
"RateLimitError",
|
||||
"AuthenticationError"
|
||||
]
|
||||
183
backend/app/core/config.py
Normal file
183
backend/app/core/config.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
应用配置管理
|
||||
处理环境变量和系统配置
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import validator
|
||||
from pydantic_settings import BaseSettings
|
||||
import os
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用设置类"""
|
||||
|
||||
# 应用基础配置
|
||||
APP_NAME: str = "基本面选股系统"
|
||||
APP_VERSION: str = "1.0.0"
|
||||
DEBUG: bool = False
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost:5432/stock_analysis"
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
# API配置
|
||||
API_V1_STR: str = "/api"
|
||||
ALLOWED_ORIGINS: List[str] = ["http://localhost:3000", "http://127.0.0.1:3000"]
|
||||
|
||||
# 外部服务配置
|
||||
GEMINI_API_KEY: Optional[str] = None
|
||||
TUSHARE_TOKEN: Optional[str] = None
|
||||
|
||||
# 数据源配置
|
||||
CHINA_DATA_SOURCE: str = "tushare"
|
||||
HK_DATA_SOURCE: str = "yahoo"
|
||||
US_DATA_SOURCE: str = "yahoo"
|
||||
JP_DATA_SOURCE: str = "yahoo"
|
||||
|
||||
# 报告生成配置
|
||||
MAX_CONCURRENT_REPORTS: int = 5
|
||||
REPORT_TIMEOUT_MINUTES: int = 30
|
||||
|
||||
# 缓存配置
|
||||
CACHE_TTL_SECONDS: int = 3600 # 1小时
|
||||
|
||||
@validator("ALLOWED_ORIGINS", pre=True)
|
||||
def assemble_cors_origins(cls, v):
|
||||
"""处理CORS origins配置"""
|
||||
if isinstance(v, str):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
return v
|
||||
|
||||
@validator("DATABASE_URL", pre=True)
|
||||
def assemble_db_connection(cls, v):
|
||||
"""处理数据库连接字符串"""
|
||||
if v and not v.startswith("postgresql"):
|
||||
raise ValueError("数据库URL必须使用PostgreSQL")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
# 创建全局设置实例
|
||||
settings = Settings()
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
"""数据库配置类"""
|
||||
|
||||
def __init__(self):
|
||||
self.url = settings.DATABASE_URL
|
||||
self.echo = settings.DATABASE_ECHO
|
||||
self.pool_size = 10
|
||||
self.max_overflow = 20
|
||||
self.pool_timeout = 30
|
||||
self.pool_recycle = 3600
|
||||
|
||||
|
||||
class ExternalAPIConfig:
|
||||
"""外部API配置类"""
|
||||
|
||||
def __init__(self):
|
||||
self.gemini_api_key = settings.GEMINI_API_KEY
|
||||
self.tushare_token = settings.TUSHARE_TOKEN
|
||||
|
||||
# 数据源配置
|
||||
self.data_sources_config = {
|
||||
"tushare": {
|
||||
"enabled": bool(self.tushare_token),
|
||||
"api_key": self.tushare_token,
|
||||
"token": self.tushare_token,
|
||||
"base_url": "http://api.tushare.pro",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1,
|
||||
"name": "tushare"
|
||||
},
|
||||
"yahoo": {
|
||||
"enabled": True,
|
||||
"base_url": "https://query1.finance.yahoo.com",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1,
|
||||
"name": "yahoo"
|
||||
}
|
||||
}
|
||||
|
||||
# AI服务配置
|
||||
self.ai_services_config = {
|
||||
"gemini": {
|
||||
"enabled": bool(self.gemini_api_key),
|
||||
"api_key": self.gemini_api_key,
|
||||
"model": "gemini-pro",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||
"timeout": 60,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 2,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.8,
|
||||
"top_k": 40,
|
||||
"max_output_tokens": 8192
|
||||
}
|
||||
}
|
||||
|
||||
def validate_gemini_config(self) -> bool:
|
||||
"""验证Gemini API配置"""
|
||||
return bool(self.gemini_api_key)
|
||||
|
||||
def validate_tushare_config(self) -> bool:
|
||||
"""验证Tushare API配置"""
|
||||
return bool(self.tushare_token)
|
||||
|
||||
def get_data_source_manager_config(self) -> dict:
|
||||
"""获取数据源管理器配置"""
|
||||
return {
|
||||
"data_sources": self.data_sources_config,
|
||||
"ai_services": self.ai_services_config,
|
||||
"market_mapping": {
|
||||
"china": "tushare",
|
||||
"中国": "tushare",
|
||||
"hongkong": "yahoo",
|
||||
"香港": "yahoo",
|
||||
"usa": "yahoo",
|
||||
"美国": "yahoo",
|
||||
"japan": "yahoo",
|
||||
"日本": "yahoo"
|
||||
},
|
||||
"fallback_sources": {
|
||||
"tushare": ["yahoo"],
|
||||
"yahoo": ["tushare"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DataSourceConfig:
|
||||
"""数据源配置类"""
|
||||
|
||||
def __init__(self):
|
||||
self.sources = {
|
||||
"china": settings.CHINA_DATA_SOURCE,
|
||||
"hongkong": settings.HK_DATA_SOURCE,
|
||||
"usa": settings.US_DATA_SOURCE,
|
||||
"japan": settings.JP_DATA_SOURCE
|
||||
}
|
||||
|
||||
def get_source_for_market(self, market: str) -> str:
|
||||
"""根据市场获取数据源"""
|
||||
market_mapping = {
|
||||
"中国": "china",
|
||||
"香港": "hongkong",
|
||||
"美国": "usa",
|
||||
"日本": "japan"
|
||||
}
|
||||
|
||||
market_key = market_mapping.get(market, "china")
|
||||
return self.sources.get(market_key, "tushare")
|
||||
|
||||
|
||||
# 创建配置实例
|
||||
db_config = DatabaseConfig()
|
||||
api_config = ExternalAPIConfig()
|
||||
data_source_config = DataSourceConfig()
|
||||
69
backend/app/core/database.py
Normal file
69
backend/app/core/database.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
数据库连接和会话管理
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.pool import NullPool
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .config import settings
|
||||
|
||||
# 创建异步数据库引擎
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DATABASE_ECHO,
|
||||
poolclass=NullPool, # 对于异步使用NullPool
|
||||
future=True
|
||||
)
|
||||
|
||||
# 创建异步会话工厂
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# 创建基础模型类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
数据库会话依赖注入
|
||||
用于FastAPI路由中获取数据库会话
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""初始化数据库表"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""关闭数据库连接"""
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def check_db_connection() -> bool:
|
||||
"""检查数据库连接是否正常"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute("SELECT 1")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def get_db_session():
|
||||
"""获取数据库会话的简化版本"""
|
||||
return AsyncSessionLocal()
|
||||
51
backend/app/core/dependencies.py
Normal file
51
backend/app/core/dependencies.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""
|
||||
FastAPI依赖注入
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .database import get_db
|
||||
from .config import settings, api_config
|
||||
|
||||
|
||||
async def get_database_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话依赖"""
|
||||
async for session in get_db():
|
||||
yield session
|
||||
|
||||
|
||||
def verify_gemini_api_key():
|
||||
"""验证Gemini API密钥"""
|
||||
if not api_config.validate_gemini_config():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Gemini API未配置或配置无效"
|
||||
)
|
||||
return api_config.gemini_api_key
|
||||
|
||||
|
||||
def verify_tushare_token():
|
||||
"""验证Tushare Token"""
|
||||
if not api_config.validate_tushare_config():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Tushare API未配置或配置无效"
|
||||
)
|
||||
return api_config.tushare_token
|
||||
|
||||
|
||||
class DatabaseDependency:
|
||||
"""数据库依赖类"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = Depends(get_database_session)
|
||||
|
||||
|
||||
class APIKeyDependency:
|
||||
"""API密钥依赖类"""
|
||||
|
||||
def __init__(self):
|
||||
self.gemini_key = Depends(verify_gemini_api_key)
|
||||
self.tushare_token = Depends(verify_tushare_token)
|
||||
98
backend/app/core/exceptions.py
Normal file
98
backend/app/core/exceptions.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
基础错误处理和异常类
|
||||
定义系统中使用的所有自定义异常
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class StockAnalysisError(Exception):
|
||||
"""基础异常类"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DataSourceError(StockAnalysisError):
|
||||
"""数据源错误"""
|
||||
|
||||
def __init__(self, message: str, data_source: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.data_source = data_source
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class AIAnalysisError(StockAnalysisError):
|
||||
"""AI分析错误"""
|
||||
|
||||
def __init__(self, message: str, model: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.model = model
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class ConfigurationError(StockAnalysisError):
|
||||
"""配置错误"""
|
||||
|
||||
def __init__(self, message: str, config_key: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.config_key = config_key
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class DatabaseError(StockAnalysisError):
|
||||
"""数据库错误"""
|
||||
|
||||
def __init__(self, message: str, operation: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.operation = operation
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class ValidationError(StockAnalysisError):
|
||||
"""数据验证错误"""
|
||||
|
||||
def __init__(self, message: str, field: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.field = field
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class APIError(StockAnalysisError):
|
||||
"""API调用错误"""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None, api_name: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.status_code = status_code
|
||||
self.api_name = api_name
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class ReportGenerationError(StockAnalysisError):
|
||||
"""报告生成错误"""
|
||||
|
||||
def __init__(self, message: str, module_type: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
|
||||
self.module_type = module_type
|
||||
super().__init__(message, details)
|
||||
|
||||
|
||||
class SymbolNotFoundError(DataSourceError):
|
||||
"""证券代码未找到错误"""
|
||||
|
||||
def __init__(self, symbol: str, market: str, data_source: Optional[str] = None):
|
||||
message = f"证券代码 {symbol} 在市场 {market} 中未找到"
|
||||
details = {"symbol": symbol, "market": market}
|
||||
super().__init__(message, data_source, details)
|
||||
|
||||
|
||||
class RateLimitError(APIError):
|
||||
"""API调用频率限制错误"""
|
||||
|
||||
def __init__(self, api_name: str, retry_after: Optional[int] = None):
|
||||
message = f"API {api_name} 调用频率超限"
|
||||
details = {"retry_after": retry_after} if retry_after else {}
|
||||
super().__init__(message, 429, api_name, details)
|
||||
|
||||
|
||||
class AuthenticationError(APIError):
|
||||
"""API认证错误"""
|
||||
|
||||
def __init__(self, api_name: str, details: Optional[Dict[str, Any]] = None):
|
||||
message = f"API {api_name} 认证失败"
|
||||
super().__init__(message, 401, api_name, details)
|
||||
17
backend/app/models/__init__.py
Normal file
17
backend/app/models/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
数据模型包
|
||||
包含SQLAlchemy数据模型定义
|
||||
"""
|
||||
|
||||
# 导入所有模型以确保它们被注册到Base.metadata
|
||||
from .report import Report
|
||||
from .analysis_module import AnalysisModule
|
||||
from .progress_tracking import ProgressTracking
|
||||
from .system_config import SystemConfig
|
||||
|
||||
__all__ = [
|
||||
"Report",
|
||||
"AnalysisModule",
|
||||
"ProgressTracking",
|
||||
"SystemConfig"
|
||||
]
|
||||
41
backend/app/models/analysis_module.py
Normal file
41
backend/app/models/analysis_module.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
分析模块数据模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from ..core.database import Base
|
||||
|
||||
|
||||
class AnalysisModule(Base):
|
||||
"""分析模块表模型"""
|
||||
|
||||
__tablename__ = "analysis_modules"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
report_id = Column(UUID(as_uuid=True), ForeignKey("reports.id", ondelete="CASCADE"), nullable=False)
|
||||
module_type = Column(String(50), nullable=False, comment="模块类型")
|
||||
module_order = Column(Integer, nullable=False, comment="模块顺序")
|
||||
title = Column(String(200), nullable=False, comment="模块标题")
|
||||
content = Column(JSONB, comment="模块内容")
|
||||
status = Column(String(20), nullable=False, default="pending", comment="模块状态")
|
||||
started_at = Column(DateTime(timezone=True), comment="开始时间")
|
||||
completed_at = Column(DateTime(timezone=True), comment="完成时间")
|
||||
error_message = Column(Text, comment="错误信息")
|
||||
|
||||
# 关系
|
||||
report = relationship("Report", back_populates="analysis_modules")
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_analysis_module_report_id', 'report_id'),
|
||||
Index('idx_analysis_module_type', 'module_type'),
|
||||
Index('idx_analysis_module_status', 'status'),
|
||||
Index('idx_analysis_module_order', 'module_order'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnalysisModule(id={self.id}, type={self.module_type}, status={self.status})>"
|
||||
39
backend/app/models/progress_tracking.py
Normal file
39
backend/app/models/progress_tracking.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
进度追踪数据模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from ..core.database import Base
|
||||
|
||||
|
||||
class ProgressTracking(Base):
|
||||
"""进度追踪表模型"""
|
||||
|
||||
__tablename__ = "progress_tracking"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
report_id = Column(UUID(as_uuid=True), ForeignKey("reports.id", ondelete="CASCADE"), nullable=False)
|
||||
step_name = Column(String(100), nullable=False, comment="步骤名称")
|
||||
step_order = Column(Integer, nullable=False, comment="步骤顺序")
|
||||
status = Column(String(20), nullable=False, default="pending", comment="步骤状态")
|
||||
started_at = Column(DateTime(timezone=True), comment="开始时间")
|
||||
completed_at = Column(DateTime(timezone=True), comment="完成时间")
|
||||
duration_ms = Column(Integer, comment="耗时(毫秒)")
|
||||
error_message = Column(Text, comment="错误信息")
|
||||
|
||||
# 关系
|
||||
report = relationship("Report", back_populates="progress_tracking")
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_progress_tracking_report_id', 'report_id'),
|
||||
Index('idx_progress_tracking_status', 'status'),
|
||||
Index('idx_progress_tracking_order', 'step_order'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ProgressTracking(id={self.id}, step={self.step_name}, status={self.status})>"
|
||||
38
backend/app/models/report.py
Normal file
38
backend/app/models/report.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
报告数据模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Text, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
|
||||
from ..core.database import Base
|
||||
|
||||
|
||||
class Report(Base):
|
||||
"""报告表模型"""
|
||||
|
||||
__tablename__ = "reports"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
symbol = Column(String(20), nullable=False, comment="证券代码")
|
||||
market = Column(String(20), nullable=False, comment="交易市场")
|
||||
status = Column(String(20), nullable=False, default="generating", comment="报告状态")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
# 关系
|
||||
analysis_modules = relationship("AnalysisModule", back_populates="report", cascade="all, delete-orphan")
|
||||
progress_tracking = relationship("ProgressTracking", back_populates="report", cascade="all, delete-orphan")
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_report_symbol_market', 'symbol', 'market'),
|
||||
Index('idx_report_status', 'status'),
|
||||
Index('idx_report_created_at', 'created_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Report(id={self.id}, symbol={self.symbol}, market={self.market}, status={self.status})>"
|
||||
30
backend/app/models/system_config.py
Normal file
30
backend/app/models/system_config.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
系统配置数据模型
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.sql import func
|
||||
import uuid
|
||||
|
||||
from ..core.database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""系统配置表模型"""
|
||||
|
||||
__tablename__ = "system_config"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
config_key = Column(String(100), unique=True, nullable=False, comment="配置键")
|
||||
config_value = Column(JSONB, nullable=False, comment="配置值")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
# 索引
|
||||
__table_args__ = (
|
||||
Index('idx_system_config_key', 'config_key'),
|
||||
Index('idx_system_config_updated_at', 'updated_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SystemConfig(key={self.config_key})>"
|
||||
8
backend/app/routers/__init__.py
Normal file
8
backend/app/routers/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""
|
||||
API路由包
|
||||
包含所有API端点定义
|
||||
"""
|
||||
|
||||
from . import reports, config, progress
|
||||
|
||||
__all__ = ["reports", "config", "progress"]
|
||||
124
backend/app/routers/config.py
Normal file
124
backend/app/routers/config.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""
|
||||
配置相关API路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import logging
|
||||
|
||||
from ..core.dependencies import get_database_session
|
||||
from ..schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestRequest, ConfigTestResponse
|
||||
from ..services.config_manager import ConfigManager
|
||||
from ..core.exceptions import ConfigurationError, DatabaseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=ConfigResponse)
|
||||
async def get_config(
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""获取系统配置"""
|
||||
|
||||
try:
|
||||
config_manager = ConfigManager(db)
|
||||
config = await config_manager.get_config()
|
||||
logger.info("获取系统配置成功")
|
||||
return config
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(f"获取配置时数据库错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"数据库错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/", response_model=ConfigResponse)
|
||||
async def update_config(
|
||||
config_update: ConfigUpdateRequest,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""更新系统配置"""
|
||||
|
||||
try:
|
||||
# 验证至少有一个配置项需要更新
|
||||
if not any([config_update.database, config_update.gemini_api, config_update.data_sources]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="至少需要提供一个配置项进行更新"
|
||||
)
|
||||
|
||||
config_manager = ConfigManager(db)
|
||||
updated_config = await config_manager.update_config(config_update)
|
||||
logger.info("更新系统配置成功")
|
||||
return updated_config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"配置错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"配置错误: {str(e)}"
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error(f"更新配置时数据库错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"数据库错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test", response_model=ConfigTestResponse)
|
||||
async def test_config(
|
||||
test_request: ConfigTestRequest,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""测试配置连接"""
|
||||
|
||||
try:
|
||||
# 验证配置类型
|
||||
valid_types = ["database", "gemini", "data_source"]
|
||||
if test_request.config_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"不支持的配置类型: {test_request.config_type},支持的类型: {valid_types}"
|
||||
)
|
||||
|
||||
config_manager = ConfigManager(db)
|
||||
test_result = await config_manager.test_config(
|
||||
test_request.config_type,
|
||||
test_request.config_data
|
||||
)
|
||||
|
||||
logger.info(f"配置测试完成: {test_request.config_type}, 结果: {test_result.success}")
|
||||
return test_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"配置测试错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"配置测试错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"配置测试失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"配置测试失败: {str(e)}"
|
||||
)
|
||||
82
backend/app/routers/progress.py
Normal file
82
backend/app/routers/progress.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
进度追踪API路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uuid import UUID
|
||||
import logging
|
||||
|
||||
from ..core.dependencies import get_database_session
|
||||
from ..schemas.progress import ProgressResponse
|
||||
from ..services.progress_tracker import ProgressTracker
|
||||
from ..core.exceptions import DatabaseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{report_id}", response_model=ProgressResponse)
|
||||
async def get_report_progress(
|
||||
report_id: UUID,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""获取报告生成进度"""
|
||||
|
||||
try:
|
||||
progress_tracker = ProgressTracker(db)
|
||||
progress = await progress_tracker.get_progress(report_id)
|
||||
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
|
||||
return progress
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"报告不存在或无进度记录: {report_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error(f"获取进度时数据库错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"数据库错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取进度失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取进度失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{report_id}/reset")
|
||||
async def reset_report_progress(
|
||||
report_id: UUID,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""重置报告生成进度"""
|
||||
|
||||
try:
|
||||
progress_tracker = ProgressTracker(db)
|
||||
await progress_tracker.reset_progress(report_id)
|
||||
logger.info(f"重置进度成功: {report_id}")
|
||||
return {"message": "进度重置成功", "report_id": str(report_id)}
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"报告不存在: {report_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
)
|
||||
except DatabaseError as e:
|
||||
logger.error(f"重置进度时数据库错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"数据库错误: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"重置进度失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"重置进度失败: {str(e)}"
|
||||
)
|
||||
298
backend/app/routers/reports.py
Normal file
298
backend/app/routers/reports.py
Normal file
@ -0,0 +1,298 @@
|
||||
"""
|
||||
报告相关API路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
import logging
|
||||
|
||||
from ..core.dependencies import get_database_session
|
||||
from ..models.report import Report
|
||||
from ..schemas.report import ReportResponse, RegenerateRequest
|
||||
from ..services.report_generator import ReportGenerator
|
||||
from ..services.config_manager import ConfigManager
|
||||
from ..core.exceptions import ReportGenerationError, DatabaseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{symbol}", response_model=ReportResponse)
|
||||
async def get_or_create_report(
|
||||
symbol: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
market: str = Query(..., description="交易市场"),
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""获取或创建股票报告"""
|
||||
|
||||
try:
|
||||
# 验证输入参数
|
||||
symbol = symbol.upper().strip()
|
||||
market = market.lower().strip()
|
||||
|
||||
if not symbol:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="证券代码不能为空"
|
||||
)
|
||||
|
||||
if market not in ["china", "hongkong", "usa", "japan"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的交易市场"
|
||||
)
|
||||
|
||||
# 查询现有报告(包含关联的分析模块)
|
||||
result = await db.execute(
|
||||
select(Report)
|
||||
.options(selectinload(Report.analysis_modules))
|
||||
.where(
|
||||
Report.symbol == symbol,
|
||||
Report.market == market
|
||||
)
|
||||
)
|
||||
existing_report = result.scalar_one_or_none()
|
||||
|
||||
if existing_report:
|
||||
logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}")
|
||||
return ReportResponse.from_attributes(existing_report)
|
||||
|
||||
# 创建新报告
|
||||
logger.info(f"开始生成新报告: {symbol}-{market}")
|
||||
config_manager = ConfigManager(db)
|
||||
report_generator = ReportGenerator(db, config_manager)
|
||||
|
||||
# 在后台任务中生成报告
|
||||
background_tasks.add_task(
|
||||
report_generator.generate_report_async,
|
||||
symbol,
|
||||
market
|
||||
)
|
||||
|
||||
# 创建初始报告记录
|
||||
new_report = Report(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
status="generating"
|
||||
)
|
||||
db.add(new_report)
|
||||
await db.commit()
|
||||
await db.refresh(new_report)
|
||||
|
||||
logger.info(f"创建报告记录: {new_report.id}")
|
||||
return ReportResponse.from_attributes(new_report)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取或创建报告失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取或创建报告失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{symbol}/regenerate", response_model=ReportResponse)
|
||||
async def regenerate_report(
|
||||
symbol: str,
|
||||
request: RegenerateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
market: str = Query(..., description="交易市场"),
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""重新生成报告"""
|
||||
|
||||
try:
|
||||
# 验证输入参数
|
||||
symbol = symbol.upper().strip()
|
||||
market = market.lower().strip()
|
||||
|
||||
if not symbol:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="证券代码不能为空"
|
||||
)
|
||||
|
||||
if market not in ["china", "hongkong", "usa", "japan"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的交易市场"
|
||||
)
|
||||
|
||||
# 查询现有报告
|
||||
result = await db.execute(
|
||||
select(Report)
|
||||
.options(selectinload(Report.analysis_modules))
|
||||
.where(
|
||||
Report.symbol == symbol,
|
||||
Report.market == market
|
||||
)
|
||||
)
|
||||
existing_report = result.scalar_one_or_none()
|
||||
|
||||
if existing_report and not request.force:
|
||||
# 如果报告存在且不强制重新生成,返回现有报告
|
||||
logger.info(f"返回现有报告: {symbol}-{market}")
|
||||
return ReportResponse.from_attributes(existing_report)
|
||||
|
||||
# 删除现有报告(如果存在)
|
||||
if existing_report:
|
||||
logger.info(f"删除现有报告: {existing_report.id}")
|
||||
await db.delete(existing_report)
|
||||
await db.commit()
|
||||
|
||||
# 创建新报告记录
|
||||
new_report = Report(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
status="generating"
|
||||
)
|
||||
db.add(new_report)
|
||||
await db.commit()
|
||||
await db.refresh(new_report)
|
||||
|
||||
# 在后台任务中生成报告
|
||||
config_manager = ConfigManager(db)
|
||||
report_generator = ReportGenerator(db, config_manager)
|
||||
background_tasks.add_task(
|
||||
report_generator.generate_report_async,
|
||||
symbol,
|
||||
market
|
||||
)
|
||||
|
||||
logger.info(f"开始重新生成报告: {new_report.id}")
|
||||
return ReportResponse.from_attributes(new_report)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重新生成报告失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"重新生成报告失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{report_id}/details", response_model=ReportResponse)
|
||||
async def get_report_details(
|
||||
report_id: UUID,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""获取报告详情"""
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Report)
|
||||
.options(selectinload(Report.analysis_modules))
|
||||
.where(Report.id == report_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="报告不存在"
|
||||
)
|
||||
|
||||
logger.info(f"获取报告详情: {report_id}")
|
||||
return ReportResponse.from_attributes(report)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告详情失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取报告详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ReportResponse])
|
||||
async def list_reports(
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
|
||||
market: Optional[str] = Query(None, description="按市场筛选"),
|
||||
status: Optional[str] = Query(None, description="按状态筛选"),
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""获取报告列表"""
|
||||
|
||||
try:
|
||||
query = select(Report).options(selectinload(Report.analysis_modules))
|
||||
|
||||
# 添加筛选条件
|
||||
if market:
|
||||
market = market.lower().strip()
|
||||
if market not in ["china", "hongkong", "usa", "japan"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的交易市场"
|
||||
)
|
||||
query = query.where(Report.market == market)
|
||||
|
||||
if status:
|
||||
status_value = status.lower().strip()
|
||||
if status_value not in ["generating", "completed", "failed"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不支持的状态值"
|
||||
)
|
||||
query = query.where(Report.status == status_value)
|
||||
|
||||
# 添加分页和排序
|
||||
query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
reports = result.scalars().all()
|
||||
|
||||
logger.info(f"获取报告列表: {len(reports)} 条记录")
|
||||
return [ReportResponse.from_attributes(report) for report in reports]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告列表失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取报告列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{report_id}")
|
||||
async def delete_report(
|
||||
report_id: UUID,
|
||||
db: AsyncSession = Depends(get_database_session)
|
||||
):
|
||||
"""删除报告"""
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Report).where(Report.id == report_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="报告不存在"
|
||||
)
|
||||
|
||||
await db.delete(report)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"删除报告成功: {report_id}")
|
||||
return {"message": "报告删除成功", "report_id": str(report_id)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除报告失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"删除报告失败: {str(e)}"
|
||||
)
|
||||
48
backend/app/schemas/__init__.py
Normal file
48
backend/app/schemas/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""
|
||||
Pydantic数据验证模式包
|
||||
"""
|
||||
|
||||
from .report import ReportResponse, ReportCreate, ReportUpdate
|
||||
from .config import (
|
||||
ConfigResponse,
|
||||
ConfigUpdateRequest,
|
||||
ConfigTestRequest,
|
||||
ConfigTestResponse,
|
||||
DatabaseConfig,
|
||||
GeminiConfig,
|
||||
DataSourceConfig
|
||||
)
|
||||
from .progress import ProgressResponse, StepTiming
|
||||
from .data import (
|
||||
FinancialDataRequest,
|
||||
MarketDataRequest,
|
||||
FinancialDataResponse,
|
||||
MarketDataResponse,
|
||||
SymbolValidationRequest,
|
||||
SymbolValidationResponse,
|
||||
DataSourceStatus,
|
||||
DataSourcesStatusResponse
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReportResponse",
|
||||
"ReportCreate",
|
||||
"ReportUpdate",
|
||||
"ConfigResponse",
|
||||
"ConfigUpdateRequest",
|
||||
"ConfigTestRequest",
|
||||
"ConfigTestResponse",
|
||||
"DatabaseConfig",
|
||||
"GeminiConfig",
|
||||
"DataSourceConfig",
|
||||
"ProgressResponse",
|
||||
"StepTiming",
|
||||
"FinancialDataRequest",
|
||||
"MarketDataRequest",
|
||||
"FinancialDataResponse",
|
||||
"MarketDataResponse",
|
||||
"SymbolValidationRequest",
|
||||
"SymbolValidationResponse",
|
||||
"DataSourceStatus",
|
||||
"DataSourcesStatusResponse"
|
||||
]
|
||||
62
backend/app/schemas/config.py
Normal file
62
backend/app/schemas/config.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
配置相关的Pydantic模式
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
"""数据库配置模式"""
|
||||
url: str = Field(..., description="数据库连接URL")
|
||||
echo: bool = Field(False, description="是否输出SQL日志")
|
||||
|
||||
|
||||
class GeminiConfig(BaseModel):
|
||||
"""Gemini API配置模式"""
|
||||
api_key: str = Field(..., description="Gemini API密钥")
|
||||
model: str = Field("gemini-pro", description="使用的模型")
|
||||
temperature: float = Field(0.7, description="生成温度")
|
||||
max_tokens: int = Field(2048, description="最大token数")
|
||||
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
"""数据源配置模式"""
|
||||
name: str = Field(..., description="数据源名称")
|
||||
api_key: Optional[str] = Field(None, description="API密钥")
|
||||
base_url: Optional[str] = Field(None, description="基础URL")
|
||||
timeout: int = Field(30, description="超时时间(秒)")
|
||||
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
"""配置响应模式"""
|
||||
database: Optional[DatabaseConfig] = None
|
||||
gemini_api: Optional[GeminiConfig] = None
|
||||
data_sources: Dict[str, DataSourceConfig] = {}
|
||||
|
||||
|
||||
class ConfigUpdateRequest(BaseModel):
|
||||
"""配置更新请求模式"""
|
||||
database: Optional[DatabaseConfig] = None
|
||||
gemini_api: Optional[GeminiConfig] = None
|
||||
data_sources: Optional[Dict[str, DataSourceConfig]] = None
|
||||
|
||||
|
||||
class ConfigTestRequest(BaseModel):
|
||||
"""配置测试请求模式"""
|
||||
config_type: str = Field(..., description="配置类型")
|
||||
config_data: Dict[str, Any] = Field(..., description="配置数据")
|
||||
|
||||
|
||||
class ConfigTestResponse(BaseModel):
|
||||
"""配置测试响应模式"""
|
||||
success: bool = Field(..., description="测试是否成功")
|
||||
message: str = Field(..., description="测试结果消息")
|
||||
details: Optional[Dict[str, Any]] = Field(None, description="详细信息")
|
||||
|
||||
|
||||
class ConfigValidationResponse(BaseModel):
|
||||
"""配置验证响应模式"""
|
||||
valid: bool = Field(..., description="配置是否有效")
|
||||
errors: List[str] = Field([], description="验证错误列表")
|
||||
warnings: List[str] = Field([], description="验证警告列表")
|
||||
78
backend/app/schemas/data.py
Normal file
78
backend/app/schemas/data.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
数据相关的Pydantic模式
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class FinancialDataRequest(BaseModel):
|
||||
"""财务数据请求模式"""
|
||||
symbol: str = Field(..., description="证券代码")
|
||||
market: str = Field(..., description="交易市场")
|
||||
data_type: str = Field("all", description="数据类型")
|
||||
period: Optional[str] = Field("annual", description="数据周期")
|
||||
|
||||
|
||||
class MarketDataRequest(BaseModel):
|
||||
"""市场数据请求模式"""
|
||||
symbol: str = Field(..., description="证券代码")
|
||||
market: str = Field(..., description="交易市场")
|
||||
start_date: Optional[datetime] = Field(None, description="开始日期")
|
||||
end_date: Optional[datetime] = Field(None, description="结束日期")
|
||||
|
||||
|
||||
class FinancialDataResponse(BaseModel):
|
||||
"""财务数据响应模式"""
|
||||
symbol: str
|
||||
market: str
|
||||
data_source: str
|
||||
last_updated: datetime
|
||||
balance_sheet: Optional[Dict[str, Any]] = None
|
||||
income_statement: Optional[Dict[str, Any]] = None
|
||||
cash_flow: Optional[Dict[str, Any]] = None
|
||||
key_metrics: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MarketDataResponse(BaseModel):
|
||||
"""市场数据响应模式"""
|
||||
symbol: str
|
||||
market: str
|
||||
data_source: str
|
||||
last_updated: datetime
|
||||
price_data: Optional[Dict[str, Any]] = None
|
||||
volume_data: Optional[Dict[str, Any]] = None
|
||||
technical_indicators: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SymbolValidationRequest(BaseModel):
|
||||
"""证券代码验证请求模式"""
|
||||
symbol: str = Field(..., description="证券代码")
|
||||
market: str = Field(..., description="交易市场")
|
||||
|
||||
|
||||
class SymbolValidationResponse(BaseModel):
|
||||
"""证券代码验证响应模式"""
|
||||
symbol: str
|
||||
market: str
|
||||
is_valid: bool
|
||||
company_name: Optional[str] = None
|
||||
sector: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DataSourceStatus(BaseModel):
|
||||
"""数据源状态模式"""
|
||||
name: str
|
||||
is_available: bool
|
||||
last_check: datetime
|
||||
response_time_ms: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class DataSourcesStatusResponse(BaseModel):
|
||||
"""数据源状态响应模式"""
|
||||
sources: List[DataSourceStatus]
|
||||
overall_status: str # "healthy", "degraded", "down"
|
||||
42
backend/app/schemas/progress.py
Normal file
42
backend/app/schemas/progress.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""
|
||||
进度追踪相关的Pydantic模式
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class StepTiming(BaseModel):
|
||||
"""步骤计时模式"""
|
||||
step_name: str = Field(..., description="步骤名称")
|
||||
step_order: int = Field(..., description="步骤顺序")
|
||||
status: str = Field(..., description="步骤状态")
|
||||
started_at: Optional[datetime] = Field(None, description="开始时间")
|
||||
completed_at: Optional[datetime] = Field(None, description="完成时间")
|
||||
duration_ms: Optional[int] = Field(None, description="耗时(毫秒)")
|
||||
error_message: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
"""进度响应模式"""
|
||||
report_id: UUID = Field(..., description="报告ID")
|
||||
current_step: int = Field(..., description="当前步骤")
|
||||
total_steps: int = Field(..., description="总步骤数")
|
||||
current_step_name: str = Field(..., description="当前步骤名称")
|
||||
status: str = Field(..., description="整体状态")
|
||||
step_timings: List[StepTiming] = Field([], description="步骤计时列表")
|
||||
estimated_remaining: Optional[int] = Field(None, description="预估剩余时间(秒)")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProgressResetResponse(BaseModel):
|
||||
"""进度重置响应模式"""
|
||||
message: str = Field(..., description="操作结果消息")
|
||||
report_id: str = Field(..., description="报告ID")
|
||||
91
backend/app/schemas/report.py
Normal file
91
backend/app/schemas/report.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
报告相关的Pydantic模式
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class AnalysisModuleSchema(BaseModel):
|
||||
"""分析模块模式"""
|
||||
id: UUID
|
||||
module_type: str
|
||||
module_order: int
|
||||
title: str
|
||||
content: Optional[Dict[str, Any]] = None
|
||||
status: str
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ReportBase(BaseModel):
|
||||
"""报告基础模式"""
|
||||
symbol: str = Field(..., description="证券代码")
|
||||
market: str = Field(..., description="交易市场")
|
||||
|
||||
|
||||
class ReportCreate(ReportBase):
|
||||
"""创建报告请求模式"""
|
||||
pass
|
||||
|
||||
|
||||
class ReportUpdate(BaseModel):
|
||||
"""更新报告请求模式"""
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ReportResponse(ReportBase):
|
||||
"""报告响应模式"""
|
||||
id: UUID
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
analysis_modules: List[AnalysisModuleSchema] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ReportListResponse(BaseModel):
|
||||
"""报告列表响应模式"""
|
||||
reports: List[ReportResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class DeleteReportResponse(BaseModel):
|
||||
"""删除报告响应模式"""
|
||||
message: str
|
||||
report_id: str
|
||||
|
||||
|
||||
class RegenerateRequest(BaseModel):
|
||||
"""重新生成报告请求模式"""
|
||||
force: bool = Field(False, description="是否强制重新生成")
|
||||
|
||||
|
||||
class AIAnalysisRequest(BaseModel):
|
||||
"""AI分析请求模式"""
|
||||
symbol: str = Field(..., description="证券代码")
|
||||
market: str = Field(..., description="交易市场")
|
||||
analysis_type: str = Field(..., description="分析类型")
|
||||
context_data: Optional[Dict[str, Any]] = Field(None, description="上下文数据")
|
||||
|
||||
|
||||
class AIAnalysisResponse(BaseModel):
|
||||
"""AI分析响应模式"""
|
||||
symbol: str
|
||||
market: str
|
||||
analysis_type: str
|
||||
content: Dict[str, Any]
|
||||
model_used: str
|
||||
generated_at: datetime
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
16
backend/app/services/__init__.py
Normal file
16
backend/app/services/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
业务服务包
|
||||
包含核心业务逻辑实现
|
||||
"""
|
||||
|
||||
from .config_manager import ConfigManager
|
||||
from .data_fetcher import DataFetcher
|
||||
from .report_generator import ReportGenerator
|
||||
from .progress_tracker import ProgressTracker
|
||||
|
||||
__all__ = [
|
||||
"ConfigManager",
|
||||
"DataFetcher",
|
||||
"ReportGenerator",
|
||||
"ProgressTracker"
|
||||
]
|
||||
803
backend/app/services/ai_analyzer.py
Normal file
803
backend/app/services/ai_analyzer.py
Normal file
@ -0,0 +1,803 @@
|
||||
"""
|
||||
AI分析服务
|
||||
处理Gemini API集成和AI分析功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import httpx
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.exceptions import (
|
||||
AIAnalysisError,
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
RateLimitError
|
||||
)
|
||||
from ..schemas.report import AIAnalysisRequest, AIAnalysisResponse
|
||||
|
||||
|
||||
class GeminiAnalyzer:
|
||||
"""Gemini AI分析器"""
|
||||
|
||||
def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None):
|
||||
if not api_key:
|
||||
raise AuthenticationError("gemini", {"message": "Gemini API密钥未配置"})
|
||||
|
||||
self.api_key = api_key
|
||||
self.config = config or {}
|
||||
self.base_url = self.config.get("base_url", "https://generativelanguage.googleapis.com/v1beta")
|
||||
self.model = self.config.get("model", "gemini-pro")
|
||||
self.timeout = self.config.get("timeout", 60)
|
||||
self.max_retries = self.config.get("max_retries", 3)
|
||||
self.retry_delay = self.config.get("retry_delay", 2)
|
||||
|
||||
# 生成配置
|
||||
self.generation_config = {
|
||||
"temperature": self.config.get("temperature", 0.7),
|
||||
"top_p": self.config.get("top_p", 0.8),
|
||||
"top_k": self.config.get("top_k", 40),
|
||||
"max_output_tokens": self.config.get("max_output_tokens", 8192),
|
||||
}
|
||||
|
||||
async def analyze_business_info(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""分析公司业务信息"""
|
||||
prompt = self._build_business_info_prompt(symbol, market, financial_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="business_info",
|
||||
content=self._parse_business_info_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"业务信息分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_fundamental(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""基本面分析(景林模型)"""
|
||||
prompt = self._build_fundamental_analysis_prompt(symbol, market, financial_data, business_info)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="fundamental_analysis",
|
||||
content=self._parse_fundamental_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"基本面分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_bullish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""看涨分析(隐藏资产、护城河分析)"""
|
||||
prompt = self._build_bullish_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="bullish_analysis",
|
||||
content=self._parse_bullish_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"看涨分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_bearish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""看跌分析(价值底线、最坏情况分析)"""
|
||||
prompt = self._build_bearish_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="bearish_analysis",
|
||||
content=self._parse_bearish_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"看跌分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_market_sentiment(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""市场情绪分析"""
|
||||
prompt = self._build_market_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="market_analysis",
|
||||
content=self._parse_market_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"市场分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_news_catalysts(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""新闻催化剂分析"""
|
||||
prompt = self._build_news_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="news_analysis",
|
||||
content=self._parse_news_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"新闻分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_trading_dynamics(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""交易动态分析"""
|
||||
prompt = self._build_trading_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="trading_analysis",
|
||||
content=self._parse_trading_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"交易分析失败: {str(e)}", self.model)
|
||||
|
||||
async def analyze_insider_institutional(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""内部人与机构动向分析"""
|
||||
prompt = self._build_insider_analysis_prompt(symbol, market, context_data)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="insider_analysis",
|
||||
content=self._parse_insider_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"内部人分析失败: {str(e)}", self.model)
|
||||
|
||||
async def generate_final_conclusion(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> AIAnalysisResponse:
|
||||
"""生成最终结论"""
|
||||
prompt = self._build_conclusion_prompt(symbol, market, all_analyses)
|
||||
|
||||
try:
|
||||
result = await self._retry_request(self._call_gemini_api, prompt)
|
||||
|
||||
return AIAnalysisResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
analysis_type="final_conclusion",
|
||||
content=self._parse_conclusion_response(result),
|
||||
model_used=self.model,
|
||||
generated_at=datetime.now()
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (AIAnalysisError, APIError)):
|
||||
raise
|
||||
raise AIAnalysisError(f"最终结论生成失败: {str(e)}", self.model)
|
||||
|
||||
async def _call_gemini_api(self, prompt: str) -> str:
|
||||
"""调用Gemini API"""
|
||||
url = f"{self.base_url}/models/{self.model}:generateContent"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": self.generation_config
|
||||
}
|
||||
|
||||
params = {
|
||||
"key": self.api_key
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(url, json=data, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
error_code = error.get("code", 0)
|
||||
error_message = error.get("message", "Unknown error")
|
||||
|
||||
if error_code == 401 or "API key" in error_message:
|
||||
raise AuthenticationError("gemini", {"message": error_message})
|
||||
elif error_code == 429 or "quota" in error_message.lower():
|
||||
raise RateLimitError("gemini")
|
||||
else:
|
||||
raise APIError(f"Gemini API错误: {error_message}", error_code, "gemini")
|
||||
|
||||
candidates = result.get("candidates", [])
|
||||
if not candidates:
|
||||
raise AIAnalysisError("Gemini API返回空结果", self.model)
|
||||
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
if not parts:
|
||||
raise AIAnalysisError("Gemini API返回内容为空", self.model)
|
||||
|
||||
return parts[0].get("text", "")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError("gemini", {"status_code": e.response.status_code})
|
||||
elif e.response.status_code == 429:
|
||||
raise RateLimitError("gemini")
|
||||
else:
|
||||
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "gemini")
|
||||
except httpx.RequestError as e:
|
||||
raise AIAnalysisError(f"网络请求失败: {str(e)}", self.model)
|
||||
|
||||
async def _retry_request(self, func, *args, **kwargs):
|
||||
"""重试机制"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (httpx.TimeoutException, httpx.ConnectError, AIAnalysisError) as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
|
||||
continue
|
||||
except (AuthenticationError, RateLimitError) as e:
|
||||
# 认证错误和频率限制错误不重试
|
||||
raise e
|
||||
except Exception as e:
|
||||
# 其他异常不重试
|
||||
raise e
|
||||
|
||||
# 所有重试都失败了
|
||||
raise AIAnalysisError(
|
||||
f"Gemini API请求失败,已重试 {self.max_retries} 次: {str(last_exception)}",
|
||||
self.model
|
||||
)
|
||||
|
||||
def _build_business_info_prompt(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> str:
|
||||
"""构建业务信息分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行全面的业务信息分析。
|
||||
|
||||
基于以下财务数据:
|
||||
{json.dumps(financial_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请提供以下内容的详细分析:
|
||||
|
||||
1. 公司概览
|
||||
- 公司基本信息和历史背景
|
||||
- 主要业务领域和市场地位
|
||||
|
||||
2. 主营业务分析
|
||||
- 核心产品和服务
|
||||
- 业务模式和盈利模式
|
||||
- 主要收入来源构成
|
||||
|
||||
3. 发展历程
|
||||
- 重要发展里程碑
|
||||
- 业务转型和扩张历史
|
||||
|
||||
4. 核心团队
|
||||
- 管理层背景和经验
|
||||
- 关键人员变动情况
|
||||
|
||||
5. 供应链分析
|
||||
- 主要供应商和客户
|
||||
- 供应链风险和优势
|
||||
|
||||
6. 销售模式
|
||||
- 销售渠道和策略
|
||||
- 市场覆盖和客户群体
|
||||
|
||||
7. 未来展望
|
||||
- 发展战略和规划
|
||||
- 市场机遇和挑战
|
||||
|
||||
请用中文回答,内容要详实、客观,基于可获得的公开信息进行分析。
|
||||
"""
|
||||
|
||||
def _build_fundamental_analysis_prompt(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> str:
|
||||
"""构建基本面分析提示词(景林模型)"""
|
||||
return f"""
|
||||
请使用景林投资的基本面分析框架,对股票代码 {symbol}({market}市场)进行深度分析。
|
||||
|
||||
财务数据:
|
||||
{json.dumps(financial_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
业务信息:
|
||||
{json.dumps(business_info, ensure_ascii=False, indent=2)}
|
||||
|
||||
请按照以下景林模型问题集进行分析:
|
||||
|
||||
1. 商业模式分析
|
||||
- 这是一门什么样的生意?
|
||||
- 商业模式的核心竞争力是什么?
|
||||
- 盈利模式是否可持续?
|
||||
|
||||
2. 行业地位分析
|
||||
- 公司在行业中的地位如何?
|
||||
- 市场份额和竞争优势?
|
||||
- 行业发展趋势对公司的影响?
|
||||
|
||||
3. 财务质量分析
|
||||
- 收入增长的质量如何?
|
||||
- 盈利能力和现金流状况?
|
||||
- 资产负债结构是否健康?
|
||||
|
||||
4. 管理层评估
|
||||
- 管理层的能力和诚信度?
|
||||
- 公司治理结构是否完善?
|
||||
- 股东利益是否得到保护?
|
||||
|
||||
5. 估值分析
|
||||
- 当前估值水平是否合理?
|
||||
- 与同行业公司比较如何?
|
||||
- 未来增长预期是否支撑估值?
|
||||
|
||||
请用中文提供详细、专业的分析,每个方面都要有具体的数据支撑和逻辑推理。
|
||||
"""
|
||||
|
||||
def _build_bullish_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建看涨分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行看涨分析,重点关注隐藏资产和护城河竞争优势。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行看涨分析:
|
||||
|
||||
1. 隐藏资产发现
|
||||
- 账面价值被低估的资产
|
||||
- 无形资产的真实价值
|
||||
- 潜在的资产重估机会
|
||||
|
||||
2. 护城河分析
|
||||
- 品牌价值和客户忠诚度
|
||||
- 技术壁垒和专利保护
|
||||
- 规模经济和网络效应
|
||||
- 转换成本和客户粘性
|
||||
|
||||
3. 成长潜力
|
||||
- 新业务和新市场机会
|
||||
- 产品创新和技术升级
|
||||
- 市场扩张的可能性
|
||||
|
||||
4. 催化剂识别
|
||||
- 短期可能的积极因素
|
||||
- 政策支持和行业利好
|
||||
- 公司内部改革和优化
|
||||
|
||||
5. 最佳情况假设
|
||||
- 如果一切顺利,公司价值可能达到什么水平?
|
||||
- 关键假设和实现路径
|
||||
|
||||
请用中文提供乐观但理性的分析,要有具体的逻辑支撑。
|
||||
"""
|
||||
|
||||
def _build_bearish_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建看跌分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行看跌分析,重点关注价值底线和最坏情况。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行看跌分析:
|
||||
|
||||
1. 价值底线分析
|
||||
- 清算价值估算
|
||||
- 资产的最低合理价值
|
||||
- 下行风险的底线在哪里
|
||||
|
||||
2. 主要风险因素
|
||||
- 行业周期性风险
|
||||
- 竞争加剧的威胁
|
||||
- 技术替代的可能性
|
||||
- 监管政策变化风险
|
||||
|
||||
3. 财务脆弱性
|
||||
- 债务压力和流动性风险
|
||||
- 现金流恶化的可能性
|
||||
- 盈利能力下降的风险
|
||||
|
||||
4. 管理层风险
|
||||
- 决策失误的历史
|
||||
- 治理结构的缺陷
|
||||
- 利益冲突的可能性
|
||||
|
||||
5. 最坏情况假设
|
||||
- 如果一切都出错,公司价值可能跌到什么水平?
|
||||
- 关键风险因素和触发条件
|
||||
|
||||
请用中文提供谨慎但客观的分析,要有具体的风险量化。
|
||||
"""
|
||||
|
||||
def _build_market_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建市场分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行市场情绪分析,重点关注分歧点与变化驱动。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行市场分析:
|
||||
|
||||
1. 市场情绪评估
|
||||
- 当前市场对该股票的主流观点
|
||||
- 机构投资者的持仓变化
|
||||
- 散户投资者的情绪指标
|
||||
|
||||
2. 分歧点识别
|
||||
- 市场存在哪些主要分歧?
|
||||
- 乐观派和悲观派的核心观点
|
||||
- 分歧的根本原因是什么?
|
||||
|
||||
3. 变化驱动因素
|
||||
- 什么因素可能改变市场共识?
|
||||
- 关键数据点和时间节点
|
||||
- 外部环境变化的影响
|
||||
|
||||
4. 资金流向分析
|
||||
- 主力资金的进出情况
|
||||
- 不同类型投资者的行为模式
|
||||
- 流动性状况评估
|
||||
|
||||
5. 市场预期vs现实
|
||||
- 市场预期是否过于乐观或悲观?
|
||||
- 预期差的投资机会在哪里?
|
||||
|
||||
请用中文提供专业的市场分析,要有数据支撑和逻辑推理。
|
||||
"""
|
||||
|
||||
def _build_news_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建新闻分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行新闻催化剂分析,重点关注股价拐点预判。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行新闻分析:
|
||||
|
||||
1. 近期重要新闻梳理
|
||||
- 公司公告和重大事件
|
||||
- 行业政策和监管变化
|
||||
- 宏观经济相关新闻
|
||||
|
||||
2. 催化剂识别
|
||||
- 正面催化剂(业绩超预期、政策利好等)
|
||||
- 负面催化剂(风险事件、竞争加剧等)
|
||||
- 中性但重要的信息
|
||||
|
||||
3. 拐点预判
|
||||
- 基本面拐点的可能时间
|
||||
- 市场情绪拐点的信号
|
||||
- 技术面拐点的确认
|
||||
|
||||
4. 新闻影响评估
|
||||
- 短期影响vs长期影响
|
||||
- 市场反应是否充分
|
||||
- 后续发展的可能路径
|
||||
|
||||
5. 关注要点
|
||||
- 未来需要重点关注的事件
|
||||
- 可能的风险点和机会点
|
||||
- 时间窗口和操作建议
|
||||
|
||||
请用中文提供前瞻性的分析,要有时间维度和影响程度的判断。
|
||||
"""
|
||||
|
||||
def _build_trading_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建交易分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行交易分析,重点关注市场体量与增长路径。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行交易分析:
|
||||
|
||||
1. 市场体量分析
|
||||
- 总市值和流通市值
|
||||
- 日均成交量和换手率
|
||||
- 市场容量和流动性评估
|
||||
|
||||
2. 增长路径分析
|
||||
- 历史增长轨迹和驱动因素
|
||||
- 未来增长的可能路径
|
||||
- 增长的可持续性评估
|
||||
|
||||
3. 交易特征分析
|
||||
- 股价波动特征和规律
|
||||
- 主要交易时段和模式
|
||||
- 大宗交易和异常交易情况
|
||||
|
||||
4. 技术面分析
|
||||
- 关键技术位和支撑阻力
|
||||
- 趋势线和形态分析
|
||||
- 技术指标的信号
|
||||
|
||||
5. 交易策略建议
|
||||
- 适合的交易时机和方式
|
||||
- 风险控制和仓位管理
|
||||
- 进出场点位的选择
|
||||
|
||||
请用中文提供实用的交易分析,要有具体的数据和操作建议。
|
||||
"""
|
||||
|
||||
def _build_insider_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
|
||||
"""构建内部人分析提示词"""
|
||||
return f"""
|
||||
请对股票代码 {symbol}({market}市场)进行内部人与机构动向分析。
|
||||
|
||||
基础数据:
|
||||
{json.dumps(context_data, ensure_ascii=False, indent=2)}
|
||||
|
||||
请从以下角度进行分析:
|
||||
|
||||
1. 内部人交易分析
|
||||
- 高管和大股东的买卖行为
|
||||
- 内部人交易的时机和规模
|
||||
- 内部人交易的信号意义
|
||||
|
||||
2. 机构持仓分析
|
||||
- 主要机构投资者的持仓变化
|
||||
- 新进和退出的机构情况
|
||||
- 机构持仓集中度分析
|
||||
|
||||
3. 股东结构变化
|
||||
- 股权结构的演变趋势
|
||||
- 重要股东的进出情况
|
||||
- 股权激励和员工持股情况
|
||||
|
||||
4. 资金流向追踪
|
||||
- 大资金的进出时机
|
||||
- 不同类型资金的偏好
|
||||
- 资金成本和收益预期
|
||||
|
||||
5. 动向信号解读
|
||||
- 内部人和机构行为的一致性
|
||||
- 与股价走势的相关性
|
||||
- 对未来走势的指示意义
|
||||
|
||||
请用中文提供专业的分析,要有数据支撑和逻辑推理。
|
||||
"""
|
||||
|
||||
def _build_conclusion_prompt(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> str:
|
||||
"""构建最终结论提示词"""
|
||||
analyses_text = "\n\n".join([
|
||||
f"{analysis.get('analysis_type', '未知分析')}:\n{json.dumps(analysis.get('content', {}), ensure_ascii=False, indent=2)}"
|
||||
for analysis in all_analyses
|
||||
])
|
||||
|
||||
return f"""
|
||||
基于以下所有分析结果,请对股票代码 {symbol}({market}市场)给出最终投资结论。
|
||||
|
||||
所有分析结果:
|
||||
{analyses_text}
|
||||
|
||||
请从以下角度进行综合分析:
|
||||
|
||||
1. 关键矛盾识别
|
||||
- 当前最核心的投资矛盾是什么?
|
||||
- 不同分析维度的结论是否一致?
|
||||
- 主要的不确定性因素有哪些?
|
||||
|
||||
2. 预期差分析
|
||||
- 市场预期与实际情况的差异
|
||||
- 可能被忽视或误解的关键信息
|
||||
- 预期差带来的投资机会
|
||||
|
||||
3. 拐点临近性判断
|
||||
- 基本面拐点的时间和概率
|
||||
- 市场情绪拐点的信号
|
||||
- 催化剂的时效性分析
|
||||
|
||||
4. 风险收益评估
|
||||
- 上行空间和下行风险的量化
|
||||
- 风险调整后的收益预期
|
||||
- 投资的风险收益比
|
||||
|
||||
5. 最终投资建议
|
||||
- 明确的投资观点(看多/看空/中性)
|
||||
- 建议的投资时间框架
|
||||
- 关键的跟踪指标和退出条件
|
||||
|
||||
请用中文提供清晰、明确的投资结论,要有逻辑性和可操作性。
|
||||
"""
|
||||
|
||||
def _parse_business_info_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析业务信息分析响应"""
|
||||
return {
|
||||
"company_overview": self._extract_section(response, "公司概览"),
|
||||
"main_business": self._extract_section(response, "主营业务"),
|
||||
"development_history": self._extract_section(response, "发展历程"),
|
||||
"core_team": self._extract_section(response, "核心团队"),
|
||||
"supply_chain": self._extract_section(response, "供应链"),
|
||||
"sales_model": self._extract_section(response, "销售模式"),
|
||||
"future_outlook": self._extract_section(response, "未来展望"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_fundamental_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析基本面分析响应"""
|
||||
return {
|
||||
"business_model": self._extract_section(response, "商业模式"),
|
||||
"industry_position": self._extract_section(response, "行业地位"),
|
||||
"financial_quality": self._extract_section(response, "财务质量"),
|
||||
"management_assessment": self._extract_section(response, "管理层"),
|
||||
"valuation_analysis": self._extract_section(response, "估值分析"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_bullish_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析看涨分析响应"""
|
||||
return {
|
||||
"hidden_assets": self._extract_section(response, "隐藏资产"),
|
||||
"moat_analysis": self._extract_section(response, "护城河"),
|
||||
"growth_potential": self._extract_section(response, "成长潜力"),
|
||||
"catalysts": self._extract_section(response, "催化剂"),
|
||||
"best_case": self._extract_section(response, "最佳情况"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_bearish_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析看跌分析响应"""
|
||||
return {
|
||||
"value_floor": self._extract_section(response, "价值底线"),
|
||||
"risk_factors": self._extract_section(response, "风险因素"),
|
||||
"financial_vulnerability": self._extract_section(response, "财务脆弱性"),
|
||||
"management_risks": self._extract_section(response, "管理层风险"),
|
||||
"worst_case": self._extract_section(response, "最坏情况"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_market_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析市场分析响应"""
|
||||
return {
|
||||
"market_sentiment": self._extract_section(response, "市场情绪"),
|
||||
"disagreement_points": self._extract_section(response, "分歧点"),
|
||||
"change_drivers": self._extract_section(response, "变化驱动"),
|
||||
"capital_flow": self._extract_section(response, "资金流向"),
|
||||
"expectation_vs_reality": self._extract_section(response, "预期vs现实"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_news_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析新闻分析响应"""
|
||||
return {
|
||||
"recent_news": self._extract_section(response, "重要新闻"),
|
||||
"catalysts": self._extract_section(response, "催化剂"),
|
||||
"inflection_points": self._extract_section(response, "拐点预判"),
|
||||
"news_impact": self._extract_section(response, "影响评估"),
|
||||
"focus_points": self._extract_section(response, "关注要点"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_trading_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析交易分析响应"""
|
||||
return {
|
||||
"market_size": self._extract_section(response, "市场体量"),
|
||||
"growth_path": self._extract_section(response, "增长路径"),
|
||||
"trading_characteristics": self._extract_section(response, "交易特征"),
|
||||
"technical_analysis": self._extract_section(response, "技术面"),
|
||||
"trading_strategy": self._extract_section(response, "交易策略"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_insider_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析内部人分析响应"""
|
||||
return {
|
||||
"insider_trading": self._extract_section(response, "内部人交易"),
|
||||
"institutional_holdings": self._extract_section(response, "机构持仓"),
|
||||
"ownership_changes": self._extract_section(response, "股东结构"),
|
||||
"capital_flow": self._extract_section(response, "资金流向"),
|
||||
"signal_interpretation": self._extract_section(response, "信号解读"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _parse_conclusion_response(self, response: str) -> Dict[str, Any]:
|
||||
"""解析最终结论响应"""
|
||||
return {
|
||||
"key_contradictions": self._extract_section(response, "关键矛盾"),
|
||||
"expectation_gap": self._extract_section(response, "预期差"),
|
||||
"inflection_timing": self._extract_section(response, "拐点临近性"),
|
||||
"risk_return": self._extract_section(response, "风险收益"),
|
||||
"investment_recommendation": self._extract_section(response, "投资建议"),
|
||||
"full_analysis": response
|
||||
}
|
||||
|
||||
def _extract_section(self, text: str, section_name: str) -> str:
|
||||
"""从文本中提取特定章节内容"""
|
||||
lines = text.split('\n')
|
||||
section_content = []
|
||||
in_section = False
|
||||
|
||||
for line in lines:
|
||||
if section_name in line and ('.' in line or ':' in line or ':' in line):
|
||||
in_section = True
|
||||
section_content.append(line)
|
||||
continue
|
||||
|
||||
if in_section:
|
||||
if line.strip() and any(keyword in line for keyword in ['1.', '2.', '3.', '4.', '5.']) and section_name not in line:
|
||||
# 遇到下一个主要章节,停止
|
||||
break
|
||||
section_content.append(line)
|
||||
|
||||
return '\n'.join(section_content).strip()
|
||||
|
||||
|
||||
class AIAnalyzerFactory:
|
||||
"""AI分析器工厂"""
|
||||
|
||||
@classmethod
|
||||
def create_gemini_analyzer(cls, api_key: str, config: Optional[Dict[str, Any]] = None) -> GeminiAnalyzer:
|
||||
"""创建Gemini分析器"""
|
||||
return GeminiAnalyzer(api_key, config)
|
||||
|
||||
@classmethod
|
||||
def create_analyzer(cls, analyzer_type: str, **kwargs) -> GeminiAnalyzer:
|
||||
"""创建分析器(可扩展支持其他AI服务)"""
|
||||
if analyzer_type.lower() == "gemini":
|
||||
return cls.create_gemini_analyzer(kwargs.get("api_key"), kwargs.get("config"))
|
||||
else:
|
||||
raise AIAnalysisError(f"不支持的AI分析器类型: {analyzer_type}")
|
||||
260
backend/app/services/config_manager.py
Normal file
260
backend/app/services/config_manager.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""
|
||||
配置管理服务
|
||||
处理系统配置的读取、更新和验证
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
from ..models.system_config import SystemConfig
|
||||
from ..schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestResponse
|
||||
from ..core.exceptions import ConfigurationError, DatabaseError, APIError
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def get_config(self) -> ConfigResponse:
|
||||
"""获取系统配置"""
|
||||
try:
|
||||
# 查询所有配置
|
||||
result = await self.db.execute(select(SystemConfig))
|
||||
configs = result.scalars().all()
|
||||
|
||||
# 组织配置数据
|
||||
config_dict = {config.config_key: config.config_value for config in configs}
|
||||
|
||||
return ConfigResponse(
|
||||
database=config_dict.get("database"),
|
||||
gemini_api=config_dict.get("gemini_api"),
|
||||
data_sources=config_dict.get("data_sources", {})
|
||||
)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"获取配置失败: {str(e)}", "get_config")
|
||||
|
||||
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
|
||||
"""更新系统配置"""
|
||||
try:
|
||||
# 更新数据库配置
|
||||
if config_update.database:
|
||||
await self._update_config_item("database", config_update.database.dict())
|
||||
|
||||
# 更新Gemini API配置
|
||||
if config_update.gemini_api:
|
||||
await self._update_config_item("gemini_api", config_update.gemini_api.dict())
|
||||
|
||||
# 更新数据源配置
|
||||
if config_update.data_sources:
|
||||
data_sources_dict = {k: v.dict() for k, v in config_update.data_sources.items()}
|
||||
await self._update_config_item("data_sources", data_sources_dict)
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# 返回更新后的配置
|
||||
return await self.get_config()
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
raise DatabaseError(f"更新配置失败: {str(e)}", "update_config")
|
||||
|
||||
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
"""测试配置连接"""
|
||||
|
||||
try:
|
||||
if config_type == "database":
|
||||
return await self._test_database_config(config_data)
|
||||
elif config_type == "gemini":
|
||||
return await self._test_gemini_config(config_data)
|
||||
elif config_type == "data_source":
|
||||
return await self._test_data_source_config(config_data)
|
||||
else:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"不支持的配置类型: {config_type}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"配置测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _update_config_item(self, key: str, value: Dict[str, Any]):
|
||||
"""更新单个配置项"""
|
||||
# 查询现有配置
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == key)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.config_value = value
|
||||
else:
|
||||
# 创建新配置
|
||||
config = SystemConfig(config_key=key, config_value=value)
|
||||
self.db.add(config)
|
||||
|
||||
async def _test_database_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
"""测试数据库配置"""
|
||||
try:
|
||||
# 尝试创建数据库连接
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
db_url = config_data.get("url")
|
||||
if not db_url:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message="数据库URL未配置"
|
||||
)
|
||||
|
||||
# 创建临时引擎测试连接
|
||||
test_engine = create_async_engine(db_url, echo=False)
|
||||
|
||||
# 测试连接
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.execute("SELECT 1")
|
||||
|
||||
await test_engine.dispose()
|
||||
|
||||
return ConfigTestResponse(
|
||||
success=True,
|
||||
message="数据库连接测试成功"
|
||||
)
|
||||
except Exception as e:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"数据库连接测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _test_gemini_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
"""测试Gemini API配置"""
|
||||
try:
|
||||
api_key = config_data.get("api_key")
|
||||
if not api_key:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message="Gemini API密钥未配置"
|
||||
)
|
||||
|
||||
# 测试API调用
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
# 这里应该调用实际的Gemini API端点进行测试
|
||||
# 暂时模拟成功
|
||||
await asyncio.sleep(0.1) # 模拟网络延迟
|
||||
|
||||
return ConfigTestResponse(
|
||||
success=True,
|
||||
message="Gemini API连接测试成功"
|
||||
)
|
||||
except Exception as e:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"Gemini API连接测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def get_data_source_config(self, market: str) -> Dict[str, Any]:
|
||||
"""获取指定市场的数据源配置"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "data_sources")
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise ConfigurationError("数据源配置未找到", "data_sources")
|
||||
|
||||
data_sources = config.config_value
|
||||
|
||||
# 根据市场选择数据源
|
||||
market_lower = market.lower()
|
||||
if market_lower == "china":
|
||||
if "tushare" in data_sources:
|
||||
return data_sources["tushare"]
|
||||
else:
|
||||
raise ConfigurationError("中国市场数据源(Tushare)未配置", "tushare")
|
||||
else:
|
||||
# 其他市场使用Yahoo Finance
|
||||
if "yahoo" in data_sources:
|
||||
return data_sources["yahoo"]
|
||||
else:
|
||||
raise ConfigurationError("国际市场数据源(Yahoo)未配置", "yahoo")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ConfigurationError):
|
||||
raise
|
||||
raise ConfigurationError(f"获取数据源配置失败: {str(e)}", "data_sources")
|
||||
|
||||
async def get_gemini_config(self) -> Dict[str, Any]:
|
||||
"""获取Gemini API配置"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(SystemConfig).where(SystemConfig.config_key == "gemini_api")
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
raise ConfigurationError("Gemini API配置未找到", "gemini_api")
|
||||
|
||||
gemini_config = config.config_value
|
||||
|
||||
if not gemini_config.get("api_key"):
|
||||
raise ConfigurationError("Gemini API密钥未配置", "gemini_api")
|
||||
|
||||
return gemini_config
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ConfigurationError):
|
||||
raise
|
||||
raise ConfigurationError(f"获取Gemini配置失败: {str(e)}", "gemini_api")
|
||||
|
||||
async def _test_data_source_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
|
||||
"""测试数据源配置"""
|
||||
try:
|
||||
name = config_data.get("name")
|
||||
api_key = config_data.get("api_key")
|
||||
base_url = config_data.get("base_url")
|
||||
timeout = config_data.get("timeout", 30)
|
||||
|
||||
if not name:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message="数据源名称未配置"
|
||||
)
|
||||
|
||||
# 根据数据源类型进行不同的测试
|
||||
if name.lower() == "tushare":
|
||||
if not api_key:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message="Tushare API密钥未配置"
|
||||
)
|
||||
# 测试Tushare API
|
||||
# 暂时模拟成功
|
||||
await asyncio.sleep(0.1)
|
||||
elif name.lower() == "yahoo":
|
||||
# 测试Yahoo Finance API
|
||||
if base_url:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f"{base_url}/health", timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"数据源API返回错误状态码: {response.status_code}"
|
||||
)
|
||||
|
||||
return ConfigTestResponse(
|
||||
success=True,
|
||||
message=f"数据源 {name} 连接测试成功"
|
||||
)
|
||||
except Exception as e:
|
||||
return ConfigTestResponse(
|
||||
success=False,
|
||||
message=f"数据源连接测试失败: {str(e)}"
|
||||
)
|
||||
673
backend/app/services/data_fetcher.py
Normal file
673
backend/app/services/data_fetcher.py
Normal file
@ -0,0 +1,673 @@
|
||||
"""
|
||||
数据获取服务基础架构
|
||||
处理外部数据源的数据获取
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import httpx
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from ..schemas.data import (
|
||||
FinancialDataResponse,
|
||||
MarketDataResponse,
|
||||
SymbolValidationResponse,
|
||||
DataSourceStatus
|
||||
)
|
||||
from ..core.exceptions import (
|
||||
DataSourceError,
|
||||
APIError,
|
||||
SymbolNotFoundError,
|
||||
RateLimitError,
|
||||
AuthenticationError
|
||||
)
|
||||
|
||||
|
||||
class DataFetcher(ABC):
|
||||
"""数据获取服务基类"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.name = config.get("name", "unknown")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.max_retries = config.get("max_retries", 3)
|
||||
self.retry_delay = config.get("retry_delay", 1)
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
||||
"""获取财务数据"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
||||
"""获取市场数据"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
||||
"""验证证券代码"""
|
||||
pass
|
||||
|
||||
async def check_status(self) -> DataSourceStatus:
|
||||
"""检查数据源状态"""
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
# 尝试进行简单的健康检查
|
||||
await self._health_check()
|
||||
end_time = datetime.now()
|
||||
response_time = int((end_time - start_time).total_seconds() * 1000)
|
||||
|
||||
return DataSourceStatus(
|
||||
name=self.name,
|
||||
is_available=True,
|
||||
last_check=end_time,
|
||||
response_time_ms=response_time
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
return DataSourceStatus(
|
||||
name=self.name,
|
||||
is_available=False,
|
||||
last_check=end_time,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _health_check(self):
|
||||
"""健康检查实现"""
|
||||
pass
|
||||
|
||||
async def _retry_request(self, func, *args, **kwargs):
|
||||
"""重试机制"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except (httpx.TimeoutException, httpx.ConnectError) as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
|
||||
continue
|
||||
except Exception as e:
|
||||
# 对于其他类型的异常,不重试
|
||||
raise e
|
||||
|
||||
# 所有重试都失败了
|
||||
raise DataSourceError(
|
||||
f"数据源 {self.name} 请求失败,已重试 {self.max_retries} 次",
|
||||
self.name,
|
||||
{"last_error": str(last_exception)}
|
||||
)
|
||||
|
||||
|
||||
class TushareDataFetcher(DataFetcher):
|
||||
"""Tushare数据获取器"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.token = config.get("api_key") or config.get("token")
|
||||
self.base_url = config.get("base_url", "http://api.tushare.pro")
|
||||
|
||||
if not self.token:
|
||||
raise AuthenticationError("tushare", {"message": "Tushare API token未配置"})
|
||||
|
||||
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
||||
"""获取财务数据"""
|
||||
try:
|
||||
# 转换证券代码格式
|
||||
ts_code = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的Tushare API调用
|
||||
# 这里暂时返回模拟数据
|
||||
financial_data = await self._retry_request(self._fetch_tushare_financial, ts_code)
|
||||
|
||||
return FinancialDataResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
data_source="tushare",
|
||||
last_updated=datetime.now(),
|
||||
balance_sheet=financial_data.get("balance_sheet"),
|
||||
income_statement=financial_data.get("income_statement"),
|
||||
cash_flow=financial_data.get("cash_flow"),
|
||||
key_metrics=financial_data.get("key_metrics")
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (DataSourceError, APIError)):
|
||||
raise
|
||||
raise DataSourceError(f"获取财务数据失败: {str(e)}", "tushare")
|
||||
|
||||
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
||||
"""获取市场数据"""
|
||||
try:
|
||||
ts_code = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的Tushare API调用
|
||||
market_data = await self._retry_request(self._fetch_tushare_market, ts_code)
|
||||
|
||||
return MarketDataResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
data_source="tushare",
|
||||
last_updated=datetime.now(),
|
||||
price_data=market_data.get("price_data"),
|
||||
volume_data=market_data.get("volume_data"),
|
||||
technical_indicators=market_data.get("technical_indicators")
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (DataSourceError, APIError)):
|
||||
raise
|
||||
raise DataSourceError(f"获取市场数据失败: {str(e)}", "tushare")
|
||||
|
||||
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
||||
"""验证证券代码"""
|
||||
try:
|
||||
ts_code = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的证券代码验证
|
||||
# 暂时模拟验证逻辑
|
||||
is_valid = await self._retry_request(self._validate_tushare_symbol, ts_code)
|
||||
|
||||
return SymbolValidationResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
is_valid=is_valid,
|
||||
company_name="示例公司" if is_valid else None,
|
||||
message="证券代码有效" if is_valid else "证券代码无效"
|
||||
)
|
||||
except Exception as e:
|
||||
return SymbolValidationResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
is_valid=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _health_check(self):
|
||||
"""健康检查"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
# 尝试调用一个简单的API来测试连通性
|
||||
await self._call_tushare_api(client, "stock_basic", {"limit": 1})
|
||||
except Exception as e:
|
||||
raise DataSourceError(f"Tushare健康检查失败: {str(e)}", "tushare")
|
||||
|
||||
def _convert_symbol_format(self, symbol: str, market: str) -> str:
|
||||
"""转换证券代码格式为Tushare格式"""
|
||||
if market.lower() == "china":
|
||||
# 中国股票代码格式转换
|
||||
if symbol.startswith("6"):
|
||||
return f"{symbol}.SH" # 上海证券交易所
|
||||
elif symbol.startswith(("0", "3")):
|
||||
return f"{symbol}.SZ" # 深圳证券交易所
|
||||
return symbol
|
||||
|
||||
async def _fetch_tushare_financial(self, ts_code: str) -> Dict[str, Any]:
|
||||
"""获取Tushare财务数据"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
# 获取资产负债表
|
||||
balance_sheet_data = await self._call_tushare_api(
|
||||
client, "balancesheet", {"ts_code": ts_code, "period": "20231231"}
|
||||
)
|
||||
|
||||
# 获取利润表
|
||||
income_data = await self._call_tushare_api(
|
||||
client, "income", {"ts_code": ts_code, "period": "20231231"}
|
||||
)
|
||||
|
||||
# 获取现金流量表
|
||||
cashflow_data = await self._call_tushare_api(
|
||||
client, "cashflow", {"ts_code": ts_code, "period": "20231231"}
|
||||
)
|
||||
|
||||
# 获取基本财务指标
|
||||
fina_indicator_data = await self._call_tushare_api(
|
||||
client, "fina_indicator", {"ts_code": ts_code, "period": "20231231"}
|
||||
)
|
||||
|
||||
return {
|
||||
"balance_sheet": self._process_balance_sheet(balance_sheet_data),
|
||||
"income_statement": self._process_income_statement(income_data),
|
||||
"cash_flow": self._process_cash_flow(cashflow_data),
|
||||
"key_metrics": self._process_key_metrics(fina_indicator_data)
|
||||
}
|
||||
|
||||
async def _fetch_tushare_market(self, ts_code: str) -> Dict[str, Any]:
|
||||
"""获取Tushare市场数据"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
# 获取日线数据
|
||||
daily_data = await self._call_tushare_api(
|
||||
client, "daily", {"ts_code": ts_code, "start_date": "20240101", "end_date": "20241231"}
|
||||
)
|
||||
|
||||
# 获取基本信息
|
||||
stock_basic_data = await self._call_tushare_api(
|
||||
client, "stock_basic", {"ts_code": ts_code}
|
||||
)
|
||||
|
||||
return {
|
||||
"price_data": self._process_price_data(daily_data),
|
||||
"volume_data": self._process_volume_data(daily_data),
|
||||
"technical_indicators": self._calculate_technical_indicators(daily_data),
|
||||
"stock_info": self._process_stock_basic(stock_basic_data)
|
||||
}
|
||||
|
||||
async def _validate_tushare_symbol(self, ts_code: str) -> bool:
|
||||
"""验证Tushare证券代码"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
result = await self._call_tushare_api(
|
||||
client, "stock_basic", {"ts_code": ts_code}
|
||||
)
|
||||
return bool(result and len(result.get("items", [])) > 0)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _call_tushare_api(self, client: httpx.AsyncClient, api_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用Tushare API"""
|
||||
request_data = {
|
||||
"api_name": api_name,
|
||||
"token": self.token,
|
||||
"params": params,
|
||||
"fields": ""
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post(self.base_url, json=request_data)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if result.get("code") != 0:
|
||||
error_msg = result.get("msg", "Unknown error")
|
||||
if "权限" in error_msg or "token" in error_msg.lower():
|
||||
raise AuthenticationError("tushare", {"message": error_msg})
|
||||
elif "频率" in error_msg or "limit" in error_msg.lower():
|
||||
raise RateLimitError("tushare")
|
||||
else:
|
||||
raise APIError(f"Tushare API错误: {error_msg}", result.get("code"), "tushare")
|
||||
|
||||
return result.get("data", {})
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError("tushare", {"status_code": e.response.status_code})
|
||||
elif e.response.status_code == 429:
|
||||
raise RateLimitError("tushare")
|
||||
else:
|
||||
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "tushare")
|
||||
except httpx.RequestError as e:
|
||||
raise DataSourceError(f"网络请求失败: {str(e)}", "tushare")
|
||||
|
||||
def _process_balance_sheet(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理资产负债表数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
# 取最新一期数据
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
# 将列表数据转换为字典
|
||||
balance_data = dict(zip(fields, latest))
|
||||
else:
|
||||
balance_data = latest
|
||||
|
||||
return {
|
||||
"total_assets": balance_data.get("total_assets", 0),
|
||||
"total_liab": balance_data.get("total_liab", 0),
|
||||
"total_hldr_eqy_exc_min_int": balance_data.get("total_hldr_eqy_exc_min_int", 0),
|
||||
"monetary_cap": balance_data.get("monetary_cap", 0),
|
||||
"accounts_receiv": balance_data.get("accounts_receiv", 0),
|
||||
"inventories": balance_data.get("inventories", 0)
|
||||
}
|
||||
|
||||
def _process_income_statement(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理利润表数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
income_data = dict(zip(fields, latest))
|
||||
else:
|
||||
income_data = latest
|
||||
|
||||
return {
|
||||
"revenue": income_data.get("revenue", 0),
|
||||
"operate_profit": income_data.get("operate_profit", 0),
|
||||
"total_profit": income_data.get("total_profit", 0),
|
||||
"n_income": income_data.get("n_income", 0),
|
||||
"n_income_attr_p": income_data.get("n_income_attr_p", 0),
|
||||
"basic_eps": income_data.get("basic_eps", 0)
|
||||
}
|
||||
|
||||
def _process_cash_flow(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理现金流量表数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
cashflow_data = dict(zip(fields, latest))
|
||||
else:
|
||||
cashflow_data = latest
|
||||
|
||||
return {
|
||||
"n_cashflow_act": cashflow_data.get("n_cashflow_act", 0),
|
||||
"n_cashflow_inv_act": cashflow_data.get("n_cashflow_inv_act", 0),
|
||||
"n_cashflow_fin_act": cashflow_data.get("n_cashflow_fin_act", 0),
|
||||
"c_cash_equ_end_period": cashflow_data.get("c_cash_equ_end_period", 0)
|
||||
}
|
||||
|
||||
def _process_key_metrics(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理关键财务指标数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
metrics_data = dict(zip(fields, latest))
|
||||
else:
|
||||
metrics_data = latest
|
||||
|
||||
return {
|
||||
"pe": metrics_data.get("pe", 0),
|
||||
"pb": metrics_data.get("pb", 0),
|
||||
"ps": metrics_data.get("ps", 0),
|
||||
"roe": metrics_data.get("roe", 0),
|
||||
"roa": metrics_data.get("roa", 0),
|
||||
"gross_margin": metrics_data.get("gross_margin", 0),
|
||||
"debt_to_assets": metrics_data.get("debt_to_assets", 0)
|
||||
}
|
||||
|
||||
def _process_price_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理价格数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
# 取最新一天的数据
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
price_data = dict(zip(fields, latest))
|
||||
else:
|
||||
price_data = latest
|
||||
|
||||
return {
|
||||
"close": price_data.get("close", 0),
|
||||
"open": price_data.get("open", 0),
|
||||
"high": price_data.get("high", 0),
|
||||
"low": price_data.get("low", 0),
|
||||
"pre_close": price_data.get("pre_close", 0),
|
||||
"change": price_data.get("change", 0),
|
||||
"pct_chg": price_data.get("pct_chg", 0)
|
||||
}
|
||||
|
||||
def _process_volume_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理成交量数据"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
volume_data = dict(zip(fields, latest))
|
||||
else:
|
||||
volume_data = latest
|
||||
|
||||
return {
|
||||
"vol": volume_data.get("vol", 0),
|
||||
"amount": volume_data.get("amount", 0)
|
||||
}
|
||||
|
||||
def _calculate_technical_indicators(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""计算技术指标"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items or len(items) < 20:
|
||||
return {}
|
||||
|
||||
# 简单的移动平均计算
|
||||
closes = []
|
||||
for item in items[:20]: # 取最近20天
|
||||
if isinstance(item, list):
|
||||
fields = data.get("fields", [])
|
||||
close_idx = fields.index("close") if "close" in fields else -1
|
||||
if close_idx >= 0:
|
||||
closes.append(item[close_idx])
|
||||
else:
|
||||
closes.append(item.get("close", 0))
|
||||
|
||||
if len(closes) >= 5:
|
||||
ma_5 = sum(closes[:5]) / 5
|
||||
else:
|
||||
ma_5 = 0
|
||||
|
||||
if len(closes) >= 20:
|
||||
ma_20 = sum(closes) / 20
|
||||
else:
|
||||
ma_20 = 0
|
||||
|
||||
return {
|
||||
"ma_5": ma_5,
|
||||
"ma_20": ma_20,
|
||||
"ma_60": 0 # 需要更多数据计算
|
||||
}
|
||||
|
||||
def _process_stock_basic(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理股票基本信息"""
|
||||
if not data or not data.get("items"):
|
||||
return {}
|
||||
|
||||
items = data["items"]
|
||||
if not items:
|
||||
return {}
|
||||
|
||||
latest = items[0] if isinstance(items[0], list) else items
|
||||
fields = data.get("fields", [])
|
||||
|
||||
if isinstance(latest, list) and fields:
|
||||
basic_data = dict(zip(fields, latest))
|
||||
else:
|
||||
basic_data = latest
|
||||
|
||||
return {
|
||||
"name": basic_data.get("name", ""),
|
||||
"industry": basic_data.get("industry", ""),
|
||||
"market": basic_data.get("market", ""),
|
||||
"list_date": basic_data.get("list_date", "")
|
||||
}
|
||||
|
||||
|
||||
class YahooDataFetcher(DataFetcher):
|
||||
"""Yahoo Finance数据获取器"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.base_url = config.get("base_url", "https://query1.finance.yahoo.com")
|
||||
|
||||
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
|
||||
"""获取财务数据"""
|
||||
try:
|
||||
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的Yahoo Finance API调用
|
||||
financial_data = await self._retry_request(self._fetch_yahoo_financial, yahoo_symbol)
|
||||
|
||||
return FinancialDataResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
data_source="yahoo",
|
||||
last_updated=datetime.now(),
|
||||
balance_sheet=financial_data.get("balance_sheet"),
|
||||
income_statement=financial_data.get("income_statement"),
|
||||
cash_flow=financial_data.get("cash_flow"),
|
||||
key_metrics=financial_data.get("key_metrics")
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (DataSourceError, APIError)):
|
||||
raise
|
||||
raise DataSourceError(f"获取财务数据失败: {str(e)}", "yahoo")
|
||||
|
||||
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
|
||||
"""获取市场数据"""
|
||||
try:
|
||||
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的Yahoo Finance API调用
|
||||
market_data = await self._retry_request(self._fetch_yahoo_market, yahoo_symbol)
|
||||
|
||||
return MarketDataResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
data_source="yahoo",
|
||||
last_updated=datetime.now(),
|
||||
price_data=market_data.get("price_data"),
|
||||
volume_data=market_data.get("volume_data"),
|
||||
technical_indicators=market_data.get("technical_indicators")
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, (DataSourceError, APIError)):
|
||||
raise
|
||||
raise DataSourceError(f"获取市场数据失败: {str(e)}", "yahoo")
|
||||
|
||||
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
|
||||
"""验证证券代码"""
|
||||
try:
|
||||
yahoo_symbol = self._convert_symbol_format(symbol, market)
|
||||
|
||||
# TODO: 实现实际的证券代码验证
|
||||
is_valid = await self._retry_request(self._validate_yahoo_symbol, yahoo_symbol)
|
||||
|
||||
return SymbolValidationResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
is_valid=is_valid,
|
||||
company_name="Example Company" if is_valid else None,
|
||||
message="Symbol is valid" if is_valid else "Symbol not found"
|
||||
)
|
||||
except Exception as e:
|
||||
return SymbolValidationResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
is_valid=False,
|
||||
message=f"Validation failed: {str(e)}"
|
||||
)
|
||||
|
||||
async def _health_check(self):
|
||||
"""健康检查"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(f"{self.base_url}/v1/finance/search?q=AAPL")
|
||||
if response.status_code != 200:
|
||||
raise APIError(f"Yahoo Finance API返回状态码: {response.status_code}", response.status_code, "yahoo")
|
||||
|
||||
def _convert_symbol_format(self, symbol: str, market: str) -> str:
|
||||
"""转换证券代码格式为Yahoo Finance格式"""
|
||||
if market.lower() == "hongkong":
|
||||
return f"{symbol}.HK"
|
||||
elif market.lower() == "japan":
|
||||
return f"{symbol}.T"
|
||||
elif market.lower() == "china":
|
||||
# 中国股票在Yahoo Finance中的格式
|
||||
if symbol.startswith("6"):
|
||||
return f"{symbol}.SS" # 上海
|
||||
elif symbol.startswith(("0", "3")):
|
||||
return f"{symbol}.SZ" # 深圳
|
||||
return symbol
|
||||
|
||||
async def _fetch_yahoo_financial(self, yahoo_symbol: str) -> Dict[str, Any]:
|
||||
"""获取Yahoo Finance财务数据"""
|
||||
# TODO: 实现实际的API调用
|
||||
return {
|
||||
"balance_sheet": {"totalAssets": 2000000},
|
||||
"income_statement": {"totalRevenue": 800000},
|
||||
"cash_flow": {"operatingCashflow": 300000},
|
||||
"key_metrics": {"trailingPE": 18.5}
|
||||
}
|
||||
|
||||
async def _fetch_yahoo_market(self, yahoo_symbol: str) -> Dict[str, Any]:
|
||||
"""获取Yahoo Finance市场数据"""
|
||||
# TODO: 实现实际的API调用
|
||||
return {
|
||||
"price_data": {"regularMarketPrice": 150.0, "dayHigh": 155.0, "dayLow": 145.0},
|
||||
"volume_data": {"regularMarketVolume": 2000000},
|
||||
"technical_indicators": {"fiftyDayAverage": 148.0, "twoHundredDayAverage": 152.0}
|
||||
}
|
||||
|
||||
async def _validate_yahoo_symbol(self, yahoo_symbol: str) -> bool:
|
||||
"""验证Yahoo Finance证券代码"""
|
||||
# TODO: 实现实际的验证逻辑
|
||||
return True
|
||||
|
||||
|
||||
class DataFetcherFactory:
|
||||
"""数据获取器工厂"""
|
||||
|
||||
_fetchers = {
|
||||
"tushare": TushareDataFetcher,
|
||||
"yahoo": YahooDataFetcher,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_fetcher(cls, data_source: str, config: Dict[str, Any]) -> DataFetcher:
|
||||
"""创建数据获取器"""
|
||||
data_source_lower = data_source.lower()
|
||||
|
||||
if data_source_lower not in cls._fetchers:
|
||||
raise DataSourceError(
|
||||
f"不支持的数据源: {data_source}",
|
||||
data_source,
|
||||
{"supported_sources": list(cls._fetchers.keys())}
|
||||
)
|
||||
|
||||
fetcher_class = cls._fetchers[data_source_lower]
|
||||
return fetcher_class(config)
|
||||
|
||||
@classmethod
|
||||
def get_supported_sources(cls) -> list:
|
||||
"""获取支持的数据源列表"""
|
||||
return list(cls._fetchers.keys())
|
||||
|
||||
@classmethod
|
||||
def register_fetcher(cls, name: str, fetcher_class: type):
|
||||
"""注册新的数据获取器"""
|
||||
if not issubclass(fetcher_class, DataFetcher):
|
||||
raise ValueError("数据获取器必须继承自DataFetcher基类")
|
||||
cls._fetchers[name.lower()] = fetcher_class
|
||||
357
backend/app/services/data_source_manager.py
Normal file
357
backend/app/services/data_source_manager.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""
|
||||
数据源管理服务
|
||||
处理数据源配置和切换逻辑
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from .data_fetcher import DataFetcher, DataFetcherFactory
|
||||
from .ai_analyzer import GeminiAnalyzer, AIAnalyzerFactory
|
||||
from ..core.exceptions import (
|
||||
DataSourceError,
|
||||
ConfigurationError,
|
||||
AIAnalysisError
|
||||
)
|
||||
from ..schemas.data import (
|
||||
FinancialDataResponse,
|
||||
MarketDataResponse,
|
||||
SymbolValidationResponse,
|
||||
DataSourceStatus,
|
||||
DataSourcesStatusResponse
|
||||
)
|
||||
from ..schemas.report import AIAnalysisResponse
|
||||
|
||||
|
||||
class DataSourceManager:
|
||||
"""数据源管理器"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self._data_fetchers: Dict[str, DataFetcher] = {}
|
||||
self._ai_analyzer: Optional[GeminiAnalyzer] = None
|
||||
self._market_source_mapping = {
|
||||
"china": "tushare",
|
||||
"中国": "tushare",
|
||||
"hongkong": "yahoo",
|
||||
"香港": "yahoo",
|
||||
"usa": "yahoo",
|
||||
"美国": "yahoo",
|
||||
"japan": "yahoo",
|
||||
"日本": "yahoo"
|
||||
}
|
||||
|
||||
# 初始化数据获取器
|
||||
self._initialize_data_fetchers()
|
||||
|
||||
# 初始化AI分析器
|
||||
self._initialize_ai_analyzer()
|
||||
|
||||
def _initialize_data_fetchers(self):
|
||||
"""初始化数据获取器"""
|
||||
data_sources_config = self.config.get("data_sources", {})
|
||||
|
||||
for source_name, source_config in data_sources_config.items():
|
||||
try:
|
||||
if source_config.get("enabled", True):
|
||||
fetcher = DataFetcherFactory.create_fetcher(source_name, source_config)
|
||||
self._data_fetchers[source_name] = fetcher
|
||||
except Exception as e:
|
||||
print(f"警告: 初始化数据源 {source_name} 失败: {str(e)}")
|
||||
|
||||
def _initialize_ai_analyzer(self):
|
||||
"""初始化AI分析器"""
|
||||
ai_config = self.config.get("ai_services", {})
|
||||
gemini_config = ai_config.get("gemini", {})
|
||||
|
||||
if gemini_config.get("enabled", True) and gemini_config.get("api_key"):
|
||||
try:
|
||||
self._ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
|
||||
gemini_config["api_key"],
|
||||
gemini_config
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"警告: 初始化Gemini分析器失败: {str(e)}")
|
||||
|
||||
def get_data_source_for_market(self, market: str) -> str:
|
||||
"""根据市场获取数据源"""
|
||||
market_lower = market.lower()
|
||||
|
||||
# 首先检查配置中的映射
|
||||
market_mapping = self.config.get("market_mapping", {})
|
||||
if market_lower in market_mapping:
|
||||
return market_mapping[market_lower]
|
||||
|
||||
# 使用默认映射
|
||||
return self._market_source_mapping.get(market_lower, "tushare")
|
||||
|
||||
def get_data_fetcher(self, data_source: str) -> DataFetcher:
|
||||
"""获取数据获取器"""
|
||||
if data_source not in self._data_fetchers:
|
||||
raise DataSourceError(f"数据源 {data_source} 未配置或不可用", data_source)
|
||||
|
||||
return self._data_fetchers[data_source]
|
||||
|
||||
def get_ai_analyzer(self) -> GeminiAnalyzer:
|
||||
"""获取AI分析器"""
|
||||
if not self._ai_analyzer:
|
||||
raise AIAnalysisError("AI分析器未配置或不可用", "gemini")
|
||||
|
||||
return self._ai_analyzer
|
||||
|
||||
async def fetch_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
|
||||
"""获取财务数据(支持数据源切换)"""
|
||||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||||
|
||||
try:
|
||||
fetcher = self.get_data_fetcher(data_source)
|
||||
return await fetcher.fetch_financial_data(symbol, market)
|
||||
except DataSourceError as e:
|
||||
# 尝试备用数据源
|
||||
fallback_sources = self._get_fallback_sources(data_source)
|
||||
|
||||
for fallback_source in fallback_sources:
|
||||
try:
|
||||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||||
return await fallback_fetcher.fetch_financial_data(symbol, market)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 所有数据源都失败了
|
||||
raise e
|
||||
|
||||
async def fetch_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
|
||||
"""获取市场数据(支持数据源切换)"""
|
||||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||||
|
||||
try:
|
||||
fetcher = self.get_data_fetcher(data_source)
|
||||
return await fetcher.fetch_market_data(symbol, market)
|
||||
except DataSourceError as e:
|
||||
# 尝试备用数据源
|
||||
fallback_sources = self._get_fallback_sources(data_source)
|
||||
|
||||
for fallback_source in fallback_sources:
|
||||
try:
|
||||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||||
return await fallback_fetcher.fetch_market_data(symbol, market)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 所有数据源都失败了
|
||||
raise e
|
||||
|
||||
async def validate_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
|
||||
"""验证证券代码(支持数据源切换)"""
|
||||
data_source = preferred_source or self.get_data_source_for_market(market)
|
||||
|
||||
try:
|
||||
fetcher = self.get_data_fetcher(data_source)
|
||||
return await fetcher.validate_symbol(symbol, market)
|
||||
except DataSourceError as e:
|
||||
# 尝试备用数据源
|
||||
fallback_sources = self._get_fallback_sources(data_source)
|
||||
|
||||
for fallback_source in fallback_sources:
|
||||
try:
|
||||
fallback_fetcher = self.get_data_fetcher(fallback_source)
|
||||
return await fallback_fetcher.validate_symbol(symbol, market)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 所有数据源都失败了
|
||||
raise e
|
||||
|
||||
async def analyze_with_ai(self, analysis_type: str, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""使用AI进行分析"""
|
||||
analyzer = self.get_ai_analyzer()
|
||||
|
||||
if analysis_type == "business_info":
|
||||
return await analyzer.analyze_business_info(symbol, market, context_data)
|
||||
elif analysis_type == "fundamental_analysis":
|
||||
business_info = context_data.get("business_info", {})
|
||||
financial_data = context_data.get("financial_data", {})
|
||||
return await analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
|
||||
elif analysis_type == "bullish_analysis":
|
||||
return await analyzer.analyze_bullish_case(symbol, market, context_data)
|
||||
elif analysis_type == "bearish_analysis":
|
||||
return await analyzer.analyze_bearish_case(symbol, market, context_data)
|
||||
elif analysis_type == "market_analysis":
|
||||
return await analyzer.analyze_market_sentiment(symbol, market, context_data)
|
||||
elif analysis_type == "news_analysis":
|
||||
return await analyzer.analyze_news_catalysts(symbol, market, context_data)
|
||||
elif analysis_type == "trading_analysis":
|
||||
return await analyzer.analyze_trading_dynamics(symbol, market, context_data)
|
||||
elif analysis_type == "insider_analysis":
|
||||
return await analyzer.analyze_insider_institutional(symbol, market, context_data)
|
||||
elif analysis_type == "final_conclusion":
|
||||
all_analyses = context_data.get("all_analyses", [])
|
||||
return await analyzer.generate_final_conclusion(symbol, market, all_analyses)
|
||||
else:
|
||||
raise AIAnalysisError(f"不支持的分析类型: {analysis_type}", "gemini")
|
||||
|
||||
async def check_all_sources_status(self) -> DataSourcesStatusResponse:
|
||||
"""检查所有数据源状态"""
|
||||
status_tasks = []
|
||||
|
||||
# 检查数据获取器状态
|
||||
for source_name, fetcher in self._data_fetchers.items():
|
||||
status_tasks.append(fetcher.check_status())
|
||||
|
||||
# 检查AI分析器状态
|
||||
if self._ai_analyzer:
|
||||
status_tasks.append(self._check_ai_analyzer_status())
|
||||
|
||||
# 并发执行状态检查
|
||||
statuses = await asyncio.gather(*status_tasks, return_exceptions=True)
|
||||
|
||||
source_statuses = []
|
||||
healthy_count = 0
|
||||
|
||||
for i, status in enumerate(statuses):
|
||||
if isinstance(status, Exception):
|
||||
# 处理异常情况
|
||||
if i < len(self._data_fetchers):
|
||||
source_name = list(self._data_fetchers.keys())[i]
|
||||
else:
|
||||
source_name = "gemini"
|
||||
|
||||
source_statuses.append(DataSourceStatus(
|
||||
name=source_name,
|
||||
is_available=False,
|
||||
last_check=datetime.now(),
|
||||
error_message=str(status)
|
||||
))
|
||||
else:
|
||||
source_statuses.append(status)
|
||||
if status.is_available:
|
||||
healthy_count += 1
|
||||
|
||||
# 确定整体状态
|
||||
total_sources = len(source_statuses)
|
||||
if healthy_count == total_sources:
|
||||
overall_status = "healthy"
|
||||
elif healthy_count > 0:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "down"
|
||||
|
||||
return DataSourcesStatusResponse(
|
||||
sources=source_statuses,
|
||||
overall_status=overall_status
|
||||
)
|
||||
|
||||
async def _check_ai_analyzer_status(self) -> DataSourceStatus:
|
||||
"""检查AI分析器状态"""
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
# 简单的健康检查 - 尝试生成一个很短的测试内容
|
||||
test_prompt = "请回答:1+1等于几?"
|
||||
await self._ai_analyzer._call_gemini_api(test_prompt)
|
||||
|
||||
end_time = datetime.now()
|
||||
response_time = int((end_time - start_time).total_seconds() * 1000)
|
||||
|
||||
return DataSourceStatus(
|
||||
name="gemini",
|
||||
is_available=True,
|
||||
last_check=end_time,
|
||||
response_time_ms=response_time
|
||||
)
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
return DataSourceStatus(
|
||||
name="gemini",
|
||||
is_available=False,
|
||||
last_check=end_time,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
def _get_fallback_sources(self, primary_source: str) -> List[str]:
|
||||
"""获取备用数据源列表"""
|
||||
fallback_config = self.config.get("fallback_sources", {})
|
||||
|
||||
if primary_source in fallback_config:
|
||||
return fallback_config[primary_source]
|
||||
|
||||
# 默认备用策略
|
||||
all_sources = list(self._data_fetchers.keys())
|
||||
return [source for source in all_sources if source != primary_source]
|
||||
|
||||
def update_config(self, new_config: Dict[str, Any]):
|
||||
"""更新配置"""
|
||||
self.config.update(new_config)
|
||||
|
||||
# 重新初始化数据获取器
|
||||
self._data_fetchers.clear()
|
||||
self._initialize_data_fetchers()
|
||||
|
||||
# 重新初始化AI分析器
|
||||
self._ai_analyzer = None
|
||||
self._initialize_ai_analyzer()
|
||||
|
||||
def get_supported_sources(self) -> List[str]:
|
||||
"""获取支持的数据源列表"""
|
||||
return DataFetcherFactory.get_supported_sources()
|
||||
|
||||
def get_available_sources(self) -> List[str]:
|
||||
"""获取当前可用的数据源列表"""
|
||||
return list(self._data_fetchers.keys())
|
||||
|
||||
def is_ai_analyzer_available(self) -> bool:
|
||||
"""检查AI分析器是否可用"""
|
||||
return self._ai_analyzer is not None
|
||||
|
||||
|
||||
def create_data_source_manager(config: Dict[str, Any]) -> DataSourceManager:
|
||||
"""创建数据源管理器"""
|
||||
return DataSourceManager(config)
|
||||
|
||||
|
||||
# 默认配置示例
|
||||
DEFAULT_CONFIG = {
|
||||
"data_sources": {
|
||||
"tushare": {
|
||||
"enabled": True,
|
||||
"api_key": "", # 需要从环境变量或配置文件获取
|
||||
"base_url": "http://api.tushare.pro",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
},
|
||||
"yahoo": {
|
||||
"enabled": True,
|
||||
"base_url": "https://query1.finance.yahoo.com",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
}
|
||||
},
|
||||
"ai_services": {
|
||||
"gemini": {
|
||||
"enabled": True,
|
||||
"api_key": "", # 需要从环境变量或配置文件获取
|
||||
"model": "gemini-pro",
|
||||
"timeout": 60,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 2,
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 8192
|
||||
}
|
||||
},
|
||||
"market_mapping": {
|
||||
"china": "tushare",
|
||||
"中国": "tushare",
|
||||
"hongkong": "yahoo",
|
||||
"香港": "yahoo",
|
||||
"usa": "yahoo",
|
||||
"美国": "yahoo",
|
||||
"japan": "yahoo",
|
||||
"日本": "yahoo"
|
||||
},
|
||||
"fallback_sources": {
|
||||
"tushare": ["yahoo"],
|
||||
"yahoo": ["tushare"]
|
||||
}
|
||||
}
|
||||
322
backend/app/services/external_api_service.py
Normal file
322
backend/app/services/external_api_service.py
Normal file
@ -0,0 +1,322 @@
|
||||
"""
|
||||
外部API集成服务
|
||||
统一管理所有外部API调用和数据源切换
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from .data_source_manager import DataSourceManager, create_data_source_manager
|
||||
from ..core.config import api_config
|
||||
from ..core.exceptions import (
|
||||
DataSourceError,
|
||||
AIAnalysisError,
|
||||
ConfigurationError
|
||||
)
|
||||
from ..schemas.data import (
|
||||
FinancialDataResponse,
|
||||
MarketDataResponse,
|
||||
SymbolValidationResponse,
|
||||
DataSourcesStatusResponse
|
||||
)
|
||||
from ..schemas.report import AIAnalysisResponse
|
||||
|
||||
|
||||
class ExternalAPIService:
|
||||
"""外部API服务"""
|
||||
|
||||
def __init__(self):
|
||||
self._data_source_manager: Optional[DataSourceManager] = None
|
||||
self._initialize_manager()
|
||||
|
||||
def _initialize_manager(self):
|
||||
"""初始化数据源管理器"""
|
||||
try:
|
||||
config = api_config.get_data_source_manager_config()
|
||||
self._data_source_manager = create_data_source_manager(config)
|
||||
except Exception as e:
|
||||
print(f"警告: 初始化数据源管理器失败: {str(e)}")
|
||||
|
||||
async def get_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
|
||||
"""获取财务数据"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.fetch_financial_data(symbol, market, preferred_source)
|
||||
except Exception as e:
|
||||
raise DataSourceError(f"获取财务数据失败: {str(e)}", preferred_source or "unknown")
|
||||
|
||||
async def get_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
|
||||
"""获取市场数据"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.fetch_market_data(symbol, market, preferred_source)
|
||||
except Exception as e:
|
||||
raise DataSourceError(f"获取市场数据失败: {str(e)}", preferred_source or "unknown")
|
||||
|
||||
async def validate_stock_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
|
||||
"""验证股票代码"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.validate_symbol(symbol, market, preferred_source)
|
||||
except Exception as e:
|
||||
# 验证失败时返回无效结果而不是抛出异常
|
||||
return SymbolValidationResponse(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
is_valid=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def analyze_business_info(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""分析公司业务信息"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"business_info", symbol, market, {"financial_data": financial_data}
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"业务信息分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_fundamental(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""基本面分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
context_data = {
|
||||
"financial_data": financial_data,
|
||||
"business_info": business_info
|
||||
}
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"fundamental_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"基本面分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_bullish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""看涨分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"bullish_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"看涨分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_bearish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""看跌分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"bearish_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"看跌分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_market_sentiment(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""市场情绪分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"market_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"市场分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_news_catalysts(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""新闻催化剂分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"news_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"新闻分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_trading_dynamics(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""交易动态分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"trading_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"交易分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def analyze_insider_institutional(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
|
||||
"""内部人与机构动向分析"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"insider_analysis", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"内部人分析失败: {str(e)}", "gemini")
|
||||
|
||||
async def generate_final_conclusion(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> AIAnalysisResponse:
|
||||
"""生成最终结论"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
if not self._data_source_manager.is_ai_analyzer_available():
|
||||
raise AIAnalysisError("AI分析器不可用", "gemini")
|
||||
|
||||
try:
|
||||
context_data = {"all_analyses": all_analyses}
|
||||
return await self._data_source_manager.analyze_with_ai(
|
||||
"final_conclusion", symbol, market, context_data
|
||||
)
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"最终结论生成失败: {str(e)}", "gemini")
|
||||
|
||||
async def check_all_services_status(self) -> DataSourcesStatusResponse:
|
||||
"""检查所有外部服务状态"""
|
||||
if not self._data_source_manager:
|
||||
raise ConfigurationError("数据源管理器未初始化")
|
||||
|
||||
try:
|
||||
return await self._data_source_manager.check_all_sources_status()
|
||||
except Exception as e:
|
||||
raise DataSourceError(f"检查服务状态失败: {str(e)}")
|
||||
|
||||
def get_supported_data_sources(self) -> List[str]:
|
||||
"""获取支持的数据源列表"""
|
||||
if not self._data_source_manager:
|
||||
return []
|
||||
|
||||
return self._data_source_manager.get_supported_sources()
|
||||
|
||||
def get_available_data_sources(self) -> List[str]:
|
||||
"""获取当前可用的数据源列表"""
|
||||
if not self._data_source_manager:
|
||||
return []
|
||||
|
||||
return self._data_source_manager.get_available_sources()
|
||||
|
||||
def is_ai_service_available(self) -> bool:
|
||||
"""检查AI服务是否可用"""
|
||||
if not self._data_source_manager:
|
||||
return False
|
||||
|
||||
return self._data_source_manager.is_ai_analyzer_available()
|
||||
|
||||
def get_data_source_for_market(self, market: str) -> str:
|
||||
"""根据市场获取推荐的数据源"""
|
||||
if not self._data_source_manager:
|
||||
return "tushare" # 默认值
|
||||
|
||||
return self._data_source_manager.get_data_source_for_market(market)
|
||||
|
||||
def update_configuration(self, new_config: Dict[str, Any]):
|
||||
"""更新配置"""
|
||||
if self._data_source_manager:
|
||||
self._data_source_manager.update_config(new_config)
|
||||
else:
|
||||
# 如果管理器未初始化,尝试重新初始化
|
||||
self._initialize_manager()
|
||||
|
||||
async def test_data_source_connection(self, data_source: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""测试数据源连接"""
|
||||
try:
|
||||
# 创建临时的数据获取器进行测试
|
||||
from .data_fetcher import DataFetcherFactory
|
||||
|
||||
test_fetcher = DataFetcherFactory.create_fetcher(data_source, config)
|
||||
status = await test_fetcher.check_status()
|
||||
|
||||
return {
|
||||
"success": status.is_available,
|
||||
"response_time_ms": status.response_time_ms,
|
||||
"error_message": status.error_message
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error_message": str(e)
|
||||
}
|
||||
|
||||
async def test_ai_service_connection(self, service_type: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""测试AI服务连接"""
|
||||
try:
|
||||
if service_type.lower() == "gemini":
|
||||
from .ai_analyzer import AIAnalyzerFactory
|
||||
|
||||
test_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
|
||||
config.get("api_key"), config
|
||||
)
|
||||
|
||||
# 简单的测试调用
|
||||
start_time = datetime.now()
|
||||
await test_analyzer._call_gemini_api("测试连接,请回答:OK")
|
||||
end_time = datetime.now()
|
||||
|
||||
response_time = int((end_time - start_time).total_seconds() * 1000)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"response_time_ms": response_time
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error_message": f"不支持的AI服务类型: {service_type}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error_message": str(e)
|
||||
}
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
external_api_service = ExternalAPIService()
|
||||
|
||||
|
||||
def get_external_api_service() -> ExternalAPIService:
|
||||
"""获取外部API服务实例"""
|
||||
return external_api_service
|
||||
163
backend/app/services/progress_tracker.py
Normal file
163
backend/app/services/progress_tracker.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""
|
||||
进度追踪服务
|
||||
处理报告生成进度的追踪和管理
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from ..models.progress_tracking import ProgressTracking
|
||||
from ..schemas.progress import ProgressResponse, StepTiming
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
"""进度追踪器"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def initialize_progress(self, report_id: UUID):
|
||||
"""初始化进度追踪"""
|
||||
|
||||
# 定义报告生成步骤
|
||||
steps = [
|
||||
"初始化报告",
|
||||
"获取财务数据",
|
||||
"生成业务信息",
|
||||
"执行基本面分析",
|
||||
"执行看涨分析",
|
||||
"执行看跌分析",
|
||||
"执行市场分析",
|
||||
"执行新闻分析",
|
||||
"执行交易分析",
|
||||
"执行内部人分析",
|
||||
"生成最终结论",
|
||||
"保存报告"
|
||||
]
|
||||
|
||||
# 创建进度记录
|
||||
for i, step_name in enumerate(steps, 1):
|
||||
progress = ProgressTracking(
|
||||
report_id=report_id,
|
||||
step_name=step_name,
|
||||
step_order=i,
|
||||
status="pending"
|
||||
)
|
||||
self.db.add(progress)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
async def start_step(self, report_id: UUID, step_name: str):
|
||||
"""开始执行步骤"""
|
||||
result = await self.db.execute(
|
||||
select(ProgressTracking).where(
|
||||
ProgressTracking.report_id == report_id,
|
||||
ProgressTracking.step_name == step_name
|
||||
)
|
||||
)
|
||||
progress = result.scalar_one_or_none()
|
||||
|
||||
if progress:
|
||||
progress.status = "running"
|
||||
progress.started_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
|
||||
async def complete_step(self, report_id: UUID, step_name: str, success: bool = True, error_message: Optional[str] = None):
|
||||
"""完成步骤"""
|
||||
result = await self.db.execute(
|
||||
select(ProgressTracking).where(
|
||||
ProgressTracking.report_id == report_id,
|
||||
ProgressTracking.step_name == step_name
|
||||
)
|
||||
)
|
||||
progress = result.scalar_one_or_none()
|
||||
|
||||
if progress:
|
||||
progress.status = "completed" if success else "failed"
|
||||
progress.completed_at = datetime.utcnow()
|
||||
progress.error_message = error_message
|
||||
|
||||
# 计算耗时
|
||||
if progress.started_at:
|
||||
duration = progress.completed_at - progress.started_at
|
||||
progress.duration_ms = int(duration.total_seconds() * 1000)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
async def get_progress(self, report_id: UUID) -> ProgressResponse:
|
||||
"""获取进度信息"""
|
||||
result = await self.db.execute(
|
||||
select(ProgressTracking)
|
||||
.where(ProgressTracking.report_id == report_id)
|
||||
.order_by(ProgressTracking.step_order)
|
||||
)
|
||||
progress_records = result.scalars().all()
|
||||
|
||||
if not progress_records:
|
||||
raise ValueError(f"未找到报告 {report_id} 的进度信息")
|
||||
|
||||
# 计算当前步骤
|
||||
current_step = 1
|
||||
current_step_name = "初始化报告"
|
||||
overall_status = "running"
|
||||
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for record in progress_records:
|
||||
if record.status == "completed":
|
||||
completed_count += 1
|
||||
elif record.status == "failed":
|
||||
failed_count += 1
|
||||
elif record.status == "running":
|
||||
current_step = record.step_order
|
||||
current_step_name = record.step_name
|
||||
|
||||
# 确定整体状态
|
||||
if failed_count > 0:
|
||||
overall_status = "failed"
|
||||
elif completed_count == len(progress_records):
|
||||
overall_status = "completed"
|
||||
|
||||
# 转换为StepTiming对象
|
||||
step_timings = [
|
||||
StepTiming(
|
||||
step_name=record.step_name,
|
||||
step_order=record.step_order,
|
||||
status=record.status,
|
||||
started_at=record.started_at,
|
||||
completed_at=record.completed_at,
|
||||
duration_ms=record.duration_ms,
|
||||
error_message=record.error_message
|
||||
)
|
||||
for record in progress_records
|
||||
]
|
||||
|
||||
return ProgressResponse(
|
||||
report_id=report_id,
|
||||
current_step=current_step,
|
||||
total_steps=len(progress_records),
|
||||
current_step_name=current_step_name,
|
||||
status=overall_status,
|
||||
step_timings=step_timings,
|
||||
estimated_remaining=self._estimate_remaining_time(step_timings)
|
||||
)
|
||||
|
||||
def _estimate_remaining_time(self, step_timings: List[StepTiming]) -> Optional[int]:
|
||||
"""估算剩余时间"""
|
||||
# 计算已完成步骤的平均耗时
|
||||
completed_durations = [
|
||||
timing.duration_ms for timing in step_timings
|
||||
if timing.status == "completed" and timing.duration_ms
|
||||
]
|
||||
|
||||
if not completed_durations:
|
||||
return None
|
||||
|
||||
avg_duration_ms = sum(completed_durations) / len(completed_durations)
|
||||
remaining_steps = len([t for t in step_timings if t.status == "pending"])
|
||||
|
||||
return int((avg_duration_ms * remaining_steps) / 1000) # 转换为秒
|
||||
643
backend/app/services/report_generator.py
Normal file
643
backend/app/services/report_generator.py
Normal file
@ -0,0 +1,643 @@
|
||||
"""
|
||||
报告生成服务
|
||||
处理股票基本面分析报告的生成和管理
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ..models.report import Report
|
||||
from ..models.analysis_module import AnalysisModule
|
||||
from ..schemas.report import ReportResponse, AnalysisModuleSchema
|
||||
from ..core.exceptions import (
|
||||
ReportGenerationError,
|
||||
DataSourceError,
|
||||
AIAnalysisError,
|
||||
DatabaseError
|
||||
)
|
||||
from .progress_tracker import ProgressTracker
|
||||
from .data_fetcher import DataFetcherFactory
|
||||
from .ai_analyzer import AIAnalyzerFactory
|
||||
from .config_manager import ConfigManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalysisModuleType:
|
||||
"""分析模块类型常量"""
|
||||
TRADING_VIEW_CHART = "trading_view_chart"
|
||||
FINANCIAL_DATA = "financial_data"
|
||||
BUSINESS_INFO = "business_info"
|
||||
FUNDAMENTAL_ANALYSIS = "fundamental_analysis"
|
||||
BULLISH_ANALYSIS = "bullish_analysis"
|
||||
BEARISH_ANALYSIS = "bearish_analysis"
|
||||
MARKET_ANALYSIS = "market_analysis"
|
||||
NEWS_ANALYSIS = "news_analysis"
|
||||
TRADING_ANALYSIS = "trading_analysis"
|
||||
INSIDER_ANALYSIS = "insider_analysis"
|
||||
FINAL_CONCLUSION = "final_conclusion"
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
"""报告生成器"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession, config_manager: ConfigManager):
|
||||
self.db = db_session
|
||||
self.config_manager = config_manager
|
||||
self.progress_tracker = ProgressTracker(db_session)
|
||||
|
||||
# 定义分析模块配置
|
||||
self.analysis_modules = [
|
||||
{
|
||||
"type": AnalysisModuleType.TRADING_VIEW_CHART,
|
||||
"title": "TradingView图表",
|
||||
"order": 1,
|
||||
"step_name": "获取财务数据"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.FINANCIAL_DATA,
|
||||
"title": "财务数据分析",
|
||||
"order": 2,
|
||||
"step_name": "获取财务数据"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.BUSINESS_INFO,
|
||||
"title": "业务信息分析",
|
||||
"order": 3,
|
||||
"step_name": "生成业务信息"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.FUNDAMENTAL_ANALYSIS,
|
||||
"title": "基本面分析(景林模型)",
|
||||
"order": 4,
|
||||
"step_name": "执行基本面分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.BULLISH_ANALYSIS,
|
||||
"title": "看涨分析师观点",
|
||||
"order": 5,
|
||||
"step_name": "执行看涨分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.BEARISH_ANALYSIS,
|
||||
"title": "看跌分析师观点",
|
||||
"order": 6,
|
||||
"step_name": "执行看跌分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.MARKET_ANALYSIS,
|
||||
"title": "市场分析师观点",
|
||||
"order": 7,
|
||||
"step_name": "执行市场分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.NEWS_ANALYSIS,
|
||||
"title": "新闻分析师观点",
|
||||
"order": 8,
|
||||
"step_name": "执行新闻分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.TRADING_ANALYSIS,
|
||||
"title": "交易分析师观点",
|
||||
"order": 9,
|
||||
"step_name": "执行交易分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.INSIDER_ANALYSIS,
|
||||
"title": "内部人与机构动向分析",
|
||||
"order": 10,
|
||||
"step_name": "执行内部人分析"
|
||||
},
|
||||
{
|
||||
"type": AnalysisModuleType.FINAL_CONCLUSION,
|
||||
"title": "最终结论",
|
||||
"order": 11,
|
||||
"step_name": "生成最终结论"
|
||||
}
|
||||
]
|
||||
|
||||
async def generate_report(self, symbol: str, market: str, force_regenerate: bool = False) -> ReportResponse:
|
||||
"""生成股票分析报告"""
|
||||
try:
|
||||
# 检查是否存在现有报告
|
||||
existing_report = await self._get_existing_report(symbol, market)
|
||||
|
||||
if existing_report and not force_regenerate:
|
||||
if existing_report.status == "completed":
|
||||
logger.info(f"返回现有报告: {symbol} ({market})")
|
||||
return await self._build_report_response(existing_report)
|
||||
elif existing_report.status == "generating":
|
||||
logger.info(f"报告正在生成中: {symbol} ({market})")
|
||||
return await self._build_report_response(existing_report)
|
||||
|
||||
# 创建新报告或重新生成
|
||||
if existing_report and force_regenerate:
|
||||
report = existing_report
|
||||
report.status = "generating"
|
||||
report.updated_at = datetime.utcnow()
|
||||
# 清理现有的分析模块
|
||||
await self._cleanup_existing_modules(report.id)
|
||||
else:
|
||||
report = await self._create_new_report(symbol, market)
|
||||
|
||||
# 初始化进度追踪
|
||||
await self.progress_tracker.initialize_progress(report.id)
|
||||
|
||||
# 异步生成报告内容
|
||||
asyncio.create_task(self._generate_report_content(report))
|
||||
|
||||
return await self._build_report_response(report)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"报告生成失败: {symbol} ({market}) - {str(e)}")
|
||||
if isinstance(e, (ReportGenerationError, DataSourceError, AIAnalysisError)):
|
||||
raise
|
||||
raise ReportGenerationError(f"报告生成失败: {str(e)}")
|
||||
|
||||
async def get_report(self, symbol: str, market: str) -> Optional[ReportResponse]:
|
||||
"""获取现有报告"""
|
||||
try:
|
||||
report = await self._get_existing_report(symbol, market)
|
||||
if report:
|
||||
return await self._build_report_response(report)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告失败: {symbol} ({market}) - {str(e)}")
|
||||
raise ReportGenerationError(f"获取报告失败: {str(e)}")
|
||||
|
||||
async def get_report_by_id(self, report_id: UUID) -> Optional[ReportResponse]:
|
||||
"""根据ID获取报告"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(Report).where(Report.id == report_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
|
||||
if report:
|
||||
return await self._build_report_response(report)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告失败: {report_id} - {str(e)}")
|
||||
raise ReportGenerationError(f"获取报告失败: {str(e)}")
|
||||
|
||||
async def _get_existing_report(self, symbol: str, market: str) -> Optional[Report]:
|
||||
"""获取现有报告"""
|
||||
result = await self.db.execute(
|
||||
select(Report).where(
|
||||
Report.symbol == symbol,
|
||||
Report.market == market
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _create_new_report(self, symbol: str, market: str) -> Report:
|
||||
"""创建新报告"""
|
||||
report = Report(
|
||||
symbol=symbol,
|
||||
market=market,
|
||||
status="generating"
|
||||
)
|
||||
self.db.add(report)
|
||||
await self.db.flush()
|
||||
return report
|
||||
|
||||
async def _cleanup_existing_modules(self, report_id: UUID):
|
||||
"""清理现有的分析模块"""
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule).where(AnalysisModule.report_id == report_id)
|
||||
)
|
||||
modules = result.scalars().all()
|
||||
|
||||
for module in modules:
|
||||
await self.db.delete(module)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
async def _generate_report_content(self, report: Report):
|
||||
"""生成报告内容(异步执行)"""
|
||||
try:
|
||||
logger.info(f"开始生成报告内容: {report.symbol} ({report.market})")
|
||||
|
||||
# 开始初始化步骤
|
||||
await self.progress_tracker.start_step(report.id, "初始化报告")
|
||||
|
||||
# 创建分析模块记录
|
||||
await self._create_analysis_modules(report)
|
||||
|
||||
await self.progress_tracker.complete_step(report.id, "初始化报告", True)
|
||||
|
||||
# 获取配置
|
||||
data_source_config = await self.config_manager.get_data_source_config(report.market)
|
||||
gemini_config = await self.config_manager.get_gemini_config()
|
||||
|
||||
# 创建数据获取器和AI分析器
|
||||
data_fetcher = DataFetcherFactory.create_fetcher(
|
||||
data_source_config["type"],
|
||||
data_source_config
|
||||
)
|
||||
ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
|
||||
gemini_config["api_key"],
|
||||
gemini_config
|
||||
)
|
||||
|
||||
# 存储分析结果的上下文
|
||||
analysis_context = {}
|
||||
|
||||
# 按顺序执行各个分析模块
|
||||
for module_config in self.analysis_modules:
|
||||
try:
|
||||
await self._execute_analysis_module(
|
||||
report, module_config, data_fetcher, ai_analyzer, analysis_context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"分析模块执行失败: {module_config['type']} - {str(e)}")
|
||||
# 标记模块为失败,但继续执行其他模块
|
||||
await self._mark_module_failed(report.id, module_config["type"], str(e))
|
||||
|
||||
# 完成报告生成
|
||||
await self.progress_tracker.start_step(report.id, "保存报告")
|
||||
|
||||
report.status = "completed"
|
||||
report.updated_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
await self.progress_tracker.complete_step(report.id, "保存报告", True)
|
||||
|
||||
logger.info(f"报告生成完成: {report.symbol} ({report.market})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"报告生成过程失败: {report.symbol} ({report.market}) - {str(e)}")
|
||||
|
||||
# 标记报告为失败状态
|
||||
report.status = "failed"
|
||||
report.updated_at = datetime.utcnow()
|
||||
await self.db.commit()
|
||||
|
||||
# 标记当前步骤为失败
|
||||
try:
|
||||
progress = await self.progress_tracker.get_progress(report.id)
|
||||
current_step = progress.current_step_name
|
||||
await self.progress_tracker.complete_step(report.id, current_step, False, str(e))
|
||||
except Exception:
|
||||
pass # 忽略进度更新失败
|
||||
|
||||
async def _create_analysis_modules(self, report: Report):
|
||||
"""创建分析模块记录"""
|
||||
for module_config in self.analysis_modules:
|
||||
module = AnalysisModule(
|
||||
report_id=report.id,
|
||||
module_type=module_config["type"],
|
||||
module_order=module_config["order"],
|
||||
title=module_config["title"],
|
||||
status="pending"
|
||||
)
|
||||
self.db.add(module)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
async def _execute_analysis_module(
|
||||
self,
|
||||
report: Report,
|
||||
module_config: Dict[str, Any],
|
||||
data_fetcher,
|
||||
ai_analyzer,
|
||||
analysis_context: Dict[str, Any]
|
||||
):
|
||||
"""执行单个分析模块"""
|
||||
module_type = module_config["type"]
|
||||
step_name = module_config["step_name"]
|
||||
|
||||
logger.info(f"执行分析模块: {module_type}")
|
||||
|
||||
# 开始步骤
|
||||
await self.progress_tracker.start_step(report.id, step_name)
|
||||
|
||||
# 标记模块开始
|
||||
await self._mark_module_started(report.id, module_type)
|
||||
|
||||
try:
|
||||
# 根据模块类型执行相应的分析
|
||||
if module_type == AnalysisModuleType.FINANCIAL_DATA:
|
||||
content = await self._execute_financial_data_module(
|
||||
report.symbol, report.market, data_fetcher
|
||||
)
|
||||
analysis_context["financial_data"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.TRADING_VIEW_CHART:
|
||||
content = await self._execute_trading_view_module(
|
||||
report.symbol, report.market
|
||||
)
|
||||
|
||||
elif module_type == AnalysisModuleType.BUSINESS_INFO:
|
||||
content = await self._execute_business_info_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["business_info"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.FUNDAMENTAL_ANALYSIS:
|
||||
content = await self._execute_fundamental_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["fundamental_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.BULLISH_ANALYSIS:
|
||||
content = await self._execute_bullish_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["bullish_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.BEARISH_ANALYSIS:
|
||||
content = await self._execute_bearish_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["bearish_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.MARKET_ANALYSIS:
|
||||
content = await self._execute_market_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["market_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.NEWS_ANALYSIS:
|
||||
content = await self._execute_news_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["news_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.TRADING_ANALYSIS:
|
||||
content = await self._execute_trading_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["trading_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.INSIDER_ANALYSIS:
|
||||
content = await self._execute_insider_analysis_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["insider_analysis"] = content
|
||||
|
||||
elif module_type == AnalysisModuleType.FINAL_CONCLUSION:
|
||||
content = await self._execute_final_conclusion_module(
|
||||
report.symbol, report.market, ai_analyzer, analysis_context
|
||||
)
|
||||
analysis_context["final_conclusion"] = content
|
||||
|
||||
else:
|
||||
raise ReportGenerationError(f"未知的分析模块类型: {module_type}")
|
||||
|
||||
# 保存模块内容
|
||||
await self._save_module_content(report.id, module_type, content)
|
||||
|
||||
# 标记模块完成
|
||||
await self._mark_module_completed(report.id, module_type)
|
||||
|
||||
# 完成步骤
|
||||
await self.progress_tracker.complete_step(report.id, step_name, True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析模块执行失败: {module_type} - {str(e)}")
|
||||
|
||||
# 标记模块失败
|
||||
await self._mark_module_failed(report.id, module_type, str(e))
|
||||
|
||||
# 完成步骤(失败)
|
||||
await self.progress_tracker.complete_step(report.id, step_name, False, str(e))
|
||||
|
||||
raise e
|
||||
|
||||
async def _execute_financial_data_module(self, symbol: str, market: str, data_fetcher) -> Dict[str, Any]:
|
||||
"""执行财务数据分析模块"""
|
||||
try:
|
||||
# 获取财务数据
|
||||
financial_data = await data_fetcher.fetch_financial_data(symbol, market)
|
||||
|
||||
# 获取市场数据
|
||||
market_data = await data_fetcher.fetch_market_data(symbol, market)
|
||||
|
||||
return {
|
||||
"financial_data": financial_data.dict(),
|
||||
"market_data": market_data.dict(),
|
||||
"summary": {
|
||||
"data_source": financial_data.data_source,
|
||||
"last_updated": financial_data.last_updated.isoformat(),
|
||||
"data_quality": "good" # 可以添加数据质量评估逻辑
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
raise DataSourceError(f"财务数据获取失败: {str(e)}")
|
||||
|
||||
async def _execute_trading_view_module(self, symbol: str, market: str) -> Dict[str, Any]:
|
||||
"""执行TradingView图表模块"""
|
||||
# 生成TradingView图表配置
|
||||
return {
|
||||
"chart_config": {
|
||||
"symbol": symbol,
|
||||
"market": market,
|
||||
"interval": "1D",
|
||||
"theme": "light",
|
||||
"style": "1", # 蜡烛图
|
||||
"toolbar_bg": "#f1f3f6",
|
||||
"enable_publishing": False,
|
||||
"withdateranges": True,
|
||||
"hide_side_toolbar": False,
|
||||
"allow_symbol_change": False,
|
||||
"studies": [
|
||||
"MASimple@tv-basicstudies", # 移动平均线
|
||||
"Volume@tv-basicstudies" # 成交量
|
||||
]
|
||||
},
|
||||
"display_settings": {
|
||||
"width": "100%",
|
||||
"height": 500,
|
||||
"autosize": True
|
||||
}
|
||||
}
|
||||
|
||||
async def _execute_business_info_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行业务信息分析模块"""
|
||||
try:
|
||||
financial_data = analysis_context.get("financial_data", {})
|
||||
|
||||
result = await ai_analyzer.analyze_business_info(symbol, market, financial_data)
|
||||
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"业务信息分析失败: {str(e)}")
|
||||
|
||||
async def _execute_fundamental_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行基本面分析模块"""
|
||||
try:
|
||||
financial_data = analysis_context.get("financial_data", {})
|
||||
business_info = analysis_context.get("business_info", {})
|
||||
|
||||
result = await ai_analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
|
||||
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"基本面分析失败: {str(e)}")
|
||||
|
||||
async def _execute_bullish_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行看涨分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_bullish_case(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"看涨分析失败: {str(e)}")
|
||||
|
||||
async def _execute_bearish_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行看跌分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_bearish_case(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"看跌分析失败: {str(e)}")
|
||||
|
||||
async def _execute_market_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行市场分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_market_sentiment(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"市场分析失败: {str(e)}")
|
||||
|
||||
async def _execute_news_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行新闻分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_news_catalysts(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"新闻分析失败: {str(e)}")
|
||||
|
||||
async def _execute_trading_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行交易分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_trading_dynamics(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"交易分析失败: {str(e)}")
|
||||
|
||||
async def _execute_insider_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行内部人分析模块"""
|
||||
try:
|
||||
result = await ai_analyzer.analyze_insider_institutional(symbol, market, analysis_context)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"内部人分析失败: {str(e)}")
|
||||
|
||||
async def _execute_final_conclusion_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行最终结论模块"""
|
||||
try:
|
||||
# 收集所有分析结果
|
||||
all_analyses = []
|
||||
for key, value in analysis_context.items():
|
||||
if key != "financial_data": # 排除原始财务数据
|
||||
all_analyses.append({
|
||||
"analysis_type": key,
|
||||
"content": value
|
||||
})
|
||||
|
||||
result = await ai_analyzer.generate_final_conclusion(symbol, market, all_analyses)
|
||||
return result.content
|
||||
except Exception as e:
|
||||
raise AIAnalysisError(f"最终结论生成失败: {str(e)}")
|
||||
|
||||
async def _mark_module_started(self, report_id: UUID, module_type: str):
|
||||
"""标记模块开始"""
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule).where(
|
||||
AnalysisModule.report_id == report_id,
|
||||
AnalysisModule.module_type == module_type
|
||||
)
|
||||
)
|
||||
module = result.scalar_one_or_none()
|
||||
|
||||
if module:
|
||||
module.status = "running"
|
||||
module.started_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
|
||||
async def _mark_module_completed(self, report_id: UUID, module_type: str):
|
||||
"""标记模块完成"""
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule).where(
|
||||
AnalysisModule.report_id == report_id,
|
||||
AnalysisModule.module_type == module_type
|
||||
)
|
||||
)
|
||||
module = result.scalar_one_or_none()
|
||||
|
||||
if module:
|
||||
module.status = "completed"
|
||||
module.completed_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
|
||||
async def _mark_module_failed(self, report_id: UUID, module_type: str, error_message: str):
|
||||
"""标记模块失败"""
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule).where(
|
||||
AnalysisModule.report_id == report_id,
|
||||
AnalysisModule.module_type == module_type
|
||||
)
|
||||
)
|
||||
module = result.scalar_one_or_none()
|
||||
|
||||
if module:
|
||||
module.status = "failed"
|
||||
module.completed_at = datetime.utcnow()
|
||||
module.error_message = error_message
|
||||
await self.db.flush()
|
||||
|
||||
async def _save_module_content(self, report_id: UUID, module_type: str, content: Dict[str, Any]):
|
||||
"""保存模块内容"""
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule).where(
|
||||
AnalysisModule.report_id == report_id,
|
||||
AnalysisModule.module_type == module_type
|
||||
)
|
||||
)
|
||||
module = result.scalar_one_or_none()
|
||||
|
||||
if module:
|
||||
module.content = content
|
||||
await self.db.flush()
|
||||
|
||||
async def _build_report_response(self, report: Report) -> ReportResponse:
|
||||
"""构建报告响应"""
|
||||
# 获取分析模块
|
||||
result = await self.db.execute(
|
||||
select(AnalysisModule)
|
||||
.where(AnalysisModule.report_id == report.id)
|
||||
.order_by(AnalysisModule.module_order)
|
||||
)
|
||||
modules = result.scalars().all()
|
||||
|
||||
# 转换为响应模式
|
||||
module_schemas = [
|
||||
AnalysisModuleSchema(
|
||||
id=module.id,
|
||||
module_type=module.module_type,
|
||||
module_order=module.module_order,
|
||||
title=module.title,
|
||||
content=module.content,
|
||||
status=module.status,
|
||||
started_at=module.started_at,
|
||||
completed_at=module.completed_at,
|
||||
error_message=module.error_message
|
||||
)
|
||||
for module in modules
|
||||
]
|
||||
|
||||
return ReportResponse(
|
||||
id=report.id,
|
||||
symbol=report.symbol,
|
||||
market=report.market,
|
||||
status=report.status,
|
||||
created_at=report.created_at,
|
||||
updated_at=report.updated_at,
|
||||
analysis_modules=module_schemas
|
||||
)
|
||||
45
backend/check_db.py
Executable file
45
backend/check_db.py
Executable file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库连接检查脚本
|
||||
用于验证数据库配置和连接状态
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from app.core.database import check_db_connection, close_db
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数:检查数据库连接"""
|
||||
try:
|
||||
print(f"正在检查数据库连接...")
|
||||
print(f"数据库URL: {settings.DATABASE_URL}")
|
||||
|
||||
# 检查数据库连接
|
||||
is_connected = await check_db_connection()
|
||||
|
||||
if is_connected:
|
||||
print("✅ 数据库连接正常!")
|
||||
else:
|
||||
print("❌ 数据库连接失败!")
|
||||
print("请检查:")
|
||||
print("1. PostgreSQL服务是否运行")
|
||||
print("2. 数据库配置是否正确")
|
||||
print("3. 网络连接是否正常")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库连接检查失败: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
await close_db()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
41
backend/init_db.py
Executable file
41
backend/init_db.py
Executable file
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库初始化脚本
|
||||
用于创建数据库表和运行初始迁移
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from app.core.database import init_db, close_db
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数:初始化数据库"""
|
||||
try:
|
||||
print(f"正在连接数据库: {settings.DATABASE_URL}")
|
||||
print("正在创建数据库表...")
|
||||
|
||||
# 初始化数据库表
|
||||
await init_db()
|
||||
|
||||
print("✅ 数据库表创建成功!")
|
||||
print("💡 提示: 如果需要使用Alembic管理迁移,请运行:")
|
||||
print(" alembic stamp head")
|
||||
print(" alembic revision --autogenerate -m '描述'")
|
||||
print(" alembic upgrade head")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库初始化失败: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
await close_db()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
101
backend/main.py
Normal file
101
backend/main.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""
|
||||
FastAPI应用入口点
|
||||
基本面选股系统后端服务
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import engine, Base
|
||||
from app.routers import reports, config, progress
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动时创建数据库表
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
# 关闭时清理资源
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# 创建FastAPI应用实例
|
||||
app = FastAPI(
|
||||
title="基本面选股系统",
|
||||
description="""
|
||||
提供股票基本面分析和报告生成的API服务
|
||||
|
||||
## 功能特性
|
||||
|
||||
* **报告管理**: 创建、查询、更新和删除股票分析报告
|
||||
* **进度追踪**: 实时追踪报告生成进度
|
||||
* **配置管理**: 管理系统配置,包括数据库、API密钥等
|
||||
* **多市场支持**: 支持中国、香港、美国、日本股票市场
|
||||
|
||||
## 支持的市场
|
||||
|
||||
* `china` - 中国A股市场
|
||||
* `hongkong` - 香港股票市场
|
||||
* `usa` - 美国股票市场
|
||||
* `japan` - 日本股票市场
|
||||
""",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json"
|
||||
)
|
||||
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(
|
||||
reports.router,
|
||||
prefix="/api/reports",
|
||||
tags=["reports"],
|
||||
responses={
|
||||
404: {"description": "报告不存在"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
app.include_router(
|
||||
config.router,
|
||||
prefix="/api/config",
|
||||
tags=["config"],
|
||||
responses={
|
||||
400: {"description": "配置参数错误"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
app.include_router(
|
||||
progress.router,
|
||||
prefix="/api/progress",
|
||||
tags=["progress"],
|
||||
responses={
|
||||
404: {"description": "进度记录不存在"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径健康检查"""
|
||||
return {"message": "基本面选股系统API服务正在运行", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
return {"status": "healthy", "service": "fundamental-stock-analysis"}
|
||||
125
backend/manage_db.py
Executable file
125
backend/manage_db.py
Executable file
@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
数据库管理脚本
|
||||
提供数据库初始化、检查、迁移等功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from app.core.database import init_db, close_db, check_db_connection
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
async def check_connection():
|
||||
"""检查数据库连接"""
|
||||
print(f"正在检查数据库连接...")
|
||||
print(f"数据库URL: {settings.DATABASE_URL}")
|
||||
|
||||
is_connected = await check_db_connection()
|
||||
|
||||
if is_connected:
|
||||
print("✅ 数据库连接正常!")
|
||||
return True
|
||||
else:
|
||||
print("❌ 数据库连接失败!")
|
||||
print("请检查:")
|
||||
print("1. PostgreSQL服务是否运行")
|
||||
print("2. 数据库配置是否正确")
|
||||
print("3. 网络连接是否正常")
|
||||
return False
|
||||
|
||||
|
||||
async def initialize_database():
|
||||
"""初始化数据库"""
|
||||
print(f"正在连接数据库: {settings.DATABASE_URL}")
|
||||
print("正在创建数据库表...")
|
||||
|
||||
try:
|
||||
# 初始化数据库表
|
||||
await init_db()
|
||||
|
||||
print("✅ 数据库表创建成功!")
|
||||
print("💡 提示: 如果需要使用Alembic管理迁移,请运行:")
|
||||
print(" alembic stamp head")
|
||||
print(" alembic revision --autogenerate -m '描述'")
|
||||
print(" alembic upgrade head")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库初始化失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def show_status():
|
||||
"""显示数据库状态"""
|
||||
print("=== 数据库状态 ===")
|
||||
print(f"数据库URL: {settings.DATABASE_URL}")
|
||||
print(f"数据库Echo: {settings.DATABASE_ECHO}")
|
||||
|
||||
# 检查连接
|
||||
is_connected = await check_db_connection()
|
||||
print(f"连接状态: {'✅ 正常' if is_connected else '❌ 失败'}")
|
||||
|
||||
if is_connected:
|
||||
try:
|
||||
from app.core.database import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as session:
|
||||
# 检查表是否存在
|
||||
result = await session.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name IN ('reports', 'analysis_modules', 'progress_tracking', 'system_config')
|
||||
ORDER BY table_name
|
||||
""")
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
print(f"已创建的表: {', '.join(tables) if tables else '无'}")
|
||||
|
||||
if 'reports' in tables:
|
||||
result = await session.execute("SELECT COUNT(*) FROM reports")
|
||||
count = result.scalar()
|
||||
print(f"报告数量: {count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取详细状态失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='数据库管理工具')
|
||||
parser.add_argument('command', choices=['check', 'init', 'status'],
|
||||
help='要执行的命令')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.command == 'check':
|
||||
success = await check_connection()
|
||||
elif args.command == 'init':
|
||||
success = await initialize_database()
|
||||
elif args.command == 'status':
|
||||
await show_status()
|
||||
success = True
|
||||
else:
|
||||
print(f"未知命令: {args.command}")
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 执行失败: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
await close_db()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
40
backend/requirements.txt
Normal file
40
backend/requirements.txt
Normal file
@ -0,0 +1,40 @@
|
||||
# FastAPI和相关依赖
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
python-multipart==0.0.6
|
||||
|
||||
# 数据库相关
|
||||
sqlalchemy[asyncio]==2.0.23
|
||||
asyncpg==0.29.0
|
||||
alembic==1.12.1
|
||||
|
||||
# 数据验证和序列化
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# HTTP客户端
|
||||
httpx==0.25.2
|
||||
aiohttp==3.9.1
|
||||
|
||||
# AI服务
|
||||
google-generativeai==0.3.2
|
||||
|
||||
# 环境变量管理
|
||||
python-dotenv==1.0.0
|
||||
|
||||
# 日志和监控
|
||||
structlog==23.2.0
|
||||
|
||||
# 开发和测试依赖
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
black==23.11.0
|
||||
isort==5.12.0
|
||||
flake8==6.1.0
|
||||
|
||||
# 类型检查
|
||||
mypy==1.7.1
|
||||
|
||||
# 安全相关
|
||||
cryptography==41.0.7
|
||||
152
backend/test_external_apis.py
Normal file
152
backend/test_external_apis.py
Normal file
@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
外部API集成测试脚本
|
||||
用于验证Tushare和Gemini API集成是否正常工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from app.services.external_api_service import get_external_api_service
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
async def test_data_sources():
|
||||
"""测试数据源"""
|
||||
print("=== 测试数据源 ===")
|
||||
|
||||
service = get_external_api_service()
|
||||
|
||||
# 检查支持的数据源
|
||||
supported_sources = service.get_supported_data_sources()
|
||||
print(f"支持的数据源: {supported_sources}")
|
||||
|
||||
available_sources = service.get_available_data_sources()
|
||||
print(f"可用的数据源: {available_sources}")
|
||||
|
||||
# 测试数据源状态
|
||||
try:
|
||||
status_response = await service.check_all_services_status()
|
||||
print(f"整体状态: {status_response.overall_status}")
|
||||
|
||||
for source_status in status_response.sources:
|
||||
status_text = "✅ 可用" if source_status.is_available else "❌ 不可用"
|
||||
print(f" {source_status.name}: {status_text}")
|
||||
if source_status.response_time_ms:
|
||||
print(f" 响应时间: {source_status.response_time_ms}ms")
|
||||
if source_status.error_message:
|
||||
print(f" 错误信息: {source_status.error_message}")
|
||||
except Exception as e:
|
||||
print(f"检查状态失败: {e}")
|
||||
|
||||
|
||||
async def test_symbol_validation():
|
||||
"""测试证券代码验证"""
|
||||
print("\n=== 测试证券代码验证 ===")
|
||||
|
||||
service = get_external_api_service()
|
||||
|
||||
test_cases = [
|
||||
("000001", "中国"), # 平安银行
|
||||
("600036", "中国"), # 招商银行
|
||||
("AAPL", "美国"), # 苹果
|
||||
("INVALID", "中国") # 无效代码
|
||||
]
|
||||
|
||||
for symbol, market in test_cases:
|
||||
try:
|
||||
result = await service.validate_stock_symbol(symbol, market)
|
||||
status_text = "✅ 有效" if result.is_valid else "❌ 无效"
|
||||
print(f" {symbol} ({market}): {status_text}")
|
||||
if result.company_name:
|
||||
print(f" 公司名称: {result.company_name}")
|
||||
if result.message:
|
||||
print(f" 消息: {result.message}")
|
||||
except Exception as e:
|
||||
print(f" {symbol} ({market}): ❌ 验证失败 - {e}")
|
||||
|
||||
|
||||
async def test_financial_data():
|
||||
"""测试财务数据获取"""
|
||||
print("\n=== 测试财务数据获取 ===")
|
||||
|
||||
service = get_external_api_service()
|
||||
|
||||
test_cases = [
|
||||
("000001", "中国"), # 平安银行
|
||||
("AAPL", "美国") # 苹果
|
||||
]
|
||||
|
||||
for symbol, market in test_cases:
|
||||
try:
|
||||
result = await service.get_financial_data(symbol, market)
|
||||
print(f" {symbol} ({market}): ✅ 获取成功")
|
||||
print(f" 数据源: {result.data_source}")
|
||||
print(f" 更新时间: {result.last_updated}")
|
||||
|
||||
# 显示部分财务数据
|
||||
if result.balance_sheet:
|
||||
print(f" 资产负债表: {len(result.balance_sheet)} 项数据")
|
||||
if result.income_statement:
|
||||
print(f" 利润表: {len(result.income_statement)} 项数据")
|
||||
|
||||
except Exception as e:
|
||||
print(f" {symbol} ({market}): ❌ 获取失败 - {e}")
|
||||
|
||||
|
||||
async def test_ai_analysis():
|
||||
"""测试AI分析"""
|
||||
print("\n=== 测试AI分析 ===")
|
||||
|
||||
service = get_external_api_service()
|
||||
|
||||
if not service.is_ai_service_available():
|
||||
print("❌ AI服务不可用,请检查Gemini API配置")
|
||||
return
|
||||
|
||||
# 模拟财务数据
|
||||
mock_financial_data = {
|
||||
"balance_sheet": {"total_assets": 1000000, "total_liab": 600000},
|
||||
"income_statement": {"revenue": 500000, "n_income": 50000},
|
||||
"cash_flow": {"n_cashflow_act": 80000},
|
||||
"key_metrics": {"pe": 15.5, "pb": 1.2, "roe": 12.5}
|
||||
}
|
||||
|
||||
try:
|
||||
result = await service.analyze_business_info("000001", "中国", mock_financial_data)
|
||||
print(" 业务信息分析: ✅ 成功")
|
||||
print(f" 使用模型: {result.model_used}")
|
||||
print(f" 生成时间: {result.generated_at}")
|
||||
|
||||
# 显示部分分析内容
|
||||
if result.content.get("company_overview"):
|
||||
overview = result.content["company_overview"][:100] + "..." if len(result.content["company_overview"]) > 100 else result.content["company_overview"]
|
||||
print(f" 公司概览: {overview}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 业务信息分析: ❌ 失败 - {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("开始测试外部API集成...")
|
||||
print(f"Tushare Token: {'已配置' if settings.TUSHARE_TOKEN else '未配置'}")
|
||||
print(f"Gemini API Key: {'已配置' if settings.GEMINI_API_KEY else '未配置'}")
|
||||
print()
|
||||
|
||||
await test_data_sources()
|
||||
await test_symbol_validation()
|
||||
await test_financial_data()
|
||||
await test_ai_analysis()
|
||||
|
||||
print("\n测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
41
frontend/.gitignore
vendored
Normal file
41
frontend/.gitignore
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.*
|
||||
.yarn/*
|
||||
!.yarn/patches
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/versions
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# env files (can opt-in for committing if needed)
|
||||
.env*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
86
frontend/README.md
Normal file
86
frontend/README.md
Normal file
@ -0,0 +1,86 @@
|
||||
# 基本面选股系统 - 前端
|
||||
|
||||
这是基本面选股系统的前端应用,使用 Next.js 14 和 TypeScript 构建。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **框架**: Next.js 14 (App Router)
|
||||
- **语言**: TypeScript
|
||||
- **样式**: Tailwind CSS
|
||||
- **UI组件**: shadcn/ui
|
||||
- **字体**: Noto Sans SC (中文支持)
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
src/
|
||||
├── app/ # Next.js App Router 页面
|
||||
│ ├── layout.tsx # 根布局
|
||||
│ ├── page.tsx # 首页
|
||||
│ └── globals.css # 全局样式
|
||||
├── components/ # React 组件
|
||||
│ └── ui/ # shadcn/ui 基础组件
|
||||
├── lib/ # 工具库
|
||||
│ ├── api.ts # API 客户端
|
||||
│ ├── types.ts # TypeScript 类型定义
|
||||
│ └── utils.ts # 工具函数
|
||||
└── hooks/ # 自定义 React Hooks
|
||||
├── useReport.ts # 报告数据钩子
|
||||
└── useProgress.ts # 进度追踪钩子
|
||||
```
|
||||
|
||||
## 开发命令
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 启动开发服务器
|
||||
npm run dev
|
||||
|
||||
# 构建生产版本
|
||||
npm run build
|
||||
|
||||
# 启动生产服务器
|
||||
npm start
|
||||
|
||||
# 代码检查
|
||||
npm run lint
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- ✅ Next.js 14 项目初始化
|
||||
- ✅ TypeScript 配置
|
||||
- ✅ Tailwind CSS 样式系统
|
||||
- ✅ shadcn/ui 组件库集成
|
||||
- ✅ 中文字体支持 (Noto Sans SC)
|
||||
- ✅ 基础项目结构
|
||||
- ✅ API 客户端封装
|
||||
- ✅ 自定义 Hooks
|
||||
- ✅ 响应式布局
|
||||
|
||||
## 环境变量
|
||||
|
||||
创建 `.env.local` 文件并配置以下变量:
|
||||
|
||||
```env
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
```
|
||||
|
||||
## 开发说明
|
||||
|
||||
1. 项目使用 App Router 架构
|
||||
2. 所有页面组件位于 `src/app/` 目录
|
||||
3. 可复用组件位于 `src/components/` 目录
|
||||
4. API 调用通过 `src/lib/api.ts` 统一管理
|
||||
5. 类型定义集中在 `src/lib/types.ts`
|
||||
6. 自定义 Hooks 用于状态管理和数据获取
|
||||
|
||||
## 下一步开发
|
||||
|
||||
- 实现股票搜索表单组件
|
||||
- 创建报告页面和路由
|
||||
- 集成 TradingView 图表
|
||||
- 实现进度显示组件
|
||||
- 添加配置管理页面
|
||||
17
frontend/components.json
Normal file
17
frontend/components.json
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "default",
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.ts",
|
||||
"css": "src/app/globals.css",
|
||||
"baseColor": "slate",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
}
|
||||
}
|
||||
25
frontend/eslint.config.mjs
Normal file
25
frontend/eslint.config.mjs
Normal file
@ -0,0 +1,25 @@
|
||||
import { dirname } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
import { FlatCompat } from "@eslint/eslintrc";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
const compat = new FlatCompat({
|
||||
baseDirectory: __dirname,
|
||||
});
|
||||
|
||||
const eslintConfig = [
|
||||
...compat.extends("next/core-web-vitals", "next/typescript"),
|
||||
{
|
||||
ignores: [
|
||||
"node_modules/**",
|
||||
".next/**",
|
||||
"out/**",
|
||||
"build/**",
|
||||
"next-env.d.ts",
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export default eslintConfig;
|
||||
7
frontend/next.config.ts
Normal file
7
frontend/next.config.ts
Normal file
@ -0,0 +1,7 @@
|
||||
import type { NextConfig } from "next";
|
||||
|
||||
const nextConfig: NextConfig = {
|
||||
/* config options here */
|
||||
};
|
||||
|
||||
export default nextConfig;
|
||||
6190
frontend/package-lock.json
generated
Normal file
6190
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
33
frontend/package.json
Normal file
33
frontend/package.json
Normal file
@ -0,0 +1,33 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build --turbopack",
|
||||
"start": "next start",
|
||||
"lint": "eslint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"lucide-react": "^0.546.0",
|
||||
"next": "15.5.6",
|
||||
"react": "19.1.0",
|
||||
"react-dom": "19.1.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwindcss-animate": "^1.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/eslintrc": "^3",
|
||||
"@tailwindcss/postcss": "^4",
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "15.5.6",
|
||||
"tailwindcss": "^4",
|
||||
"typescript": "^5"
|
||||
}
|
||||
}
|
||||
5
frontend/postcss.config.mjs
Normal file
5
frontend/postcss.config.mjs
Normal file
@ -0,0 +1,5 @@
|
||||
const config = {
|
||||
plugins: ["@tailwindcss/postcss"],
|
||||
};
|
||||
|
||||
export default config;
|
||||
BIN
frontend/src/app/favicon.ico
Normal file
BIN
frontend/src/app/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
54
frontend/src/app/globals.css
Normal file
54
frontend/src/app/globals.css
Normal file
@ -0,0 +1,54 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
:root {
|
||||
--background: 0 0% 100%;
|
||||
--foreground: 222.2 84% 4.9%;
|
||||
--card: 0 0% 100%;
|
||||
--card-foreground: 222.2 84% 4.9%;
|
||||
--popover: 0 0% 100%;
|
||||
--popover-foreground: 222.2 84% 4.9%;
|
||||
--primary: 222.2 47.4% 11.2%;
|
||||
--primary-foreground: 210 40% 98%;
|
||||
--secondary: 210 40% 96%;
|
||||
--secondary-foreground: 222.2 47.4% 11.2%;
|
||||
--muted: 210 40% 96%;
|
||||
--muted-foreground: 215.4 16.3% 46.9%;
|
||||
--accent: 210 40% 96%;
|
||||
--accent-foreground: 222.2 47.4% 11.2%;
|
||||
--destructive: 0 84.2% 60.2%;
|
||||
--destructive-foreground: 210 40% 98%;
|
||||
--border: 214.3 31.8% 91.4%;
|
||||
--input: 214.3 31.8% 91.4%;
|
||||
--ring: 222.2 84% 4.9%;
|
||||
--radius: 0.5rem;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: 222.2 84% 4.9%;
|
||||
--foreground: 210 40% 98%;
|
||||
--card: 222.2 84% 4.9%;
|
||||
--card-foreground: 210 40% 98%;
|
||||
--popover: 222.2 84% 4.9%;
|
||||
--popover-foreground: 210 40% 98%;
|
||||
--primary: 210 40% 98%;
|
||||
--primary-foreground: 222.2 47.4% 11.2%;
|
||||
--secondary: 217.2 32.6% 17.5%;
|
||||
--secondary-foreground: 210 40% 98%;
|
||||
--muted: 217.2 32.6% 17.5%;
|
||||
--muted-foreground: 215 20.2% 65.1%;
|
||||
--accent: 217.2 32.6% 17.5%;
|
||||
--accent-foreground: 210 40% 98%;
|
||||
--destructive: 0 62.8% 30.6%;
|
||||
--destructive-foreground: 210 40% 98%;
|
||||
--border: 217.2 32.6% 17.5%;
|
||||
--input: 217.2 32.6% 17.5%;
|
||||
--ring: 212.7 26.8% 83.9%;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: "Noto Sans SC", "PingFang SC", "Microsoft YaHei", sans-serif;
|
||||
background-color: hsl(var(--background));
|
||||
color: hsl(var(--foreground));
|
||||
}
|
||||
32
frontend/src/app/layout.tsx
Normal file
32
frontend/src/app/layout.tsx
Normal file
@ -0,0 +1,32 @@
|
||||
import type { Metadata } from "next";
|
||||
import { Noto_Sans_SC } from "next/font/google";
|
||||
import "./globals.css";
|
||||
|
||||
const notoSansSC = Noto_Sans_SC({
|
||||
subsets: ["latin"],
|
||||
weight: ["300", "400", "500", "700"],
|
||||
variable: "--font-noto-sans-sc",
|
||||
});
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "基本面选股系统",
|
||||
description: "专业的股票基本面分析平台,提供全面的投资决策支持",
|
||||
};
|
||||
|
||||
export default function RootLayout({
|
||||
children,
|
||||
}: Readonly<{
|
||||
children: React.ReactNode;
|
||||
}>) {
|
||||
return (
|
||||
<html lang="zh-CN" suppressHydrationWarning>
|
||||
<body className={`${notoSansSC.variable} antialiased`}>
|
||||
<div className="min-h-screen bg-background font-sans antialiased">
|
||||
<main className="relative flex min-h-screen flex-col">
|
||||
{children}
|
||||
</main>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
17
frontend/src/app/page.tsx
Normal file
17
frontend/src/app/page.tsx
Normal file
@ -0,0 +1,17 @@
|
||||
export default function Home() {
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-8">
|
||||
<div className="flex flex-col items-center justify-center min-h-[80vh] text-center">
|
||||
<h1 className="text-4xl font-bold text-foreground mb-6">
|
||||
基本面选股系统
|
||||
</h1>
|
||||
<p className="text-xl text-muted-foreground mb-8 max-w-2xl">
|
||||
专业的股票基本面分析平台,通过多维度分析为您提供全面的投资决策支持
|
||||
</p>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
前端项目已成功初始化,准备开始开发核心功能
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
1
frontend/src/components/.gitkeep
Normal file
1
frontend/src/components/.gitkeep
Normal file
@ -0,0 +1 @@
|
||||
# This file ensures the components directory is tracked by git
|
||||
56
frontend/src/components/ui/button.tsx
Normal file
56
frontend/src/components/ui/button.tsx
Normal file
@ -0,0 +1,56 @@
|
||||
import * as React from "react"
|
||||
import { Slot } from "@radix-ui/react-slot"
|
||||
import { cva, type VariantProps } from "class-variance-authority"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: "bg-primary text-primary-foreground hover:bg-primary/90",
|
||||
destructive:
|
||||
"bg-destructive text-destructive-foreground hover:bg-destructive/90",
|
||||
outline:
|
||||
"border border-input bg-background hover:bg-accent hover:text-accent-foreground",
|
||||
secondary:
|
||||
"bg-secondary text-secondary-foreground hover:bg-secondary/80",
|
||||
ghost: "hover:bg-accent hover:text-accent-foreground",
|
||||
link: "text-primary underline-offset-4 hover:underline",
|
||||
},
|
||||
size: {
|
||||
default: "h-10 px-4 py-2",
|
||||
sm: "h-9 rounded-md px-3",
|
||||
lg: "h-11 rounded-md px-8",
|
||||
icon: "h-10 w-10",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
size: "default",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
export interface ButtonProps
|
||||
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
|
||||
VariantProps<typeof buttonVariants> {
|
||||
asChild?: boolean
|
||||
}
|
||||
|
||||
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
|
||||
({ className, variant, size, asChild = false, ...props }, ref) => {
|
||||
const Comp = asChild ? Slot : "button"
|
||||
return (
|
||||
<Comp
|
||||
className={cn(buttonVariants({ variant, size, className }))}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
)
|
||||
Button.displayName = "Button"
|
||||
|
||||
export { Button, buttonVariants }
|
||||
79
frontend/src/components/ui/card.tsx
Normal file
79
frontend/src/components/ui/card.tsx
Normal file
@ -0,0 +1,79 @@
|
||||
import * as React from "react"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
const Card = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"rounded-lg border bg-card text-card-foreground shadow-sm",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
Card.displayName = "Card"
|
||||
|
||||
const CardHeader = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex flex-col space-y-1.5 p-6", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
CardHeader.displayName = "CardHeader"
|
||||
|
||||
const CardTitle = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLHeadingElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<h3
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"text-2xl font-semibold leading-none tracking-tight",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
CardTitle.displayName = "CardTitle"
|
||||
|
||||
const CardDescription = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLParagraphElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<p
|
||||
ref={ref}
|
||||
className={cn("text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
CardDescription.displayName = "CardDescription"
|
||||
|
||||
const CardContent = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
|
||||
))
|
||||
CardContent.displayName = "CardContent"
|
||||
|
||||
const CardFooter = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex items-center p-6 pt-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
CardFooter.displayName = "CardFooter"
|
||||
|
||||
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent }
|
||||
1
frontend/src/hooks/.gitkeep
Normal file
1
frontend/src/hooks/.gitkeep
Normal file
@ -0,0 +1 @@
|
||||
# This file ensures the hooks directory is tracked by git
|
||||
68
frontend/src/hooks/useProgress.ts
Normal file
68
frontend/src/hooks/useProgress.ts
Normal file
@ -0,0 +1,68 @@
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import { ProgressResponse } from "@/lib/types";
|
||||
import { apiClient } from "@/lib/api";
|
||||
|
||||
export function useProgress(reportId?: string, pollingInterval: number = 2000) {
|
||||
const [progress, setProgress] = useState<ProgressResponse | null>(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const intervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
setLoading(false);
|
||||
}, []);
|
||||
|
||||
const fetchProgress = useCallback(async (id: string) => {
|
||||
try {
|
||||
const progressData = await apiClient.getReportProgress(id);
|
||||
setProgress(progressData);
|
||||
|
||||
// 如果报告已完成或失败,停止轮询
|
||||
if (progressData.status === "completed" || progressData.status === "failed") {
|
||||
stopPolling();
|
||||
}
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "获取进度失败");
|
||||
stopPolling();
|
||||
}
|
||||
}, [stopPolling]);
|
||||
|
||||
const startPolling = useCallback((id: string) => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
// 立即获取一次进度
|
||||
fetchProgress(id);
|
||||
|
||||
// 开始轮询
|
||||
intervalRef.current = setInterval(() => {
|
||||
fetchProgress(id);
|
||||
}, pollingInterval);
|
||||
}, [fetchProgress, pollingInterval]);
|
||||
|
||||
useEffect(() => {
|
||||
if (reportId) {
|
||||
startPolling(reportId);
|
||||
}
|
||||
|
||||
return () => {
|
||||
stopPolling();
|
||||
};
|
||||
}, [reportId, startPolling, stopPolling]);
|
||||
|
||||
return {
|
||||
progress,
|
||||
loading,
|
||||
error,
|
||||
startPolling,
|
||||
stopPolling,
|
||||
};
|
||||
}
|
||||
49
frontend/src/hooks/useReport.ts
Normal file
49
frontend/src/hooks/useReport.ts
Normal file
@ -0,0 +1,49 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { Report, TradingMarket } from "@/lib/types";
|
||||
import { apiClient } from "@/lib/api";
|
||||
|
||||
export function useReport(symbol?: string, market?: TradingMarket) {
|
||||
const [report, setReport] = useState<Report | null>(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchReport = async (sym: string, mkt: TradingMarket) => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const reportData = await apiClient.getReport(sym, mkt);
|
||||
setReport(reportData);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "获取报告失败");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const regenerateReport = async (sym: string, mkt: TradingMarket, force: boolean = false) => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const reportData = await apiClient.regenerateReport(sym, mkt, force);
|
||||
setReport(reportData);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : "重新生成报告失败");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (symbol && market) {
|
||||
fetchReport(symbol, market);
|
||||
}
|
||||
}, [symbol, market]);
|
||||
|
||||
return {
|
||||
report,
|
||||
loading,
|
||||
error,
|
||||
fetchReport,
|
||||
regenerateReport,
|
||||
};
|
||||
}
|
||||
79
frontend/tailwind.config.ts
Normal file
79
frontend/tailwind.config.ts
Normal file
@ -0,0 +1,79 @@
|
||||
import type { Config } from "tailwindcss"
|
||||
|
||||
const config: Config = {
|
||||
darkMode: "class",
|
||||
content: [
|
||||
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
],
|
||||
prefix: "",
|
||||
theme: {
|
||||
container: {
|
||||
center: true,
|
||||
padding: "2rem",
|
||||
screens: {
|
||||
"2xl": "1400px",
|
||||
},
|
||||
},
|
||||
extend: {
|
||||
colors: {
|
||||
border: "hsl(var(--border))",
|
||||
input: "hsl(var(--input))",
|
||||
ring: "hsl(var(--ring))",
|
||||
background: "hsl(var(--background))",
|
||||
foreground: "hsl(var(--foreground))",
|
||||
primary: {
|
||||
DEFAULT: "hsl(var(--primary))",
|
||||
foreground: "hsl(var(--primary-foreground))",
|
||||
},
|
||||
secondary: {
|
||||
DEFAULT: "hsl(var(--secondary))",
|
||||
foreground: "hsl(var(--secondary-foreground))",
|
||||
},
|
||||
destructive: {
|
||||
DEFAULT: "hsl(var(--destructive))",
|
||||
foreground: "hsl(var(--destructive-foreground))",
|
||||
},
|
||||
muted: {
|
||||
DEFAULT: "hsl(var(--muted))",
|
||||
foreground: "hsl(var(--muted-foreground))",
|
||||
},
|
||||
accent: {
|
||||
DEFAULT: "hsl(var(--accent))",
|
||||
foreground: "hsl(var(--accent-foreground))",
|
||||
},
|
||||
popover: {
|
||||
DEFAULT: "hsl(var(--popover))",
|
||||
foreground: "hsl(var(--popover-foreground))",
|
||||
},
|
||||
card: {
|
||||
DEFAULT: "hsl(var(--card))",
|
||||
foreground: "hsl(var(--card-foreground))",
|
||||
},
|
||||
},
|
||||
borderRadius: {
|
||||
lg: "var(--radius)",
|
||||
md: "calc(var(--radius) - 2px)",
|
||||
sm: "calc(var(--radius) - 4px)",
|
||||
},
|
||||
keyframes: {
|
||||
"accordion-down": {
|
||||
from: { height: "0" },
|
||||
to: { height: "var(--radix-accordion-content-height)" },
|
||||
},
|
||||
"accordion-up": {
|
||||
from: { height: "var(--radix-accordion-content-height)" },
|
||||
to: { height: "0" },
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
"accordion-down": "accordion-down 0.2s ease-out",
|
||||
"accordion-up": "accordion-up 0.2s ease-out",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [require("tailwindcss-animate")],
|
||||
}
|
||||
|
||||
export default config
|
||||
27
frontend/tsconfig.json
Normal file
27
frontend/tsconfig.json
Normal file
@ -0,0 +1,27 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2017",
|
||||
"lib": ["dom", "dom.iterable", "esnext"],
|
||||
"allowJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"noEmit": true,
|
||||
"esModuleInterop": true,
|
||||
"module": "esnext",
|
||||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
"name": "next"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user