Initial commit: Fundamental stock analysis project setup

This commit is contained in:
xucheng 2025-10-20 15:20:32 +08:00
commit 91f701139f
66 changed files with 13958 additions and 0 deletions

314
.gitignore vendored Normal file
View File

@ -0,0 +1,314 @@
# ===== 通用文件 =====
# 操作系统生成的文件
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# IDE 和编辑器
.vscode/
.idea/
*.swp
*.swo
*~
# 日志文件
*.log
logs/
# 临时文件
*.tmp
*.temp
.cache/
# ===== Python 后端 =====
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.env.*
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
# ===== Node.js / Next.js 前端 =====
# Dependencies
node_modules/
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Coverage directory used by tools like istanbul
coverage/
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage
.grunt
# Bower dependency directory
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons
build/Release
# Dependency directories
jspm_packages/
# Snowpack dependency directory
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# parcel-bundler cache
.parcel-cache
# Next.js build output
.next/
out/
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
public
# Storybook build outputs
.out
.storybook-out
storybook-static
# Temporary folders
tmp/
temp/
# Vercel
.vercel
# Turbo
.turbo
# ===== 数据库 =====
# SQLite
*.sqlite
*.sqlite3
*.db
# PostgreSQL
*.sql
# MySQL
*.mysql
# ===== 配置和密钥 =====
# 环境变量文件
.env
.env.local
.env.development.local
.env.test.local
.env.production.local
# API 密钥和配置
config/secrets.json
secrets/
*.key
*.pem
*.p12
*.pfx
# ===== 项目特定 =====
# 数据文件
data/
*.csv
*.json.bak
*.xlsx
# 报告和输出
reports/
output/
exports/
# 备份文件
*.bak
*.backup
backup/
# 测试数据
test_data/
mock_data/

View File

