import os import sys import argparse import subprocess import shutil import yaml import re import time from datetime import datetime from dotenv import load_dotenv from openai import OpenAI def load_config(): load_dotenv(override=True) api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/") model_name = os.getenv("LLM_MODEL", "gemini-1.5-flash") # Sanitize base_url if base_url and base_url.endswith("/chat/completions"): base_url = base_url.replace("/chat/completions", "") if base_url and base_url.endswith("/v1/"): pass return api_key, base_url, model_name def run_main_script(market, symbol): print(f"Running main.py for {market} {symbol}...") try: # Assuming main.py is in current directory # Using sys.executable to ensure the same python environment cmd = [sys.executable, "main.py", market, symbol] subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: print(f"Error running main.py: {e}") sys.exit(1) def get_company_info(market, symbol): # Path: data/MARKET/SYMBOL/report.md base_dir = os.path.join("data", market) symbol_dir = os.path.join(base_dir, symbol) # Try exact match first if not os.path.exists(symbol_dir): # Try to find a directory that starts with the symbol candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))] if candidates: # Use the first match (e.g., 688334.SH) symbol_dir = os.path.join(base_dir, candidates[0]) print(f"Redirecting to found directory: {candidates[0]}") report_path = os.path.join(symbol_dir, "report.md") if not os.path.exists(report_path): print(f"Error: {report_path} not found.") sys.exit(1) company_name = symbol # Fallback try: with open(report_path, 'r', encoding='utf-8') as f: content = f.read() lines = content.splitlines() if not lines: return symbol # Attempt 1: Header Regex "# Name (Code) - Financial Report" header_match = re.search(r'^#\s+(.+?)\s+\(', lines[0]) if header_match: company_name = header_match.group(1).strip() # Attempt 2: Table | Code | Name | ... else: for line in lines: if f"| {symbol}" in line or f"| {symbol.upper()}" in line: parts = line.split('|') if len(parts) >= 3: company_name = parts[2].strip() break except Exception as e: print(f"Warning: Could not extract company name: {e}") return company_name def create_report_file(market, symbol, company_name): date_str = datetime.now().strftime("%Y%m%d") folder_name = "Reports" file_name = f"{symbol}_{market}_{company_name}_{date_str}_基本面分析报告.md" os.makedirs(folder_name, exist_ok=True) target_path = os.path.join(folder_name, file_name) # Find data directory again (should ideally be passed, but simple lookup works) base_dir = os.path.join("data", market) symbol_dir = os.path.join(base_dir, symbol) if not os.path.exists(symbol_dir): candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))] if candidates: symbol_dir = os.path.join(base_dir, candidates[0]) source_html = os.path.join(symbol_dir, "report.html") if os.path.exists(source_html): # Create a relative link to the HTML file instead of copying content report_dir = os.path.dirname(target_path) rel_path = os.path.relpath(source_html, report_dir) with open(target_path, 'w', encoding='utf-8') as f: f.write(f"# {company_name} ({symbol}) - 基本面分析报告\n\n") f.write(f"**生成日期**: {date_str}\n\n") f.write(f"## 财务报表\n\n") f.write(f"> [点击查看详细财务图表 ({os.path.basename(source_html)})]({rel_path})\n\n") else: print(f"Warning: {source_html} not found. Starting with empty report.") open(target_path, 'w').close() return target_path def load_prompts(): prompts = {} prompt_dir = "Prompt" mapping = { "company_profile": "公司简介.md", "fundamental_analysis": "基本面分析.md", "insider_analysis": "内部人与机构动向分析.md", "bullish_analysis": "看涨分析.md", "bearish_analysis": "看跌分析.md" } for key, filename in mapping.items(): try: with open(os.path.join(prompt_dir, filename), 'r', encoding='utf-8') as f: prompts[key] = f.read() except FileNotFoundError: print(f"Warning: Prompt file {filename} not found.") return prompts def call_llm(client, model, system_prompt, user_prompt, context, enable_search=False): start_time = time.time() messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"{user_prompt}\n\nExisting Report Data for context:\n{context}"} ] extra_body = {} if enable_search: # Configuration for Gemini Google Search tool # Use google_search (with underscore) as per Gemini API specification extra_body["tools"] = [{"google_search": {}}] try: response = client.chat.completions.create( model=model, messages=messages, extra_body=extra_body if enable_search else None ) end_time = time.time() duration = end_time - start_time return response.choices[0].message.content, response.usage, duration except Exception as e: print(f"API Call Failed: {e}") # import traceback # traceback.print_exc() return f"\n\nError generating section: {e}\n\n", None, 0 def main(): parser = argparse.ArgumentParser(description="Stock Analysis Automation") parser.add_argument("market", nargs='?', help="Market (CN/US/HK/JP)") parser.add_argument("symbol", nargs='?', help="Stock Symbol") parser.add_argument("--search", action="store_true", help="Enable Google Search for LLM") args = parser.parse_args() market = args.market symbol = args.symbol if market == "CN" and symbol and symbol.isdigit(): if symbol.startswith("6"): symbol = f"{symbol}.SH" elif symbol.startswith("0") or symbol.startswith("3"): symbol = f"{symbol}.SZ" elif symbol.startswith("4") or symbol.startswith("8"): symbol = f"{symbol}.BJ" print(f"Auto-corrected symbol to: {symbol}") if not market or not symbol: try: market = input("请输入市场 (CN/US/HK/JP): ").strip().upper() symbol = input("请输入股票代码: ").strip() except EOFError: print("\nError: Input needed but EOF received. Please provide arguments or run interactively.") sys.exit(1) if not market or not symbol: print("Market and Symbol are required.") sys.exit(1) api_key, base_url, model_name = load_config() print(f"Configuration Loaded:\n Base URL: {base_url}\n Model: {model_name}\n") if not api_key: print("Warning: OPENAI_API_KEY not found in .env. API calls might fail.") client = OpenAI(api_key=api_key if api_key else "dummy", base_url=base_url, timeout=60.0) # Step 1: Get Data run_main_script(market, symbol) # Step 2: Init Report company_name = get_company_info(market, symbol) print(f"Identified Company: {company_name}") report_file = create_report_file(market, symbol, company_name) print(f"Report initialized at: {report_file}") # Load Prompts prompts = load_prompts() # Read Data Context base_dir = os.path.join("data", market) symbol_dir = os.path.join(base_dir, symbol) if not os.path.exists(symbol_dir): candidates = [d for d in os.listdir(base_dir) if d.startswith(symbol) and os.path.isdir(os.path.join(base_dir, d))] if candidates: symbol_dir = os.path.join(base_dir, candidates[0]) data_path = os.path.join(symbol_dir, "report.md") with open(data_path, 'r', encoding='utf-8') as f: data_context = f.read() # Read CSV data for bearish analysis csv_path = os.path.join(symbol_dir, "raw_balance_sheet_raw.csv") csv_context = "" if os.path.exists(csv_path): try: with open(csv_path, 'r', encoding='utf-8') as f: csv_content = f.read() csv_context = f"\n\nRaw Balance Sheet Data (CSV):\n{csv_content}\n" except Exception as e: print(f"Warning: Could not read CSV file: {e}") steps = [ ("company_profile", "3. 公司简介 (Company Profile)"), ("fundamental_analysis", "4. 基本面分析 (Fundamental Analysis)"), ("insider_analysis", "5. 内部人与机构动向 (Insider Analysis)"), ("bullish_analysis", "6. 看涨分析 (Bullish Analysis)"), ("bearish_analysis", "7. 看跌分析 (Bearish Analysis)") ] total_stats = [] for key, name in steps: print(f"\nProcessing {name}...") prompt_template = prompts.get(key) if not prompt_template: print(f"Warning: Prompt for {key} not found.") continue formatted_prompt = prompt_template.format(company_name=company_name, ts_code=symbol) system_content = formatted_prompt user_content = "请根据上述角色设定和要求,结合提供的财务数据,撰写本章节的分析报告。" current_data_context = data_context # Inject CSV data only for bearish analysis if key == "bearish_analysis" and csv_context: current_data_context += csv_context # Retry logic max_retries = 3 result = None usage = None duration = 0 for attempt in range(max_retries): try: result, usage, duration = call_llm(client, model_name, system_content, user_content, current_data_context, enable_search=args.search) if result and not result.startswith("\n\nError"): break else: print(f" Attempt {attempt + 1} failed with error content, retrying...") except Exception as e: print(f" Attempt {attempt + 1} failed: {e}, retrying...") if not result or result.startswith("\n\nError"): result = f"\n> [Error] Failed to generate section {name} after {max_retries} attempts.\n" elif usage: total_stats.append({ "step": name, "duration": duration, "prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens, "total_tokens": usage.total_tokens }) with open(report_file, 'a', encoding='utf-8') as f: f.write(f"\n\n# {name.split('. ')[1] if '. ' in name else name}\n\n") f.write(result) print(f"Finished {name} in {duration:.2f}s.") # Append stats if total_stats: with open(report_file, 'a', encoding='utf-8') as f: f.write("\n\n## API Usage Summary\n\n") f.write("| Step | Duration (s) | Prompt Tokens | Completion Tokens | Total Tokens |\n") f.write("|---|---|---|---|---|\n") total_duration = 0 total_prompt = 0 total_completion = 0 total_tokens = 0 for stat in total_stats: f.write(f"| {stat['step']} | {stat['duration']:.2f} | {stat['prompt_tokens']} | {stat['completion_tokens']} | {stat['total_tokens']} |\n") total_duration += stat['duration'] total_prompt += stat['prompt_tokens'] total_completion += stat['completion_tokens'] total_tokens += stat['total_tokens'] f.write(f"| **Total** | **{total_duration:.2f}** | **{total_prompt}** | **{total_completion}** | **{total_tokens}** |\n") print(f"\nStock analysis for {company_name} ({symbol}) completed. Report saved to {report_file}") if __name__ == "__main__": main()