295 lines
11 KiB
Python
295 lines
11 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
import subprocess
|
|
import shutil
|
|
import yaml
|
|
import re
|
|
import time
|
|
from datetime import datetime
|
|
from dotenv import load_dotenv
|
|
from openai import OpenAI
|
|
|
|
def load_config():
|
|
load_dotenv(override=True)
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
|
|
model_name = os.getenv("LLM_MODEL", "gemini-1.5-flash")
|
|
|
|
# Sanitize base_url
|
|
if base_url and base_url.endswith("/chat/completions"):
|
|
base_url = base_url.replace("/chat/completions", "")
|
|
if base_url and base_url.endswith("/v1/"):
|
|
pass
|
|
|
|
return api_key, base_url, model_name
|
|
|
|
def run_main_script(market, symbol):
|
|
print(f"Running main.py for {market} {symbol}...")
|
|
try:
|
|
# Assuming main.py is in current directory
|
|
# Using sys.executable to ensure the same python environment
|
|
cmd = [sys.executable, "main.py", market, symbol]
|
|
subprocess.run(cmd, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error running main.py: {e}")
|
|
sys.exit(1)
|
|
|
|
def get_company_info(market, symbol):
|
|
# Path: data/MARKET/SYMBOL/report.md
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
|
|
# Try exact match first
|
|
if not os.path.exists(symbol_dir):
|
|
# Try to find a directory that starts with the symbol
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
# Use the first match (e.g., 688334.SH)
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
print(f"Redirecting to found directory: {candidates[0]}")
|
|
|
|
report_path = os.path.join(symbol_dir, "report.md")
|
|
|
|
if not os.path.exists(report_path):
|
|
print(f"Error: {report_path} not found.")
|
|
sys.exit(1)
|
|
|
|
company_name = symbol # Fallback
|
|
|
|
try:
|
|
with open(report_path, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
lines = content.splitlines()
|
|
if not lines:
|
|
return symbol
|
|
|
|
# Attempt 1: Header Regex "# Name (Code) - Financial Report"
|
|
header_match = re.search(r'^#\s+(.+?)\s+\(', lines[0])
|
|
if header_match:
|
|
company_name = header_match.group(1).strip()
|
|
# Attempt 2: Table | Code | Name | ...
|
|
else:
|
|
for line in lines:
|
|
if f"| {symbol}" in line or f"| {symbol.upper()}" in line:
|
|
parts = line.split('|')
|
|
if len(parts) >= 3:
|
|
company_name = parts[2].strip()
|
|
break
|
|
except Exception as e:
|
|
print(f"Warning: Could not extract company name: {e}")
|
|
|
|
return company_name
|
|
|
|
def create_report_file(market, symbol, company_name):
|
|
date_str = datetime.now().strftime("%Y%m%d")
|
|
folder_name = "reports"
|
|
file_name = f"{symbol}_{market}_{company_name}_{date_str}_基本面分析报告.md"
|
|
|
|
os.makedirs(folder_name, exist_ok=True)
|
|
target_path = os.path.join(folder_name, file_name)
|
|
|
|
# Find data directory again (should ideally be passed, but simple lookup works)
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
source_html = os.path.join(symbol_dir, "report.html")
|
|
if os.path.exists(source_html):
|
|
# Create a relative link to the HTML file instead of copying content
|
|
report_dir = os.path.dirname(target_path)
|
|
rel_path = os.path.relpath(source_html, report_dir)
|
|
|
|
with open(target_path, 'w', encoding='utf-8') as f:
|
|
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
|
|
f.write(f"**生成日期**: {date_str}\n\n")
|
|
f.write(f"## 财务报表\n\n")
|
|
f.write(f"> [点击查看详细财务图表 ({os.path.basename(source_html)})]({rel_path})\n\n")
|
|
else:
|
|
print(f"Warning: {source_html} not found. Starting with empty report.")
|
|
open(target_path, 'w').close()
|
|
|
|
return target_path
|
|
|
|
def load_prompts():
|
|
try:
|
|
with open("prompts.yaml", 'r', encoding='utf-8') as f:
|
|
return yaml.safe_load(f)
|
|
except FileNotFoundError:
|
|
print("Error: prompts.yaml not found.")
|
|
sys.exit(1)
|
|
|
|
def call_llm(client, model, system_prompt, user_prompt, context, enable_search=False):
|
|
start_time = time.time()
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": f"{user_prompt}\n\nExisting Report Data for context:\n{context}"}
|
|
]
|
|
|
|
extra_body = {}
|
|
if enable_search:
|
|
# Configuration for Gemini Google Search tool
|
|
extra_body["tools"] = [{"googleSearch": {}}]
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
extra_body=extra_body if enable_search else None
|
|
)
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
return response.choices[0].message.content, response.usage, duration
|
|
except Exception as e:
|
|
print(f"API Call Failed: {e}")
|
|
# import traceback
|
|
# traceback.print_exc()
|
|
return f"\n\nError generating section: {e}\n\n", None, 0
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Stock Analysis Automation")
|
|
parser.add_argument("market", nargs='?', help="Market (CN/US/HK/JP)")
|
|
parser.add_argument("symbol", nargs='?', help="Stock Symbol")
|
|
parser.add_argument("--search", action="store_true", help="Enable Google Search for LLM")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
market = args.market
|
|
symbol = args.symbol
|
|
|
|
if market == "CN" and symbol and symbol.isdigit():
|
|
if symbol.startswith("6"):
|
|
symbol = f"{symbol}.SH"
|
|
elif symbol.startswith("0") or symbol.startswith("3"):
|
|
symbol = f"{symbol}.SZ"
|
|
elif symbol.startswith("4") or symbol.startswith("8"):
|
|
symbol = f"{symbol}.BJ"
|
|
print(f"Auto-corrected symbol to: {symbol}")
|
|
|
|
if not market or not symbol:
|
|
try:
|
|
market = input("请输入市场 (CN/US/HK/JP): ").strip().upper()
|
|
symbol = input("请输入股票代码: ").strip()
|
|
except EOFError:
|
|
print("\nError: Input needed but EOF received. Please provide arguments or run interactively.")
|
|
sys.exit(1)
|
|
|
|
if not market or not symbol:
|
|
print("Market and Symbol are required.")
|
|
sys.exit(1)
|
|
|
|
api_key, base_url, model_name = load_config()
|
|
print(f"Configuration Loaded:\n Base URL: {base_url}\n Model: {model_name}\n")
|
|
|
|
if not api_key:
|
|
print("Warning: OPENAI_API_KEY not found in .env. API calls might fail.")
|
|
|
|
client = OpenAI(api_key=api_key if api_key else "dummy", base_url=base_url, timeout=60.0)
|
|
|
|
# Step 1: Get Data
|
|
run_main_script(market, symbol)
|
|
|
|
# Step 2: Init Report
|
|
company_name = get_company_info(market, symbol)
|
|
print(f"Identified Company: {company_name}")
|
|
report_file = create_report_file(market, symbol, company_name)
|
|
print(f"Report initialized at: {report_file}")
|
|
|
|
# Load Prompts
|
|
prompts = load_prompts()
|
|
|
|
# Read Data Context
|
|
base_dir = os.path.join("data", market)
|
|
symbol_dir = os.path.join(base_dir, symbol)
|
|
if not os.path.exists(symbol_dir):
|
|
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
|
|
if candidates:
|
|
symbol_dir = os.path.join(base_dir, candidates[0])
|
|
|
|
data_path = os.path.join(symbol_dir, "report.md")
|
|
with open(data_path, 'r', encoding='utf-8') as f:
|
|
data_context = f.read()
|
|
|
|
steps = [
|
|
("company_profile", "3. 公司简介 (Company Profile)"),
|
|
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
|
|
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
|
|
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
|
|
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
|
|
]
|
|
|
|
total_stats = []
|
|
|
|
for key, name in steps:
|
|
print(f"\nProcessing {name}...")
|
|
prompt_template = prompts.get(key)
|
|
if not prompt_template:
|
|
print(f"Warning: Prompt for {key} not found.")
|
|
continue
|
|
|
|
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
|
|
|
|
system_content = formatted_prompt
|
|
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
|
|
|
|
# Retry logic
|
|
max_retries = 3
|
|
result = None
|
|
usage = None
|
|
duration = 0
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
result, usage, duration = call_llm(client, model_name, system_content, user_content, data_context, enable_search=args.search)
|
|
if result and not result.startswith("\n\nError"):
|
|
break
|
|
else:
|
|
print(f" Attempt {attempt + 1} failed with error content, retrying...")
|
|
except Exception as e:
|
|
print(f" Attempt {attempt + 1} failed: {e}, retrying...")
|
|
|
|
if not result or result.startswith("\n\nError"):
|
|
result = f"\n> [Error] Failed to generate section {name} after {max_retries} attempts.\n"
|
|
elif usage:
|
|
total_stats.append({
|
|
"step": name,
|
|
"duration": duration,
|
|
"prompt_tokens": usage.prompt_tokens,
|
|
"completion_tokens": usage.completion_tokens,
|
|
"total_tokens": usage.total_tokens
|
|
})
|
|
|
|
with open(report_file, 'a', encoding='utf-8') as f:
|
|
f.write(f"\n\n# {name.split('. ')[1] if '. ' in name else name}\n\n")
|
|
f.write(result)
|
|
print(f"Finished {name} in {duration:.2f}s.")
|
|
|
|
# Append stats
|
|
if total_stats:
|
|
with open(report_file, 'a', encoding='utf-8') as f:
|
|
f.write("\n\n## API Usage Summary\n\n")
|
|
f.write("| Step | Duration (s) | Prompt Tokens | Completion Tokens | Total Tokens |\n")
|
|
f.write("|---|---|---|---|---|\n")
|
|
total_duration = 0
|
|
total_prompt = 0
|
|
total_completion = 0
|
|
total_tokens = 0
|
|
|
|
for stat in total_stats:
|
|
f.write(f"| {stat['step']} | {stat['duration']:.2f} | {stat['prompt_tokens']} | {stat['completion_tokens']} | {stat['total_tokens']} |\n")
|
|
total_duration += stat['duration']
|
|
total_prompt += stat['prompt_tokens']
|
|
total_completion += stat['completion_tokens']
|
|
total_tokens += stat['total_tokens']
|
|
|
|
f.write(f"| **Total** | **{total_duration:.2f}** | **{total_prompt}** | **{total_completion}** | **{total_tokens}** |\n")
|
|
|
|
print(f"\nStock analysis for {company_name} ({symbol}) completed. Report saved to {report_file}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|