improve prompt template handling with safer variable substitution and missing var fallback

This commit is contained in:
Your Name 2025-10-11 21:03:48 +08:00
parent 3563220385
commit c57fb7edd1
2 changed files with 19 additions and 5 deletions

View File

@ -94,6 +94,15 @@ def department_prompt(
"supplements": supplements.strip() or "- 当前无追加数据", "supplements": supplements.strip() or "- 当前无追加数据",
"action": "" # 添加 action 变量以避免模板格式化错误 "action": "" # 添加 action 变量以避免模板格式化错误
} }
template_vars.setdefault("scratchpad", "")
# Ensure all declared template variables exist to avoid KeyError
try:
declared_vars = list(getattr(template, "variables", []) or [])
except Exception: # noqa: BLE001
declared_vars = []
for var in declared_vars:
template_vars.setdefault(var, "")
# Get template and format prompt # Get template and format prompt
return template.format(template_vars) return template.format(template_vars)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import logging import logging
import re
from pathlib import Path from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, Dict, List, Optional, TYPE_CHECKING
@ -55,11 +56,15 @@ class PromptTemplate:
if missing: if missing:
raise ValueError(f"Missing required context: {', '.join(missing)}") raise ValueError(f"Missing required context: {', '.join(missing)}")
# Format template pattern = re.compile(r"\{([^{}]+)\}")
try:
result = self.template.format(**context) def _replace(match: re.Match[str]) -> str:
except KeyError as e: token = match.group(1)
raise ValueError(f"Missing template variable: {e}") if token in context:
return str(context[token])
return match.group(0)
result = pattern.sub(_replace, self.template)
# Truncate if needed, preserving exact number of characters # Truncate if needed, preserving exact number of characters
if self.max_length > 0 and len(result) > self.max_length: if self.max_length > 0 and len(result) > self.max_length: