94 lines
3.6 KiB
Python
94 lines
3.6 KiB
Python
import asyncio
|
|
import sys
|
|
import os
|
|
import re
|
|
from sqlalchemy import select, update
|
|
|
|
# Add parent directory to path to import app modules
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from app.database import SessionLocal
|
|
from app.models import LLMUsageLog, Company
|
|
|
|
async def migrate():
|
|
async with SessionLocal() as session:
|
|
print("🚀 Starting migration of stock codes in LLMUsageLogs...")
|
|
|
|
# 1. Fetch all companies to build a lookup map
|
|
# Map: symbol -> { name, market }
|
|
# Note: Symbols might not be unique across markets (e.g. same code in diff markets?),
|
|
# but for now we assume symbol allows us to find the main entry or we try to best match.
|
|
# To be safe, we might need to be smart about extracting symbol.
|
|
|
|
print("📦 Fetching company metadata...")
|
|
result = await session.execute(select(Company))
|
|
companies = result.scalars().all()
|
|
|
|
# Create lookup: symbol -> Company
|
|
# If multiple markets have same symbol, this simple map might be ambiguous.
|
|
# We'll use the first one found or maybe we can improve logic if needed.
|
|
company_map = {c.symbol: c for c in companies}
|
|
|
|
print(f"✅ Loaded {len(company_map)} companies.")
|
|
|
|
# 2. Fetch all logs
|
|
print("📜 Fetching chat logs...")
|
|
result = await session.execute(select(LLMUsageLog))
|
|
logs = result.scalars().all()
|
|
|
|
print(f"✅ Found {len(logs)} logs. Processing...")
|
|
|
|
updated_count = 0
|
|
skipped_count = 0
|
|
|
|
for log in logs:
|
|
if not log.stock_code:
|
|
continue
|
|
|
|
original_code = log.stock_code
|
|
symbol = None
|
|
|
|
# Pattern 1: "Name (Symbol)" e.g. "金龙鱼 (300999)"
|
|
match1 = re.search(r'\((.*?)\)', original_code)
|
|
|
|
# Pattern 2: "Symbol" e.g. "300999" or "AAPL"
|
|
# If no brackets, assume the whole string is the symbol (trimmed)
|
|
|
|
if match1:
|
|
# Extract content inside brackets
|
|
content = match1.group(1)
|
|
# Check if it already has market info inside brackets e.g. "300999 CH" (space separated)
|
|
parts = content.split()
|
|
symbol = parts[0]
|
|
else:
|
|
# No brackets, assume it is just the symbol
|
|
symbol = original_code.strip()
|
|
|
|
# Lookup
|
|
if symbol in company_map:
|
|
company = company_map[symbol]
|
|
|
|
# Format: "Name (Symbol Market)"
|
|
new_code = f"{company.company_name} ({company.symbol} {company.market})"
|
|
|
|
if new_code != original_code:
|
|
log.stock_code = new_code
|
|
updated_count += 1
|
|
# print(f" 🔄 Updating: '{original_code}' -> '{new_code}'")
|
|
else:
|
|
# print(f" ⚠️ Symbol '{symbol}' not found in companies table. Skipping '{original_code}'.")
|
|
skipped_count += 1
|
|
|
|
# 3. Commit changes
|
|
if updated_count > 0:
|
|
print(f"💾 Committing {updated_count} updates to database...")
|
|
await session.commit()
|
|
print("✅ Database updated successfully.")
|
|
else:
|
|
print("✨ No updates needed.")
|
|
|
|
print(f"Done. Updated: {updated_count}, Skipped/Unchanged: {skipped_count}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(migrate())
|