@ -0,0 +1,351 @@
# 设计文档 - 基本面选股系统
## 概览
基本面选股系统是一个全栈Web应用采用前后端分离架构。前端使用Next.js和shadcn/ui构建响应式中文界面后端使用Python FastAPI提供API服务。系统通过多个专业分析模块结合财务数据API、AI大模型和实时数据为用户提供全面的股票基本面分析报告。
## 架构
### 系统架构图
```mermaid
graph TB
subgraph "前端层"
A[Next.js应用] --> B[shadcn/ui UI组件]
A --> C[TradingView图表组件]
end
subgraph "后端层"
D[FastAPI服务器] --> E[报告生成引擎]
D --> F[数据获取服务]
D --> G[配置管理服务]
end
subgraph "外部服务"
H[Tushare API]
I[Gemini API]
J[其他数据源APIs]
end
subgraph "数据层"
K[PostgreSQL数据库]
end
A --> D
F --> H
F --> I
F --> J
E --> K
G --> K
```
### 技术栈
**前端:**
- Next.js 14 (App Router)
- TypeScript
- shadcn/ui组件库 (https://ui.shadcn.com/)
- TradingView Charting Library
- Tailwind CSS
- Radix UI (shadcn/ui的底层组件)
**后端:**
- Python 3.11+
- FastAPI
- SQLAlchemy (ORM)
- Alembic (数据库迁移)
- Pydantic (数据验证)
- httpx (HTTP客户端)
**数据库:**
- PostgreSQL 15+
**外部服务:**
- Tushare API (中国股票数据)
- Google Gemini API (AI分析)
- 其他市场数据源APIs
## 组件和接口
### shadcn/ui组件使用规划
系统将使用shadcn/ui (https://ui.shadcn.com/) 的官方组件来构建一致的用户界面:
**核心组件使用:**
- `Button`: 主要操作按钮(生成报告、保存配置等)
- `Input`: 证券代码输入框
- `Select`: 交易市场选择器
- `Card`: 分析模块容器、报告卡片
- `Progress`: 报告生成进度条
- `Badge`: 状态标识(完成、进行中、失败)
- `Tabs`: 分析模块切换
- `Form`: 配置表单、搜索表单
- `Alert`: 错误提示、成功消息
- `Separator`: 内容分隔线
- `Skeleton`: 加载占位符
- `Toast`: 操作反馈通知
- `Table`: 财务数据展示表格(资产负债表、利润表、现金流量表等)
**主题配置:**
- 使用默认主题,支持深色/浅色模式切换
- 自定义中文字体配置
- 适配中文内容的间距和排版
### 前端组件结构
```
src/
├── app/
│ ├── page.tsx # 首页
│ ├── report/[symbol]/page.tsx # 报告页面
│ ├── config/page.tsx # 配置页面
│ └── layout.tsx # 根布局
├── components/
│ ├── ui/ # shadcn/ui基础组件 (Button, Input, Card, etc.)
│ ├── StockSearchForm.tsx # 股票搜索表单 (使用Form, Input, Select)
│ ├── ReportProgress.tsx # 报告生成进度 (使用Progress, Badge, Card)
│ ├── TradingViewChart.tsx # TradingView图表
│ ├── AnalysisModule.tsx # 分析模块容器 (使用Card, Tabs, Separator)
│ ├── FinancialDataTable.tsx # 财务数据表格 (使用Table, TableHeader, TableBody, TableRow, TableCell)
│ └── ConfigForm.tsx # 配置表单 (使用Form, Input, Button, Alert)
├── lib/
│ ├── api.ts # API客户端
│ ├── types.ts # TypeScript类型定义
│ └── utils.ts # 工具函数
└── hooks/
├── useReport.ts # 报告数据钩子
└── useProgress.ts # 进度追踪钩子
```
### 后端API结构
```
app/
├── main.py # FastAPI应用入口
├── models/
│ ├── __init__.py
│ ├── report.py # 报告数据模型
│ ├── config.py # 配置数据模型
│ └── progress.py # 进度追踪模型
├── schemas/
│ ├── __init__.py
│ ├── report.py # 报告Pydantic模式
│ ├── config.py # 配置Pydantic模式
│ └── progress.py # 进度Pydantic模式
├── services/
│ ├── __init__.py
│ ├── data_fetcher.py # 数据获取服务
│ ├── ai_analyzer.py # AI分析服务
│ ├── report_generator.py # 报告生成服务
│ └── config_manager.py # 配置管理服务
├── routers/
│ ├── __init__.py
│ ├── reports.py # 报告相关API
│ ├── config.py # 配置相关API
│ └── progress.py # 进度相关API
└── core/
├── __init__.py
├── database.py # 数据库连接
├── config.py # 应用配置
└── dependencies.py # 依赖注入
```
### 核心API接口
#### 报告相关API
```python
# GET /api/reports/{symbol}?market={market}
# 获取或生成股票报告
class ReportResponse:
symbol: str
market: str
report_id: str
status: str # "existing" | "generating" | "completed" | "failed"
created_at: datetime
updated_at: datetime
modules: List[AnalysisModule]
# POST /api/reports/{symbol}/regenerate?market={market}
# 重新生成报告
class RegenerateRequest:
force: bool = False
# GET /api/reports/{report_id}/progress
# 获取报告生成进度
class ProgressResponse:
report_id: str
current_step: int
total_steps: int
current_step_name: str
status: str # "running" | "completed" | "failed"
step_timings: List[StepTiming]
estimated_remaining: Optional[int]
```
#### 配置相关API
```python
# GET /api/config
# 获取系统配置
class ConfigResponse:
database: DatabaseConfig
gemini_api: GeminiConfig
data_sources: Dict[str, DataSourceConfig]
# PUT /api/config
# 更新系统配置
class ConfigUpdateRequest:
database: Optional[DatabaseConfig]
gemini_api: Optional[GeminiConfig]
data_sources: Optional[Dict[str, DataSourceConfig]]
# POST /api/config/test
# 测试配置连接
class ConfigTestRequest:
config_type: str # "database" | "gemini" | "data_source"
config_data: Dict[str, Any]
```
## 数据模型
### 数据库表结构
```sql
-- 报告表
CREATE TABLE reports (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
symbol VARCHAR(20) NOT NULL,
market VARCHAR(20) NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'generating',
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(symbol, market)
);
-- 分析模块表
CREATE TABLE analysis_modules (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
module_type VARCHAR(50) NOT NULL,
module_order INTEGER NOT NULL,
title VARCHAR(200) NOT NULL,
content JSONB,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE,
error_message TEXT
);
-- 进度追踪表
CREATE TABLE progress_tracking (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
step_name VARCHAR(100) NOT NULL,
step_order INTEGER NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'pending',
started_at TIMESTAMP WITH TIME ZONE,
completed_at TIMESTAMP WITH TIME ZONE,
duration_ms INTEGER,
error_message TEXT
);
-- 系统配置表
CREATE TABLE system_config (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
config_key VARCHAR(100) UNIQUE NOT NULL,
config_value JSONB NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
```
### 分析模块类型定义
```python
class AnalysisModuleType(Enum):
TRADING_VIEW_CHART = "trading_view_chart"
FINANCIAL_DATA = "financial_data"
BUSINESS_INFO = "business_info"
FUNDAMENTAL_ANALYSIS = "fundamental_analysis"
BULLISH_ANALYSIS = "bullish_analysis"
BEARISH_ANALYSIS = "bearish_analysis"
MARKET_ANALYSIS = "market_analysis"
NEWS_ANALYSIS = "news_analysis"
TRADING_ANALYSIS = "trading_analysis"
INSIDER_ANALYSIS = "insider_analysis"
FINAL_CONCLUSION = "final_conclusion"
class AnalysisModule(BaseModel):
id: UUID
module_type: AnalysisModuleType
title: str
content: Dict[str, Any]
status: str
duration_ms: Optional[int]
error_message: Optional[str]
```
## 错误处理
### 错误类型定义
```python
class StockAnalysisError(Exception):
"""基础异常类"""
pass
class DataSourceError(StockAnalysisError):
"""数据源错误"""
pass
class AIAnalysisError(StockAnalysisError):
"""AI分析错误"""
pass
class ConfigurationError(StockAnalysisError):
"""配置错误"""
pass
class DatabaseError(StockAnalysisError):
"""数据库错误"""
pass
```
### 错误处理策略
1. **数据获取失败**: 重试机制最多3次重试指数退避
2. **AI分析失败**: 记录错误,继续其他模块,最后汇总失败信息
3. **数据库连接失败**: 使用连接池,自动重连
4. **配置错误**: 提供详细错误信息,阻止系统启动
5. **前端错误**: Toast通知错误边界组件
## 测试策略
### 测试层级
1. **单元测试**
- 后端服务函数测试
- 前端组件测试
- 数据模型验证测试
2. **集成测试**
- API端点测试
- 数据库操作测试
- 外部服务集成测试
3. **端到端测试**
- 完整报告生成流程测试
- 用户界面交互测试
### 测试工具
- **后端**: pytest, pytest-asyncio, httpx
- **前端**: Jest, React Testing Library, Playwright
- **数据库**: pytest-postgresql
- **API测试**: FastAPI TestClient
### 测试数据
- 使用测试数据库和模拟数据
- 外部API使用mock响应
- 测试用例覆盖各种市场和股票类型

View File

@ -0,0 +1,112 @@
# 需求文档 - MVP版本
## 介绍
基本面选股系统MVP是一个综合的中文网站允许用户输入证券代码和交易市场生成包含多维度分析的详细股票基本面报告。系统通过多个专业分析模块结合财务数据、AI分析和市场信息为用户提供全面的投资决策支持。
## 术语表
- **选股系统 (Stock_Selection_System)**: 提供基本面分析和报告生成的主要系统
- **用户 (User)**: 使用系统进行股票分析的终端用户
- **证券代码 (Security_Code)**: 股票在特定交易市场的唯一标识符
- **交易市场 (Trading_Market)**: 股票交易的地理区域,包括中国、香港、美国、日本
- **基本面报告 (Fundamental_Report)**: 包含九个分析模块的综合股票分析报告
- **TradingView图表 (TradingView_Chart)**: 使用TradingView高级图表组件显示的股价图表
- **Tushare_API**: 用于获取中国股票财务数据的数据源接口
- **Gemini_Model**: Google的大语言模型用于生成业务分析内容
- **景林模型 (Jinglin_Model)**: 基本面分析师使用的问题集分析框架
- **PostgreSQL数据库 (PostgreSQL_Database)**: 用于存储报告数据的关系型数据库
- **分析模块 (Analysis_Module)**: 报告中的独立分析部分,每个模块对应一个显示页面
## 需求
### 需求 1
**用户故事:** 作为投资者,我希望能够输入股票代码和选择交易市场,以便获取该股票的综合基本面分析报告
#### 验收标准
1. 当用户访问首页时,选股系统应当显示证券代码输入框和交易市场选择器
2. 当用户选择交易市场时,选股系统应当提供中国、香港、美国、日本四个选项
3. 当用户提交证券代码和交易市场时,选股系统应当处理用户请求并跳转到报告页面
### 需求 2
**用户故事:** 作为投资者,我希望系统能够检查历史报告,以便决定是查看现有报告还是生成新报告
#### 验收标准
1. 当用户提交证券代码和交易市场后选股系统应当在PostgreSQL数据库中查询对应的历史报告
2. 如果存在历史报告,选股系统应当显示历史报告内容和"生成最新报告"按钮
3. 如果不存在历史报告,选股系统应当自动启动九步报告生成流程
### 需求 3
**用户故事:** 作为投资者,我希望系统能够获取准确的财务数据,以便进行可靠的基本面分析
#### 验收标准
1. 当生成中国股票报告时选股系统应当使用Tushare_API获取财务信息
2. 当处理其他市场股票时,选股系统应当根据交易市场选择相应的数据源
3. 当财务数据获取完成时,选股系统应当将数据作为后续分析的基础
### 需求 4
**用户故事:** 作为投资者我希望系统能够通过AI分析获取公司业务信息以便了解公司的全面情况
#### 验收标准
1. 当需要业务信息时选股系统应当使用Gemini生成公司概览、主营业务、发展历程、核心团队、供应链、主要客户及销售模式、未来展望
2. 当调用Gemini_Model时选股系统应当使用配置的API密钥进行身份验证
3. 当业务信息生成完成时,选股系统应当将内容整合到报告的第二部分
### 需求 5
**用户故事:** 作为投资者,我希望系统能够提供多维度的专业分析,以便获得全面的投资决策支持
#### 验收标准
1. 当生成报告时选股系统应当按顺序执行10个分析模块财务信息、业务信息、基本面分析、看涨分析、看跌分析、市场分析、新闻分析、交易分析、内部人与机构动向分析、最终结论
2. 当执行基本面分析时,选股系统应当使用问题集进行分析
3. 当执行看涨分析时,选股系统应当研究潜在隐藏资产和护城河竞争优势
4. 当执行看跌分析时,选股系统应当分析公司价值底线和最坏情况
5. 当执行市场分析时,选股系统应当研究市场情绪分歧点与变化驱动
6. 当执行新闻分析时,选股系统应当研究股价催化剂与拐点预判
7. 当执行交易分析时,选股系统应当研究市场体量与增长路径
8. 当执行内部人分析时,选股系统应当研究内部人与机构动向
9. 当生成最终结论时,选股系统应当指出关键矛盾与预期差以及拐点的临近
### 需求 6
**用户故事:** 作为投资者,我希望每个分析模块都能独立查看,以便专注于特定的分析维度
#### 验收标准
1. 当显示报告时,选股系统应当为每个分析模块提供独立的显示页面
2. 当用户在模块间切换时,选股系统应当保持导航的流畅性
3. 当所有模块完成时选股系统应当将完整报告保存到PostgreSQL数据库
### 需求 7
**用户故事:** 作为投资者,我希望在报告生成过程中能够看到实时进度,以便了解当前状态和预估完成时间
#### 验收标准
1. 当开始生成报告时,选股系统应当显示进度指示器展示所有分析步骤
2. 当执行每个分析步骤时,选股系统应当高亮显示当前正在进行的步骤
3. 当每个步骤完成时,选股系统应当更新步骤状态为已完成
4. 当执行分析步骤时,选股系统应当记录每个步骤的开始时间和完成时间
5. 当显示进度时,选股系统应当展示每个步骤的耗时统计
6. 当步骤执行失败时,选股系统应当显示错误状态和错误信息
### 需求 8
**用户故事:** 作为系统管理员,我希望能够配置系统参数,以便系统能够正常连接外部服务
#### 验收标准
1. 选股系统应当提供配置页面用于设置数据库连接参数
2. 选股系统应当提供配置页面用于设置Gemini_API密钥
3. 选股系统应当提供配置页面用于设置各市场的数据源配置
4. 当配置更新时,选股系统应当验证配置的有效性
5. 当配置保存时,选股系统应当将配置持久化存储

View File

@ -0,0 +1,167 @@
# 实施计划
- [x] 1. 后端项目初始化和基础架构
- 创建Python FastAPI项目结构
- 设置虚拟环境和依赖管理requirements.txt或pyproject.toml
- 配置FastAPI应用入口main.py
- 创建核心目录结构models, schemas, services, routers, core
- 设置基础配置管理core/config.py
- _需求: 8.1, 8.2_
- [x] 2. 数据库设置和模型定义
- 配置PostgreSQL数据库连接core/database.py
- 创建SQLAlchemy数据模型reports, analysis_modules, progress_tracking, system_config
- 设置Alembic数据库迁移工具
- 创建初始数据库迁移脚本
- 实现数据库会话管理和依赖注入
- _需求: 6.3, 8.1_
- [x] 3. Pydantic模式和基础服务
- 创建Pydantic数据验证模式schemas/
- 实现配置管理服务services/config_manager.py
- 创建数据获取服务基础架构services/data_fetcher.py
- 实现基础错误处理和异常类
- _需求: 8.2, 8.3, 8.4, 8.5_
- [x] 4. 外部API集成服务
- 实现Tushare API集成中国股票数据获取
- 实现Gemini API集成AI分析服务
- 创建数据源配置和切换逻辑
- 添加API调用错误处理和重试机制
- _需求: 3.1, 3.2, 4.1, 4.2_
- [x] 5. 报告生成引擎核心
- 创建报告生成服务services/report_generator.py
- 实现分析模块执行框架
- 创建进度追踪服务services/progress_tracker.py
- 实现步骤计时和状态管理
- _需求: 5.1, 7.1, 7.2, 7.3, 7.4, 7.5_
- [x] 6. 后端API路由实现
- 实现报告相关API端点routers/reports.py
- 创建配置管理API端点routers/config.py
- 实现进度追踪API端点routers/progress.py
- 添加API文档和验证
- _需求: 2.1, 2.2, 2.3, 8.1, 8.2, 8.3_
- [x] 7. 前端项目初始化
- 创建Next.js项目并配置TypeScript
- 安装和配置shadcn/ui组件库
- 设置Tailwind CSS和基础样式
- 配置项目文件夹结构components, lib, hooks, app
- 创建基础布局和主题配置
- _需求: 1.1_
- [ ] 8. 前端核心组件开发
- 安装和配置shadcn/ui基础组件
- 实现StockSearchForm组件使用Form, Input, Select, Button
- 创建ReportProgress组件使用Progress, Badge, Card
- 实现AnalysisModule组件使用Card, Tabs, Separator
- 创建FinancialDataTable组件使用Table组件系列
- _需求: 1.1, 1.2, 7.1, 7.2_
- [ ] 9. 首页和股票搜索功能
- 实现首页布局和设计app/page.tsx
- 创建股票代码输入和市场选择功能
- 实现表单验证和提交逻辑
- 添加中文界面文本和错误提示
- 连接前端表单到后端API
- _需求: 1.1, 1.2, 1.3_
- [ ] 10. 报告页面和历史报告功能
- 实现报告页面路由app/report/[symbol]/page.tsx
- 创建历史报告检查和显示逻辑
- 实现"生成最新报告"按钮功能
- 添加报告加载状态和错误处理
- _需求: 2.1, 2.2, 2.3_
- [ ] 11. TradingView图表集成
- 集成TradingView高级图表组件
- 实现图表配置和参数设置
- 根据证券代码和市场配置图表
- 处理图表加载错误和异常情况
- _需求: 5.1, 5.2, 5.3, 5.4_
- [ ] 12. 财务数据分析模块
- 实现财务数据获取和处理逻辑
- 创建财务数据格式化和展示
- 实现FinancialDataTable的数据绑定
- 添加财务数据的错误处理和重试
- _需求: 3.1, 3.2, 3.3_
- [ ] 13. AI业务信息分析模块
- 实现Gemini API调用逻辑和提示词模板
- 创建业务信息分析内容生成
- 实现公司概览、主营业务、发展历程等内容
- 添加AI分析结果的格式化和展示
- _需求: 4.1, 4.2, 4.3_
- [ ] 14. 专业分析模块实现
- 实现景林模型基本面分析模块
- 创建看涨分析师模块(隐藏资产、护城河分析)
- 实现看跌分析师模块(价值底线、最坏情况分析)
- 创建市场分析师模块(市场情绪分歧点分析)
- 实现新闻分析师模块(股价催化剂分析)
- 创建交易分析模块(市场体量与增长路径)
- 实现内部人与机构动向分析模块
- 创建最终结论模块(关键矛盾与拐点分析)
- _需求: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9_
- [ ] 15. 报告生成流程整合
- 整合所有分析模块到报告生成引擎
- 实现模块间的数据传递和依赖关系
- 创建报告生成的错误处理和重试机制
- 实现报告完成后的数据库保存
- _需求: 5.1, 6.3_
- [ ] 16. 实时进度显示功能
- 实现前端进度追踪钩子useProgress
- 连接WebSocket或Server-Sent Events到进度显示
- 添加步骤高亮和状态更新
- 实现计时显示和预估完成时间
- 添加错误状态显示
- _需求: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6_
- [ ] 17. 配置管理页面
- 创建配置页面布局和表单app/config/page.tsx
- 实现数据库配置界面
- 添加Gemini API配置功能
- 创建数据源配置管理
- 实现配置验证和测试功能
- _需求: 8.1, 8.2, 8.3, 8.4, 8.5_
- [ ] 18. 报告展示和导航优化
- 实现分析模块的独立页面展示
- 创建模块间的流畅导航
- 添加报告概览和目录功能
- 优化移动端响应式显示
- _需求: 6.1, 6.2_
- [ ] 19. 错误处理和用户体验优化
- 实现全局错误处理和错误边界
- 添加Toast通知系统
- 创建加载状态和骨架屏
- 优化中文界面和用户反馈
- 添加操作确认和提示
- _需求: 7.6, 1.1_
- [ ]* 20. 测试实现
- [ ]* 20.1 后端单元测试
- 为数据获取服务编写单元测试
- 为AI分析服务编写单元测试
- 为报告生成引擎编写单元测试
- 为配置管理服务编写单元测试
- [ ]* 20.2 前端组件测试
- 为核心组件编写React Testing Library测试
- 为表单组件编写交互测试
- 为进度组件编写状态测试
- [ ]* 20.3 API集成测试
- 为报告生成API编写集成测试
- 为配置管理API编写集成测试
- 为进度追踪API编写集成测试
- [ ]* 20.4 端到端测试
- 编写完整报告生成流程的E2E测试
- 编写配置管理流程的E2E测试

228
backend/DATABASE_SETUP.md Normal file
View File

@ -0,0 +1,228 @@
# 数据库设置指南
## 概述
本项目使用PostgreSQL作为主数据库SQLAlchemy作为ORMAlembic作为数据库迁移工具。
## 数据库架构
### 表结构
1. **reports** - 报告主表
- 存储股票分析报告的基本信息
- 包含证券代码、市场、状态等字段
2. **analysis_modules** - 分析模块表
- 存储报告中各个分析模块的内容
- 与reports表一对多关系
3. **progress_tracking** - 进度追踪表
- 记录报告生成过程中各步骤的执行状态
- 与reports表一对多关系
4. **system_config** - 系统配置表
- 存储系统配置信息
- 使用JSONB格式存储配置值
## 环境配置
### 1. 安装PostgreSQL
```bash
# macOS (使用Homebrew)
brew install postgresql
brew services start postgresql
# Ubuntu/Debian
sudo apt-get install postgresql postgresql-contrib
# CentOS/RHEL
sudo yum install postgresql-server postgresql-contrib
```
### 2. 创建数据库
```sql
-- 连接到PostgreSQL
psql -U postgres
-- 创建数据库
CREATE DATABASE stock_analysis;
-- 创建用户(可选)
CREATE USER stock_user WITH PASSWORD 'your_password';
GRANT ALL PRIVILEGES ON DATABASE stock_analysis TO stock_user;
```
### 3. 配置环境变量
创建 `.env` 文件:
```bash
# 数据库配置
DATABASE_URL=postgresql+asyncpg://username:password@localhost:5432/stock_analysis
DATABASE_ECHO=false
# API配置
GEMINI_API_KEY=your_gemini_api_key
TUSHARE_TOKEN=your_tushare_token
```
## 数据库管理
### 使用管理脚本
```bash
# 检查数据库连接
python manage_db.py check
# 初始化数据库表
python manage_db.py init
# 查看数据库状态
python manage_db.py status
```
### 使用Alembic迁移
```bash
# 初始化Alembic已完成
alembic init alembic
# 创建迁移文件
alembic revision --autogenerate -m "描述信息"
# 应用迁移
alembic upgrade head
# 查看迁移历史
alembic history
# 回滚迁移
alembic downgrade -1
```
## 开发工具
### 1. 数据库连接检查
```bash
python check_db.py
```
### 2. 数据库初始化
```bash
python init_db.py
```
### 3. 综合管理工具
```bash
python manage_db.py [check|init|status]
```
## 模型使用示例
### 创建报告
```python
from app.models import Report, AnalysisModule
from app.core.database import get_db
async def create_report():
async for db in get_db():
# 创建报告
report = Report(
symbol="000001",
market="中国",
status="generating"
)
db.add(report)
await db.commit()
await db.refresh(report)
# 创建分析模块
module = AnalysisModule(
report_id=report.id,
module_type="financial_data",
module_order=1,
title="财务数据分析",
status="pending"
)
db.add(module)
await db.commit()
```
### 查询报告
```python
from sqlalchemy import select
from app.models import Report
async def get_report(symbol: str, market: str):
async for db in get_db():
stmt = select(Report).where(
Report.symbol == symbol,
Report.market == market
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
```
## 性能优化
### 索引
所有表都已配置适当的索引:
- reports: symbol+market, status, created_at
- analysis_modules: report_id, module_type, status, module_order
- progress_tracking: report_id, status, step_order
- system_config: config_key, updated_at
### 连接池
数据库连接使用异步连接池,配置参数:
- pool_size: 10
- max_overflow: 20
- pool_timeout: 30秒
- pool_recycle: 1小时
## 故障排除
### 常见问题
1. **连接失败**
- 检查PostgreSQL服务是否运行
- 验证数据库URL配置
- 确认防火墙设置
2. **迁移失败**
- 检查数据库权限
- 验证表结构冲突
- 查看Alembic日志
3. **性能问题**
- 检查索引使用情况
- 分析慢查询日志
- 优化查询语句
### 日志配置
`config.py` 中设置 `DATABASE_ECHO=True` 可以查看SQL执行日志。
## 备份与恢复
### 备份
```bash
pg_dump -U username -h localhost stock_analysis > backup.sql
```
### 恢复
```bash
psql -U username -h localhost stock_analysis < backup.sql
```

View File

@ -0,0 +1,287 @@
# 外部API集成文档
## 概述
本系统集成了多个外部API服务用于获取股票数据和进行AI分析
- **Tushare API**: 中国股票财务数据和市场数据
- **Yahoo Finance API**: 全球股票数据(美国、香港、日本等)
- **Google Gemini API**: AI分析和内容生成
## 架构设计
### 数据源管理器 (DataSourceManager)
- 统一管理所有数据源
- 支持数据源切换和故障转移
- 提供健康检查和状态监控
### 外部API服务 (ExternalAPIService)
- 提供统一的API接口
- 处理错误和重试机制
- 支持配置动态更新
## 支持的数据源
### 1. Tushare API
- **用途**: 中国股票数据A股、港股通等
- **数据类型**: 财务数据、市场数据、基本信息
- **配置要求**: 需要Tushare API Token
- **限制**: 有调用频率限制,需要付费账户获取完整数据
#### 配置示例
```python
{
"tushare": {
"enabled": True,
"api_key": "your_tushare_token",
"base_url": "http://api.tushare.pro",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
}
}
```
### 2. Yahoo Finance API
- **用途**: 全球股票数据
- **数据类型**: 财务数据、市场数据、基本信息
- **配置要求**: 无需API密钥
- **限制**: 有调用频率限制,可能被反爬虫机制阻止
#### 配置示例
```python
{
"yahoo": {
"enabled": True,
"base_url": "https://query1.finance.yahoo.com",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
}
}
```
### 3. Google Gemini API
- **用途**: AI分析和内容生成
- **功能**: 业务分析、基本面分析、投资建议等
- **配置要求**: 需要Google Cloud API密钥
- **限制**: 有调用频率和配额限制
#### 配置示例
```python
{
"gemini": {
"enabled": True,
"api_key": "your_gemini_api_key",
"model": "gemini-pro",
"timeout": 60,
"max_retries": 3,
"retry_delay": 2,
"temperature": 0.7,
"max_output_tokens": 8192
}
}
```
## 数据源切换逻辑
### 市场映射
系统根据股票市场自动选择合适的数据源:
```python
market_mapping = {
"china": "tushare", # 中国A股使用Tushare
"中国": "tushare",
"hongkong": "yahoo", # 香港股票使用Yahoo Finance
"香港": "yahoo",
"usa": "yahoo", # 美国股票使用Yahoo Finance
"美国": "yahoo",
"japan": "yahoo", # 日本股票使用Yahoo Finance
"日本": "yahoo"
}
```
### 故障转移
当主要数据源不可用时,系统会自动切换到备用数据源:
```python
fallback_sources = {
"tushare": ["yahoo"], # Tushare失败时使用Yahoo
"yahoo": ["tushare"] # Yahoo失败时使用Tushare
}
```
## 错误处理和重试机制
### 错误类型
- `DataSourceError`: 数据源相关错误
- `AIAnalysisError`: AI分析相关错误
- `AuthenticationError`: 认证失败
- `RateLimitError`: 调用频率超限
- `APIError`: 通用API错误
### 重试策略
- **指数退避**: 重试间隔逐渐增加
- **最大重试次数**: 默认3次
- **超时处理**: 每个请求都有超时限制
- **错误分类**: 不同错误类型采用不同的重试策略
## API使用示例
### 1. 获取财务数据
```python
from app.services.external_api_service import get_external_api_service
service = get_external_api_service()
# 获取平安银行财务数据
financial_data = await service.get_financial_data("000001", "中国")
print(f"数据源: {financial_data.data_source}")
print(f"总资产: {financial_data.balance_sheet.get('total_assets')}")
```
### 2. 验证股票代码
```python
# 验证股票代码是否有效
validation = await service.validate_stock_symbol("AAPL", "美国")
if validation.is_valid:
print(f"公司名称: {validation.company_name}")
```
### 3. AI分析
```python
# 进行业务信息分析
analysis = await service.analyze_business_info(
"000001", "中国", financial_data.dict()
)
print(f"分析内容: {analysis.content['company_overview']}")
```
### 4. 检查服务状态
```python
# 检查所有外部服务状态
status = await service.check_all_services_status()
print(f"整体状态: {status.overall_status}")
for source in status.sources:
print(f"{source.name}: {'可用' if source.is_available else '不可用'}")
```
## 配置管理
### 环境变量
在`.env`文件中配置API密钥
```bash
# Tushare API Token
TUSHARE_TOKEN=your_tushare_token_here
# Gemini API Key
GEMINI_API_KEY=your_gemini_api_key_here
```
### 动态配置更新
```python
# 更新配置
new_config = {
"data_sources": {
"tushare": {
"enabled": True,
"api_key": "new_token"
}
}
}
service.update_configuration(new_config)
```
## 测试和调试
### 运行测试脚本
```bash
cd backend
python test_external_apis.py
```
### 测试单个数据源
```python
# 测试Tushare连接
result = await service.test_data_source_connection("tushare", {
"api_key": "your_token",
"base_url": "http://api.tushare.pro"
})
print(f"连接成功: {result['success']}")
```
### 测试AI服务
```python
# 测试Gemini连接
result = await service.test_ai_service_connection("gemini", {
"api_key": "your_api_key"
})
print(f"连接成功: {result['success']}")
```
## 性能优化
### 缓存策略
- 财务数据缓存1小时
- 市场数据缓存5分钟
- AI分析结果缓存24小时
### 并发控制
- 限制同时进行的API调用数量
- 使用连接池管理HTTP连接
- 实现请求队列避免频率限制
### 监控和日志
- 记录所有API调用和响应时间
- 监控错误率和成功率
- 设置告警机制
## 故障排除
### 常见问题
1. **Tushare API调用失败**
- 检查API Token是否正确
- 确认账户是否有足够的积分
- 检查网络连接
2. **Gemini API调用失败**
- 检查API Key是否有效
- 确认配额是否充足
- 检查请求格式是否正确
3. **Yahoo Finance被限制**
- 降低请求频率
- 使用代理服务器
- 切换到其他数据源
### 调试技巧
- 启用详细日志记录
- 使用测试脚本验证配置
- 检查网络连接和防火墙设置
- 监控API调用的响应时间和错误率
## 扩展性
### 添加新的数据源
1. 继承`DataFetcher`基类
2. 实现必要的方法
3. 在`DataFetcherFactory`中注册
4. 更新配置文件
### 添加新的AI服务
1. 创建新的分析器类
2. 实现统一的接口
3. 在`AIAnalyzerFactory`中注册
4. 更新配置管理
## 安全考虑
- API密钥加密存储
- 使用HTTPS进行所有外部调用
- 实现访问控制和权限管理
- 定期轮换API密钥
- 监控异常访问模式

100
backend/README.md Normal file
View File

@ -0,0 +1,100 @@
# 基本面选股系统 - 后端服务
基于FastAPI的股票基本面分析后端服务提供报告生成、配置管理和进度追踪功能。
## 项目结构
```
backend/
├── main.py # FastAPI应用入口
├── requirements.txt # Python依赖
├── .env.example # 环境变量示例
├── README.md # 项目说明
└── app/
├── __init__.py
├── core/ # 核心模块
│ ├── config.py # 配置管理
│ ├── database.py # 数据库连接
│ └── dependencies.py # 依赖注入
├── models/ # 数据模型
│ ├── report.py
│ ├── analysis_module.py
│ ├── progress_tracking.py
│ └── system_config.py
├── schemas/ # Pydantic模式
│ ├── report.py
│ ├── config.py
│ └── progress.py
├── services/ # 业务服务
│ ├── config_manager.py
│ ├── data_fetcher.py
│ ├── report_generator.py
│ └── progress_tracker.py
└── routers/ # API路由
├── reports.py
├── config.py
└── progress.py
```
## 快速开始
1. 安装依赖:
```bash
pip install -r requirements.txt
```
2. 配置环境变量:
```bash
cp .env.example .env
# 编辑 .env 文件设置数据库连接和API密钥
```
3. 启动服务:
```bash
uvicorn main:app --reload --host 0.0.0.0 --port 8000
```
4. 访问API文档
- Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc
## API端点
### 报告相关
- `GET /api/reports/{symbol}?market={market}` - 获取或创建报告
- `POST /api/reports/{symbol}/regenerate?market={market}` - 重新生成报告
- `GET /api/reports/{report_id}/details` - 获取报告详情
### 配置管理
- `GET /api/config` - 获取系统配置
- `PUT /api/config` - 更新系统配置
- `POST /api/config/test` - 测试配置连接
### 进度追踪
- `GET /api/progress/{report_id}` - 获取报告生成进度
## 数据库
系统使用PostgreSQL数据库包含以下主要表
- `reports` - 报告基本信息
- `analysis_modules` - 分析模块内容
- `progress_tracking` - 进度追踪记录
- `system_config` - 系统配置
## 开发
### 代码格式化
```bash
black app/
isort app/
```
### 类型检查
```bash
mypy app/
```
### 运行测试
```bash
pytest
```

118
backend/alembic.ini Normal file
View File

@ -0,0 +1,118 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = driver://user:pass@localhost/dbname
# Database URL will be set programmatically in env.py
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@ -0,0 +1 @@
Generic single-database configuration.

102
backend/alembic/env.py Normal file
View File

@ -0,0 +1,102 @@
import asyncio
import os
import sys
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from alembic import context
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
# 导入我们的模型和配置
from app.core.database import Base
from app.core.config import settings
from app.models import * # 导入所有模型以确保它们被注册
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# 设置数据库URL
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL.replace("+asyncpg", ""))
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""运行迁移的核心逻辑"""
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""异步运行迁移"""
connectable = create_async_engine(
settings.DATABASE_URL,
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,90 @@
"""Initial migration: create all tables
Revision ID: 001
Revises:
Create Date: 2024-01-01 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '001'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Create reports table
op.create_table('reports',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('symbol', sa.String(length=20), nullable=False, comment='证券代码'),
sa.Column('market', sa.String(length=20), nullable=False, comment='交易市场'),
sa.Column('status', sa.String(length=20), nullable=False, comment='报告状态'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
# Create analysis_modules table
op.create_table('analysis_modules',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('report_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('module_type', sa.String(length=50), nullable=False, comment='模块类型'),
sa.Column('module_order', sa.Integer(), nullable=False, comment='模块顺序'),
sa.Column('title', sa.String(length=200), nullable=False, comment='模块标题'),
sa.Column('content', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='模块内容'),
sa.Column('status', sa.String(length=20), nullable=False, comment='模块状态'),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True, comment='完成时间'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# Create progress_tracking table
op.create_table('progress_tracking',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('report_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('step_name', sa.String(length=100), nullable=False, comment='步骤名称'),
sa.Column('step_order', sa.Integer(), nullable=False, comment='步骤顺序'),
sa.Column('status', sa.String(length=20), nullable=False, comment='步骤状态'),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True, comment='开始时间'),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True, comment='完成时间'),
sa.Column('duration_ms', sa.Integer(), nullable=True, comment='耗时(毫秒)'),
sa.Column('error_message', sa.Text(), nullable=True, comment='错误信息'),
sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# Create system_config table
op.create_table('system_config',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('config_key', sa.String(length=100), nullable=False, comment='配置键'),
sa.Column('config_value', postgresql.JSONB(astext_type=sa.Text()), nullable=False, comment='配置值'),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('config_key')
)
# Set default values for status columns
op.execute("ALTER TABLE reports ALTER COLUMN status SET DEFAULT 'generating'")
op.execute("ALTER TABLE analysis_modules ALTER COLUMN status SET DEFAULT 'pending'")
op.execute("ALTER TABLE progress_tracking ALTER COLUMN status SET DEFAULT 'pending'")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('system_config')
op.drop_table('progress_tracking')
op.drop_table('analysis_modules')
op.drop_table('reports')
# ### end Alembic commands ###

7
backend/app/__init__.py Normal file
View File

@ -0,0 +1,7 @@
"""
基本面选股系统后端应用
"""
__version__ = "1.0.0"
__author__ = "基本面选股系统开发团队"
__description__ = "提供股票基本面分析和报告生成的后端服务"

View File

@ -0,0 +1,48 @@
"""
核心模块
包含配置数据库连接依赖注入和异常处理
"""
from .config import settings, db_config, api_config, data_source_config
from .database import engine, AsyncSessionLocal, Base, get_db, init_db, close_db
from .dependencies import get_database_session, verify_gemini_api_key, verify_tushare_token
from .exceptions import (
StockAnalysisError,
DataSourceError,
AIAnalysisError,
ConfigurationError,
DatabaseError,
ValidationError,
APIError,
ReportGenerationError,
SymbolNotFoundError,
RateLimitError,
AuthenticationError
)
__all__ = [
"settings",
"db_config",
"api_config",
"data_source_config",
"engine",
"AsyncSessionLocal",
"Base",
"get_db",
"init_db",
"close_db",
"get_database_session",
"verify_gemini_api_key",
"verify_tushare_token",
"StockAnalysisError",
"DataSourceError",
"AIAnalysisError",
"ConfigurationError",
"DatabaseError",
"ValidationError",
"APIError",
"ReportGenerationError",
"SymbolNotFoundError",
"RateLimitError",
"AuthenticationError"
]

183
backend/app/core/config.py Normal file
View File

@ -0,0 +1,183 @@
"""
应用配置管理
处理环境变量和系统配置
"""
from typing import List, Optional
from pydantic import validator
from pydantic_settings import BaseSettings
import os
class Settings(BaseSettings):
"""应用设置类"""
# 应用基础配置
APP_NAME: str = "基本面选股系统"
APP_VERSION: str = "1.0.0"
DEBUG: bool = False
# 数据库配置
DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost:5432/stock_analysis"
DATABASE_ECHO: bool = False
# API配置
API_V1_STR: str = "/api"
ALLOWED_ORIGINS: List[str] = ["http://localhost:3000", "http://127.0.0.1:3000"]
# 外部服务配置
GEMINI_API_KEY: Optional[str] = None
TUSHARE_TOKEN: Optional[str] = None
# 数据源配置
CHINA_DATA_SOURCE: str = "tushare"
HK_DATA_SOURCE: str = "yahoo"
US_DATA_SOURCE: str = "yahoo"
JP_DATA_SOURCE: str = "yahoo"
# 报告生成配置
MAX_CONCURRENT_REPORTS: int = 5
REPORT_TIMEOUT_MINUTES: int = 30
# 缓存配置
CACHE_TTL_SECONDS: int = 3600 # 1小时
@validator("ALLOWED_ORIGINS", pre=True)
def assemble_cors_origins(cls, v):
"""处理CORS origins配置"""
if isinstance(v, str):
return [i.strip() for i in v.split(",")]
return v
@validator("DATABASE_URL", pre=True)
def assemble_db_connection(cls, v):
"""处理数据库连接字符串"""
if v and not v.startswith("postgresql"):
raise ValueError("数据库URL必须使用PostgreSQL")
return v
class Config:
env_file = ".env"
case_sensitive = True
# 创建全局设置实例
settings = Settings()
class DatabaseConfig:
"""数据库配置类"""
def __init__(self):
self.url = settings.DATABASE_URL
self.echo = settings.DATABASE_ECHO
self.pool_size = 10
self.max_overflow = 20
self.pool_timeout = 30
self.pool_recycle = 3600
class ExternalAPIConfig:
"""外部API配置类"""
def __init__(self):
self.gemini_api_key = settings.GEMINI_API_KEY
self.tushare_token = settings.TUSHARE_TOKEN
# 数据源配置
self.data_sources_config = {
"tushare": {
"enabled": bool(self.tushare_token),
"api_key": self.tushare_token,
"token": self.tushare_token,
"base_url": "http://api.tushare.pro",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1,
"name": "tushare"
},
"yahoo": {
"enabled": True,
"base_url": "https://query1.finance.yahoo.com",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1,
"name": "yahoo"
}
}
# AI服务配置
self.ai_services_config = {
"gemini": {
"enabled": bool(self.gemini_api_key),
"api_key": self.gemini_api_key,
"model": "gemini-pro",
"base_url": "https://generativelanguage.googleapis.com/v1beta",
"timeout": 60,
"max_retries": 3,
"retry_delay": 2,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 40,
"max_output_tokens": 8192
}
}
def validate_gemini_config(self) -> bool:
"""验证Gemini API配置"""
return bool(self.gemini_api_key)
def validate_tushare_config(self) -> bool:
"""验证Tushare API配置"""
return bool(self.tushare_token)
def get_data_source_manager_config(self) -> dict:
"""获取数据源管理器配置"""
return {
"data_sources": self.data_sources_config,
"ai_services": self.ai_services_config,
"market_mapping": {
"china": "tushare",
"中国": "tushare",
"hongkong": "yahoo",
"香港": "yahoo",
"usa": "yahoo",
"美国": "yahoo",
"japan": "yahoo",
"日本": "yahoo"
},
"fallback_sources": {
"tushare": ["yahoo"],
"yahoo": ["tushare"]
}
}
class DataSourceConfig:
"""数据源配置类"""
def __init__(self):
self.sources = {
"china": settings.CHINA_DATA_SOURCE,
"hongkong": settings.HK_DATA_SOURCE,
"usa": settings.US_DATA_SOURCE,
"japan": settings.JP_DATA_SOURCE
}
def get_source_for_market(self, market: str) -> str:
"""根据市场获取数据源"""
market_mapping = {
"中国": "china",
"香港": "hongkong",
"美国": "usa",
"日本": "japan"
}
market_key = market_mapping.get(market, "china")
return self.sources.get(market_key, "tushare")
# 创建配置实例
db_config = DatabaseConfig()
api_config = ExternalAPIConfig()
data_source_config = DataSourceConfig()

View File

@ -0,0 +1,69 @@
"""
数据库连接和会话管理
"""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.pool import NullPool
from typing import AsyncGenerator
from .config import settings
# 创建异步数据库引擎
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.DATABASE_ECHO,
poolclass=NullPool, # 对于异步使用NullPool
future=True
)
# 创建异步会话工厂
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# 创建基础模型类
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
数据库会话依赖注入
用于FastAPI路由中获取数据库会话
"""
async with AsyncSessionLocal() as session:
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db():
"""初始化数据库表"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def close_db():
"""关闭数据库连接"""
await engine.dispose()
async def check_db_connection() -> bool:
"""检查数据库连接是否正常"""
try:
async with AsyncSessionLocal() as session:
await session.execute("SELECT 1")
return True
except Exception:
return False
async def get_db_session():
"""获取数据库会话的简化版本"""
return AsyncSessionLocal()

View File

@ -0,0 +1,51 @@
"""
FastAPI依赖注入
"""
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from typing import AsyncGenerator
from .database import get_db
from .config import settings, api_config
async def get_database_session() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话依赖"""
async for session in get_db():
yield session
def verify_gemini_api_key():
"""验证Gemini API密钥"""
if not api_config.validate_gemini_config():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Gemini API未配置或配置无效"
)
return api_config.gemini_api_key
def verify_tushare_token():
"""验证Tushare Token"""
if not api_config.validate_tushare_config():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Tushare API未配置或配置无效"
)
return api_config.tushare_token
class DatabaseDependency:
"""数据库依赖类"""
def __init__(self):
self.session = Depends(get_database_session)
class APIKeyDependency:
"""API密钥依赖类"""
def __init__(self):
self.gemini_key = Depends(verify_gemini_api_key)
self.tushare_token = Depends(verify_tushare_token)

