184 lines
7.3 KiB
Python
184 lines
7.3 KiB
Python
import sys
|
||
import os
|
||
import subprocess
|
||
import asyncio
|
||
import json
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from app.models import Report, ReportSection, AnalysisStatus
|
||
from app.services import llm_engine
|
||
from datetime import datetime
|
||
import google.genai as genai
|
||
from google.genai import types
|
||
|
||
async def search_stock(query: str, api_key: str, model: str = "gemini-2.0-flash"):
|
||
if not api_key:
|
||
return {"error": "API Key not provided"}
|
||
|
||
client = genai.Client(api_key=api_key)
|
||
prompt = f"""
|
||
你是一个专业的股票代码查询助手。请识别公司 '{query}' 的股票市场和代码。
|
||
|
||
**重要提示**:
|
||
1. 用户输入可能是公司全称、简称、别名或股票代码
|
||
2. 请仔细匹配,优先完全匹配,避免返回不相关的公司
|
||
3. 中国公司的简称经常省略"股份有限公司"等后缀
|
||
4. 例如:"茅台" = "贵州茅台酒股份有限公司" (600519.SH)
|
||
|
||
请返回一个 JSON 数组,包含所有匹配的公司。每个对象包含以下字段:
|
||
- 'market': 'CN' (中国), 'US' (美国), 'HK' (香港), 'JP' (日本), 或 'VN' (越南) 之一
|
||
- 'symbol': 完整的股票代码 (例如 'AAPL', '600519.SH', '00700.HK', '688778.SH', '2503.T', 'SAB.HM')
|
||
- 'company_name': 公司的中文简称(如果有的话,优先使用中文;如果只有英文名,则使用英文名)
|
||
|
||
**匹配规则**:
|
||
- 如果查询词与某公司简称、全称或股票代码完全匹配,返回该公司
|
||
- 如果有多个可能的匹配,返回所有相关公司
|
||
- 如果公司在中国但用户没有指定市场,默认为 CN(上海/深圳/北京交易所)
|
||
- **越南股票后缀规则(iFinD数据源)**:
|
||
- 胡志明交易所 (HOSE) -> 后缀 **.HM** (例如: VNM.HM, SAB.HM, VCB.HM)
|
||
- 河内交易所 (HNX) -> 后缀 **.HN** (例如: PVS.HN, SHS.HN)
|
||
- UPCoM 市场 -> iFinD 通常也使用 .HN 或 .HM,或者特定的 UPCoM 后缀,但绝不要使用 .VN (这是 Bloomberg 格式)
|
||
- 示例:MCH (Masan Consumer) -> MCH.HN 或 MCH.HM (UPCoM 在 iFinD 中可能归类不一,但 MCH.HN 是常见尝试,MCH.HM 也可以尝试。请根据搜索结果确认 iFinD 使用哪一个。**强烈倾向于使用 .HM 或 .HN,严禁使用 .VN**)
|
||
- MCH 具体案例:用户指出应为 MCH.HM (或 MCH.HN),绝非 MCH.VN。请只返回 .HM 或 .HN。
|
||
- 如果不确定是 HM 还是 HN,优先返回 .HM。
|
||
- 如果完全没找到匹配,返回 {{ "error": "未找到相关公司" }}
|
||
|
||
示例响应(单个结果):
|
||
[
|
||
{{
|
||
"market": "CN",
|
||
"symbol": "600519.SH",
|
||
"company_name": "贵州茅台"
|
||
}}
|
||
]
|
||
|
||
示例响应(多个结果):
|
||
[
|
||
{{
|
||
"market": "HK",
|
||
"symbol": "00700.HK",
|
||
"company_name": "腾讯控股"
|
||
}},
|
||
{{
|
||
"market": "VN",
|
||
"symbol": "VNM.HM",
|
||
"company_name": "Vinamilk"
|
||
}}
|
||
]
|
||
|
||
现在请处理查询: '{query}'
|
||
"""
|
||
|
||
try:
|
||
# Enable Google Search for more accurate results
|
||
grounding_tool = types.Tool(google_search=types.GoogleSearch())
|
||
# Note: Cannot use response_mime_type with tools
|
||
config = types.GenerateContentConfig(tools=[grounding_tool])
|
||
|
||
response = client.models.generate_content(
|
||
model=model,
|
||
contents=prompt,
|
||
config=config
|
||
)
|
||
|
||
response_text = response.text.strip()
|
||
print(f"Search API raw response: {response_text[:500]}")
|
||
|
||
# Extract JSON from response (may be wrapped in markdown code blocks)
|
||
if "```json" in response_text:
|
||
# Extract JSON from code block
|
||
start = response_text.find("```json") + 7
|
||
end = response_text.find("```", start)
|
||
json_str = response_text[start:end].strip()
|
||
elif "```" in response_text:
|
||
# Extract from generic code block
|
||
start = response_text.find("```") + 3
|
||
end = response_text.find("```", start)
|
||
json_str = response_text[start:end].strip()
|
||
else:
|
||
json_str = response_text
|
||
|
||
result = json.loads(json_str)
|
||
|
||
# Ensure result is always an array for consistent handling
|
||
if not isinstance(result, list):
|
||
if isinstance(result, dict) and "error" in result:
|
||
return result # Return error as-is
|
||
result = [result] # Wrap single object in array
|
||
|
||
return result
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON decode error: {e}, Response text: {response_text}")
|
||
return {"error": f"无法解析搜索结果: {str(e)}"}
|
||
except Exception as e:
|
||
print(f"Search error: {e}")
|
||
return {"error": f"搜索失败: {str(e)}"}
|
||
|
||
async def run_analysis_task(report_id: int, market: str, symbol: str, api_key: str):
|
||
"""
|
||
Background task to run the full analysis pipeline.
|
||
Creates its own DB session.
|
||
"""
|
||
print(f"Starting analysis for report {report_id}: {market} {symbol}")
|
||
|
||
# Create new session
|
||
from app.database import AsyncSessionLocal
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
try:
|
||
report = await session.get(Report, report_id)
|
||
if not report:
|
||
print(f"Report {report_id} not found in background task")
|
||
return
|
||
|
||
report.status = AnalysisStatus.IN_PROGRESS
|
||
await session.commit()
|
||
|
||
company_name_for_prompt = report.company_name
|
||
|
||
# 2. Run Main Data Fetching Script (run_fetcher.py)
|
||
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
||
cmd = [sys.executable, "run_fetcher.py", market, symbol]
|
||
|
||
print(f"Executing data fetch command: {cmd} in {root_dir}")
|
||
process = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
cwd=root_dir,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
|
||
if process.returncode != 0:
|
||
error_msg = stderr.decode()
|
||
print(f"Data fetch failed: {error_msg}")
|
||
report.status = AnalysisStatus.FAILED
|
||
await session.commit()
|
||
return
|
||
|
||
print("Data fetch successful.")
|
||
|
||
# 3. Perform Analysis Logic
|
||
await llm_engine.process_analysis_steps(
|
||
report_id=report_id,
|
||
company_name=company_name_for_prompt,
|
||
symbol=symbol,
|
||
market=market,
|
||
db=session,
|
||
api_key=api_key
|
||
)
|
||
|
||
# 4. Finalize
|
||
report.status = AnalysisStatus.COMPLETED
|
||
await session.commit()
|
||
print(f"Analysis for report {report_id} completed.")
|
||
|
||
except Exception as e:
|
||
print(f"Analysis task exception: {e}")
|
||
try:
|
||
report = await session.get(Report, report_id)
|
||
if report:
|
||
report.status = AnalysisStatus.FAILED
|
||
await session.commit()
|
||
except:
|
||
pass
|