This commit is contained in:
sam 2025-10-05 16:44:28 +08:00
parent a619a24440
commit 4a0d8d4226
3 changed files with 57 additions and 13 deletions

View File

@ -0,0 +1,20 @@
"""LLM module exports."""
from .templates import PromptTemplate, TemplateRegistry, DEFAULT_TEMPLATES
from .context import (
Context,
ContextConfig,
ContextManager,
DataAccessConfig,
Message,
)
__all__ = [
"Context",
"ContextConfig",
"ContextManager",
"DataAccessConfig",
"Message",
"PromptTemplate",
"TemplateRegistry",
"DEFAULT_TEMPLATES",
]

View File

@ -28,13 +28,13 @@ class DataAccessConfig:
start_ts = time.strptime(start_date, "%Y%m%d") start_ts = time.strptime(start_date, "%Y%m%d")
if end_date: if end_date:
end_ts = time.strptime(end_date, "%Y%m%d") end_ts = time.strptime(end_date, "%Y%m%d")
days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600) delta_days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
if days > self.max_history_days: if delta_days < 0:
errors.append(
f"Date range exceeds max {self.max_history_days} days"
)
if days < 0:
errors.append("End date before start date") errors.append("End date before start date")
elif delta_days > self.max_history_days:
errors.append(
f"Date range ({int(delta_days)} days) exceeds max {self.max_history_days} days"
)
except ValueError: except ValueError:
errors.append("Invalid date format (expected YYYYMMDD)") errors.append("Invalid date format (expected YYYYMMDD)")

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -56,9 +57,13 @@ class PromptTemplate:
except KeyError as e: except KeyError as e:
raise ValueError(f"Missing template variable: {e}") raise ValueError(f"Missing template variable: {e}")
# Truncate if needed # Truncate if needed, preserving exact number of characters
if len(result) > self.max_length: if len(result) > self.max_length:
result = result[:self.max_length-3] + "..." target = self.max_length - 3 # Reserve space for "..."
if target > 0: # Only truncate if we have space for content
result = result[:target] + "..."
else:
result = "..." # If max_length <= 3, just return "..."
return result return result
@ -116,7 +121,7 @@ class TemplateRegistry:
cls._templates.clear() cls._templates.clear()
# Register default templates # Default template definitions
DEFAULT_TEMPLATES = { DEFAULT_TEMPLATES = {
"department_base": { "department_base": {
"name": "部门基础模板", "name": "部门基础模板",
@ -214,6 +219,25 @@ DEFAULT_TEMPLATES = {
} }
} }
# Register default templates
def register_default_templates() -> None:
"""Register all default templates from DEFAULT_TEMPLATES."""
for template_id, cfg in DEFAULT_TEMPLATES.items(): for template_id, cfg in DEFAULT_TEMPLATES.items():
TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg})) template_config = {
"id": template_id,
"name": cfg.get("name", template_id),
"description": cfg.get("description", ""),
"template": cfg.get("template", ""),
"variables": cfg.get("variables", []),
"max_length": cfg.get("max_length", 4000),
"required_context": cfg.get("required_context", []),
"validation_rules": cfg.get("validation_rules", [])
}
try:
TemplateRegistry.register(PromptTemplate(**template_config))
except ValueError as e:
logging.warning(f"Failed to register template {template_id}: {e}")
# Auto-register default templates on module import
register_default_templates()