FA3-Datafetch/backend/app/services/analysis_service.py
2026-01-03 18:27:19 +08:00

177 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
import subprocess
import asyncio
import json
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Report, ReportSection, AnalysisStatus
from app.services import llm_engine
from datetime import datetime
import google.genai as genai
from google.genai import types
async def search_stock(query: str, api_key: str, model: str = "gemini-2.0-flash"):
if not api_key:
return {"error": "API Key not provided"}
client = genai.Client(api_key=api_key)
prompt = f"""
你是一个专业的股票代码查询助手。请识别公司 '{query}' 的股票市场和代码。
**重要提示**
1. 用户输入可能是公司全称、简称、别名或股票代码
2. 请仔细匹配,优先完全匹配,避免返回不相关的公司
3. 中国公司的简称经常省略"股份有限公司"等后缀
4. 例如:"茅台" = "贵州茅台酒股份有限公司" (600519.SH)
请返回一个 JSON 数组,包含所有匹配的公司。每个对象包含以下字段:
- 'market': 'CN' (中国), 'US' (美国), 'HK' (香港), 或 'JP' (日本) 之一
- 'symbol': 完整的股票代码 (例如 'AAPL', '600519.SH', '00700.HK', '688778.SH', '2503.T')
- 'company_name': 公司的中文简称(如果有的话,优先使用中文;如果只有英文名,则使用英文名)
**匹配规则**
- 如果查询词与某公司简称、全称或股票代码完全匹配,返回该公司
- 如果有多个可能的匹配,返回所有相关公司
- 如果公司在中国但用户没有指定市场,默认为 CN上海/深圳/北京交易所)
- 如果完全没找到匹配,返回 {{ "error": "未找到相关公司" }}
示例响应(单个结果):
[
{{
"market": "CN",
"symbol": "600519.SH",
"company_name": "贵州茅台"
}}
]
示例响应(多个结果):
[
{{
"market": "HK",
"symbol": "00700.HK",
"company_name": "腾讯控股"
}},
{{
"market": "US",
"symbol": "TCEHY",
"company_name": "Tencent Holdings ADR"
}}
]
现在请处理查询: '{query}'
"""
try:
# Enable Google Search for more accurate results
grounding_tool = types.Tool(google_search=types.GoogleSearch())
# Note: Cannot use response_mime_type with tools
config = types.GenerateContentConfig(tools=[grounding_tool])
response = client.models.generate_content(
model=model,
contents=prompt,
config=config
)
response_text = response.text.strip()
print(f"Search API raw response: {response_text[:500]}")
# Extract JSON from response (may be wrapped in markdown code blocks)
if "```json" in response_text:
# Extract JSON from code block
start = response_text.find("```json") + 7
end = response_text.find("```", start)
json_str = response_text[start:end].strip()
elif "```" in response_text:
# Extract from generic code block
start = response_text.find("```") + 3
end = response_text.find("```", start)
json_str = response_text[start:end].strip()
else:
json_str = response_text
result = json.loads(json_str)
# Ensure result is always an array for consistent handling
if not isinstance(result, list):
if isinstance(result, dict) and "error" in result:
return result # Return error as-is
result = [result] # Wrap single object in array
return result
except json.JSONDecodeError as e:
print(f"JSON decode error: {e}, Response text: {response_text}")
return {"error": f"无法解析搜索结果: {str(e)}"}
except Exception as e:
print(f"Search error: {e}")
return {"error": f"搜索失败: {str(e)}"}
async def run_analysis_task(report_id: int, market: str, symbol: str, api_key: str):
"""
Background task to run the full analysis pipeline.
Creates its own DB session.
"""
print(f"Starting analysis for report {report_id}: {market} {symbol}")
# Create new session
from app.database import AsyncSessionLocal
async with AsyncSessionLocal() as session:
try:
report = await session.get(Report, report_id)
if not report:
print(f"Report {report_id} not found in background task")
return
report.status = AnalysisStatus.IN_PROGRESS
await session.commit()
company_name_for_prompt = report.company_name
# 2. Run Main Data Fetching Script (main.py)
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
cmd = [sys.executable, "main.py", market, symbol]
print(f"Executing data fetch command: {cmd} in {root_dir}")
process = await asyncio.create_subprocess_exec(
*cmd,
cwd=root_dir,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode()
print(f"Data fetch failed: {error_msg}")
report.status = AnalysisStatus.FAILED
await session.commit()
return
print("Data fetch successful.")
# 3. Perform Analysis Logic
await llm_engine.process_analysis_steps(
report_id=report_id,
company_name=company_name_for_prompt,
symbol=symbol,
market=market,
db=session,
api_key=api_key
)
# 4. Finalize
report.status = AnalysisStatus.COMPLETED
await session.commit()
print(f"Analysis for report {report_id} completed.")
except Exception as e:
print(f"Analysis task exception: {e}")
try:
report = await session.get(Report, report_id)
if report:
report.status = AnalysisStatus.FAILED
await session.commit()
except:
pass