FA3-Datafetch/backend/scripts/migrate_stock_codes.py

94 lines
3.6 KiB
Python

import asyncio
import sys
import os
import re
from sqlalchemy import select, update
# Add parent directory to path to import app modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.database import SessionLocal
from app.models import LLMUsageLog, Company
async def migrate():
async with SessionLocal() as session:
print("🚀 Starting migration of stock codes in LLMUsageLogs...")
# 1. Fetch all companies to build a lookup map
# Map: symbol -> { name, market }
# Note: Symbols might not be unique across markets (e.g. same code in diff markets?),
# but for now we assume symbol allows us to find the main entry or we try to best match.
# To be safe, we might need to be smart about extracting symbol.
print("📦 Fetching company metadata...")
result = await session.execute(select(Company))
companies = result.scalars().all()
# Create lookup: symbol -> Company
# If multiple markets have same symbol, this simple map might be ambiguous.
# We'll use the first one found or maybe we can improve logic if needed.
company_map = {c.symbol: c for c in companies}
print(f"✅ Loaded {len(company_map)} companies.")
# 2. Fetch all logs
print("📜 Fetching chat logs...")
result = await session.execute(select(LLMUsageLog))
logs = result.scalars().all()
print(f"✅ Found {len(logs)} logs. Processing...")
updated_count = 0
skipped_count = 0
for log in logs:
if not log.stock_code:
continue
original_code = log.stock_code
symbol = None
# Pattern 1: "Name (Symbol)" e.g. "金龙鱼 (300999)"
match1 = re.search(r'\((.*?)\)', original_code)
# Pattern 2: "Symbol" e.g. "300999" or "AAPL"
# If no brackets, assume the whole string is the symbol (trimmed)
if match1:
# Extract content inside brackets
content = match1.group(1)
# Check if it already has market info inside brackets e.g. "300999 CH" (space separated)
parts = content.split()
symbol = parts[0]
else:
# No brackets, assume it is just the symbol
symbol = original_code.strip()
# Lookup
if symbol in company_map:
company = company_map[symbol]
# Format: "Name (Symbol Market)"
new_code = f"{company.company_name} ({company.symbol} {company.market})"
if new_code != original_code:
log.stock_code = new_code
updated_count += 1
# print(f" 🔄 Updating: '{original_code}' -> '{new_code}'")
else:
# print(f" ⚠️ Symbol '{symbol}' not found in companies table. Skipping '{original_code}'.")
skipped_count += 1
# 3. Commit changes
if updated_count > 0:
print(f"💾 Committing {updated_count} updates to database...")
await session.commit()
print("✅ Database updated successfully.")
else:
print("✨ No updates needed.")
print(f"Done. Updated: {updated_count}, Skipped/Unchanged: {skipped_count}")
if __name__ == "__main__":
asyncio.run(migrate())