View File

@ -0,0 +1,98 @@
"""
基础错误处理和异常类
定义系统中使用的所有自定义异常
"""
from typing import Optional, Dict, Any
class StockAnalysisError(Exception):
"""基础异常类"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class DataSourceError(StockAnalysisError):
"""数据源错误"""
def __init__(self, message: str, data_source: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.data_source = data_source
super().__init__(message, details)
class AIAnalysisError(StockAnalysisError):
"""AI分析错误"""
def __init__(self, message: str, model: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.model = model
super().__init__(message, details)
class ConfigurationError(StockAnalysisError):
"""配置错误"""
def __init__(self, message: str, config_key: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.config_key = config_key
super().__init__(message, details)
class DatabaseError(StockAnalysisError):
"""数据库错误"""
def __init__(self, message: str, operation: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.operation = operation
super().__init__(message, details)
class ValidationError(StockAnalysisError):
"""数据验证错误"""
def __init__(self, message: str, field: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.field = field
super().__init__(message, details)
class APIError(StockAnalysisError):
"""API调用错误"""
def __init__(self, message: str, status_code: Optional[int] = None, api_name: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.status_code = status_code
self.api_name = api_name
super().__init__(message, details)
class ReportGenerationError(StockAnalysisError):
"""报告生成错误"""
def __init__(self, message: str, module_type: Optional[str] = None, details: Optional[Dict[str, Any]] = None):
self.module_type = module_type
super().__init__(message, details)
class SymbolNotFoundError(DataSourceError):
"""证券代码未找到错误"""
def __init__(self, symbol: str, market: str, data_source: Optional[str] = None):
message = f"证券代码 {symbol} 在市场 {market} 中未找到"
details = {"symbol": symbol, "market": market}
super().__init__(message, data_source, details)
class RateLimitError(APIError):
"""API调用频率限制错误"""
def __init__(self, api_name: str, retry_after: Optional[int] = None):
message = f"API {api_name} 调用频率超限"
details = {"retry_after": retry_after} if retry_after else {}
super().__init__(message, 429, api_name, details)
class AuthenticationError(APIError):
"""API认证错误"""
def __init__(self, api_name: str, details: Optional[Dict[str, Any]] = None):
message = f"API {api_name} 认证失败"
super().__init__(message, 401, api_name, details)

View File

@ -0,0 +1,17 @@
"""
数据模型包
包含SQLAlchemy数据模型定义
"""
# 导入所有模型以确保它们被注册到Base.metadata
from .report import Report
from .analysis_module import AnalysisModule
from .progress_tracking import ProgressTracking
from .system_config import SystemConfig
__all__ = [
"Report",
"AnalysisModule",
"ProgressTracking",
"SystemConfig"
]

View File

@ -0,0 +1,41 @@
"""
分析模块数据模型
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
import uuid
from ..core.database import Base
class AnalysisModule(Base):
"""分析模块表模型"""
__tablename__ = "analysis_modules"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
report_id = Column(UUID(as_uuid=True), ForeignKey("reports.id", ondelete="CASCADE"), nullable=False)
module_type = Column(String(50), nullable=False, comment="模块类型")
module_order = Column(Integer, nullable=False, comment="模块顺序")
title = Column(String(200), nullable=False, comment="模块标题")
content = Column(JSONB, comment="模块内容")
status = Column(String(20), nullable=False, default="pending", comment="模块状态")
started_at = Column(DateTime(timezone=True), comment="开始时间")
completed_at = Column(DateTime(timezone=True), comment="完成时间")
error_message = Column(Text, comment="错误信息")
# 关系
report = relationship("Report", back_populates="analysis_modules")
# 索引
__table_args__ = (
Index('idx_analysis_module_report_id', 'report_id'),
Index('idx_analysis_module_type', 'module_type'),
Index('idx_analysis_module_status', 'status'),
Index('idx_analysis_module_order', 'module_order'),
)
def __repr__(self):
return f"<AnalysisModule(id={self.id}, type={self.module_type}, status={self.status})>"

View File

@ -0,0 +1,39 @@
"""
进度追踪数据模型
"""
from sqlalchemy import Column, String, Integer, DateTime, Text, ForeignKey, Index
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
import uuid
from ..core.database import Base
class ProgressTracking(Base):
"""进度追踪表模型"""
__tablename__ = "progress_tracking"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
report_id = Column(UUID(as_uuid=True), ForeignKey("reports.id", ondelete="CASCADE"), nullable=False)
step_name = Column(String(100), nullable=False, comment="步骤名称")
step_order = Column(Integer, nullable=False, comment="步骤顺序")
status = Column(String(20), nullable=False, default="pending", comment="步骤状态")
started_at = Column(DateTime(timezone=True), comment="开始时间")
completed_at = Column(DateTime(timezone=True), comment="完成时间")
duration_ms = Column(Integer, comment="耗时(毫秒)")
error_message = Column(Text, comment="错误信息")
# 关系
report = relationship("Report", back_populates="progress_tracking")
# 索引
__table_args__ = (
Index('idx_progress_tracking_report_id', 'report_id'),
Index('idx_progress_tracking_status', 'status'),
Index('idx_progress_tracking_order', 'step_order'),
)
def __repr__(self):
return f"<ProgressTracking(id={self.id}, step={self.step_name}, status={self.status})>"

View File

@ -0,0 +1,38 @@
"""
报告数据模型
"""
from sqlalchemy import Column, String, DateTime, Text, Index
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
import uuid
from ..core.database import Base
class Report(Base):
"""报告表模型"""
__tablename__ = "reports"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
symbol = Column(String(20), nullable=False, comment="证券代码")
market = Column(String(20), nullable=False, comment="交易市场")
status = Column(String(20), nullable=False, default="generating", comment="报告状态")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
# 关系
analysis_modules = relationship("AnalysisModule", back_populates="report", cascade="all, delete-orphan")
progress_tracking = relationship("ProgressTracking", back_populates="report", cascade="all, delete-orphan")
# 索引
__table_args__ = (
Index('idx_report_symbol_market', 'symbol', 'market'),
Index('idx_report_status', 'status'),
Index('idx_report_created_at', 'created_at'),
)
def __repr__(self):
return f"<Report(id={self.id}, symbol={self.symbol}, market={self.market}, status={self.status})>"

View File

@ -0,0 +1,30 @@
"""
系统配置数据模型
"""
from sqlalchemy import Column, String, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
import uuid
from ..core.database import Base
class SystemConfig(Base):
"""系统配置表模型"""
__tablename__ = "system_config"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
config_key = Column(String(100), unique=True, nullable=False, comment="配置键")
config_value = Column(JSONB, nullable=False, comment="配置值")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
# 索引
__table_args__ = (
Index('idx_system_config_key', 'config_key'),
Index('idx_system_config_updated_at', 'updated_at'),
)
def __repr__(self):
return f"<SystemConfig(key={self.config_key})>"

View File

@ -0,0 +1,8 @@
"""
API路由包
包含所有API端点定义
"""
from . import reports, config, progress
__all__ = ["reports", "config", "progress"]

View File

