FA3-Datafetch/stock_analysis.py

527 lines
18 KiB
Python

import os
import sys
import argparse
import subprocess
import shutil
import yaml
import re
import time
from datetime import datetime
from dotenv import load_dotenv
from google import genai
from google.genai import types
import markdown
def load_config():
load_dotenv(override=True)
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("OPENAI_API_KEY")
model_name = os.getenv("LLM_MODEL", "gemini-2.0-flash-exp")
return api_key, model_name
def run_main_script(market, symbol):
print(f"Running main.py for {market} {symbol}...")
try:
# Assuming main.py is in current directory
# Using sys.executable to ensure the same python environment
cmd = [sys.executable, "main.py", market, symbol]
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f"Error running main.py: {e}")
sys.exit(1)
def get_company_info(market, symbol):
# Path: data/MARKET/SYMBOL/report.md
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
# Try exact match first
if not os.path.exists(symbol_dir):
# Try to find a directory that starts with the symbol
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
# Use the first match (e.g., 688334.SH)
symbol_dir = os.path.join(base_dir, candidates[0])
print(f"Redirecting to found directory: {candidates[0]}")
report_path = os.path.join(symbol_dir, "report.md")
if not os.path.exists(report_path):
print(f"Error: {report_path} not found.")
sys.exit(1)
company_name = symbol # Fallback
try:
with open(report_path, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.splitlines()
if not lines:
return symbol
# Attempt 1: Header Regex "# Name (Code) - Financial Report"
header_match = re.search(r'^#\s+(.+?)\s+\(', lines[0])
if header_match:
company_name = header_match.group(1).strip()
# Attempt 2: Table | Code | Name | ...
else:
for line in lines:
if f"| {symbol}" in line or f"| {symbol.upper()}" in line:
parts = line.split('|')
if len(parts) >= 3:
company_name = parts[2].strip()
break
except Exception as e:
print(f"Warning: Could not extract company name: {e}")
return company_name
def create_report_file(market, symbol, company_name):
date_str = datetime.now().strftime("%Y%m%d")
folder_name = "Reports"
file_name = f"{symbol}_{market}_{company_name}_{date_str}_基本面分析报告.md"
os.makedirs(folder_name, exist_ok=True)
target_path = os.path.join(folder_name, file_name)
# Find data directory again (should ideally be passed, but simple lookup works)
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
source_html = os.path.join(symbol_dir, "report.html")
if os.path.exists(source_html):
# Create a relative link to the HTML file instead of copying content
report_dir = os.path.dirname(target_path)
rel_path = os.path.relpath(source_html, report_dir)
with open(target_path, 'w', encoding='utf-8') as f:
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
f.write(f"**生成日期**: {date_str}\n\n")
else:
with open(target_path, 'w', encoding='utf-8') as f:
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
f.write(f"**生成日期**: {date_str}\n\n")
return target_path
def load_prompts():
prompts = {}
prompt_dir = "Prompt"
mapping = {
"company_profile": "公司简介.md",
"fundamental_analysis": "基本面分析.md",
"insider_analysis": "内部人与机构动向分析.md",
"bullish_analysis": "看涨分析.md",
"bearish_analysis": "看跌分析.md"
}
for key, filename in mapping.items():
try:
with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f:
prompts[key] = f.read()
except FileNotFoundError:
print(f"Warning: Prompt file {filename} not found.")
return prompts
def call_llm(api_key, model_name, system_prompt, user_prompt, context, enable_search=False):
start_time = time.time()
# Combine system and user prompts for Gemini
full_prompt = f"{system_prompt}\n\n{user_prompt}\n\nExisting Report Data for context:\n{context}"
# Create client
client = genai.Client(api_key=api_key)
# Configure tools if search is enabled
config_params = {}
if enable_search:
grounding_tool = types.Tool(google_search=types.GoogleSearch())
config_params['tools'] = [grounding_tool]
config = types.GenerateContentConfig(**config_params)
try:
response = client.models.generate_content(
model=model_name,
contents=full_prompt,
config=config
)
end_time = time.time()
duration = end_time - start_time
# Extract usage information
usage = {
'prompt_tokens': response.usage_metadata.prompt_token_count if hasattr(response, 'usage_metadata') else 0,
'completion_tokens': response.usage_metadata.candidates_token_count if hasattr(response, 'usage_metadata') else 0,
'total_tokens': response.usage_metadata.total_token_count if hasattr(response, 'usage_metadata') else 0
}
# Create a simple usage object
class Usage:
def __init__(self, prompt_tokens, completion_tokens, total_tokens):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = total_tokens
usage_obj = Usage(usage['prompt_tokens'], usage['completion_tokens'], usage['total_tokens'])
return response.text, usage_obj, duration
except Exception as e:
print(f"API Call Failed: {e}")
import traceback
traceback.print_exc()
return f"\n\nError generating section: {e}\n\n", None, 0
def render_html_report(report_file, market, symbol):
"""
将Markdown报告渲染成HTML,并在顶部嵌入财务数据HTML图表
Args:
report_file: Markdown报告文件路径
market: 市场代码
symbol: 股票代码
Returns:
生成的HTML文件路径
"""
# 读取Markdown报告
with open(report_file, 'r', encoding='utf-8') as f:
md_content = f.read()
# 查找财务数据HTML文件
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
financial_html_path = os.path.join(symbol_dir, "report.html")
financial_html_content = ""
if os.path.exists(financial_html_path):
with open(financial_html_path, 'r', encoding='utf-8') as f:
financial_html_content = f.read()
# 将Markdown转换为HTML
md = markdown.Markdown(extensions=['tables', 'fenced_code', 'codehilite'])
report_html_content = md.convert(md_content)
# 构建完整的HTML文档
html_template = f"""<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{os.path.basename(report_file).replace('.md', '')}</title>
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
line-height: 1.6;
max-width: 1200px;
margin: 0 auto;
padding: 20px;
background-color: #f5f5f5;
}}
.container {{
background-color: white;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
h1 {{
color: #2c3e50;
border-bottom: 3px solid #3498db;
padding-bottom: 10px;
}}
h2 {{
color: #34495e;
margin-top: 30px;
border-bottom: 2px solid #ecf0f1;
padding-bottom: 8px;
}}
h3 {{
color: #7f8c8d;
margin-top: 20px;
}}
table {{
border-collapse: collapse;
width: 100%;
margin: 20px 0;
background-color: white;
}}
th, td {{
border: 1px solid #ddd;
padding: 12px;
text-align: left;
}}
th {{
background-color: #3498db;
color: white;
font-weight: bold;
}}
tr:nth-child(even) {{
background-color: #f9f9f9;
}}
tr:hover {{
background-color: #f5f5f5;
}}
blockquote {{
border-left: 4px solid #3498db;
padding-left: 20px;
margin: 20px 0;
color: #555;
background-color: #f8f9fa;
padding: 15px 20px;
border-radius: 4px;
}}
code {{
background-color: #f4f4f4;
padding: 2px 6px;
border-radius: 3px;
font-family: "Courier New", monospace;
}}
pre {{
background-color: #f4f4f4;
padding: 15px;
border-radius: 5px;
overflow-x: auto;
}}
.financial-data {{
margin: 30px 0;
padding: 20px;
background-color: #f8f9fa;
border-radius: 8px;
border: 2px solid #3498db;
}}
.financial-data h2 {{
color: #3498db;
margin-top: 0;
}}
a {{
color: #3498db;
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
}}
.toc {{
background-color: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}}
.toc h2 {{
margin-top: 0;
color: #2c3e50;
}}
.toc ul {{
list-style-type: none;
padding-left: 0;
}}
.toc li {{
margin: 8px 0;
}}
</style>
</head>
<body>
<div class="container">
{report_html_content.split('</h1>')[0]}</h1>
<div class="financial-data">
<h2>📊 财务数据可视化</h2>
{financial_html_content if financial_html_content else '<p>财务数据图表未找到</p>'}
</div>
{'</h1>'.join(report_html_content.split('</h1>')[1:]) if '</h1>' in report_html_content else report_html_content}
</div>
</body>
</html>
"""
# 保存HTML文件到Reports-HTML目录
html_folder = "Reports-HTML"
os.makedirs(html_folder, exist_ok=True)
# 获取原始文件名并替换扩展名
base_filename = os.path.basename(report_file).replace('.md', '.html')
html_file = os.path.join(html_folder, base_filename)
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_template)
print(f"HTML report generated: {html_file}")
return html_file
def main():
parser = argparse.ArgumentParser(description="Stock Analysis Automation")
parser.add_argument("market", nargs='?', help="Market (CN/US/HK/JP)")
parser.add_argument("symbol", nargs='?', help="Stock Symbol")
parser.add_argument("--search", action="store_true", default=True, help="Enable Google Search for LLM (enabled by default)")
parser.add_argument("--no-search", dest="search", action="store_false", help="Disable Google Search for LLM")
args = parser.parse_args()
market = args.market
symbol = args.symbol
if market == "CN" and symbol and symbol.isdigit():
if symbol.startswith("6"):
symbol = f"{symbol}.SH"
elif symbol.startswith("0") or symbol.startswith("3"):
symbol = f"{symbol}.SZ"
elif symbol.startswith("4") or symbol.startswith("8"):
symbol = f"{symbol}.BJ"
print(f"Auto-corrected symbol to: {symbol}")
if not market or not symbol:
try:
market = input("请输入市场 (CN/US/HK/JP): ").strip().upper()
symbol = input("请输入股票代码: ").strip()
except EOFError:
print("\nError: Input needed but EOF received. Please provide arguments or run interactively.")
sys.exit(1)
if not market or not symbol:
print("Market and Symbol are required.")
sys.exit(1)
api_key, model_name = load_config()
print(f"Configuration Loaded:\n Model: {model_name}\n")
if not api_key:
print("Warning: GEMINI_API_KEY not found in .env. API calls might fail.")
# Step 1: Get Data
run_main_script(market, symbol)
# Step 2: Init Report
company_name = get_company_info(market, symbol)
print(f"Identified Company: {company_name}")
report_file = create_report_file(market, symbol, company_name)
print(f"Report initialized at: {report_file}")
# Load Prompts
prompts = load_prompts()
# Read Data Context
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
data_path = os.path.join(symbol_dir, "report.md")
with open(data_path, 'r', encoding='utf-8') as f:
data_context = f.read()
# Read CSV data for bearish analysis
csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv")
csv_context = ""
if os.path.exists(csv_path):
try:
with open(csv_path, 'r', encoding='utf-8') as f:
csv_content = f.read()
csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n"
except Exception as e:
print(f"Warning: Could not read CSV file: {e}")
steps = [
("company_profile", "3. 公司简介 (Company Profile)"),
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
]
total_stats = []
for key, name in steps:
print(f"\nProcessing {name}...")
prompt_template = prompts.get(key)
if not prompt_template:
print(f"Warning: Prompt for {key} not found.")
continue
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
system_content = formatted_prompt
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
current_data_context = data_context
# Inject CSV data only for bearish analysis
if key == "bearish_analysis" and csv_context:
current_data_context += csv_context
# Retry logic
max_retries = 3
result = None
usage = None
duration = 0
for attempt in range(max_retries):
try:
result, usage, duration = call_llm(api_key, model_name, system_content, user_content, current_data_context, enable_search=args.search)
if result and not result.startswith("\n\nError"):
break
else:
print(f" Attempt {attempt + 1} failed with error content, retrying...")
except Exception as e:
print(f" Attempt {attempt + 1} failed: {e}, retrying...")
if not result or result.startswith("\n\nError"):
result = f"\n> [Error] Failed to generate section {name} after {max_retries} attempts.\n"
elif usage:
total_stats.append({
"step": name,
"duration": duration,
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens
})
with open(report_file, 'a', encoding='utf-8') as f:
f.write(f"\n\n# {name.split('. ')[1] if '. ' in name else name}\n\n")
f.write(result)
print(f"Finished {name} in {duration:.2f}s.")
# Append stats
if total_stats:
with open(report_file, 'a', encoding='utf-8') as f:
f.write("\n\n## API Usage Summary\n\n")
f.write("| Step | Duration (s) | Prompt Tokens | Completion Tokens | Total Tokens |\n")
f.write("|---|---|---|---|---|\n")
total_duration = 0
total_prompt = 0
total_completion = 0
total_tokens = 0
for stat in total_stats:
f.write(f"| {stat['step']} | {stat['duration']:.2f} | {stat['prompt_tokens']} | {stat['completion_tokens']} | {stat['total_tokens']} |\n")
total_duration += stat['duration']
total_prompt += stat['prompt_tokens']
total_completion += stat['completion_tokens']
total_tokens += stat['total_tokens']
f.write(f"| **Total** | **{total_duration:.2f}** | **{total_prompt}** | **{total_completion}** | **{total_tokens}** |\n")
print(f"\nStock analysis for {company_name} ({symbol}) completed. Report saved to {report_file}")
# Generate HTML version of the report
try:
html_file = render_html_report(report_file, market, symbol)
print(f"HTML version available at: {html_file}")
except Exception as e:
print(f"Warning: Failed to generate HTML report: {e}")
if __name__ == "__main__":
main()