542 lines
19 KiB
Python
542 lines
19 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
import subprocess
|
|
import shutil
|
|
import yaml
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
from dotenv import load_dotenv
|
|
from google import genai
|
|
from google.genai import types
|
|
import markdown
|
|
|
|
def load_config():
|
|
load_dotenv(override=True)
|
|
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("OPENAI_API_KEY")
|
|
model_name = os.getenv("LLM_MODEL", "gemini-2.0-flash-exp")
|
|
|
|
return api_key, model_name
|
|
|
|
def run_main_script(market, symbol):
|
|
print(f"Running main.py for {market} {symbol}...")
|
|
try:
|
|
# Assuming main.py is in current directory
|
|
# Using sys.executable to ensure the same python environment
|
|
cmd = [sys.executable, "main.py", market, symbol]
|
|
subprocess.run(cmd, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error running main.py: {e}")
|
|
sys.exit(1)
|
|
|
|
def get_company_info(market, symbol):
|
|
# Path: data/MARKET/SYMBOL/report.md
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
|
|
# Try exact match first
|
|
if not os.path.exists(symbol_dir):
|
|
# Try to find a directory that starts with the symbol
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
# Use the first match (e.g., 688334.SH)
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
print(f"Redirecting to found directory: {candidates[0]}")
|
|
|
|
report_path = os.path.join(symbol_dir, "report.md")
|
|
|
|
if not os.path.exists(report_path):
|
|
print(f"Error: {report_path} not found.")
|
|
sys.exit(1)
|
|
|
|
company_name = symbol # Fallback
|
|
|
|
try:
|
|
with open(report_path, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
lines = content.splitlines()
|
|
if not lines:
|
|
return symbol
|
|
|
|
# Attempt 1: Header Regex "# Name (Code) - Financial Report"
|
|
header_match = re.search(r'^#\s+(.+?)\s+\(', lines[0])
|
|
if header_match:
|
|
company_name = header_match.group(1).strip()
|
|
# Attempt 2: Table | Code | Name | ...
|
|
else:
|
|
for line in lines:
|
|
if f"| {symbol}" in line or f"| {symbol.upper()}" in line:
|
|
parts = line.split('|')
|
|
if len(parts) >= 3:
|
|
company_name = parts[2].strip()
|
|
break
|
|
except Exception as e:
|
|
print(f"Warning: Could not extract company name: {e}")
|
|
|
|
return company_name
|
|
|
|
def create_report_file(market, symbol, company_name):
|
|
date_str = datetime.now().strftime("%Y%m%d")
|
|
folder_name = "Reports"
|
|
file_name = f"{symbol}_{market}_{company_name}_{date_str}_基本面分析报告.md"
|
|
|
|
os.makedirs(folder_name, exist_ok=True)
|
|
target_path = os.path.join(folder_name, file_name)
|
|
|
|
# Find data directory again (should ideally be passed, but simple lookup works)
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
source_html = os.path.join(symbol_dir, "report.html")
|
|
if os.path.exists(source_html):
|
|
# Create a relative link to the HTML file instead of copying content
|
|
report_dir = os.path.dirname(target_path)
|
|
rel_path = os.path.relpath(source_html, report_dir)
|
|
|
|
with open(target_path, 'w', encoding='utf-8') as f:
|
|
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
|
|
f.write(f"**生成日期**: {date_str}\n\n")
|
|
else:
|
|
with open(target_path, 'w', encoding='utf-8') as f:
|
|
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
|
|
f.write(f"**生成日期**: {date_str}\n\n")
|
|
|
|
return target_path
|
|
|
|
def load_prompts():
|
|
prompts = {}
|
|
prompt_dir = "Prompt"
|
|
mapping = {
|
|
"company_profile": "公司简介.md",
|
|
"fundamental_analysis": "基本面分析.md",
|
|
"insider_analysis": "内部人与机构动向分析.md",
|
|
"bullish_analysis": "看涨分析.md",
|
|
"bearish_analysis": "看跌分析.md"
|
|
}
|
|
|
|
for key, filename in mapping.items():
|
|
try:
|
|
with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f:
|
|
prompts[key] = f.read()
|
|
except FileNotFoundError:
|
|
print(f"Warning: Prompt file {filename} not found.")
|
|
|
|
return prompts
|
|
|
|
def call_llm(api_key, model_name, system_prompt, user_prompt, context, enable_search=False):
|
|
start_time = time.time()
|
|
|
|
# Combine system and user prompts for Gemini
|
|
full_prompt = f"{system_prompt}\n\n{user_prompt}\n\nExisting Report Data for context:\n{context}"
|
|
|
|
# Create client
|
|
client = genai.Client(api_key=api_key)
|
|
|
|
# Configure tools if search is enabled
|
|
config_params = {}
|
|
if enable_search:
|
|
grounding_tool = types.Tool(google_search=types.GoogleSearch())
|
|
config_params['tools'] = [grounding_tool]
|
|
|
|
config = types.GenerateContentConfig(**config_params)
|
|
|
|
try:
|
|
response = client.models.generate_content(
|
|
model=model_name,
|
|
contents=full_prompt,
|
|
config=config
|
|
)
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
|
|
# Extract usage information
|
|
usage = {
|
|
'prompt_tokens': response.usage_metadata.prompt_token_count if hasattr(response, 'usage_metadata') else 0,
|
|
'completion_tokens': response.usage_metadata.candidates_token_count if hasattr(response, 'usage_metadata') else 0,
|
|
'total_tokens': response.usage_metadata.total_token_count if hasattr(response, 'usage_metadata') else 0
|
|
}
|
|
|
|
# Create a simple usage object
|
|
class Usage:
|
|
def __init__(self, prompt_tokens, completion_tokens, total_tokens):
|
|
self.prompt_tokens = prompt_tokens
|
|
self.completion_tokens = completion_tokens
|
|
self.total_tokens = total_tokens
|
|
|
|
usage_obj = Usage(usage['prompt_tokens'], usage['completion_tokens'], usage['total_tokens'])
|
|
|
|
return response.text, usage_obj, duration
|
|
except Exception as e:
|
|
print(f"API Call Failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return f"\n\nError generating section: {e}\n\n", None, 0
|
|
|
|
def render_html_report(report_file, market, symbol):
|
|
"""
|
|
将Markdown报告渲染成HTML,并在顶部嵌入财务数据HTML图表
|
|
|
|
Args:
|
|
report_file: Markdown报告文件路径
|
|
market: 市场代码
|
|
symbol: 股票代码
|
|
|
|
Returns:
|
|
生成的HTML文件路径
|
|
"""
|
|
# 读取Markdown报告
|
|
with open(report_file, 'r', encoding='utf-8') as f:
|
|
md_content = f.read()
|
|
|
|
# 查找财务数据HTML文件
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
financial_html_path = os.path.join(symbol_dir, "report.html")
|
|
financial_html_content = ""
|
|
|
|
if os.path.exists(financial_html_path):
|
|
with open(financial_html_path, 'r', encoding='utf-8') as f:
|
|
financial_html_content = f.read()
|
|
|
|
# 将Markdown转换为HTML,添加更多扩展以支持完整的Markdown语法
|
|
md = markdown.Markdown(extensions=[
|
|
'tables', # 表格支持
|
|
'fenced_code', # 代码块支持
|
|
'codehilite', # 代码高亮
|
|
'nl2br', # 换行转<br>
|
|
'sane_lists', # 更好的列表处理
|
|
'attr_list', # 属性列表
|
|
'def_list', # 定义列表
|
|
'abbr', # 缩写
|
|
'footnotes', # 脚注
|
|
'md_in_html' # HTML中的Markdown
|
|
])
|
|
report_html_content = md.convert(md_content)
|
|
|
|
# 构建完整的HTML文档
|
|
html_template = f"""<!DOCTYPE html>
|
|
<html lang="zh-CN">
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<title>{os.path.basename(report_file).replace('.md', '')}</title>
|
|
<style>
|
|
body {{
|
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
|
line-height: 1.6;
|
|
margin: 0;
|
|
padding: 20px;
|
|
background-color: #f5f5f5;
|
|
display: flex;
|
|
justify-content: center;
|
|
min-height: 100vh;
|
|
}}
|
|
.container {{
|
|
background-color: white;
|
|
padding: 40px;
|
|
border-radius: 8px;
|
|
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
|
max-width: 1200px;
|
|
width: 100%;
|
|
}}
|
|
h1 {{
|
|
color: #2c3e50;
|
|
border-bottom: 3px solid #3498db;
|
|
padding-bottom: 10px;
|
|
}}
|
|
h2 {{
|
|
color: #34495e;
|
|
margin-top: 30px;
|
|
border-bottom: 2px solid #ecf0f1;
|
|
padding-bottom: 8px;
|
|
}}
|
|
h3 {{
|
|
color: #7f8c8d;
|
|
margin-top: 20px;
|
|
}}
|
|
table {{
|
|
border-collapse: collapse;
|
|
width: 100%;
|
|
margin: 20px 0;
|
|
background-color: white;
|
|
}}
|
|
th, td {{
|
|
border: 1px solid #ddd;
|
|
padding: 12px;
|
|
text-align: left;
|
|
}}
|
|
th {{
|
|
background-color: #3498db;
|
|
color: white;
|
|
font-weight: bold;
|
|
}}
|
|
tr:nth-child(even) {{
|
|
background-color: #f9f9f9;
|
|
}}
|
|
tr:hover {{
|
|
background-color: #f5f5f5;
|
|
}}
|
|
blockquote {{
|
|
border-left: 4px solid #3498db;
|
|
padding-left: 20px;
|
|
margin: 20px 0;
|
|
color: #555;
|
|
background-color: #f8f9fa;
|
|
padding: 15px 20px;
|
|
border-radius: 4px;
|
|
}}
|
|
code {{
|
|
background-color: #f4f4f4;
|
|
padding: 2px 6px;
|
|
border-radius: 3px;
|
|
font-family: "Courier New", monospace;
|
|
}}
|
|
pre {{
|
|
background-color: #f4f4f4;
|
|
padding: 15px;
|
|
border-radius: 5px;
|
|
overflow-x: auto;
|
|
}}
|
|
.financial-data {{
|
|
margin: 30px 0;
|
|
padding: 20px;
|
|
background-color: #f8f9fa;
|
|
border-radius: 8px;
|
|
border: 2px solid #3498db;
|
|
}}
|
|
.financial-data h2 {{
|
|
color: #3498db;
|
|
margin-top: 0;
|
|
}}
|
|
a {{
|
|
color: #3498db;
|
|
text-decoration: none;
|
|
}}
|
|
a:hover {{
|
|
text-decoration: underline;
|
|
}}
|
|
.toc {{
|
|
background-color: #f8f9fa;
|
|
padding: 20px;
|
|
border-radius: 8px;
|
|
margin: 20px 0;
|
|
}}
|
|
.toc h2 {{
|
|
margin-top: 0;
|
|
color: #2c3e50;
|
|
}}
|
|
.toc ul {{
|
|
list-style-type: none;
|
|
padding-left: 0;
|
|
}}
|
|
.toc li {{
|
|
margin: 8px 0;
|
|
}}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
{report_html_content.split('</h1>')[0]}</h1>
|
|
|
|
<div class="financial-data">
|
|
<h2>📊 财务数据可视化</h2>
|
|
{financial_html_content if financial_html_content else '<p>财务数据图表未找到</p>'}
|
|
</div>
|
|
|
|
{'</h1>'.join(report_html_content.split('</h1>')[1:]) if '</h1>' in report_html_content else report_html_content}
|
|
</div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
# 保存HTML文件到Reports-HTML目录
|
|
html_folder = "Reports-HTML"
|
|
os.makedirs(html_folder, exist_ok=True)
|
|
|
|
# 获取原始文件名并替换扩展名
|
|
base_filename = os.path.basename(report_file).replace('.md', '.html')
|
|
html_file = os.path.join(html_folder, base_filename)
|
|
|
|
with open(html_file, 'w', encoding='utf-8') as f:
|
|
f.write(html_template)
|
|
|
|
print(f"HTML report generated: {html_file}")
|
|
return html_file
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Stock Analysis Automation")
|
|
parser.add_argument("market", nargs='?', help="Market (CN/US/HK/JP)")
|
|
parser.add_argument("symbol", nargs='?', help="Stock Symbol")
|
|
parser.add_argument("--search", action="store_true", default=True, help="Enable Google Search for LLM (enabled by default)")
|
|
parser.add_argument("--no-search", dest="search", action="store_false", help="Disable Google Search for LLM")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
market = args.market
|
|
symbol = args.symbol
|
|
|
|
if market == "CN" and symbol and symbol.isdigit():
|
|
if symbol.startswith("6"):
|
|
symbol = f"{symbol}.SH"
|
|
elif symbol.startswith("0") or symbol.startswith("3"):
|
|
symbol = f"{symbol}.SZ"
|
|
elif symbol.startswith("4") or symbol.startswith("8"):
|
|
symbol = f"{symbol}.BJ"
|
|
print(f"Auto-corrected symbol to: {symbol}")
|
|
|
|
if not market or not symbol:
|
|
try:
|
|
market = input("请输入市场 (CN/US/HK/JP): ").strip().upper()
|
|
symbol = input("请输入股票代码: ").strip()
|
|
except EOFError:
|
|
print("\nError: Input needed but EOF received. Please provide arguments or run interactively.")
|
|
sys.exit(1)
|
|
|
|
if not market or not symbol:
|
|
print("Market and Symbol are required.")
|
|
sys.exit(1)
|
|
|
|
api_key, model_name = load_config()
|
|
print(f"Configuration Loaded:\n Model: {model_name}\n")
|
|
|
|
if not api_key:
|
|
print("Warning: GEMINI_API_KEY not found in .env. API calls might fail.")
|
|
|
|
# Step 1: Get Data
|
|
run_main_script(market, symbol)
|
|
|
|
# Step 2: Init Report
|
|
company_name = get_company_info(market, symbol)
|
|
print(f"Identified Company: {company_name}")
|
|
report_file = create_report_file(market, symbol, company_name)
|
|
print(f"Report initialized at: {report_file}")
|
|
|
|
# Load Prompts
|
|
prompts = load_prompts()
|
|
|
|
# Read Data Context
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
data_path = os.path.join(symbol_dir, "report.md")
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
data_context = f.read()
|
|
|
|
# Read CSV data for bearish analysis
|
|
csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv")
|
|
csv_context = ""
|
|
if os.path.exists(csv_path):
|
|
try:
|
|
with open(csv_path, 'r', encoding='utf-8') as f:
|
|
csv_content = f.read()
|
|
csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n"
|
|
except Exception as e:
|
|
print(f"Warning: Could not read CSV file: {e}")
|
|
|
|
steps = [
|
|
("company_profile", "3. 公司简介 (Company Profile)"),
|
|
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
|
|
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
|
|
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
|
|
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
|
|
]
|
|
|
|
total_stats = []
|
|
|
|
for key, name in steps:
|
|
print(f"\nProcessing {name}...")
|
|
prompt_template = prompts.get(key)
|
|
if not prompt_template:
|
|
print(f"Warning: Prompt for {key} not found.")
|
|
continue
|
|
|
|
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
|
|
|
|
system_content = formatted_prompt
|
|
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
|
|
|
|
current_data_context = data_context
|
|
# Inject CSV data only for bearish analysis
|
|
if key == "bearish_analysis" and csv_context:
|
|
current_data_context += csv_context
|
|
|
|
# Retry logic
|
|
max_retries = 3
|
|
result = None
|
|
usage = None
|
|
duration = 0
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
result, usage, duration = call_llm(api_key, model_name, system_content, user_content, current_data_context, enable_search=args.search)
|
|
if result and not result.startswith("\n\nError"):
|
|
break
|
|
else:
|
|
print(f" Attempt {attempt + 1} failed with error content, retrying...")
|
|
except Exception as e:
|
|
print(f" Attempt {attempt + 1} failed: {e}, retrying...")
|
|
|
|
if not result or result.startswith("\n\nError"):
|
|
result = f"\n> [Error] Failed to generate section {name} after {max_retries} attempts.\n"
|
|
elif usage:
|
|
total_stats.append({
|
|
"step": name,
|
|
"duration": duration,
|
|
"prompt_tokens": usage.prompt_tokens,
|
|
"completion_tokens": usage.completion_tokens,
|
|
"total_tokens": usage.total_tokens
|
|
})
|
|
|
|
with open(report_file, 'a', encoding='utf-8') as f:
|
|
f.write(f"\n\n")
|
|
f.write(result)
|
|
print(f"Finished {name} in {duration:.2f}s.")
|
|
|
|
# Append stats
|
|
if total_stats:
|
|
with open(report_file, 'a', encoding='utf-8') as f:
|
|
f.write("\n\n## API Usage Summary\n\n")
|
|
f.write("| Step | Duration (s) | Prompt Tokens | Completion Tokens | Total Tokens |\n")
|
|
f.write("|---|---|---|---|---|\n")
|
|
total_duration = 0
|
|
total_prompt = 0
|
|
total_completion = 0
|
|
total_tokens = 0
|
|
|
|
for stat in total_stats:
|
|
f.write(f"| {stat['step']} | {stat['duration']:.2f} | {stat['prompt_tokens']} | {stat['completion_tokens']} | {stat['total_tokens']} |\n")
|
|
total_duration += stat['duration']
|
|
total_prompt += stat['prompt_tokens']
|
|
total_completion += stat['completion_tokens']
|
|
total_tokens += stat['total_tokens']
|
|
|
|
f.write(f"| **Total** | **{total_duration:.2f}** | **{total_prompt}** | **{total_completion}** | **{total_tokens}** |\n")
|
|
|
|
print(f"\nStock analysis for {company_name} ({symbol}) completed. Report saved to {report_file}")
|
|
|
|
# Generate HTML version of the report
|
|
try:
|
|
html_file = render_html_report(report_file, market, symbol)
|
|
print(f"HTML version available at: {html_file}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to generate HTML report: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|