@ -0,0 +1,124 @@
"""
配置相关API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from ..core.dependencies import get_database_session
from ..schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestRequest, ConfigTestResponse
from ..services.config_manager import ConfigManager
from ..core.exceptions import ConfigurationError, DatabaseError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/", response_model=ConfigResponse)
async def get_config(
db: AsyncSession = Depends(get_database_session)
):
"""获取系统配置"""
try:
config_manager = ConfigManager(db)
config = await config_manager.get_config()
logger.info("获取系统配置成功")
return config
except DatabaseError as e:
logger.error(f"获取配置时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"获取配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取配置失败: {str(e)}"
)
@router.put("/", response_model=ConfigResponse)
async def update_config(
config_update: ConfigUpdateRequest,
db: AsyncSession = Depends(get_database_session)
):
"""更新系统配置"""
try:
# 验证至少有一个配置项需要更新
if not any([config_update.database, config_update.gemini_api, config_update.data_sources]):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="至少需要提供一个配置项进行更新"
)
config_manager = ConfigManager(db)
updated_config = await config_manager.update_config(config_update)
logger.info("更新系统配置成功")
return updated_config
except HTTPException:
raise
except ConfigurationError as e:
logger.error(f"配置错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"配置错误: {str(e)}"
)
except DatabaseError as e:
logger.error(f"更新配置时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"更新配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新配置失败: {str(e)}"
)
@router.post("/test", response_model=ConfigTestResponse)
async def test_config(
test_request: ConfigTestRequest,
db: AsyncSession = Depends(get_database_session)
):
"""测试配置连接"""
try:
# 验证配置类型
valid_types = ["database", "gemini", "data_source"]
if test_request.config_type not in valid_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的配置类型: {test_request.config_type},支持的类型: {valid_types}"
)
config_manager = ConfigManager(db)
test_result = await config_manager.test_config(
test_request.config_type,
test_request.config_data
)
logger.info(f"配置测试完成: {test_request.config_type}, 结果: {test_result.success}")
return test_result
except HTTPException:
raise
except ConfigurationError as e:
logger.error(f"配置测试错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"配置测试错误: {str(e)}"
)
except Exception as e:
logger.error(f"配置测试失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"配置测试失败: {str(e)}"
)

View File

@ -0,0 +1,82 @@
"""
进度追踪API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import UUID
import logging
from ..core.dependencies import get_database_session
from ..schemas.progress import ProgressResponse
from ..services.progress_tracker import ProgressTracker
from ..core.exceptions import DatabaseError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/{report_id}", response_model=ProgressResponse)
async def get_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""获取报告生成进度"""
try:
progress_tracker = ProgressTracker(db)
progress = await progress_tracker.get_progress(report_id)
logger.info(f"获取进度成功: {report_id}, 当前步骤: {progress.current_step}/{progress.total_steps}")
return progress
except ValueError as e:
logger.warning(f"报告不存在或无进度记录: {report_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except DatabaseError as e:
logger.error(f"获取进度时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"获取进度失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取进度失败: {str(e)}"
)
@router.post("/{report_id}/reset")
async def reset_report_progress(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""重置报告生成进度"""
try:
progress_tracker = ProgressTracker(db)
await progress_tracker.reset_progress(report_id)
logger.info(f"重置进度成功: {report_id}")
return {"message": "进度重置成功", "report_id": str(report_id)}
except ValueError as e:
logger.warning(f"报告不存在: {report_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except DatabaseError as e:
logger.error(f"重置进度时数据库错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误: {str(e)}"
)
except Exception as e:
logger.error(f"重置进度失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重置进度失败: {str(e)}"
)

View File

@ -0,0 +1,298 @@
"""
报告相关API路由
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from typing import Optional, List
from uuid import UUID
import logging
from ..core.dependencies import get_database_session
from ..models.report import Report
from ..schemas.report import ReportResponse, RegenerateRequest
from ..services.report_generator import ReportGenerator
from ..services.config_manager import ConfigManager
from ..core.exceptions import ReportGenerationError, DatabaseError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/{symbol}", response_model=ReportResponse)
async def get_or_create_report(
symbol: str,
background_tasks: BackgroundTasks,
market: str = Query(..., description="交易市场"),
db: AsyncSession = Depends(get_database_session)
):
"""获取或创建股票报告"""
try:
# 验证输入参数
symbol = symbol.upper().strip()
market = market.lower().strip()
if not symbol:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="证券代码不能为空"
)
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
# 查询现有报告(包含关联的分析模块)
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(
Report.symbol == symbol,
Report.market == market
)
)
existing_report = result.scalar_one_or_none()
if existing_report:
logger.info(f"找到现有报告: {symbol}-{market}, 状态: {existing_report.status}")
return ReportResponse.from_attributes(existing_report)
# 创建新报告
logger.info(f"开始生成新报告: {symbol}-{market}")
config_manager = ConfigManager(db)
report_generator = ReportGenerator(db, config_manager)
# 在后台任务中生成报告
background_tasks.add_task(
report_generator.generate_report_async,
symbol,
market
)
# 创建初始报告记录
new_report = Report(
symbol=symbol,
market=market,
status="generating"
)
db.add(new_report)
await db.commit()
await db.refresh(new_report)
logger.info(f"创建报告记录: {new_report.id}")
return ReportResponse.from_attributes(new_report)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取或创建报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取或创建报告失败: {str(e)}"
)
@router.post("/{symbol}/regenerate", response_model=ReportResponse)
async def regenerate_report(
symbol: str,
request: RegenerateRequest,
background_tasks: BackgroundTasks,
market: str = Query(..., description="交易市场"),
db: AsyncSession = Depends(get_database_session)
):
"""重新生成报告"""
try:
# 验证输入参数
symbol = symbol.upper().strip()
market = market.lower().strip()
if not symbol:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="证券代码不能为空"
)
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
# 查询现有报告
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(
Report.symbol == symbol,
Report.market == market
)
)
existing_report = result.scalar_one_or_none()
if existing_report and not request.force:
# 如果报告存在且不强制重新生成,返回现有报告
logger.info(f"返回现有报告: {symbol}-{market}")
return ReportResponse.from_attributes(existing_report)
# 删除现有报告(如果存在)
if existing_report:
logger.info(f"删除现有报告: {existing_report.id}")
await db.delete(existing_report)
await db.commit()
# 创建新报告记录
new_report = Report(
symbol=symbol,
market=market,
status="generating"
)
db.add(new_report)
await db.commit()
await db.refresh(new_report)
# 在后台任务中生成报告
config_manager = ConfigManager(db)
report_generator = ReportGenerator(db, config_manager)
background_tasks.add_task(
report_generator.generate_report_async,
symbol,
market
)
logger.info(f"开始重新生成报告: {new_report.id}")
return ReportResponse.from_attributes(new_report)
except HTTPException:
raise
except Exception as e:
logger.error(f"重新生成报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重新生成报告失败: {str(e)}"
)
@router.get("/{report_id}/details", response_model=ReportResponse)
async def get_report_details(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""获取报告详情"""
try:
result = await db.execute(
select(Report)
.options(selectinload(Report.analysis_modules))
.where(Report.id == report_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="报告不存在"
)
logger.info(f"获取报告详情: {report_id}")
return ReportResponse.from_attributes(report)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取报告详情失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取报告详情失败: {str(e)}"
)
@router.get("/", response_model=List[ReportResponse])
async def list_reports(
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
market: Optional[str] = Query(None, description="按市场筛选"),
status: Optional[str] = Query(None, description="按状态筛选"),
db: AsyncSession = Depends(get_database_session)
):
"""获取报告列表"""
try:
query = select(Report).options(selectinload(Report.analysis_modules))
# 添加筛选条件
if market:
market = market.lower().strip()
if market not in ["china", "hongkong", "usa", "japan"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的交易市场"
)
query = query.where(Report.market == market)
if status:
status_value = status.lower().strip()
if status_value not in ["generating", "completed", "failed"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不支持的状态值"
)
query = query.where(Report.status == status_value)
# 添加分页和排序
query = query.order_by(Report.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(query)
reports = result.scalars().all()
logger.info(f"获取报告列表: {len(reports)} 条记录")
return [ReportResponse.from_attributes(report) for report in reports]
except HTTPException:
raise
except Exception as e:
logger.error(f"获取报告列表失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取报告列表失败: {str(e)}"
)
@router.delete("/{report_id}")
async def delete_report(
report_id: UUID,
db: AsyncSession = Depends(get_database_session)
):
"""删除报告"""
try:
result = await db.execute(
select(Report).where(Report.id == report_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="报告不存在"
)
await db.delete(report)
await db.commit()
logger.info(f"删除报告成功: {report_id}")
return {"message": "报告删除成功", "report_id": str(report_id)}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除报告失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"删除报告失败: {str(e)}"
)

View File

@ -0,0 +1,48 @@
"""
Pydantic数据验证模式包
"""
from .report import ReportResponse, ReportCreate, ReportUpdate
from .config import (
ConfigResponse,
ConfigUpdateRequest,
ConfigTestRequest,
ConfigTestResponse,
DatabaseConfig,
GeminiConfig,
DataSourceConfig
)
from .progress import ProgressResponse, StepTiming
from .data import (
FinancialDataRequest,
MarketDataRequest,
FinancialDataResponse,
MarketDataResponse,
SymbolValidationRequest,
SymbolValidationResponse,
DataSourceStatus,
DataSourcesStatusResponse
)
__all__ = [
"ReportResponse",
"ReportCreate",
"ReportUpdate",
"ConfigResponse",
"ConfigUpdateRequest",
"ConfigTestRequest",
"ConfigTestResponse",
"DatabaseConfig",
"GeminiConfig",
"DataSourceConfig",
"ProgressResponse",
"StepTiming",
"FinancialDataRequest",
"MarketDataRequest",
"FinancialDataResponse",
"MarketDataResponse",
"SymbolValidationRequest",
"SymbolValidationResponse",
"DataSourceStatus",
"DataSourcesStatusResponse"
]

View File

@ -0,0 +1,62 @@
"""
配置相关的Pydantic模式
"""
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, List
class DatabaseConfig(BaseModel):
"""数据库配置模式"""
url: str = Field(..., description="数据库连接URL")
echo: bool = Field(False, description="是否输出SQL日志")
class GeminiConfig(BaseModel):
"""Gemini API配置模式"""
api_key: str = Field(..., description="Gemini API密钥")
model: str = Field("gemini-pro", description="使用的模型")
temperature: float = Field(0.7, description="生成温度")
max_tokens: int = Field(2048, description="最大token数")
class DataSourceConfig(BaseModel):
"""数据源配置模式"""
name: str = Field(..., description="数据源名称")
api_key: Optional[str] = Field(None, description="API密钥")
base_url: Optional[str] = Field(None, description="基础URL")
timeout: int = Field(30, description="超时时间(秒)")
class ConfigResponse(BaseModel):
"""配置响应模式"""
database: Optional[DatabaseConfig] = None
gemini_api: Optional[GeminiConfig] = None
data_sources: Dict[str, DataSourceConfig] = {}
class ConfigUpdateRequest(BaseModel):
"""配置更新请求模式"""
database: Optional[DatabaseConfig] = None
gemini_api: Optional[GeminiConfig] = None
data_sources: Optional[Dict[str, DataSourceConfig]] = None
class ConfigTestRequest(BaseModel):
"""配置测试请求模式"""
config_type: str = Field(..., description="配置类型")
config_data: Dict[str, Any] = Field(..., description="配置数据")
class ConfigTestResponse(BaseModel):
"""配置测试响应模式"""
success: bool = Field(..., description="测试是否成功")
message: str = Field(..., description="测试结果消息")
details: Optional[Dict[str, Any]] = Field(None, description="详细信息")
class ConfigValidationResponse(BaseModel):
"""配置验证响应模式"""
valid: bool = Field(..., description="配置是否有效")
errors: List[str] = Field([], description="验证错误列表")
warnings: List[str] = Field([], description="验证警告列表")

View File

@ -0,0 +1,78 @@
"""
数据相关的Pydantic模式
"""
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, List
from datetime import datetime
class FinancialDataRequest(BaseModel):
"""财务数据请求模式"""
symbol: str = Field(..., description="证券代码")
market: str = Field(..., description="交易市场")
data_type: str = Field("all", description="数据类型")
period: Optional[str] = Field("annual", description="数据周期")
class MarketDataRequest(BaseModel):
"""市场数据请求模式"""
symbol: str = Field(..., description="证券代码")
market: str = Field(..., description="交易市场")
start_date: Optional[datetime] = Field(None, description="开始日期")
end_date: Optional[datetime] = Field(None, description="结束日期")
class FinancialDataResponse(BaseModel):
"""财务数据响应模式"""
symbol: str
market: str
data_source: str
last_updated: datetime
balance_sheet: Optional[Dict[str, Any]] = None
income_statement: Optional[Dict[str, Any]] = None
cash_flow: Optional[Dict[str, Any]] = None
key_metrics: Optional[Dict[str, Any]] = None
class MarketDataResponse(BaseModel):
"""市场数据响应模式"""
symbol: str
market: str
data_source: str
last_updated: datetime
price_data: Optional[Dict[str, Any]] = None
volume_data: Optional[Dict[str, Any]] = None
technical_indicators: Optional[Dict[str, Any]] = None
class SymbolValidationRequest(BaseModel):
"""证券代码验证请求模式"""
symbol: str = Field(..., description="证券代码")
market: str = Field(..., description="交易市场")
class SymbolValidationResponse(BaseModel):
"""证券代码验证响应模式"""
symbol: str
market: str
is_valid: bool
company_name: Optional[str] = None
sector: Optional[str] = None
industry: Optional[str] = None
message: Optional[str] = None
class DataSourceStatus(BaseModel):
"""数据源状态模式"""
name: str
is_available: bool
last_check: datetime
response_time_ms: Optional[int] = None
error_message: Optional[str] = None
class DataSourcesStatusResponse(BaseModel):
"""数据源状态响应模式"""
sources: List[DataSourceStatus]
overall_status: str # "healthy", "degraded", "down"

View File

@ -0,0 +1,42 @@
"""
进度追踪相关的Pydantic模式
"""
from pydantic import BaseModel, Field
from typing import List, Optional
from datetime import datetime
from uuid import UUID
class StepTiming(BaseModel):
"""步骤计时模式"""
step_name: str = Field(..., description="步骤名称")
step_order: int = Field(..., description="步骤顺序")
status: str = Field(..., description="步骤状态")
started_at: Optional[datetime] = Field(None, description="开始时间")
completed_at: Optional[datetime] = Field(None, description="完成时间")
duration_ms: Optional[int] = Field(None, description="耗时(毫秒)")
error_message: Optional[str] = Field(None, description="错误信息")
class Config:
from_attributes = True
class ProgressResponse(BaseModel):
"""进度响应模式"""
report_id: UUID = Field(..., description="报告ID")
current_step: int = Field(..., description="当前步骤")
total_steps: int = Field(..., description="总步骤数")
current_step_name: str = Field(..., description="当前步骤名称")
status: str = Field(..., description="整体状态")
step_timings: List[StepTiming] = Field([], description="步骤计时列表")
estimated_remaining: Optional[int] = Field(None, description="预估剩余时间(秒)")
class Config:
from_attributes = True
class ProgressResetResponse(BaseModel):
"""进度重置响应模式"""
message: str = Field(..., description="操作结果消息")
report_id: str = Field(..., description="报告ID")

View File

@ -0,0 +1,91 @@
"""
报告相关的Pydantic模式
"""
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
from uuid import UUID
class AnalysisModuleSchema(BaseModel):
"""分析模块模式"""
id: UUID
module_type: str
module_order: int
title: str
content: Optional[Dict[str, Any]] = None
status: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
class Config:
from_attributes = True
class ReportBase(BaseModel):
"""报告基础模式"""
symbol: str = Field(..., description="证券代码")
market: str = Field(..., description="交易市场")
class ReportCreate(ReportBase):
"""创建报告请求模式"""
pass
class ReportUpdate(BaseModel):
"""更新报告请求模式"""
status: Optional[str] = None
class ReportResponse(ReportBase):
"""报告响应模式"""
id: UUID
status: str
created_at: datetime
updated_at: datetime
analysis_modules: List[AnalysisModuleSchema] = []
class Config:
from_attributes = True
class ReportListResponse(BaseModel):
"""报告列表响应模式"""
reports: List[ReportResponse]
total: int
skip: int
limit: int
class DeleteReportResponse(BaseModel):
"""删除报告响应模式"""
message: str
report_id: str
class RegenerateRequest(BaseModel):
"""重新生成报告请求模式"""
force: bool = Field(False, description="是否强制重新生成")
class AIAnalysisRequest(BaseModel):
"""AI分析请求模式"""
symbol: str = Field(..., description="证券代码")
market: str = Field(..., description="交易市场")
analysis_type: str = Field(..., description="分析类型")
context_data: Optional[Dict[str, Any]] = Field(None, description="上下文数据")
class AIAnalysisResponse(BaseModel):
"""AI分析响应模式"""
symbol: str
market: str
analysis_type: str
content: Dict[str, Any]
model_used: str
generated_at: datetime
model_config = {"protected_namespaces": ()}

View File

@ -0,0 +1,16 @@
"""
业务服务包
包含核心业务逻辑实现
"""
from .config_manager import ConfigManager
from .data_fetcher import DataFetcher
from .report_generator import ReportGenerator
from .progress_tracker import ProgressTracker
__all__ = [
"ConfigManager",
"DataFetcher",
"ReportGenerator",
"ProgressTracker"
]

View File

@ -0,0 +1,803 @@
"""
AI分析服务
处理Gemini API集成和AI分析功能
"""
from typing import Dict, Any, Optional, List
import httpx
import asyncio
import json
from datetime import datetime
from ..core.exceptions import (
AIAnalysisError,
APIError,
AuthenticationError,
RateLimitError
)
from ..schemas.report import AIAnalysisRequest, AIAnalysisResponse
class GeminiAnalyzer:
"""Gemini AI分析器"""
def __init__(self, api_key: str, config: Optional[Dict[str, Any]] = None):
if not api_key:
raise AuthenticationError("gemini", {"message": "Gemini API密钥未配置"})
self.api_key = api_key
self.config = config or {}
self.base_url = self.config.get("base_url", "https://generativelanguage.googleapis.com/v1beta")
self.model = self.config.get("model", "gemini-pro")
self.timeout = self.config.get("timeout", 60)
self.max_retries = self.config.get("max_retries", 3)
self.retry_delay = self.config.get("retry_delay", 2)
# 生成配置
self.generation_config = {
"temperature": self.config.get("temperature", 0.7),
"top_p": self.config.get("top_p", 0.8),
"top_k": self.config.get("top_k", 40),
"max_output_tokens": self.config.get("max_output_tokens", 8192),
}
async def analyze_business_info(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> AIAnalysisResponse:
"""分析公司业务信息"""
prompt = self._build_business_info_prompt(symbol, market, financial_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="business_info",
content=self._parse_business_info_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"业务信息分析失败: {str(e)}", self.model)
async def analyze_fundamental(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> AIAnalysisResponse:
"""基本面分析(景林模型)"""
prompt = self._build_fundamental_analysis_prompt(symbol, market, financial_data, business_info)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="fundamental_analysis",
content=self._parse_fundamental_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"基本面分析失败: {str(e)}", self.model)
async def analyze_bullish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""看涨分析(隐藏资产、护城河分析)"""
prompt = self._build_bullish_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="bullish_analysis",
content=self._parse_bullish_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"看涨分析失败: {str(e)}", self.model)
async def analyze_bearish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""看跌分析(价值底线、最坏情况分析)"""
prompt = self._build_bearish_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="bearish_analysis",
content=self._parse_bearish_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"看跌分析失败: {str(e)}", self.model)
async def analyze_market_sentiment(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""市场情绪分析"""
prompt = self._build_market_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="market_analysis",
content=self._parse_market_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"市场分析失败: {str(e)}", self.model)
async def analyze_news_catalysts(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""新闻催化剂分析"""
prompt = self._build_news_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="news_analysis",
content=self._parse_news_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"新闻分析失败: {str(e)}", self.model)
async def analyze_trading_dynamics(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""交易动态分析"""
prompt = self._build_trading_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="trading_analysis",
content=self._parse_trading_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"交易分析失败: {str(e)}", self.model)
async def analyze_insider_institutional(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""内部人与机构动向分析"""
prompt = self._build_insider_analysis_prompt(symbol, market, context_data)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="insider_analysis",
content=self._parse_insider_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"内部人分析失败: {str(e)}", self.model)
async def generate_final_conclusion(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> AIAnalysisResponse:
"""生成最终结论"""
prompt = self._build_conclusion_prompt(symbol, market, all_analyses)
try:
result = await self._retry_request(self._call_gemini_api, prompt)
return AIAnalysisResponse(
symbol=symbol,
market=market,
analysis_type="final_conclusion",
content=self._parse_conclusion_response(result),
model_used=self.model,
generated_at=datetime.now()
)
except Exception as e:
if isinstance(e, (AIAnalysisError, APIError)):
raise
raise AIAnalysisError(f"最终结论生成失败: {str(e)}", self.model)
async def _call_gemini_api(self, prompt: str) -> str:
"""调用Gemini API"""
url = f"{self.base_url}/models/{self.model}:generateContent"
headers = {
"Content-Type": "application/json",
}
data = {
"contents": [
{
"parts": [
{
"text": prompt
}
]
}
],
"generationConfig": self.generation_config
}
params = {
"key": self.api_key
}
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(url, json=data, headers=headers, params=params)
response.raise_for_status()
result = response.json()
if "error" in result:
error = result["error"]
error_code = error.get("code", 0)
error_message = error.get("message", "Unknown error")
if error_code == 401 or "API key" in error_message:
raise AuthenticationError("gemini", {"message": error_message})
elif error_code == 429 or "quota" in error_message.lower():
raise RateLimitError("gemini")
else:
raise APIError(f"Gemini API错误: {error_message}", error_code, "gemini")
candidates = result.get("candidates", [])
if not candidates:
raise AIAnalysisError("Gemini API返回空结果", self.model)
content = candidates[0].get("content", {})
parts = content.get("parts", [])
if not parts:
raise AIAnalysisError("Gemini API返回内容为空", self.model)
return parts[0].get("text", "")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError("gemini", {"status_code": e.response.status_code})
elif e.response.status_code == 429:
raise RateLimitError("gemini")
else:
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "gemini")
except httpx.RequestError as e:
raise AIAnalysisError(f"网络请求失败: {str(e)}", self.model)
async def _retry_request(self, func, *args, **kwargs):
"""重试机制"""
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except (httpx.TimeoutException, httpx.ConnectError, AIAnalysisError) as e:
last_exception = e
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
continue
except (AuthenticationError, RateLimitError) as e:
# 认证错误和频率限制错误不重试
raise e
except Exception as e:
# 其他异常不重试
raise e
# 所有重试都失败了
raise AIAnalysisError(
f"Gemini API请求失败已重试 {self.max_retries} 次: {str(last_exception)}",
self.model
)
def _build_business_info_prompt(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> str:
"""构建业务信息分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行全面的业务信息分析
基于以下财务数据
{json.dumps(financial_data, ensure_ascii=False, indent=2)}
请提供以下内容的详细分析
1. 公司概览
- 公司基本信息和历史背景
- 主要业务领域和市场地位
2. 主营业务分析
- 核心产品和服务
- 业务模式和盈利模式
- 主要收入来源构成
3. 发展历程
- 重要发展里程碑
- 业务转型和扩张历史
4. 核心团队
- 管理层背景和经验
- 关键人员变动情况
5. 供应链分析
- 主要供应商和客户
- 供应链风险和优势
6. 销售模式
- 销售渠道和策略
- 市场覆盖和客户群体
7. 未来展望
- 发展战略和规划
- 市场机遇和挑战
请用中文回答内容要详实客观基于可获得的公开信息进行分析
"""
def _build_fundamental_analysis_prompt(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> str:
"""构建基本面分析提示词(景林模型)"""
return f"""
请使用景林投资的基本面分析框架对股票代码 {symbol}{market}市场进行深度分析
财务数据
{json.dumps(financial_data, ensure_ascii=False, indent=2)}
业务信息
{json.dumps(business_info, ensure_ascii=False, indent=2)}
请按照以下景林模型问题集进行分析
1. 商业模式分析
- 这是一门什么样的生意
- 商业模式的核心竞争力是什么
- 盈利模式是否可持续
2. 行业地位分析
- 公司在行业中的地位如何
- 市场份额和竞争优势
- 行业发展趋势对公司的影响
3. 财务质量分析
- 收入增长的质量如何
- 盈利能力和现金流状况
- 资产负债结构是否健康
4. 管理层评估
- 管理层的能力和诚信度
- 公司治理结构是否完善
- 股东利益是否得到保护
5. 估值分析
- 当前估值水平是否合理
- 与同行业公司比较如何
- 未来增长预期是否支撑估值
请用中文提供详细专业的分析每个方面都要有具体的数据支撑和逻辑推理
"""
def _build_bullish_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建看涨分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行看涨分析重点关注隐藏资产和护城河竞争优势
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行看涨分析
1. 隐藏资产发现
- 账面价值被低估的资产
- 无形资产的真实价值
- 潜在的资产重估机会
2. 护城河分析
- 品牌价值和客户忠诚度
- 技术壁垒和专利保护
- 规模经济和网络效应
- 转换成本和客户粘性
3. 成长潜力
- 新业务和新市场机会
- 产品创新和技术升级
- 市场扩张的可能性
4. 催化剂识别
- 短期可能的积极因素
- 政策支持和行业利好
- 公司内部改革和优化
5. 最佳情况假设
- 如果一切顺利公司价值可能达到什么水平
- 关键假设和实现路径
请用中文提供乐观但理性的分析要有具体的逻辑支撑
"""
def _build_bearish_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建看跌分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行看跌分析重点关注价值底线和最坏情况
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行看跌分析
1. 价值底线分析
- 清算价值估算
- 资产的最低合理价值
- 下行风险的底线在哪里
2. 主要风险因素
- 行业周期性风险
- 竞争加剧的威胁
- 技术替代的可能性
- 监管政策变化风险
3. 财务脆弱性
- 债务压力和流动性风险
- 现金流恶化的可能性
- 盈利能力下降的风险
4. 管理层风险
- 决策失误的历史
- 治理结构的缺陷
- 利益冲突的可能性
5. 最坏情况假设
- 如果一切都出错公司价值可能跌到什么水平
- 关键风险因素和触发条件
请用中文提供谨慎但客观的分析要有具体的风险量化
"""
def _build_market_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建市场分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行市场情绪分析重点关注分歧点与变化驱动
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行市场分析
1. 市场情绪评估
- 当前市场对该股票的主流观点
- 机构投资者的持仓变化
- 散户投资者的情绪指标
2. 分歧点识别
- 市场存在哪些主要分歧
- 乐观派和悲观派的核心观点
- 分歧的根本原因是什么
3. 变化驱动因素
- 什么因素可能改变市场共识
- 关键数据点和时间节点
- 外部环境变化的影响
4. 资金流向分析
- 主力资金的进出情况
- 不同类型投资者的行为模式
- 流动性状况评估
5. 市场预期vs现实
- 市场预期是否过于乐观或悲观
- 预期差的投资机会在哪里
请用中文提供专业的市场分析要有数据支撑和逻辑推理
"""
def _build_news_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建新闻分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行新闻催化剂分析重点关注股价拐点预判
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行新闻分析
1. 近期重要新闻梳理
- 公司公告和重大事件
- 行业政策和监管变化
- 宏观经济相关新闻
2. 催化剂识别
- 正面催化剂业绩超预期政策利好等
- 负面催化剂风险事件竞争加剧等
- 中性但重要的信息
3. 拐点预判
- 基本面拐点的可能时间
- 市场情绪拐点的信号
- 技术面拐点的确认
4. 新闻影响评估
- 短期影响vs长期影响
- 市场反应是否充分
- 后续发展的可能路径
5. 关注要点
- 未来需要重点关注的事件
- 可能的风险点和机会点
- 时间窗口和操作建议
请用中文提供前瞻性的分析要有时间维度和影响程度的判断
"""
def _build_trading_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建交易分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行交易分析重点关注市场体量与增长路径
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行交易分析
1. 市场体量分析
- 总市值和流通市值
- 日均成交量和换手率
- 市场容量和流动性评估
2. 增长路径分析
- 历史增长轨迹和驱动因素
- 未来增长的可能路径
- 增长的可持续性评估
3. 交易特征分析
- 股价波动特征和规律
- 主要交易时段和模式
- 大宗交易和异常交易情况
4. 技术面分析
- 关键技术位和支撑阻力
- 趋势线和形态分析
- 技术指标的信号
5. 交易策略建议
- 适合的交易时机和方式
- 风险控制和仓位管理
- 进出场点位的选择
请用中文提供实用的交易分析要有具体的数据和操作建议
"""
def _build_insider_analysis_prompt(self, symbol: str, market: str, context_data: Dict[str, Any]) -> str:
"""构建内部人分析提示词"""
return f"""
请对股票代码 {symbol}{market}市场进行内部人与机构动向分析
基础数据
{json.dumps(context_data, ensure_ascii=False, indent=2)}
请从以下角度进行分析
1. 内部人交易分析
- 高管和大股东的买卖行为
- 内部人交易的时机和规模
- 内部人交易的信号意义
2. 机构持仓分析
- 主要机构投资者的持仓变化
- 新进和退出的机构情况
- 机构持仓集中度分析
3. 股东结构变化
- 股权结构的演变趋势
- 重要股东的进出情况
- 股权激励和员工持股情况
4. 资金流向追踪
- 大资金的进出时机
- 不同类型资金的偏好
- 资金成本和收益预期
5. 动向信号解读
- 内部人和机构行为的一致性
- 与股价走势的相关性
- 对未来走势的指示意义
请用中文提供专业的分析要有数据支撑和逻辑推理
"""
def _build_conclusion_prompt(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> str:
"""构建最终结论提示词"""
analyses_text = "\n\n".join([
f"{analysis.get('analysis_type', '未知分析')}:\n{json.dumps(analysis.get('content', {}), ensure_ascii=False, indent=2)}"
for analysis in all_analyses
])
return f"""
基于以下所有分析结果请对股票代码 {symbol}{market}市场给出最终投资结论
所有分析结果
{analyses_text}
请从以下角度进行综合分析
1. 关键矛盾识别
- 当前最核心的投资矛盾是什么
- 不同分析维度的结论是否一致
- 主要的不确定性因素有哪些
2. 预期差分析
- 市场预期与实际情况的差异
- 可能被忽视或误解的关键信息
- 预期差带来的投资机会
3. 拐点临近性判断
- 基本面拐点的时间和概率
- 市场情绪拐点的信号
- 催化剂的时效性分析
4. 风险收益评估
- 上行空间和下行风险的量化
- 风险调整后的收益预期
- 投资的风险收益比
5. 最终投资建议
- 明确的投资观点看多/看空/中性
- 建议的投资时间框架
- 关键的跟踪指标和退出条件
请用中文提供清晰明确的投资结论要有逻辑性和可操作性
"""
def _parse_business_info_response(self, response: str) -> Dict[str, Any]:
"""解析业务信息分析响应"""
return {
"company_overview": self._extract_section(response, "公司概览"),
"main_business": self._extract_section(response, "主营业务"),
"development_history": self._extract_section(response, "发展历程"),
"core_team": self._extract_section(response, "核心团队"),
"supply_chain": self._extract_section(response, "供应链"),
"sales_model": self._extract_section(response, "销售模式"),
"future_outlook": self._extract_section(response, "未来展望"),
"full_analysis": response
}
def _parse_fundamental_response(self, response: str) -> Dict[str, Any]:
"""解析基本面分析响应"""
return {
"business_model": self._extract_section(response, "商业模式"),
"industry_position": self._extract_section(response, "行业地位"),
"financial_quality": self._extract_section(response, "财务质量"),
"management_assessment": self._extract_section(response, "管理层"),
"valuation_analysis": self._extract_section(response, "估值分析"),
"full_analysis": response
}
def _parse_bullish_response(self, response: str) -> Dict[str, Any]:
"""解析看涨分析响应"""
return {
"hidden_assets": self._extract_section(response, "隐藏资产"),
"moat_analysis": self._extract_section(response, "护城河"),
"growth_potential": self._extract_section(response, "成长潜力"),
"catalysts": self._extract_section(response, "催化剂"),
"best_case": self._extract_section(response, "最佳情况"),
"full_analysis": response
}
def _parse_bearish_response(self, response: str) -> Dict[str, Any]:
"""解析看跌分析响应"""
return {
"value_floor": self._extract_section(response, "价值底线"),
"risk_factors": self._extract_section(response, "风险因素"),
"financial_vulnerability": self._extract_section(response, "财务脆弱性"),
"management_risks": self._extract_section(response, "管理层风险"),
"worst_case": self._extract_section(response, "最坏情况"),
"full_analysis": response
}
def _parse_market_response(self, response: str) -> Dict[str, Any]:
"""解析市场分析响应"""
return {
"market_sentiment": self._extract_section(response, "市场情绪"),
"disagreement_points": self._extract_section(response, "分歧点"),
"change_drivers": self._extract_section(response, "变化驱动"),
"capital_flow": self._extract_section(response, "资金流向"),
"expectation_vs_reality": self._extract_section(response, "预期vs现实"),
"full_analysis": response
}
def _parse_news_response(self, response: str) -> Dict[str, Any]:
"""解析新闻分析响应"""
return {
"recent_news": self._extract_section(response, "重要新闻"),
"catalysts": self._extract_section(response, "催化剂"),
"inflection_points": self._extract_section(response, "拐点预判"),
"news_impact": self._extract_section(response, "影响评估"),
"focus_points": self._extract_section(response, "关注要点"),
"full_analysis": response
}
def _parse_trading_response(self, response: str) -> Dict[str, Any]:
"""解析交易分析响应"""
return {
"market_size": self._extract_section(response, "市场体量"),
"growth_path": self._extract_section(response, "增长路径"),
"trading_characteristics": self._extract_section(response, "交易特征"),
"technical_analysis": self._extract_section(response, "技术面"),
"trading_strategy": self._extract_section(response, "交易策略"),
"full_analysis": response
}
def _parse_insider_response(self, response: str) -> Dict[str, Any]:
"""解析内部人分析响应"""
return {
"insider_trading": self._extract_section(response, "内部人交易"),
"institutional_holdings": self._extract_section(response, "机构持仓"),
"ownership_changes": self._extract_section(response, "股东结构"),
"capital_flow": self._extract_section(response, "资金流向"),
"signal_interpretation": self._extract_section(response, "信号解读"),
"full_analysis": response
}
def _parse_conclusion_response(self, response: str) -> Dict[str, Any]:
"""解析最终结论响应"""
return {
"key_contradictions": self._extract_section(response, "关键矛盾"),
"expectation_gap": self._extract_section(response, "预期差"),
"inflection_timing": self._extract_section(response, "拐点临近性"),
"risk_return": self._extract_section(response, "风险收益"),
"investment_recommendation": self._extract_section(response, "投资建议"),
"full_analysis": response
}
def _extract_section(self, text: str, section_name: str) -> str:
"""从文本中提取特定章节内容"""
lines = text.split('\n')
section_content = []
in_section = False
for line in lines:
if section_name in line and ('.' in line or '' in line or ':' in line):
in_section = True
section_content.append(line)
continue
if in_section:
if line.strip() and any(keyword in line for keyword in ['1.', '2.', '3.', '4.', '5.']) and section_name not in line:
# 遇到下一个主要章节,停止
break
section_content.append(line)
return '\n'.join(section_content).strip()
class AIAnalyzerFactory:
"""AI分析器工厂"""
@classmethod
def create_gemini_analyzer(cls, api_key: str, config: Optional[Dict[str, Any]] = None) -> GeminiAnalyzer:
"""创建Gemini分析器"""
return GeminiAnalyzer(api_key, config)
@classmethod
def create_analyzer(cls, analyzer_type: str, **kwargs) -> GeminiAnalyzer:
"""创建分析器可扩展支持其他AI服务"""
if analyzer_type.lower() == "gemini":
return cls.create_gemini_analyzer(kwargs.get("api_key"), kwargs.get("config"))
else:
raise AIAnalysisError(f"不支持的AI分析器类型: {analyzer_type}")

View File

@ -0,0 +1,260 @@
"""
配置管理服务
处理系统配置的读取更新和验证
"""
from typing import Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import httpx
import asyncio
from ..models.system_config import SystemConfig
from ..schemas.config import ConfigResponse, ConfigUpdateRequest, ConfigTestResponse
from ..core.exceptions import ConfigurationError, DatabaseError, APIError
class ConfigManager:
"""配置管理器"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def get_config(self) -> ConfigResponse:
"""获取系统配置"""
try:
# 查询所有配置
result = await self.db.execute(select(SystemConfig))
configs = result.scalars().all()
# 组织配置数据
config_dict = {config.config_key: config.config_value for config in configs}
return ConfigResponse(
database=config_dict.get("database"),
gemini_api=config_dict.get("gemini_api"),
data_sources=config_dict.get("data_sources", {})
)
except Exception as e:
raise DatabaseError(f"获取配置失败: {str(e)}", "get_config")
async def update_config(self, config_update: ConfigUpdateRequest) -> ConfigResponse:
"""更新系统配置"""
try:
# 更新数据库配置
if config_update.database:
await self._update_config_item("database", config_update.database.dict())
# 更新Gemini API配置
if config_update.gemini_api:
await self._update_config_item("gemini_api", config_update.gemini_api.dict())
# 更新数据源配置
if config_update.data_sources:
data_sources_dict = {k: v.dict() for k, v in config_update.data_sources.items()}
await self._update_config_item("data_sources", data_sources_dict)
await self.db.commit()
# 返回更新后的配置
return await self.get_config()
except Exception as e:
await self.db.rollback()
raise DatabaseError(f"更新配置失败: {str(e)}", "update_config")
async def test_config(self, config_type: str, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""测试配置连接"""
try:
if config_type == "database":
return await self._test_database_config(config_data)
elif config_type == "gemini":
return await self._test_gemini_config(config_data)
elif config_type == "data_source":
return await self._test_data_source_config(config_data)
else:
return ConfigTestResponse(
success=False,
message=f"不支持的配置类型: {config_type}"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"配置测试失败: {str(e)}"
)
async def _update_config_item(self, key: str, value: Dict[str, Any]):
"""更新单个配置项"""
# 查询现有配置
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.config_key == key)
)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.config_value = value
else:
# 创建新配置
config = SystemConfig(config_key=key, config_value=value)
self.db.add(config)
async def _test_database_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""测试数据库配置"""
try:
# 尝试创建数据库连接
from sqlalchemy.ext.asyncio import create_async_engine
db_url = config_data.get("url")
if not db_url:
return ConfigTestResponse(
success=False,
message="数据库URL未配置"
)
# 创建临时引擎测试连接
test_engine = create_async_engine(db_url, echo=False)
# 测试连接
async with test_engine.begin() as conn:
await conn.execute("SELECT 1")
await test_engine.dispose()
return ConfigTestResponse(
success=True,
message="数据库连接测试成功"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"数据库连接测试失败: {str(e)}"
)
async def _test_gemini_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""测试Gemini API配置"""
try:
api_key = config_data.get("api_key")
if not api_key:
return ConfigTestResponse(
success=False,
message="Gemini API密钥未配置"
)
# 测试API调用
async with httpx.AsyncClient(timeout=10.0) as client:
headers = {"Authorization": f"Bearer {api_key}"}
# 这里应该调用实际的Gemini API端点进行测试
# 暂时模拟成功
await asyncio.sleep(0.1) # 模拟网络延迟
return ConfigTestResponse(
success=True,
message="Gemini API连接测试成功"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"Gemini API连接测试失败: {str(e)}"
)
async def get_data_source_config(self, market: str) -> Dict[str, Any]:
"""获取指定市场的数据源配置"""
try:
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.config_key == "data_sources")
)
config = result.scalar_one_or_none()
if not config:
raise ConfigurationError("数据源配置未找到", "data_sources")
data_sources = config.config_value
# 根据市场选择数据源
market_lower = market.lower()
if market_lower == "china":
if "tushare" in data_sources:
return data_sources["tushare"]
else:
raise ConfigurationError("中国市场数据源(Tushare)未配置", "tushare")
else:
# 其他市场使用Yahoo Finance
if "yahoo" in data_sources:
return data_sources["yahoo"]
else:
raise ConfigurationError("国际市场数据源(Yahoo)未配置", "yahoo")
except Exception as e:
if isinstance(e, ConfigurationError):
raise
raise ConfigurationError(f"获取数据源配置失败: {str(e)}", "data_sources")
async def get_gemini_config(self) -> Dict[str, Any]:
"""获取Gemini API配置"""
try:
result = await self.db.execute(
select(SystemConfig).where(SystemConfig.config_key == "gemini_api")
)
config = result.scalar_one_or_none()
if not config:
raise ConfigurationError("Gemini API配置未找到", "gemini_api")
gemini_config = config.config_value
if not gemini_config.get("api_key"):
raise ConfigurationError("Gemini API密钥未配置", "gemini_api")
return gemini_config
except Exception as e:
if isinstance(e, ConfigurationError):
raise
raise ConfigurationError(f"获取Gemini配置失败: {str(e)}", "gemini_api")
async def _test_data_source_config(self, config_data: Dict[str, Any]) -> ConfigTestResponse:
"""测试数据源配置"""
try:
name = config_data.get("name")
api_key = config_data.get("api_key")
base_url = config_data.get("base_url")
timeout = config_data.get("timeout", 30)
if not name:
return ConfigTestResponse(
success=False,
message="数据源名称未配置"
)
# 根据数据源类型进行不同的测试
if name.lower() == "tushare":
if not api_key:
return ConfigTestResponse(
success=False,
message="Tushare API密钥未配置"
)
# 测试Tushare API
# 暂时模拟成功
await asyncio.sleep(0.1)
elif name.lower() == "yahoo":
# 测试Yahoo Finance API
if base_url:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(f"{base_url}/health", timeout=timeout)
if response.status_code != 200:
return ConfigTestResponse(
success=False,
message=f"数据源API返回错误状态码: {response.status_code}"
)
return ConfigTestResponse(
success=True,
message=f"数据源 {name} 连接测试成功"
)
except Exception as e:
return ConfigTestResponse(
success=False,
message=f"数据源连接测试失败: {str(e)}"
)

View File

@ -0,0 +1,673 @@
"""
数据获取服务基础架构
处理外部数据源的数据获取
"""
from typing import Dict, Any, Optional
from abc import ABC, abstractmethod
import httpx
import asyncio
from datetime import datetime
from ..schemas.data import (
FinancialDataResponse,
MarketDataResponse,
SymbolValidationResponse,
DataSourceStatus
)
from ..core.exceptions import (
DataSourceError,
APIError,
SymbolNotFoundError,
RateLimitError,
AuthenticationError
)
class DataFetcher(ABC):
"""数据获取服务基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.name = config.get("name", "unknown")
self.timeout = config.get("timeout", 30)
self.max_retries = config.get("max_retries", 3)
self.retry_delay = config.get("retry_delay", 1)
@abstractmethod
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
pass
@abstractmethod
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
pass
@abstractmethod
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
pass
async def check_status(self) -> DataSourceStatus:
"""检查数据源状态"""
start_time = datetime.now()
try:
# 尝试进行简单的健康检查
await self._health_check()
end_time = datetime.now()
response_time = int((end_time - start_time).total_seconds() * 1000)
return DataSourceStatus(
name=self.name,
is_available=True,
last_check=end_time,
response_time_ms=response_time
)
except Exception as e:
end_time = datetime.now()
return DataSourceStatus(
name=self.name,
is_available=False,
last_check=end_time,
error_message=str(e)
)
@abstractmethod
async def _health_check(self):
"""健康检查实现"""
pass
async def _retry_request(self, func, *args, **kwargs):
"""重试机制"""
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except (httpx.TimeoutException, httpx.ConnectError) as e:
last_exception = e
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
continue
except Exception as e:
# 对于其他类型的异常,不重试
raise e
# 所有重试都失败了
raise DataSourceError(
f"数据源 {self.name} 请求失败,已重试 {self.max_retries}",
self.name,
{"last_error": str(last_exception)}
)
class TushareDataFetcher(DataFetcher):
"""Tushare数据获取器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.token = config.get("api_key") or config.get("token")
self.base_url = config.get("base_url", "http://api.tushare.pro")
if not self.token:
raise AuthenticationError("tushare", {"message": "Tushare API token未配置"})
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
try:
# 转换证券代码格式
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Tushare API调用
# 这里暂时返回模拟数据
financial_data = await self._retry_request(self._fetch_tushare_financial, ts_code)
return FinancialDataResponse(
symbol=symbol,
market=market,
data_source="tushare",
last_updated=datetime.now(),
balance_sheet=financial_data.get("balance_sheet"),
income_statement=financial_data.get("income_statement"),
cash_flow=financial_data.get("cash_flow"),
key_metrics=financial_data.get("key_metrics")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取财务数据失败: {str(e)}", "tushare")
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
try:
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Tushare API调用
market_data = await self._retry_request(self._fetch_tushare_market, ts_code)
return MarketDataResponse(
symbol=symbol,
market=market,
data_source="tushare",
last_updated=datetime.now(),
price_data=market_data.get("price_data"),
volume_data=market_data.get("volume_data"),
technical_indicators=market_data.get("technical_indicators")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取市场数据失败: {str(e)}", "tushare")
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
try:
ts_code = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的证券代码验证
# 暂时模拟验证逻辑
is_valid = await self._retry_request(self._validate_tushare_symbol, ts_code)
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=is_valid,
company_name="示例公司" if is_valid else None,
message="证券代码有效" if is_valid else "证券代码无效"
)
except Exception as e:
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=False,
message=f"验证失败: {str(e)}"
)
async def _health_check(self):
"""健康检查"""
try:
async with httpx.AsyncClient(timeout=5) as client:
# 尝试调用一个简单的API来测试连通性
await self._call_tushare_api(client, "stock_basic", {"limit": 1})
except Exception as e:
raise DataSourceError(f"Tushare健康检查失败: {str(e)}", "tushare")
def _convert_symbol_format(self, symbol: str, market: str) -> str:
"""转换证券代码格式为Tushare格式"""
if market.lower() == "china":
# 中国股票代码格式转换
if symbol.startswith("6"):
return f"{symbol}.SH" # 上海证券交易所
elif symbol.startswith(("0", "3")):
return f"{symbol}.SZ" # 深圳证券交易所
return symbol
async def _fetch_tushare_financial(self, ts_code: str) -> Dict[str, Any]:
"""获取Tushare财务数据"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 获取资产负债表
balance_sheet_data = await self._call_tushare_api(
client, "balancesheet", {"ts_code": ts_code, "period": "20231231"}
)
# 获取利润表
income_data = await self._call_tushare_api(
client, "income", {"ts_code": ts_code, "period": "20231231"}
)
# 获取现金流量表
cashflow_data = await self._call_tushare_api(
client, "cashflow", {"ts_code": ts_code, "period": "20231231"}
)
# 获取基本财务指标
fina_indicator_data = await self._call_tushare_api(
client, "fina_indicator", {"ts_code": ts_code, "period": "20231231"}
)
return {
"balance_sheet": self._process_balance_sheet(balance_sheet_data),
"income_statement": self._process_income_statement(income_data),
"cash_flow": self._process_cash_flow(cashflow_data),
"key_metrics": self._process_key_metrics(fina_indicator_data)
}
async def _fetch_tushare_market(self, ts_code: str) -> Dict[str, Any]:
"""获取Tushare市场数据"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
# 获取日线数据
daily_data = await self._call_tushare_api(
client, "daily", {"ts_code": ts_code, "start_date": "20240101", "end_date": "20241231"}
)
# 获取基本信息
stock_basic_data = await self._call_tushare_api(
client, "stock_basic", {"ts_code": ts_code}
)
return {
"price_data": self._process_price_data(daily_data),
"volume_data": self._process_volume_data(daily_data),
"technical_indicators": self._calculate_technical_indicators(daily_data),
"stock_info": self._process_stock_basic(stock_basic_data)
}
async def _validate_tushare_symbol(self, ts_code: str) -> bool:
"""验证Tushare证券代码"""
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
result = await self._call_tushare_api(
client, "stock_basic", {"ts_code": ts_code}
)
return bool(result and len(result.get("items", [])) > 0)
except Exception:
return False
async def _call_tushare_api(self, client: httpx.AsyncClient, api_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""调用Tushare API"""
request_data = {
"api_name": api_name,
"token": self.token,
"params": params,
"fields": ""
}
try:
response = await client.post(self.base_url, json=request_data)
response.raise_for_status()
result = response.json()
if result.get("code") != 0:
error_msg = result.get("msg", "Unknown error")
if "权限" in error_msg or "token" in error_msg.lower():
raise AuthenticationError("tushare", {"message": error_msg})
elif "频率" in error_msg or "limit" in error_msg.lower():
raise RateLimitError("tushare")
else:
raise APIError(f"Tushare API错误: {error_msg}", result.get("code"), "tushare")
return result.get("data", {})
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError("tushare", {"status_code": e.response.status_code})
elif e.response.status_code == 429:
raise RateLimitError("tushare")
else:
raise APIError(f"HTTP错误: {e.response.status_code}", e.response.status_code, "tushare")
except httpx.RequestError as e:
raise DataSourceError(f"网络请求失败: {str(e)}", "tushare")
def _process_balance_sheet(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理资产负债表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
# 取最新一期数据
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
# 将列表数据转换为字典
balance_data = dict(zip(fields, latest))
else:
balance_data = latest
return {
"total_assets": balance_data.get("total_assets", 0),
"total_liab": balance_data.get("total_liab", 0),
"total_hldr_eqy_exc_min_int": balance_data.get("total_hldr_eqy_exc_min_int", 0),
"monetary_cap": balance_data.get("monetary_cap", 0),
"accounts_receiv": balance_data.get("accounts_receiv", 0),
"inventories": balance_data.get("inventories", 0)
}
def _process_income_statement(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理利润表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
income_data = dict(zip(fields, latest))
else:
income_data = latest
return {
"revenue": income_data.get("revenue", 0),
"operate_profit": income_data.get("operate_profit", 0),
"total_profit": income_data.get("total_profit", 0),
"n_income": income_data.get("n_income", 0),
"n_income_attr_p": income_data.get("n_income_attr_p", 0),
"basic_eps": income_data.get("basic_eps", 0)
}
def _process_cash_flow(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理现金流量表数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
cashflow_data = dict(zip(fields, latest))
else:
cashflow_data = latest
return {
"n_cashflow_act": cashflow_data.get("n_cashflow_act", 0),
"n_cashflow_inv_act": cashflow_data.get("n_cashflow_inv_act", 0),
"n_cashflow_fin_act": cashflow_data.get("n_cashflow_fin_act", 0),
"c_cash_equ_end_period": cashflow_data.get("c_cash_equ_end_period", 0)
}
def _process_key_metrics(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理关键财务指标数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
metrics_data = dict(zip(fields, latest))
else:
metrics_data = latest
return {
"pe": metrics_data.get("pe", 0),
"pb": metrics_data.get("pb", 0),
"ps": metrics_data.get("ps", 0),
"roe": metrics_data.get("roe", 0),
"roa": metrics_data.get("roa", 0),
"gross_margin": metrics_data.get("gross_margin", 0),
"debt_to_assets": metrics_data.get("debt_to_assets", 0)
}
def _process_price_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理价格数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
# 取最新一天的数据
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
price_data = dict(zip(fields, latest))
else:
price_data = latest
return {
"close": price_data.get("close", 0),
"open": price_data.get("open", 0),
"high": price_data.get("high", 0),
"low": price_data.get("low", 0),
"pre_close": price_data.get("pre_close", 0),
"change": price_data.get("change", 0),
"pct_chg": price_data.get("pct_chg", 0)
}
def _process_volume_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理成交量数据"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
volume_data = dict(zip(fields, latest))
else:
volume_data = latest
return {
"vol": volume_data.get("vol", 0),
"amount": volume_data.get("amount", 0)
}
def _calculate_technical_indicators(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""计算技术指标"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items or len(items) < 20:
return {}
# 简单的移动平均计算
closes = []
for item in items[:20]: # 取最近20天
if isinstance(item, list):
fields = data.get("fields", [])
close_idx = fields.index("close") if "close" in fields else -1
if close_idx >= 0:
closes.append(item[close_idx])
else:
closes.append(item.get("close", 0))
if len(closes) >= 5:
ma_5 = sum(closes[:5]) / 5
else:
ma_5 = 0
if len(closes) >= 20:
ma_20 = sum(closes) / 20
else:
ma_20 = 0
return {
"ma_5": ma_5,
"ma_20": ma_20,
"ma_60": 0 # 需要更多数据计算
}
def _process_stock_basic(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""处理股票基本信息"""
if not data or not data.get("items"):
return {}
items = data["items"]
if not items:
return {}
latest = items[0] if isinstance(items[0], list) else items
fields = data.get("fields", [])
if isinstance(latest, list) and fields:
basic_data = dict(zip(fields, latest))
else:
basic_data = latest
return {
"name": basic_data.get("name", ""),
"industry": basic_data.get("industry", ""),
"market": basic_data.get("market", ""),
"list_date": basic_data.get("list_date", "")
}
class YahooDataFetcher(DataFetcher):
"""Yahoo Finance数据获取器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.base_url = config.get("base_url", "https://query1.finance.yahoo.com")
async def fetch_financial_data(self, symbol: str, market: str) -> FinancialDataResponse:
"""获取财务数据"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Yahoo Finance API调用
financial_data = await self._retry_request(self._fetch_yahoo_financial, yahoo_symbol)
return FinancialDataResponse(
symbol=symbol,
market=market,
data_source="yahoo",
last_updated=datetime.now(),
balance_sheet=financial_data.get("balance_sheet"),
income_statement=financial_data.get("income_statement"),
cash_flow=financial_data.get("cash_flow"),
key_metrics=financial_data.get("key_metrics")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取财务数据失败: {str(e)}", "yahoo")
async def fetch_market_data(self, symbol: str, market: str) -> MarketDataResponse:
"""获取市场数据"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的Yahoo Finance API调用
market_data = await self._retry_request(self._fetch_yahoo_market, yahoo_symbol)
return MarketDataResponse(
symbol=symbol,
market=market,
data_source="yahoo",
last_updated=datetime.now(),
price_data=market_data.get("price_data"),
volume_data=market_data.get("volume_data"),
technical_indicators=market_data.get("technical_indicators")
)
except Exception as e:
if isinstance(e, (DataSourceError, APIError)):
raise
raise DataSourceError(f"获取市场数据失败: {str(e)}", "yahoo")
async def validate_symbol(self, symbol: str, market: str) -> SymbolValidationResponse:
"""验证证券代码"""
try:
yahoo_symbol = self._convert_symbol_format(symbol, market)
# TODO: 实现实际的证券代码验证
is_valid = await self._retry_request(self._validate_yahoo_symbol, yahoo_symbol)
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=is_valid,
company_name="Example Company" if is_valid else None,
message="Symbol is valid" if is_valid else "Symbol not found"
)
except Exception as e:
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=False,
message=f"Validation failed: {str(e)}"
)
async def _health_check(self):
"""健康检查"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(f"{self.base_url}/v1/finance/search?q=AAPL")
if response.status_code != 200:
raise APIError(f"Yahoo Finance API返回状态码: {response.status_code}", response.status_code, "yahoo")
def _convert_symbol_format(self, symbol: str, market: str) -> str:
"""转换证券代码格式为Yahoo Finance格式"""
if market.lower() == "hongkong":
return f"{symbol}.HK"
elif market.lower() == "japan":
return f"{symbol}.T"
elif market.lower() == "china":
# 中国股票在Yahoo Finance中的格式
if symbol.startswith("6"):
return f"{symbol}.SS" # 上海
elif symbol.startswith(("0", "3")):
return f"{symbol}.SZ" # 深圳
return symbol
async def _fetch_yahoo_financial(self, yahoo_symbol: str) -> Dict[str, Any]:
"""获取Yahoo Finance财务数据"""
# TODO: 实现实际的API调用
return {
"balance_sheet": {"totalAssets": 2000000},
"income_statement": {"totalRevenue": 800000},
"cash_flow": {"operatingCashflow": 300000},
"key_metrics": {"trailingPE": 18.5}
}
async def _fetch_yahoo_market(self, yahoo_symbol: str) -> Dict[str, Any]:
"""获取Yahoo Finance市场数据"""
# TODO: 实现实际的API调用
return {
"price_data": {"regularMarketPrice": 150.0, "dayHigh": 155.0, "dayLow": 145.0},
"volume_data": {"regularMarketVolume": 2000000},
"technical_indicators": {"fiftyDayAverage": 148.0, "twoHundredDayAverage": 152.0}
}
async def _validate_yahoo_symbol(self, yahoo_symbol: str) -> bool:
"""验证Yahoo Finance证券代码"""
# TODO: 实现实际的验证逻辑
return True
class DataFetcherFactory:
"""数据获取器工厂"""
_fetchers = {
"tushare": TushareDataFetcher,
"yahoo": YahooDataFetcher,
}
@classmethod
def create_fetcher(cls, data_source: str, config: Dict[str, Any]) -> DataFetcher:
"""创建数据获取器"""
data_source_lower = data_source.lower()
if data_source_lower not in cls._fetchers:
raise DataSourceError(
f"不支持的数据源: {data_source}",
data_source,
{"supported_sources": list(cls._fetchers.keys())}
)
fetcher_class = cls._fetchers[data_source_lower]
return fetcher_class(config)
@classmethod
def get_supported_sources(cls) -> list:
"""获取支持的数据源列表"""
return list(cls._fetchers.keys())
@classmethod
def register_fetcher(cls, name: str, fetcher_class: type):
"""注册新的数据获取器"""
if not issubclass(fetcher_class, DataFetcher):
raise ValueError("数据获取器必须继承自DataFetcher基类")
cls._fetchers[name.lower()] = fetcher_class

View File

@ -0,0 +1,357 @@
"""
数据源管理服务
处理数据源配置和切换逻辑
"""
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime
from .data_fetcher import DataFetcher, DataFetcherFactory
from .ai_analyzer import GeminiAnalyzer, AIAnalyzerFactory
from ..core.exceptions import (
DataSourceError,
ConfigurationError,
AIAnalysisError
)
from ..schemas.data import (
FinancialDataResponse,
MarketDataResponse,
SymbolValidationResponse,
DataSourceStatus,
DataSourcesStatusResponse
)
from ..schemas.report import AIAnalysisResponse
class DataSourceManager:
"""数据源管理器"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self._data_fetchers: Dict[str, DataFetcher] = {}
self._ai_analyzer: Optional[GeminiAnalyzer] = None
self._market_source_mapping = {
"china": "tushare",
"中国": "tushare",
"hongkong": "yahoo",
"香港": "yahoo",
"usa": "yahoo",
"美国": "yahoo",
"japan": "yahoo",
"日本": "yahoo"
}
# 初始化数据获取器
self._initialize_data_fetchers()
# 初始化AI分析器
self._initialize_ai_analyzer()
def _initialize_data_fetchers(self):
"""初始化数据获取器"""
data_sources_config = self.config.get("data_sources", {})
for source_name, source_config in data_sources_config.items():
try:
if source_config.get("enabled", True):
fetcher = DataFetcherFactory.create_fetcher(source_name, source_config)
self._data_fetchers[source_name] = fetcher
except Exception as e:
print(f"警告: 初始化数据源 {source_name} 失败: {str(e)}")
def _initialize_ai_analyzer(self):
"""初始化AI分析器"""
ai_config = self.config.get("ai_services", {})
gemini_config = ai_config.get("gemini", {})
if gemini_config.get("enabled", True) and gemini_config.get("api_key"):
try:
self._ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
gemini_config["api_key"],
gemini_config
)
except Exception as e:
print(f"警告: 初始化Gemini分析器失败: {str(e)}")
def get_data_source_for_market(self, market: str) -> str:
"""根据市场获取数据源"""
market_lower = market.lower()
# 首先检查配置中的映射
market_mapping = self.config.get("market_mapping", {})
if market_lower in market_mapping:
return market_mapping[market_lower]
# 使用默认映射
return self._market_source_mapping.get(market_lower, "tushare")
def get_data_fetcher(self, data_source: str) -> DataFetcher:
"""获取数据获取器"""
if data_source not in self._data_fetchers:
raise DataSourceError(f"数据源 {data_source} 未配置或不可用", data_source)
return self._data_fetchers[data_source]
def get_ai_analyzer(self) -> GeminiAnalyzer:
"""获取AI分析器"""
if not self._ai_analyzer:
raise AIAnalysisError("AI分析器未配置或不可用", "gemini")
return self._ai_analyzer
async def fetch_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
"""获取财务数据(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.fetch_financial_data(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.fetch_financial_data(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def fetch_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
"""获取市场数据(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.fetch_market_data(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.fetch_market_data(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def validate_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
"""验证证券代码(支持数据源切换)"""
data_source = preferred_source or self.get_data_source_for_market(market)
try:
fetcher = self.get_data_fetcher(data_source)
return await fetcher.validate_symbol(symbol, market)
except DataSourceError as e:
# 尝试备用数据源
fallback_sources = self._get_fallback_sources(data_source)
for fallback_source in fallback_sources:
try:
fallback_fetcher = self.get_data_fetcher(fallback_source)
return await fallback_fetcher.validate_symbol(symbol, market)
except Exception:
continue
# 所有数据源都失败了
raise e
async def analyze_with_ai(self, analysis_type: str, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""使用AI进行分析"""
analyzer = self.get_ai_analyzer()
if analysis_type == "business_info":
return await analyzer.analyze_business_info(symbol, market, context_data)
elif analysis_type == "fundamental_analysis":
business_info = context_data.get("business_info", {})
financial_data = context_data.get("financial_data", {})
return await analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
elif analysis_type == "bullish_analysis":
return await analyzer.analyze_bullish_case(symbol, market, context_data)
elif analysis_type == "bearish_analysis":
return await analyzer.analyze_bearish_case(symbol, market, context_data)
elif analysis_type == "market_analysis":
return await analyzer.analyze_market_sentiment(symbol, market, context_data)
elif analysis_type == "news_analysis":
return await analyzer.analyze_news_catalysts(symbol, market, context_data)
elif analysis_type == "trading_analysis":
return await analyzer.analyze_trading_dynamics(symbol, market, context_data)
elif analysis_type == "insider_analysis":
return await analyzer.analyze_insider_institutional(symbol, market, context_data)
elif analysis_type == "final_conclusion":
all_analyses = context_data.get("all_analyses", [])
return await analyzer.generate_final_conclusion(symbol, market, all_analyses)
else:
raise AIAnalysisError(f"不支持的分析类型: {analysis_type}", "gemini")
async def check_all_sources_status(self) -> DataSourcesStatusResponse:
"""检查所有数据源状态"""
status_tasks = []
# 检查数据获取器状态
for source_name, fetcher in self._data_fetchers.items():
status_tasks.append(fetcher.check_status())
# 检查AI分析器状态
if self._ai_analyzer:
status_tasks.append(self._check_ai_analyzer_status())
# 并发执行状态检查
statuses = await asyncio.gather(*status_tasks, return_exceptions=True)
source_statuses = []
healthy_count = 0
for i, status in enumerate(statuses):
if isinstance(status, Exception):
# 处理异常情况
if i < len(self._data_fetchers):
source_name = list(self._data_fetchers.keys())[i]
else:
source_name = "gemini"
source_statuses.append(DataSourceStatus(
name=source_name,
is_available=False,
last_check=datetime.now(),
error_message=str(status)
))
else:
source_statuses.append(status)
if status.is_available:
healthy_count += 1
# 确定整体状态
total_sources = len(source_statuses)
if healthy_count == total_sources:
overall_status = "healthy"
elif healthy_count > 0:
overall_status = "degraded"
else:
overall_status = "down"
return DataSourcesStatusResponse(
sources=source_statuses,
overall_status=overall_status
)
async def _check_ai_analyzer_status(self) -> DataSourceStatus:
"""检查AI分析器状态"""
start_time = datetime.now()
try:
# 简单的健康检查 - 尝试生成一个很短的测试内容
test_prompt = "请回答1+1等于几"
await self._ai_analyzer._call_gemini_api(test_prompt)
end_time = datetime.now()
response_time = int((end_time - start_time).total_seconds() * 1000)
return DataSourceStatus(
name="gemini",
is_available=True,
last_check=end_time,
response_time_ms=response_time
)
except Exception as e:
end_time = datetime.now()
return DataSourceStatus(
name="gemini",
is_available=False,
last_check=end_time,
error_message=str(e)
)
def _get_fallback_sources(self, primary_source: str) -> List[str]:
"""获取备用数据源列表"""
fallback_config = self.config.get("fallback_sources", {})
if primary_source in fallback_config:
return fallback_config[primary_source]
# 默认备用策略
all_sources = list(self._data_fetchers.keys())
return [source for source in all_sources if source != primary_source]
def update_config(self, new_config: Dict[str, Any]):
"""更新配置"""
self.config.update(new_config)
# 重新初始化数据获取器
self._data_fetchers.clear()
self._initialize_data_fetchers()
# 重新初始化AI分析器
self._ai_analyzer = None
self._initialize_ai_analyzer()
def get_supported_sources(self) -> List[str]:
"""获取支持的数据源列表"""
return DataFetcherFactory.get_supported_sources()
def get_available_sources(self) -> List[str]:
"""获取当前可用的数据源列表"""
return list(self._data_fetchers.keys())
def is_ai_analyzer_available(self) -> bool:
"""检查AI分析器是否可用"""
return self._ai_analyzer is not None
def create_data_source_manager(config: Dict[str, Any]) -> DataSourceManager:
"""创建数据源管理器"""
return DataSourceManager(config)
# 默认配置示例
DEFAULT_CONFIG = {
"data_sources": {
"tushare": {
"enabled": True,
"api_key": "", # 需要从环境变量或配置文件获取
"base_url": "http://api.tushare.pro",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
},
"yahoo": {
"enabled": True,
"base_url": "https://query1.finance.yahoo.com",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
}
},
"ai_services": {
"gemini": {
"enabled": True,
"api_key": "", # 需要从环境变量或配置文件获取
"model": "gemini-pro",
"timeout": 60,
"max_retries": 3,
"retry_delay": 2,
"temperature": 0.7,
"max_output_tokens": 8192
}
},
"market_mapping": {
"china": "tushare",
"中国": "tushare",
"hongkong": "yahoo",
"香港": "yahoo",
"usa": "yahoo",
"美国": "yahoo",
"japan": "yahoo",
"日本": "yahoo"
},
"fallback_sources": {
"tushare": ["yahoo"],
"yahoo": ["tushare"]
}
}

View File

@ -0,0 +1,322 @@
"""
外部API集成服务
统一管理所有外部API调用和数据源切换
"""
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime
from .data_source_manager import DataSourceManager, create_data_source_manager
from ..core.config import api_config
from ..core.exceptions import (
DataSourceError,
AIAnalysisError,
ConfigurationError
)
from ..schemas.data import (
FinancialDataResponse,
MarketDataResponse,
SymbolValidationResponse,
DataSourcesStatusResponse
)
from ..schemas.report import AIAnalysisResponse
class ExternalAPIService:
"""外部API服务"""
def __init__(self):
self._data_source_manager: Optional[DataSourceManager] = None
self._initialize_manager()
def _initialize_manager(self):
"""初始化数据源管理器"""
try:
config = api_config.get_data_source_manager_config()
self._data_source_manager = create_data_source_manager(config)
except Exception as e:
print(f"警告: 初始化数据源管理器失败: {str(e)}")
async def get_financial_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> FinancialDataResponse:
"""获取财务数据"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
try:
return await self._data_source_manager.fetch_financial_data(symbol, market, preferred_source)
except Exception as e:
raise DataSourceError(f"获取财务数据失败: {str(e)}", preferred_source or "unknown")
async def get_market_data(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> MarketDataResponse:
"""获取市场数据"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
try:
return await self._data_source_manager.fetch_market_data(symbol, market, preferred_source)
except Exception as e:
raise DataSourceError(f"获取市场数据失败: {str(e)}", preferred_source or "unknown")
async def validate_stock_symbol(self, symbol: str, market: str, preferred_source: Optional[str] = None) -> SymbolValidationResponse:
"""验证股票代码"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
try:
return await self._data_source_manager.validate_symbol(symbol, market, preferred_source)
except Exception as e:
# 验证失败时返回无效结果而不是抛出异常
return SymbolValidationResponse(
symbol=symbol,
market=market,
is_valid=False,
message=f"验证失败: {str(e)}"
)
async def analyze_business_info(self, symbol: str, market: str, financial_data: Dict[str, Any]) -> AIAnalysisResponse:
"""分析公司业务信息"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"business_info", symbol, market, {"financial_data": financial_data}
)
except Exception as e:
raise AIAnalysisError(f"业务信息分析失败: {str(e)}", "gemini")
async def analyze_fundamental(self, symbol: str, market: str, financial_data: Dict[str, Any], business_info: Dict[str, Any]) -> AIAnalysisResponse:
"""基本面分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
context_data = {
"financial_data": financial_data,
"business_info": business_info
}
return await self._data_source_manager.analyze_with_ai(
"fundamental_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"基本面分析失败: {str(e)}", "gemini")
async def analyze_bullish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""看涨分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"bullish_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"看涨分析失败: {str(e)}", "gemini")
async def analyze_bearish_case(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""看跌分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"bearish_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"看跌分析失败: {str(e)}", "gemini")
async def analyze_market_sentiment(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""市场情绪分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"market_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"市场分析失败: {str(e)}", "gemini")
async def analyze_news_catalysts(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""新闻催化剂分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"news_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"新闻分析失败: {str(e)}", "gemini")
async def analyze_trading_dynamics(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""交易动态分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"trading_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"交易分析失败: {str(e)}", "gemini")
async def analyze_insider_institutional(self, symbol: str, market: str, context_data: Dict[str, Any]) -> AIAnalysisResponse:
"""内部人与机构动向分析"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
return await self._data_source_manager.analyze_with_ai(
"insider_analysis", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"内部人分析失败: {str(e)}", "gemini")
async def generate_final_conclusion(self, symbol: str, market: str, all_analyses: List[Dict[str, Any]]) -> AIAnalysisResponse:
"""生成最终结论"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
if not self._data_source_manager.is_ai_analyzer_available():
raise AIAnalysisError("AI分析器不可用", "gemini")
try:
context_data = {"all_analyses": all_analyses}
return await self._data_source_manager.analyze_with_ai(
"final_conclusion", symbol, market, context_data
)
except Exception as e:
raise AIAnalysisError(f"最终结论生成失败: {str(e)}", "gemini")
async def check_all_services_status(self) -> DataSourcesStatusResponse:
"""检查所有外部服务状态"""
if not self._data_source_manager:
raise ConfigurationError("数据源管理器未初始化")
try:
return await self._data_source_manager.check_all_sources_status()
except Exception as e:
raise DataSourceError(f"检查服务状态失败: {str(e)}")
def get_supported_data_sources(self) -> List[str]:
"""获取支持的数据源列表"""
if not self._data_source_manager:
return []
return self._data_source_manager.get_supported_sources()
def get_available_data_sources(self) -> List[str]:
"""获取当前可用的数据源列表"""
if not self._data_source_manager:
return []
return self._data_source_manager.get_available_sources()
def is_ai_service_available(self) -> bool:
"""检查AI服务是否可用"""
if not self._data_source_manager:
return False
return self._data_source_manager.is_ai_analyzer_available()
def get_data_source_for_market(self, market: str) -> str:
"""根据市场获取推荐的数据源"""
if not self._data_source_manager:
return "tushare" # 默认值
return self._data_source_manager.get_data_source_for_market(market)
def update_configuration(self, new_config: Dict[str, Any]):
"""更新配置"""
if self._data_source_manager:
self._data_source_manager.update_config(new_config)
else:
# 如果管理器未初始化,尝试重新初始化
self._initialize_manager()
async def test_data_source_connection(self, data_source: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""测试数据源连接"""
try:
# 创建临时的数据获取器进行测试
from .data_fetcher import DataFetcherFactory
test_fetcher = DataFetcherFactory.create_fetcher(data_source, config)
status = await test_fetcher.check_status()
return {
"success": status.is_available,
"response_time_ms": status.response_time_ms,
"error_message": status.error_message
}
except Exception as e:
return {
"success": False,
"error_message": str(e)
}
async def test_ai_service_connection(self, service_type: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""测试AI服务连接"""
try:
if service_type.lower() == "gemini":
from .ai_analyzer import AIAnalyzerFactory
test_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
config.get("api_key"), config
)
# 简单的测试调用
start_time = datetime.now()
await test_analyzer._call_gemini_api("测试连接请回答OK")
end_time = datetime.now()
response_time = int((end_time - start_time).total_seconds() * 1000)
return {
"success": True,
"response_time_ms": response_time
}
else:
return {
"success": False,
"error_message": f"不支持的AI服务类型: {service_type}"
}
except Exception as e:
return {
"success": False,
"error_message": str(e)
}
# 创建全局服务实例
external_api_service = ExternalAPIService()
def get_external_api_service() -> ExternalAPIService:
"""获取外部API服务实例"""
return external_api_service

View File

@ -0,0 +1,163 @@
"""
进度追踪服务
处理报告生成进度的追踪和管理
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from ..models.progress_tracking import ProgressTracking
from ..schemas.progress import ProgressResponse, StepTiming
class ProgressTracker:
"""进度追踪器"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def initialize_progress(self, report_id: UUID):
"""初始化进度追踪"""
# 定义报告生成步骤
steps = [
"初始化报告",
"获取财务数据",
"生成业务信息",
"执行基本面分析",
"执行看涨分析",
"执行看跌分析",
"执行市场分析",
"执行新闻分析",
"执行交易分析",
"执行内部人分析",
"生成最终结论",
"保存报告"
]
# 创建进度记录
for i, step_name in enumerate(steps, 1):
progress = ProgressTracking(
report_id=report_id,
step_name=step_name,
step_order=i,
status="pending"
)
self.db.add(progress)
await self.db.flush()
async def start_step(self, report_id: UUID, step_name: str):
"""开始执行步骤"""
result = await self.db.execute(
select(ProgressTracking).where(
ProgressTracking.report_id == report_id,
ProgressTracking.step_name == step_name
)
)
progress = result.scalar_one_or_none()
if progress:
progress.status = "running"
progress.started_at = datetime.utcnow()
await self.db.flush()
async def complete_step(self, report_id: UUID, step_name: str, success: bool = True, error_message: Optional[str] = None):
"""完成步骤"""
result = await self.db.execute(
select(ProgressTracking).where(
ProgressTracking.report_id == report_id,
ProgressTracking.step_name == step_name
)
)
progress = result.scalar_one_or_none()
if progress:
progress.status = "completed" if success else "failed"
progress.completed_at = datetime.utcnow()
progress.error_message = error_message
# 计算耗时
if progress.started_at:
duration = progress.completed_at - progress.started_at
progress.duration_ms = int(duration.total_seconds() * 1000)
await self.db.flush()
async def get_progress(self, report_id: UUID) -> ProgressResponse:
"""获取进度信息"""
result = await self.db.execute(
select(ProgressTracking)
.where(ProgressTracking.report_id == report_id)
.order_by(ProgressTracking.step_order)
)
progress_records = result.scalars().all()
if not progress_records:
raise ValueError(f"未找到报告 {report_id} 的进度信息")
# 计算当前步骤
current_step = 1
current_step_name = "初始化报告"
overall_status = "running"
completed_count = 0
failed_count = 0
for record in progress_records:
if record.status == "completed":
completed_count += 1
elif record.status == "failed":
failed_count += 1
elif record.status == "running":
current_step = record.step_order
current_step_name = record.step_name
# 确定整体状态
if failed_count > 0:
overall_status = "failed"
elif completed_count == len(progress_records):
overall_status = "completed"
# 转换为StepTiming对象
step_timings = [
StepTiming(
step_name=record.step_name,
step_order=record.step_order,
status=record.status,
started_at=record.started_at,
completed_at=record.completed_at,
duration_ms=record.duration_ms,
error_message=record.error_message
)
for record in progress_records
]
return ProgressResponse(
report_id=report_id,
current_step=current_step,
total_steps=len(progress_records),
current_step_name=current_step_name,
status=overall_status,
step_timings=step_timings,
estimated_remaining=self._estimate_remaining_time(step_timings)
)
def _estimate_remaining_time(self, step_timings: List[StepTiming]) -> Optional[int]:
"""估算剩余时间"""
# 计算已完成步骤的平均耗时
completed_durations = [
timing.duration_ms for timing in step_timings
if timing.status == "completed" and timing.duration_ms
]
if not completed_durations:
return None
avg_duration_ms = sum(completed_durations) / len(completed_durations)
remaining_steps = len([t for t in step_timings if t.status == "pending"])
return int((avg_duration_ms * remaining_steps) / 1000) # 转换为秒

View File

@ -0,0 +1,643 @@
"""
报告生成服务
处理股票基本面分析报告的生成和管理
"""
from typing import Dict, Any, Optional, List
from uuid import UUID
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import logging
from ..models.report import Report
from ..models.analysis_module import AnalysisModule
from ..schemas.report import ReportResponse, AnalysisModuleSchema
from ..core.exceptions import (
ReportGenerationError,
DataSourceError,
AIAnalysisError,
DatabaseError
)
from .progress_tracker import ProgressTracker
from .data_fetcher import DataFetcherFactory
from .ai_analyzer import AIAnalyzerFactory
from .config_manager import ConfigManager
logger = logging.getLogger(__name__)
class AnalysisModuleType:
"""分析模块类型常量"""
TRADING_VIEW_CHART = "trading_view_chart"
FINANCIAL_DATA = "financial_data"
BUSINESS_INFO = "business_info"
FUNDAMENTAL_ANALYSIS = "fundamental_analysis"
BULLISH_ANALYSIS = "bullish_analysis"
BEARISH_ANALYSIS = "bearish_analysis"
MARKET_ANALYSIS = "market_analysis"
NEWS_ANALYSIS = "news_analysis"
TRADING_ANALYSIS = "trading_analysis"
INSIDER_ANALYSIS = "insider_analysis"
FINAL_CONCLUSION = "final_conclusion"
class ReportGenerator:
"""报告生成器"""
def __init__(self, db_session: AsyncSession, config_manager: ConfigManager):
self.db = db_session
self.config_manager = config_manager
self.progress_tracker = ProgressTracker(db_session)
# 定义分析模块配置
self.analysis_modules = [
{
"type": AnalysisModuleType.TRADING_VIEW_CHART,
"title": "TradingView图表",
"order": 1,
"step_name": "获取财务数据"
},
{
"type": AnalysisModuleType.FINANCIAL_DATA,
"title": "财务数据分析",
"order": 2,
"step_name": "获取财务数据"
},
{
"type": AnalysisModuleType.BUSINESS_INFO,
"title": "业务信息分析",
"order": 3,
"step_name": "生成业务信息"
},
{
"type": AnalysisModuleType.FUNDAMENTAL_ANALYSIS,
"title": "基本面分析(景林模型)",
"order": 4,
"step_name": "执行基本面分析"
},
{
"type": AnalysisModuleType.BULLISH_ANALYSIS,
"title": "看涨分析师观点",
"order": 5,
"step_name": "执行看涨分析"
},
{
"type": AnalysisModuleType.BEARISH_ANALYSIS,
"title": "看跌分析师观点",
"order": 6,
"step_name": "执行看跌分析"
},
{
"type": AnalysisModuleType.MARKET_ANALYSIS,
"title": "市场分析师观点",
"order": 7,
"step_name": "执行市场分析"
},
{
"type": AnalysisModuleType.NEWS_ANALYSIS,
"title": "新闻分析师观点",
"order": 8,
"step_name": "执行新闻分析"
},
{
"type": AnalysisModuleType.TRADING_ANALYSIS,
"title": "交易分析师观点",
"order": 9,
"step_name": "执行交易分析"
},
{
"type": AnalysisModuleType.INSIDER_ANALYSIS,
"title": "内部人与机构动向分析",
"order": 10,
"step_name": "执行内部人分析"
},
{
"type": AnalysisModuleType.FINAL_CONCLUSION,
"title": "最终结论",
"order": 11,
"step_name": "生成最终结论"
}
]
async def generate_report(self, symbol: str, market: str, force_regenerate: bool = False) -> ReportResponse:
"""生成股票分析报告"""
try:
# 检查是否存在现有报告
existing_report = await self._get_existing_report(symbol, market)
if existing_report and not force_regenerate:
if existing_report.status == "completed":
logger.info(f"返回现有报告: {symbol} ({market})")
return await self._build_report_response(existing_report)
elif existing_report.status == "generating":
logger.info(f"报告正在生成中: {symbol} ({market})")
return await self._build_report_response(existing_report)
# 创建新报告或重新生成
if existing_report and force_regenerate:
report = existing_report
report.status = "generating"
report.updated_at = datetime.utcnow()
# 清理现有的分析模块
await self._cleanup_existing_modules(report.id)
else:
report = await self._create_new_report(symbol, market)
# 初始化进度追踪
await self.progress_tracker.initialize_progress(report.id)
# 异步生成报告内容
asyncio.create_task(self._generate_report_content(report))
return await self._build_report_response(report)
except Exception as e:
logger.error(f"报告生成失败: {symbol} ({market}) - {str(e)}")
if isinstance(e, (ReportGenerationError, DataSourceError, AIAnalysisError)):
raise
raise ReportGenerationError(f"报告生成失败: {str(e)}")
async def get_report(self, symbol: str, market: str) -> Optional[ReportResponse]:
"""获取现有报告"""
try:
report = await self._get_existing_report(symbol, market)
if report:
return await self._build_report_response(report)
return None
except Exception as e:
logger.error(f"获取报告失败: {symbol} ({market}) - {str(e)}")
raise ReportGenerationError(f"获取报告失败: {str(e)}")
async def get_report_by_id(self, report_id: UUID) -> Optional[ReportResponse]:
"""根据ID获取报告"""
try:
result = await self.db.execute(
select(Report).where(Report.id == report_id)
)
report = result.scalar_one_or_none()
if report:
return await self._build_report_response(report)
return None
except Exception as e:
logger.error(f"获取报告失败: {report_id} - {str(e)}")
raise ReportGenerationError(f"获取报告失败: {str(e)}")
async def _get_existing_report(self, symbol: str, market: str) -> Optional[Report]:
"""获取现有报告"""
result = await self.db.execute(
select(Report).where(
Report.symbol == symbol,
Report.market == market
)
)
return result.scalar_one_or_none()
async def _create_new_report(self, symbol: str, market: str) -> Report:
"""创建新报告"""
report = Report(
symbol=symbol,
market=market,
status="generating"
)
self.db.add(report)
await self.db.flush()
return report
async def _cleanup_existing_modules(self, report_id: UUID):
"""清理现有的分析模块"""
result = await self.db.execute(
select(AnalysisModule).where(AnalysisModule.report_id == report_id)
)
modules = result.scalars().all()
for module in modules:
await self.db.delete(module)
await self.db.flush()
async def _generate_report_content(self, report: Report):
"""生成报告内容(异步执行)"""
try:
logger.info(f"开始生成报告内容: {report.symbol} ({report.market})")
# 开始初始化步骤
await self.progress_tracker.start_step(report.id, "初始化报告")
# 创建分析模块记录
await self._create_analysis_modules(report)
await self.progress_tracker.complete_step(report.id, "初始化报告", True)
# 获取配置
data_source_config = await self.config_manager.get_data_source_config(report.market)
gemini_config = await self.config_manager.get_gemini_config()
# 创建数据获取器和AI分析器
data_fetcher = DataFetcherFactory.create_fetcher(
data_source_config["type"],
data_source_config
)
ai_analyzer = AIAnalyzerFactory.create_gemini_analyzer(
gemini_config["api_key"],
gemini_config
)
# 存储分析结果的上下文
analysis_context = {}
# 按顺序执行各个分析模块
for module_config in self.analysis_modules:
try:
await self._execute_analysis_module(
report, module_config, data_fetcher, ai_analyzer, analysis_context
)
except Exception as e:
logger.error(f"分析模块执行失败: {module_config['type']} - {str(e)}")
# 标记模块为失败,但继续执行其他模块
await self._mark_module_failed(report.id, module_config["type"], str(e))
# 完成报告生成
await self.progress_tracker.start_step(report.id, "保存报告")
report.status = "completed"
report.updated_at = datetime.utcnow()
await self.db.commit()
await self.progress_tracker.complete_step(report.id, "保存报告", True)
logger.info(f"报告生成完成: {report.symbol} ({report.market})")
except Exception as e:
logger.error(f"报告生成过程失败: {report.symbol} ({report.market}) - {str(e)}")
# 标记报告为失败状态
report.status = "failed"
report.updated_at = datetime.utcnow()
await self.db.commit()
# 标记当前步骤为失败
try:
progress = await self.progress_tracker.get_progress(report.id)
current_step = progress.current_step_name
await self.progress_tracker.complete_step(report.id, current_step, False, str(e))
except Exception:
pass # 忽略进度更新失败
async def _create_analysis_modules(self, report: Report):
"""创建分析模块记录"""
for module_config in self.analysis_modules:
module = AnalysisModule(
report_id=report.id,
module_type=module_config["type"],
module_order=module_config["order"],
title=module_config["title"],
status="pending"
)
self.db.add(module)
await self.db.flush()
async def _execute_analysis_module(
self,
report: Report,
module_config: Dict[str, Any],
data_fetcher,
ai_analyzer,
analysis_context: Dict[str, Any]
):
"""执行单个分析模块"""
module_type = module_config["type"]
step_name = module_config["step_name"]
logger.info(f"执行分析模块: {module_type}")
# 开始步骤
await self.progress_tracker.start_step(report.id, step_name)
# 标记模块开始
await self._mark_module_started(report.id, module_type)
try:
# 根据模块类型执行相应的分析
if module_type == AnalysisModuleType.FINANCIAL_DATA:
content = await self._execute_financial_data_module(
report.symbol, report.market, data_fetcher
)
analysis_context["financial_data"] = content
elif module_type == AnalysisModuleType.TRADING_VIEW_CHART:
content = await self._execute_trading_view_module(
report.symbol, report.market
)
elif module_type == AnalysisModuleType.BUSINESS_INFO:
content = await self._execute_business_info_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["business_info"] = content
elif module_type == AnalysisModuleType.FUNDAMENTAL_ANALYSIS:
content = await self._execute_fundamental_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["fundamental_analysis"] = content
elif module_type == AnalysisModuleType.BULLISH_ANALYSIS:
content = await self._execute_bullish_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["bullish_analysis"] = content
elif module_type == AnalysisModuleType.BEARISH_ANALYSIS:
content = await self._execute_bearish_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["bearish_analysis"] = content
elif module_type == AnalysisModuleType.MARKET_ANALYSIS:
content = await self._execute_market_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["market_analysis"] = content
elif module_type == AnalysisModuleType.NEWS_ANALYSIS:
content = await self._execute_news_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["news_analysis"] = content
elif module_type == AnalysisModuleType.TRADING_ANALYSIS:
content = await self._execute_trading_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["trading_analysis"] = content
elif module_type == AnalysisModuleType.INSIDER_ANALYSIS:
content = await self._execute_insider_analysis_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["insider_analysis"] = content
elif module_type == AnalysisModuleType.FINAL_CONCLUSION:
content = await self._execute_final_conclusion_module(
report.symbol, report.market, ai_analyzer, analysis_context
)
analysis_context["final_conclusion"] = content
else:
raise ReportGenerationError(f"未知的分析模块类型: {module_type}")
# 保存模块内容
await self._save_module_content(report.id, module_type, content)
# 标记模块完成
await self._mark_module_completed(report.id, module_type)
# 完成步骤
await self.progress_tracker.complete_step(report.id, step_name, True)
except Exception as e:
logger.error(f"分析模块执行失败: {module_type} - {str(e)}")
# 标记模块失败
await self._mark_module_failed(report.id, module_type, str(e))
# 完成步骤(失败)
await self.progress_tracker.complete_step(report.id, step_name, False, str(e))
raise e
async def _execute_financial_data_module(self, symbol: str, market: str, data_fetcher) -> Dict[str, Any]:
"""执行财务数据分析模块"""
try:
# 获取财务数据
financial_data = await data_fetcher.fetch_financial_data(symbol, market)
# 获取市场数据
market_data = await data_fetcher.fetch_market_data(symbol, market)
return {
"financial_data": financial_data.dict(),
"market_data": market_data.dict(),
"summary": {
"data_source": financial_data.data_source,
"last_updated": financial_data.last_updated.isoformat(),
"data_quality": "good" # 可以添加数据质量评估逻辑
}
}
except Exception as e:
raise DataSourceError(f"财务数据获取失败: {str(e)}")
async def _execute_trading_view_module(self, symbol: str, market: str) -> Dict[str, Any]:
"""执行TradingView图表模块"""
# 生成TradingView图表配置
return {
"chart_config": {
"symbol": symbol,
"market": market,
"interval": "1D",
"theme": "light",
"style": "1", # 蜡烛图
"toolbar_bg": "#f1f3f6",
"enable_publishing": False,
"withdateranges": True,
"hide_side_toolbar": False,
"allow_symbol_change": False,
"studies": [
"MASimple@tv-basicstudies", # 移动平均线
"Volume@tv-basicstudies" # 成交量
]
},
"display_settings": {
"width": "100%",
"height": 500,
"autosize": True
}
}
async def _execute_business_info_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行业务信息分析模块"""
try:
financial_data = analysis_context.get("financial_data", {})
result = await ai_analyzer.analyze_business_info(symbol, market, financial_data)
return result.content
except Exception as e:
raise AIAnalysisError(f"业务信息分析失败: {str(e)}")
async def _execute_fundamental_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行基本面分析模块"""
try:
financial_data = analysis_context.get("financial_data", {})
business_info = analysis_context.get("business_info", {})
result = await ai_analyzer.analyze_fundamental(symbol, market, financial_data, business_info)
return result.content
except Exception as e:
raise AIAnalysisError(f"基本面分析失败: {str(e)}")
async def _execute_bullish_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行看涨分析模块"""
try:
result = await ai_analyzer.analyze_bullish_case(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"看涨分析失败: {str(e)}")
async def _execute_bearish_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行看跌分析模块"""
try:
result = await ai_analyzer.analyze_bearish_case(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"看跌分析失败: {str(e)}")
async def _execute_market_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行市场分析模块"""
try:
result = await ai_analyzer.analyze_market_sentiment(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"市场分析失败: {str(e)}")
async def _execute_news_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行新闻分析模块"""
try:
result = await ai_analyzer.analyze_news_catalysts(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"新闻分析失败: {str(e)}")
async def _execute_trading_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行交易分析模块"""
try:
result = await ai_analyzer.analyze_trading_dynamics(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"交易分析失败: {str(e)}")
async def _execute_insider_analysis_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行内部人分析模块"""
try:
result = await ai_analyzer.analyze_insider_institutional(symbol, market, analysis_context)
return result.content
except Exception as e:
raise AIAnalysisError(f"内部人分析失败: {str(e)}")
async def _execute_final_conclusion_module(self, symbol: str, market: str, ai_analyzer, analysis_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行最终结论模块"""
try:
# 收集所有分析结果
all_analyses = []
for key, value in analysis_context.items():
if key != "financial_data": # 排除原始财务数据
all_analyses.append({
"analysis_type": key,
"content": value
})
result = await ai_analyzer.generate_final_conclusion(symbol, market, all_analyses)
return result.content
except Exception as e:
raise AIAnalysisError(f"最终结论生成失败: {str(e)}")
async def _mark_module_started(self, report_id: UUID, module_type: str):
"""标记模块开始"""
result = await self.db.execute(
select(AnalysisModule).where(
AnalysisModule.report_id == report_id,
AnalysisModule.module_type == module_type
)
)
module = result.scalar_one_or_none()
if module:
module.status = "running"
module.started_at = datetime.utcnow()
await self.db.flush()
async def _mark_module_completed(self, report_id: UUID, module_type: str):
"""标记模块完成"""
result = await self.db.execute(
select(AnalysisModule).where(
AnalysisModule.report_id == report_id,
AnalysisModule.module_type == module_type
)
)
module = result.scalar_one_or_none()
if module:
module.status = "completed"
module.completed_at = datetime.utcnow()
await self.db.flush()
async def _mark_module_failed(self, report_id: UUID, module_type: str, error_message: str):
"""标记模块失败"""
result = await self.db.execute(
select(AnalysisModule).where(
AnalysisModule.report_id == report_id,
AnalysisModule.module_type == module_type
)
)
module = result.scalar_one_or_none()
if module:
module.status = "failed"
module.completed_at = datetime.utcnow()
module.error_message = error_message
await self.db.flush()
async def _save_module_content(self, report_id: UUID, module_type: str, content: Dict[str, Any]):
"""保存模块内容"""
result = await self.db.execute(
select(AnalysisModule).where(
AnalysisModule.report_id == report_id,
AnalysisModule.module_type == module_type
)
)
module = result.scalar_one_or_none()
if module:
module.content = content
await self.db.flush()
async def _build_report_response(self, report: Report) -> ReportResponse:
"""构建报告响应"""
# 获取分析模块
result = await self.db.execute(
select(AnalysisModule)
.where(AnalysisModule.report_id == report.id)
.order_by(AnalysisModule.module_order)
)
modules = result.scalars().all()
# 转换为响应模式
module_schemas = [
AnalysisModuleSchema(
id=module.id,
module_type=module.module_type,
module_order=module.module_order,
title=module.title,
content=module.content,
status=module.status,
started_at=module.started_at,
completed_at=module.completed_at,
error_message=module.error_message
)
for module in modules
]
return ReportResponse(
id=report.id,
symbol=report.symbol,
market=report.market,
status=report.status,
created_at=report.created_at,
updated_at=report.updated_at,
analysis_modules=module_schemas
)

45
backend/check_db.py Executable file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""
数据库连接检查脚本
用于验证数据库配置和连接状态
"""
import asyncio
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(__file__))
from app.core.database import check_db_connection, close_db
from app.core.config import settings
async def main():
"""主函数:检查数据库连接"""
try:
print(f"正在检查数据库连接...")
print(f"数据库URL: {settings.DATABASE_URL}")
# 检查数据库连接
is_connected = await check_db_connection()
if is_connected:
print("✅ 数据库连接正常!")
else:
print("❌ 数据库连接失败!")
print("请检查:")
print("1. PostgreSQL服务是否运行")
print("2. 数据库配置是否正确")
print("3. 网络连接是否正常")
sys.exit(1)
except Exception as e:
print(f"❌ 数据库连接检查失败: {e}")
sys.exit(1)
finally:
await close_db()
if __name__ == "__main__":
asyncio.run(main())

41
backend/init_db.py Executable file
View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
"""
数据库初始化脚本
用于创建数据库表和运行初始迁移
"""
import asyncio
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(__file__))
from app.core.database import init_db, close_db
from app.core.config import settings
async def main():
"""主函数:初始化数据库"""
try:
print(f"正在连接数据库: {settings.DATABASE_URL}")
print("正在创建数据库表...")
# 初始化数据库表
await init_db()
print("✅ 数据库表创建成功!")
print("💡 提示: 如果需要使用Alembic管理迁移请运行:")
print(" alembic stamp head")
print(" alembic revision --autogenerate -m '描述'")
print(" alembic upgrade head")
except Exception as e:
print(f"❌ 数据库初始化失败: {e}")
sys.exit(1)
finally:
await close_db()
if __name__ == "__main__":
asyncio.run(main())

101
backend/main.py Normal file
View File

@ -0,0 +1,101 @@
"""
FastAPI应用入口点
基本面选股系统后端服务
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from app.core.config import settings
from app.core.database import engine, Base
from app.routers import reports, config, progress
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时创建数据库表
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# 关闭时清理资源
await engine.dispose()
# 创建FastAPI应用实例
app = FastAPI(
title="基本面选股系统",
description="""
提供股票基本面分析和报告生成的API服务
## 功能特性
* **报告管理**: 创建查询更新和删除股票分析报告
* **进度追踪**: 实时追踪报告生成进度
* **配置管理**: 管理系统配置包括数据库API密钥等
* **多市场支持**: 支持中国香港美国日本股票市场
## 支持的市场
* `china` - 中国A股市场
* `hongkong` - 香港股票市场
* `usa` - 美国股票市场
* `japan` - 日本股票市场
""",
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(
reports.router,
prefix="/api/reports",
tags=["reports"],
responses={
404: {"description": "报告不存在"},
500: {"description": "服务器内部错误"}
}
)
app.include_router(
config.router,
prefix="/api/config",
tags=["config"],
responses={
400: {"description": "配置参数错误"},
500: {"description": "服务器内部错误"}
}
)
app.include_router(
progress.router,
prefix="/api/progress",
tags=["progress"],
responses={
404: {"description": "进度记录不存在"},
500: {"description": "服务器内部错误"}
}
)
@app.get("/")
async def root():
"""根路径健康检查"""
return {"message": "基本面选股系统API服务正在运行", "version": "1.0.0"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "healthy", "service": "fundamental-stock-analysis"}

125
backend/manage_db.py Executable file
View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
数据库管理脚本
提供数据库初始化检查迁移等功能
"""
import asyncio
import sys
import os
import argparse
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(__file__))
from app.core.database import init_db, close_db, check_db_connection
from app.core.config import settings
async def check_connection():
"""检查数据库连接"""
print(f"正在检查数据库连接...")
print(f"数据库URL: {settings.DATABASE_URL}")
is_connected = await check_db_connection()
if is_connected:
print("✅ 数据库连接正常!")
return True
else:
print("❌ 数据库连接失败!")
print("请检查:")
print("1. PostgreSQL服务是否运行")
print("2. 数据库配置是否正确")
print("3. 网络连接是否正常")
return False
async def initialize_database():
"""初始化数据库"""
print(f"正在连接数据库: {settings.DATABASE_URL}")
print("正在创建数据库表...")
try:
# 初始化数据库表
await init_db()
print("✅ 数据库表创建成功!")
print("💡 提示: 如果需要使用Alembic管理迁移请运行:")
print(" alembic stamp head")
print(" alembic revision --autogenerate -m '描述'")
print(" alembic upgrade head")
return True
except Exception as e:
print(f"❌ 数据库初始化失败: {e}")
return False
async def show_status():
"""显示数据库状态"""
print("=== 数据库状态 ===")
print(f"数据库URL: {settings.DATABASE_URL}")
print(f"数据库Echo: {settings.DATABASE_ECHO}")
# 检查连接
is_connected = await check_db_connection()
print(f"连接状态: {'✅ 正常' if is_connected else '❌ 失败'}")
if is_connected:
try:
from app.core.database import AsyncSessionLocal
async with AsyncSessionLocal() as session:
# 检查表是否存在
result = await session.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name IN ('reports', 'analysis_modules', 'progress_tracking', 'system_config')
ORDER BY table_name
""")
tables = [row[0] for row in result.fetchall()]
print(f"已创建的表: {', '.join(tables) if tables else ''}")
if 'reports' in tables:
result = await session.execute("SELECT COUNT(*) FROM reports")
count = result.scalar()
print(f"报告数量: {count}")
except Exception as e:
print(f"获取详细状态失败: {e}")
async def main():
"""主函数"""
parser = argparse.ArgumentParser(description='数据库管理工具')
parser.add_argument('command', choices=['check', 'init', 'status'],
help='要执行的命令')
args = parser.parse_args()
try:
if args.command == 'check':
success = await check_connection()
elif args.command == 'init':
success = await initialize_database()
elif args.command == 'status':
await show_status()
success = True
else:
print(f"未知命令: {args.command}")
success = False
if not success:
sys.exit(1)
except Exception as e:
print(f"❌ 执行失败: {e}")
sys.exit(1)
finally:
await close_db()
if __name__ == "__main__":
asyncio.run(main())

40
backend/requirements.txt Normal file
View File

@ -0,0 +1,40 @@
# FastAPI和相关依赖
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# 数据库相关
sqlalchemy[asyncio]==2.0.23
asyncpg==0.29.0
alembic==1.12.1
# 数据验证和序列化
pydantic==2.5.0
pydantic-settings==2.1.0
# HTTP客户端
httpx==0.25.2
aiohttp==3.9.1
# AI服务
google-generativeai==0.3.2
# 环境变量管理
python-dotenv==1.0.0
# 日志和监控
structlog==23.2.0
# 开发和测试依赖
pytest==7.4.3
pytest-asyncio==0.21.1
pytest-cov==4.1.0
black==23.11.0
isort==5.12.0
flake8==6.1.0
# 类型检查
mypy==1.7.1
# 安全相关
cryptography==41.0.7

View File

@ -0,0 +1,152 @@
#!/usr/bin/env python3
"""
外部API集成测试脚本
用于验证Tushare和Gemini API集成是否正常工作
"""
import asyncio
import os
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from app.services.external_api_service import get_external_api_service
from app.core.config import settings
async def test_data_sources():
"""测试数据源"""
print("=== 测试数据源 ===")
service = get_external_api_service()
# 检查支持的数据源
supported_sources = service.get_supported_data_sources()
print(f"支持的数据源: {supported_sources}")
available_sources = service.get_available_data_sources()
print(f"可用的数据源: {available_sources}")
# 测试数据源状态
try:
status_response = await service.check_all_services_status()
print(f"整体状态: {status_response.overall_status}")
for source_status in status_response.sources:
status_text = "✅ 可用" if source_status.is_available else "❌ 不可用"
print(f" {source_status.name}: {status_text}")
if source_status.response_time_ms:
print(f" 响应时间: {source_status.response_time_ms}ms")
if source_status.error_message:
print(f" 错误信息: {source_status.error_message}")
except Exception as e:
print(f"检查状态失败: {e}")
async def test_symbol_validation():
"""测试证券代码验证"""
print("\n=== 测试证券代码验证 ===")
service = get_external_api_service()
test_cases = [
("000001", "中国"), # 平安银行
("600036", "中国"), # 招商银行
("AAPL", "美国"), # 苹果
("INVALID", "中国") # 无效代码
]
for symbol, market in test_cases:
try:
result = await service.validate_stock_symbol(symbol, market)
status_text = "✅ 有效" if result.is_valid else "❌ 无效"
print(f" {symbol} ({market}): {status_text}")
if result.company_name:
print(f" 公司名称: {result.company_name}")
if result.message:
print(f" 消息: {result.message}")
except Exception as e:
print(f" {symbol} ({market}): ❌ 验证失败 - {e}")
async def test_financial_data():
"""测试财务数据获取"""
print("\n=== 测试财务数据获取 ===")
service = get_external_api_service()
test_cases = [
("000001", "中国"), # 平安银行
("AAPL", "美国") # 苹果
]
for symbol, market in test_cases:
try:
result = await service.get_financial_data(symbol, market)
print(f" {symbol} ({market}): ✅ 获取成功")
print(f" 数据源: {result.data_source}")
print(f" 更新时间: {result.last_updated}")
# 显示部分财务数据
if result.balance_sheet:
print(f" 资产负债表: {len(result.balance_sheet)} 项数据")
if result.income_statement:
print(f" 利润表: {len(result.income_statement)} 项数据")
except Exception as e:
print(f" {symbol} ({market}): ❌ 获取失败 - {e}")
async def test_ai_analysis():
"""测试AI分析"""
print("\n=== 测试AI分析 ===")
service = get_external_api_service()
if not service.is_ai_service_available():
print("❌ AI服务不可用请检查Gemini API配置")
return
# 模拟财务数据
mock_financial_data = {
"balance_sheet": {"total_assets": 1000000, "total_liab": 600000},
"income_statement": {"revenue": 500000, "n_income": 50000},
"cash_flow": {"n_cashflow_act": 80000},
"key_metrics": {"pe": 15.5, "pb": 1.2, "roe": 12.5}
}
try:
result = await service.analyze_business_info("000001", "中国", mock_financial_data)
print(" 业务信息分析: ✅ 成功")
print(f" 使用模型: {result.model_used}")
print(f" 生成时间: {result.generated_at}")
# 显示部分分析内容
if result.content.get("company_overview"):
overview = result.content["company_overview"][:100] + "..." if len(result.content["company_overview"]) > 100 else result.content["company_overview"]
print(f" 公司概览: {overview}")
except Exception as e:
print(f" 业务信息分析: ❌ 失败 - {e}")
async def main():
"""主测试函数"""
print("开始测试外部API集成...")
print(f"Tushare Token: {'已配置' if settings.TUSHARE_TOKEN else '未配置'}")
print(f"Gemini API Key: {'已配置' if settings.GEMINI_API_KEY else '未配置'}")
print()
await test_data_sources()
await test_symbol_validation()
await test_financial_data()
await test_ai_analysis()
print("\n测试完成!")
if __name__ == "__main__":
asyncio.run(main())

41
frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,41 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.*
.yarn/*
!.yarn/patches
!.yarn/plugins
!.yarn/releases
!.yarn/versions
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
# env files (can opt-in for committing if needed)
.env*
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts

86
frontend/README.md Normal file
View File

@ -0,0 +1,86 @@
# 基本面选股系统 - 前端
这是基本面选股系统的前端应用,使用 Next.js 14 和 TypeScript 构建。
## 技术栈
- **框架**: Next.js 14 (App Router)
- **语言**: TypeScript
- **样式**: Tailwind CSS
- **UI组件**: shadcn/ui
- **字体**: Noto Sans SC (中文支持)
## 项目结构
```
src/
├── app/ # Next.js App Router 页面
│ ├── layout.tsx # 根布局
│ ├── page.tsx # 首页
│ └── globals.css # 全局样式
├── components/ # React 组件
│ └── ui/ # shadcn/ui 基础组件
├── lib/ # 工具库
│ ├── api.ts # API 客户端
│ ├── types.ts # TypeScript 类型定义
│ └── utils.ts # 工具函数
└── hooks/ # 自定义 React Hooks
├── useReport.ts # 报告数据钩子
└── useProgress.ts # 进度追踪钩子
```
## 开发命令
```bash
# 安装依赖
npm install
# 启动开发服务器
npm run dev
# 构建生产版本
npm run build
# 启动生产服务器
npm start
# 代码检查
npm run lint
```
## 功能特性
- ✅ Next.js 14 项目初始化
- ✅ TypeScript 配置
- ✅ Tailwind CSS 样式系统
- ✅ shadcn/ui 组件库集成
- ✅ 中文字体支持 (Noto Sans SC)
- ✅ 基础项目结构
- ✅ API 客户端封装
- ✅ 自定义 Hooks
- ✅ 响应式布局
## 环境变量
创建 `.env.local` 文件并配置以下变量:
```env
NEXT_PUBLIC_API_URL=http://localhost:8000
```
## 开发说明
1. 项目使用 App Router 架构
2. 所有页面组件位于 `src/app/` 目录
3. 可复用组件位于 `src/components/` 目录
4. API 调用通过 `src/lib/api.ts` 统一管理
5. 类型定义集中在 `src/lib/types.ts`
6. 自定义 Hooks 用于状态管理和数据获取
## 下一步开发
- 实现股票搜索表单组件
- 创建报告页面和路由
- 集成 TradingView 图表
- 实现进度显示组件
- 添加配置管理页面

17
frontend/components.json Normal file
View File

@ -0,0 +1,17 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "default",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "src/app/globals.css",
"baseColor": "slate",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"
}
}

View File

@ -0,0 +1,25 @@
import { dirname } from "path";
import { fileURLToPath } from "url";
import { FlatCompat } from "@eslint/eslintrc";
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
const compat = new FlatCompat({
baseDirectory: __dirname,
});
const eslintConfig = [
...compat.extends("next/core-web-vitals", "next/typescript"),
{
ignores: [
"node_modules/**",
".next/**",
"out/**",
"build/**",
"next-env.d.ts",
],
},
];
export default eslintConfig;

7
frontend/next.config.ts Normal file
View File

@ -0,0 +1,7 @@
import type { NextConfig } from "next";
const nextConfig: NextConfig = {
/* config options here */
};
export default nextConfig;

6190
frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

33
frontend/package.json Normal file
View File

@ -0,0 +1,33 @@
{
"name": "frontend",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"build": "next build --turbopack",
"start": "next start",
"lint": "eslint"
},
"dependencies": {
"@radix-ui/react-slot": "^1.2.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^0.546.0",
"next": "15.5.6",
"react": "19.1.0",
"react-dom": "19.1.0",
"tailwind-merge": "^3.3.1",
"tailwindcss-animate": "^1.0.7"
},
"devDependencies": {
"@eslint/eslintrc": "^3",
"@tailwindcss/postcss": "^4",
"@types/node": "^20",
"@types/react": "^19",
"@types/react-dom": "^19",
"eslint": "^9",
"eslint-config-next": "15.5.6",
"tailwindcss": "^4",
"typescript": "^5"
}
}

View File

@ -0,0 +1,5 @@
const config = {
plugins: ["@tailwindcss/postcss"],
};
export default config;

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -0,0 +1,54 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
--background: 0 0% 100%;
--foreground: 222.2 84% 4.9%;
--card: 0 0% 100%;
--card-foreground: 222.2 84% 4.9%;
--popover: 0 0% 100%;
--popover-foreground: 222.2 84% 4.9%;
--primary: 222.2 47.4% 11.2%;
--primary-foreground: 210 40% 98%;
--secondary: 210 40% 96%;
--secondary-foreground: 222.2 47.4% 11.2%;
--muted: 210 40% 96%;
--muted-foreground: 215.4 16.3% 46.9%;
--accent: 210 40% 96%;
--accent-foreground: 222.2 47.4% 11.2%;
--destructive: 0 84.2% 60.2%;
--destructive-foreground: 210 40% 98%;
--border: 214.3 31.8% 91.4%;
--input: 214.3 31.8% 91.4%;
--ring: 222.2 84% 4.9%;
--radius: 0.5rem;
}
.dark {
--background: 222.2 84% 4.9%;
--foreground: 210 40% 98%;
--card: 222.2 84% 4.9%;
--card-foreground: 210 40% 98%;
--popover: 222.2 84% 4.9%;
--popover-foreground: 210 40% 98%;
--primary: 210 40% 98%;
--primary-foreground: 222.2 47.4% 11.2%;
--secondary: 217.2 32.6% 17.5%;
--secondary-foreground: 210 40% 98%;
--muted: 217.2 32.6% 17.5%;
--muted-foreground: 215 20.2% 65.1%;
--accent: 217.2 32.6% 17.5%;
--accent-foreground: 210 40% 98%;
--destructive: 0 62.8% 30.6%;
--destructive-foreground: 210 40% 98%;
--border: 217.2 32.6% 17.5%;
--input: 217.2 32.6% 17.5%;
--ring: 212.7 26.8% 83.9%;
}
body {
font-family: "Noto Sans SC", "PingFang SC", "Microsoft YaHei", sans-serif;
background-color: hsl(var(--background));
color: hsl(var(--foreground));
}

View File

@ -0,0 +1,32 @@
import type { Metadata } from "next";
import { Noto_Sans_SC } from "next/font/google";
import "./globals.css";
const notoSansSC = Noto_Sans_SC({
subsets: ["latin"],
weight: ["300", "400", "500", "700"],
variable: "--font-noto-sans-sc",
});
export const metadata: Metadata = {
title: "基本面选股系统",
description: "专业的股票基本面分析平台,提供全面的投资决策支持",
};
export default function RootLayout({
children,
}: Readonly<{
children: React.ReactNode;
}>) {
return (
<html lang="zh-CN" suppressHydrationWarning>
<body className={`${notoSansSC.variable} antialiased`}>
<div className="min-h-screen bg-background font-sans antialiased">
<main className="relative flex min-h-screen flex-col">
{children}
</main>
</div>
</body>
</html>
);
}

17
frontend/src/app/page.tsx Normal file
View File

@ -0,0 +1,17 @@
export default function Home() {
return (
<div className="container mx-auto px-4 py-8">
<div className="flex flex-col items-center justify-center min-h-[80vh] text-center">
<h1 className="text-4xl font-bold text-foreground mb-6">
</h1>
<p className="text-xl text-muted-foreground mb-8 max-w-2xl">
</p>
<div className="text-sm text-muted-foreground">
</div>
</div>
</div>
);
}

View File

@ -0,0 +1 @@
# This file ensures the components directory is tracked by git

View File

@ -0,0 +1,56 @@
import * as React from "react"
import { Slot } from "@radix-ui/react-slot"
import { cva, type VariantProps } from "class-variance-authority"
import { cn } from "@/lib/utils"
const buttonVariants = cva(
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
{
variants: {
variant: {
default: "bg-primary text-primary-foreground hover:bg-primary/90",
destructive:
"bg-destructive text-destructive-foreground hover:bg-destructive/90",
outline:
"border border-input bg-background hover:bg-accent hover:text-accent-foreground",
secondary:
"bg-secondary text-secondary-foreground hover:bg-secondary/80",
ghost: "hover:bg-accent hover:text-accent-foreground",
link: "text-primary underline-offset-4 hover:underline",
},
size: {
default: "h-10 px-4 py-2",
sm: "h-9 rounded-md px-3",
lg: "h-11 rounded-md px-8",
icon: "h-10 w-10",
},
},
defaultVariants: {
variant: "default",
size: "default",
},
}
)
export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
VariantProps<typeof buttonVariants> {
asChild?: boolean
}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({ className, variant, size, asChild = false, ...props }, ref) => {
const Comp = asChild ? Slot : "button"
return (
<Comp
className={cn(buttonVariants({ variant, size, className }))}
ref={ref}
{...props}
/>
)
}
)
Button.displayName = "Button"
export { Button, buttonVariants }

View File

@ -0,0 +1,79 @@
import * as React from "react"
import { cn } from "@/lib/utils"
const Card = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
"rounded-lg border bg-card text-card-foreground shadow-sm",
className
)}
{...props}
/>
))
Card.displayName = "Card"
const CardHeader = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex flex-col space-y-1.5 p-6", className)}
{...props}
/>
))
CardHeader.displayName = "CardHeader"
const CardTitle = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLHeadingElement>
>(({ className, ...props }, ref) => (
<h3
ref={ref}
className={cn(
"text-2xl font-semibold leading-none tracking-tight",
className
)}
{...props}
/>
))
CardTitle.displayName = "CardTitle"
const CardDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => (
<p
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
CardDescription.displayName = "CardDescription"
const CardContent = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
))
CardContent.displayName = "CardContent"
const CardFooter = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex items-center p-6 pt-0", className)}
{...props}
/>
))
CardFooter.displayName = "CardFooter"
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent }

View File

@ -0,0 +1 @@
# This file ensures the hooks directory is tracked by git

View File

@ -0,0 +1,68 @@
import { useState, useEffect, useRef, useCallback } from "react";
import { ProgressResponse } from "@/lib/types";
import { apiClient } from "@/lib/api";
export function useProgress(reportId?: string, pollingInterval: number = 2000) {
const [progress, setProgress] = useState<ProgressResponse | null>(null);
const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const intervalRef = useRef<NodeJS.Timeout | null>(null);
const stopPolling = useCallback(() => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
setLoading(false);
}, []);
const fetchProgress = useCallback(async (id: string) => {
try {
const progressData = await apiClient.getReportProgress(id);
setProgress(progressData);
// 如果报告已完成或失败,停止轮询
if (progressData.status === "completed" || progressData.status === "failed") {
stopPolling();
}
} catch (err) {
setError(err instanceof Error ? err.message : "获取进度失败");
stopPolling();
}
}, [stopPolling]);
const startPolling = useCallback((id: string) => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
}
setLoading(true);
setError(null);
// 立即获取一次进度
fetchProgress(id);
// 开始轮询
intervalRef.current = setInterval(() => {
fetchProgress(id);
}, pollingInterval);
}, [fetchProgress, pollingInterval]);
useEffect(() => {
if (reportId) {
startPolling(reportId);
}
return () => {
stopPolling();
};
}, [reportId, startPolling, stopPolling]);
return {
progress,
loading,
error,
startPolling,
stopPolling,
};
}

View File

@ -0,0 +1,49 @@
import { useState, useEffect } from "react";
import { Report, TradingMarket } from "@/lib/types";
import { apiClient } from "@/lib/api";
export function useReport(symbol?: string, market?: TradingMarket) {
const [report, setReport] = useState<Report | null>(null);
const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const fetchReport = async (sym: string, mkt: TradingMarket) => {
setLoading(true);
setError(null);
try {
const reportData = await apiClient.getReport(sym, mkt);
setReport(reportData);
} catch (err) {
setError(err instanceof Error ? err.message : "获取报告失败");
} finally {
setLoading(false);
}
};
const regenerateReport = async (sym: string, mkt: TradingMarket, force: boolean = false) => {
setLoading(true);
setError(null);
try {
const reportData = await apiClient.regenerateReport(sym, mkt, force);
setReport(reportData);
} catch (err) {
setError(err instanceof Error ? err.message : "重新生成报告失败");
} finally {
setLoading(false);
}
};
useEffect(() => {
if (symbol && market) {
fetchReport(symbol, market);
}
}, [symbol, market]);
return {
report,
loading,
error,
fetchReport,
regenerateReport,
};
}

View File

@ -0,0 +1,79 @@
import type { Config } from "tailwindcss"
const config: Config = {
darkMode: "class",
content: [
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
],
prefix: "",
theme: {
container: {
center: true,
padding: "2rem",
screens: {
"2xl": "1400px",
},
},
extend: {
colors: {
border: "hsl(var(--border))",
input: "hsl(var(--input))",
ring: "hsl(var(--ring))",
background: "hsl(var(--background))",
foreground: "hsl(var(--foreground))",
primary: {
DEFAULT: "hsl(var(--primary))",
foreground: "hsl(var(--primary-foreground))",
},
secondary: {
DEFAULT: "hsl(var(--secondary))",
foreground: "hsl(var(--secondary-foreground))",
},
destructive: {
DEFAULT: "hsl(var(--destructive))",
foreground: "hsl(var(--destructive-foreground))",
},
muted: {
DEFAULT: "hsl(var(--muted))",
foreground: "hsl(var(--muted-foreground))",
},
accent: {
DEFAULT: "hsl(var(--accent))",
foreground: "hsl(var(--accent-foreground))",
},
popover: {
DEFAULT: "hsl(var(--popover))",
foreground: "hsl(var(--popover-foreground))",
},
card: {
DEFAULT: "hsl(var(--card))",
foreground: "hsl(var(--card-foreground))",
},
},
borderRadius: {
lg: "var(--radius)",
md: "calc(var(--radius) - 2px)",
sm: "calc(var(--radius) - 4px)",
},
keyframes: {
"accordion-down": {
from: { height: "0" },
to: { height: "var(--radix-accordion-content-height)" },
},
"accordion-up": {
from: { height: "var(--radix-accordion-content-height)" },
to: { height: "0" },
},
},
animation: {
"accordion-down": "accordion-down 0.2s ease-out",
"accordion-up": "accordion-up 0.2s ease-out",
},
},
},
plugins: [require("tailwindcss-animate")],
}
export default config

27
frontend/tsconfig.json Normal file
View File

@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2017",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [
{
"name": "next"
}
],
"paths": {
"@/*": ["./src/*"]
}
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
"exclude": ["node_modules"]
}