FA3-Datafetch/stock_analysis.py

295 lines
11 KiB
Python

import os
import sys
import argparse
import subprocess
import shutil
import yaml
import re
import time
from datetime import datetime
from dotenv import load_dotenv
from openai import OpenAI
def load_config():
load_dotenv(override=True)
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
model_name = os.getenv("LLM_MODEL", "gemini-1.5-flash")
# Sanitize base_url
if base_url and base_url.endswith("/chat/completions"):
base_url = base_url.replace("/chat/completions", "")
if base_url and base_url.endswith("/v1/"):
pass
return api_key, base_url, model_name
def run_main_script(market, symbol):
print(f"Running main.py for {market} {symbol}...")
try:
# Assuming main.py is in current directory
# Using sys.executable to ensure the same python environment
cmd = [sys.executable, "main.py", market, symbol]
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f"Error running main.py: {e}")
sys.exit(1)
def get_company_info(market, symbol):
# Path: data/MARKET/SYMBOL/report.md
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
# Try exact match first
if not os.path.exists(symbol_dir):
# Try to find a directory that starts with the symbol
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
# Use the first match (e.g., 688334.SH)
symbol_dir = os.path.join(base_dir, candidates[0])
print(f"Redirecting to found directory: {candidates[0]}")
report_path = os.path.join(symbol_dir, "report.md")
if not os.path.exists(report_path):
print(f"Error: {report_path} not found.")
sys.exit(1)
company_name = symbol # Fallback
try:
with open(report_path, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.splitlines()
if not lines:
return symbol
# Attempt 1: Header Regex "# Name (Code) - Financial Report"
header_match = re.search(r'^#\s+(.+?)\s+\(', lines[0])
if header_match:
company_name = header_match.group(1).strip()
# Attempt 2: Table | Code | Name | ...
else:
for line in lines:
if f"| {symbol}" in line or f"| {symbol.upper()}" in line:
parts = line.split('|')
if len(parts) >= 3:
company_name = parts[2].strip()
break
except Exception as e:
print(f"Warning: Could not extract company name: {e}")
return company_name
def create_report_file(market, symbol, company_name):
date_str = datetime.now().strftime("%Y%m%d")
folder_name = "reports"
file_name = f"{symbol}_{market}_{company_name}_{date_str}_基本面分析报告.md"
os.makedirs(folder_name, exist_ok=True)
target_path = os.path.join(folder_name, file_name)
# Find data directory again (should ideally be passed, but simple lookup works)
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
source_html = os.path.join(symbol_dir, "report.html")
if os.path.exists(source_html):
# Create a relative link to the HTML file instead of copying content
report_dir = os.path.dirname(target_path)
rel_path = os.path.relpath(source_html, report_dir)
with open(target_path, 'w', encoding='utf-8') as f:
f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n")
f.write(f"**生成日期**: {date_str}\n\n")
f.write(f"## 财务报表\n\n")
f.write(f"> [点击查看详细财务图表 ({os.path.basename(source_html)})]({rel_path})\n\n")
else:
print(f"Warning: {source_html} not found. Starting with empty report.")
open(target_path, 'w').close()
return target_path
def load_prompts():
try:
with open("prompts.yaml", 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
except FileNotFoundError:
print("Error: prompts.yaml not found.")
sys.exit(1)
def call_llm(client, model, system_prompt, user_prompt, context, enable_search=False):
start_time = time.time()
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{user_prompt}\n\nExisting Report Data for context:\n{context}"}
]
extra_body = {}
if enable_search:
# Configuration for Gemini Google Search tool
extra_body["tools"] = [{"googleSearch": {}}]
try:
response = client.chat.completions.create(
model=model,
messages=messages,
extra_body=extra_body if enable_search else None
)
end_time = time.time()
duration = end_time - start_time
return response.choices[0].message.content, response.usage, duration
except Exception as e:
print(f"API Call Failed: {e}")
# import traceback
# traceback.print_exc()
return f"\n\nError generating section: {e}\n\n", None, 0
def main():
parser = argparse.ArgumentParser(description="Stock Analysis Automation")
parser.add_argument("market", nargs='?', help="Market (CN/US/HK/JP)")
parser.add_argument("symbol", nargs='?', help="Stock Symbol")
parser.add_argument("--search", action="store_true", help="Enable Google Search for LLM")
args = parser.parse_args()
market = args.market
symbol = args.symbol
if market == "CN" and symbol and symbol.isdigit():
if symbol.startswith("6"):
symbol = f"{symbol}.SH"
elif symbol.startswith("0") or symbol.startswith("3"):
symbol = f"{symbol}.SZ"
elif symbol.startswith("4") or symbol.startswith("8"):
symbol = f"{symbol}.BJ"
print(f"Auto-corrected symbol to: {symbol}")
if not market or not symbol:
try:
market = input("请输入市场 (CN/US/HK/JP): ").strip().upper()
symbol = input("请输入股票代码: ").strip()
except EOFError:
print("\nError: Input needed but EOF received. Please provide arguments or run interactively.")
sys.exit(1)
if not market or not symbol:
print("Market and Symbol are required.")
sys.exit(1)
api_key, base_url, model_name = load_config()
print(f"Configuration Loaded:\n Base URL: {base_url}\n Model: {model_name}\n")
if not api_key:
print("Warning: OPENAI_API_KEY not found in .env. API calls might fail.")
client = OpenAI(api_key=api_key if api_key else "dummy", base_url=base_url, timeout=60.0)
# Step 1: Get Data
run_main_script(market, symbol)
# Step 2: Init Report
company_name = get_company_info(market, symbol)
print(f"Identified Company: {company_name}")
report_file = create_report_file(market, symbol, company_name)
print(f"Report initialized at: {report_file}")
# Load Prompts
prompts = load_prompts()
# Read Data Context
base_dir = os.path.join("data", market)
symbol_dir = os.path.join(base_dir, symbol)
if not os.path.exists(symbol_dir):
candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))]
if candidates:
symbol_dir = os.path.join(base_dir, candidates[0])
data_path = os.path.join(symbol_dir, "report.md")
with open(data_path, 'r', encoding='utf-8') as f:
data_context = f.read()
steps = [
("company_profile", "3. 公司简介 (Company Profile)"),
("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"),
("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"),
("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"),
("bearish_analysis", "7. 看跌分析 (Bearish Analysis)")
]
total_stats = []
for key, name in steps:
print(f"\nProcessing {name}...")
prompt_template = prompts.get(key)
if not prompt_template:
print(f"Warning: Prompt for {key} not found.")
continue
formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol)
system_content = formatted_prompt
user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。"
# Retry logic
max_retries = 3
result = None
usage = None
duration = 0
for attempt in range(max_retries):
try:
result, usage, duration = call_llm(client, model_name, system_content, user_content, data_context, enable_search=args.search)
if result and not result.startswith("\n\nError"):
break
else:
print(f" Attempt {attempt + 1} failed with error content, retrying...")
except Exception as e:
print(f" Attempt {attempt + 1} failed: {e}, retrying...")
if not result or result.startswith("\n\nError"):
result = f"\n> [Error] Failed to generate section {name} after {max_retries} attempts.\n"
elif usage:
total_stats.append({
"step": name,
"duration": duration,
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens
})
with open(report_file, 'a', encoding='utf-8') as f:
f.write(f"\n\n# {name.split('. ')[1] if '. ' in name else name}\n\n")
f.write(result)
print(f"Finished {name} in {duration:.2f}s.")
# Append stats
if total_stats:
with open(report_file, 'a', encoding='utf-8') as f:
f.write("\n\n## API Usage Summary\n\n")
f.write("| Step | Duration (s) | Prompt Tokens | Completion Tokens | Total Tokens |\n")
f.write("|---|---|---|---|---|\n")
total_duration = 0
total_prompt = 0
total_completion = 0
total_tokens = 0
for stat in total_stats:
f.write(f"| {stat['step']} | {stat['duration']:.2f} | {stat['prompt_tokens']} | {stat['completion_tokens']} | {stat['total_tokens']} |\n")
total_duration += stat['duration']
total_prompt += stat['prompt_tokens']
total_completion += stat['completion_tokens']
total_tokens += stat['total_tokens']
f.write(f"| **Total** | **{total_duration:.2f}** | **{total_prompt}** | **{total_completion}** | **{total_tokens}** |\n")
print(f"\nStock analysis for {company_name} ({symbol}) completed. Report saved to {report_file}")
if __name__ == "__main__":
main()