llm-quant/app/llm/templates.py
2025-10-05 16:44:28 +08:00

244 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM prompt templates management with configuration driven design."""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@dataclass
class PromptTemplate:
"""Configuration driven prompt template."""
id: str
name: str
description: str
template: str
variables: List[str]
max_length: int = 4000
required_context: List[str] = None
validation_rules: List[str] = None
def validate(self) -> List[str]:
"""Validate template configuration."""
errors = []
# Check template contains all variables
for var in self.variables:
if f"{{{var}}}" not in self.template:
errors.append(f"Template missing variable: {var}")
# Check required context fields
if self.required_context:
for field in self.required_context:
if not field:
errors.append("Empty required context field")
# Check validation rules format
if self.validation_rules:
for rule in self.validation_rules:
if not rule:
errors.append("Empty validation rule")
return errors
def format(self, context: Dict[str, Any]) -> str:
"""Format template with provided context."""
# Validate required context
if self.required_context:
missing = [f for f in self.required_context if f not in context]
if missing:
raise ValueError(f"Missing required context: {', '.join(missing)}")
# Format template
try:
result = self.template.format(**context)
except KeyError as e:
raise ValueError(f"Missing template variable: {e}")
# Truncate if needed, preserving exact number of characters
if len(result) > self.max_length:
target = self.max_length - 3 # Reserve space for "..."
if target > 0: # Only truncate if we have space for content
result = result[:target] + "..."
else:
result = "..." # If max_length <= 3, just return "..."
return result
class TemplateRegistry:
"""Global registry for prompt templates."""
_templates: Dict[str, PromptTemplate] = {}
@classmethod
def register(cls, template: PromptTemplate) -> None:
"""Register a new template."""
errors = template.validate()
if errors:
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
cls._templates[template.id] = template
@classmethod
def get(cls, template_id: str) -> Optional[PromptTemplate]:
"""Get template by ID."""
return cls._templates.get(template_id)
@classmethod
def list(cls) -> List[PromptTemplate]:
"""List all registered templates."""
return list(cls._templates.values())
@classmethod
def load_from_json(cls, json_str: str) -> None:
"""Load templates from JSON string."""
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON: {e}")
if not isinstance(data, dict):
raise ValueError("JSON root must be an object")
for template_id, cfg in data.items():
template = PromptTemplate(
id=template_id,
name=cfg.get("name", template_id),
description=cfg.get("description", ""),
template=cfg.get("template", ""),
variables=cfg.get("variables", []),
max_length=cfg.get("max_length", 4000),
required_context=cfg.get("required_context", []),
validation_rules=cfg.get("validation_rules", [])
)
cls.register(template)
@classmethod
def clear(cls) -> None:
"""Clear all registered templates."""
cls._templates.clear()
# Default template definitions
DEFAULT_TEMPLATES = {
"department_base": {
"name": "部门基础模板",
"description": "通用的部门分析提示模板",
"template": """
部门名称:{title}
股票代码:{ts_code}
交易日:{trade_date}
角色说明:{description}
职责指令:{instruction}
【可用数据范围】
{data_scope}
【核心特征】
{features}
【市场背景】
{market_snapshot}
【追加数据】
{supplements}
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数,表示信心,
"summary": "一句话概括理由",
"signals": ["详细要点", "..."],
"risks": ["风险点", "..."]
}}
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。
请严格返回单个 JSON 对象,不要添加额外文本。
""",
"variables": [
"title", "ts_code", "trade_date", "description", "instruction",
"data_scope", "features", "market_snapshot", "supplements"
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
],
"validation_rules": [
"len(features) > 0",
"len(market_snapshot) > 0"
]
},
"momentum_dept": {
"name": "动量研究部门",
"description": "专注于动量因子分析的部门模板",
"template": """
部门名称:动量研究部门
股票代码:{ts_code}
交易日:{trade_date}
角色说明:专注于分析股票价格动量、成交量动量和技术指标动量
职责指令:重点关注以下方面:
1. 价格趋势强度和持续性
2. 成交量配合度
3. 技术指标背离
【可用数据范围】
{data_scope}
【动量特征】
{features}
【市场背景】
{market_snapshot}
【追加数据】
{supplements}
请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数,表示信心,
"summary": "一句话概括动量分析结论",
"signals": ["动量信号要点", "..."],
"risks": ["动量风险点", "..."]
}}
""",
"variables": [
"ts_code", "trade_date", "data_scope",
"features", "market_snapshot", "supplements"
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
],
"validation_rules": [
"len(features) > 0",
"'momentum' in ' '.join(features.keys()).lower()"
]
}
}
def register_default_templates() -> None:
"""Register all default templates from DEFAULT_TEMPLATES."""
for template_id, cfg in DEFAULT_TEMPLATES.items():
template_config = {
"id": template_id,
"name": cfg.get("name", template_id),
"description": cfg.get("description", ""),
"template": cfg.get("template", ""),
"variables": cfg.get("variables", []),
"max_length": cfg.get("max_length", 4000),
"required_context": cfg.get("required_context", []),
"validation_rules": cfg.get("validation_rules", [])
}
try:
TemplateRegistry.register(PromptTemplate(**template_config))
except ValueError as e:
logging.warning(f"Failed to register template {template_id}: {e}")
# Auto-register default templates on module import
register_default_templates()