import asyncio import sys import os import re from sqlalchemy import select, update # Add parent directory to path to import app modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.database import SessionLocal from app.models import LLMUsageLog, Company async def migrate(): async with SessionLocal() as session: print("🚀 Starting migration of stock codes in LLMUsageLogs...") # 1. Fetch all companies to build a lookup map # Map: symbol -> { name, market } # Note: Symbols might not be unique across markets (e.g. same code in diff markets?), # but for now we assume symbol allows us to find the main entry or we try to best match. # To be safe, we might need to be smart about extracting symbol. print("📦 Fetching company metadata...") result = await session.execute(select(Company)) companies = result.scalars().all() # Create lookup: symbol -> Company # If multiple markets have same symbol, this simple map might be ambiguous. # We'll use the first one found or maybe we can improve logic if needed. company_map = {c.symbol: c for c in companies} print(f"✅ Loaded {len(company_map)} companies.") # 2. Fetch all logs print("📜 Fetching chat logs...") result = await session.execute(select(LLMUsageLog)) logs = result.scalars().all() print(f"✅ Found {len(logs)} logs. Processing...") updated_count = 0 skipped_count = 0 for log in logs: if not log.stock_code: continue original_code = log.stock_code symbol = None # Pattern 1: "Name (Symbol)" e.g. "金龙鱼 (300999)" match1 = re.search(r'\((.*?)\)', original_code) # Pattern 2: "Symbol" e.g. "300999" or "AAPL" # If no brackets, assume the whole string is the symbol (trimmed) if match1: # Extract content inside brackets content = match1.group(1) # Check if it already has market info inside brackets e.g. "300999 CH" (space separated) parts = content.split() symbol = parts[0] else: # No brackets, assume it is just the symbol symbol = original_code.strip() # Lookup if symbol in company_map: company = company_map[symbol] # Format: "Name (Symbol Market)" new_code = f"{company.company_name} ({company.symbol} {company.market})" if new_code != original_code: log.stock_code = new_code updated_count += 1 # print(f" 🔄 Updating: '{original_code}' -> '{new_code}'") else: # print(f" ⚠️ Symbol '{symbol}' not found in companies table. Skipping '{original_code}'.") skipped_count += 1 # 3. Commit changes if updated_count > 0: print(f"💾 Committing {updated_count} updates to database...") await session.commit() print("✅ Database updated successfully.") else: print("✨ No updates needed.") print(f"Done. Updated: {updated_count}, Skipped/Unchanged: {skipped_count}") if __name__ == "__main__": asyncio.run(migrate())