update
This commit is contained in:
parent
a619a24440
commit
4a0d8d4226
@ -0,0 +1,20 @@
|
||||
"""LLM module exports."""
|
||||
from .templates import PromptTemplate, TemplateRegistry, DEFAULT_TEMPLATES
|
||||
from .context import (
|
||||
Context,
|
||||
ContextConfig,
|
||||
ContextManager,
|
||||
DataAccessConfig,
|
||||
Message,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Context",
|
||||
"ContextConfig",
|
||||
"ContextManager",
|
||||
"DataAccessConfig",
|
||||
"Message",
|
||||
"PromptTemplate",
|
||||
"TemplateRegistry",
|
||||
"DEFAULT_TEMPLATES",
|
||||
]
|
||||
@ -28,13 +28,13 @@ class DataAccessConfig:
|
||||
start_ts = time.strptime(start_date, "%Y%m%d")
|
||||
if end_date:
|
||||
end_ts = time.strptime(end_date, "%Y%m%d")
|
||||
days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
|
||||
if days > self.max_history_days:
|
||||
errors.append(
|
||||
f"Date range exceeds max {self.max_history_days} days"
|
||||
)
|
||||
if days < 0:
|
||||
delta_days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600)
|
||||
if delta_days < 0:
|
||||
errors.append("End date before start date")
|
||||
elif delta_days > self.max_history_days:
|
||||
errors.append(
|
||||
f"Date range ({int(delta_days)} days) exceeds max {self.max_history_days} days"
|
||||
)
|
||||
except ValueError:
|
||||
errors.append("Invalid date format (expected YYYYMMDD)")
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@ -56,10 +57,14 @@ class PromptTemplate:
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing template variable: {e}")
|
||||
|
||||
# Truncate if needed
|
||||
# Truncate if needed, preserving exact number of characters
|
||||
if len(result) > self.max_length:
|
||||
result = result[:self.max_length-3] + "..."
|
||||
|
||||
target = self.max_length - 3 # Reserve space for "..."
|
||||
if target > 0: # Only truncate if we have space for content
|
||||
result = result[:target] + "..."
|
||||
else:
|
||||
result = "..." # If max_length <= 3, just return "..."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -116,7 +121,7 @@ class TemplateRegistry:
|
||||
cls._templates.clear()
|
||||
|
||||
|
||||
# Register default templates
|
||||
# Default template definitions
|
||||
DEFAULT_TEMPLATES = {
|
||||
"department_base": {
|
||||
"name": "部门基础模板",
|
||||
@ -214,6 +219,25 @@ DEFAULT_TEMPLATES = {
|
||||
}
|
||||
}
|
||||
|
||||
# Register default templates
|
||||
for template_id, cfg in DEFAULT_TEMPLATES.items():
|
||||
TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg}))
|
||||
|
||||
def register_default_templates() -> None:
|
||||
"""Register all default templates from DEFAULT_TEMPLATES."""
|
||||
for template_id, cfg in DEFAULT_TEMPLATES.items():
|
||||
template_config = {
|
||||
"id": template_id,
|
||||
"name": cfg.get("name", template_id),
|
||||
"description": cfg.get("description", ""),
|
||||
"template": cfg.get("template", ""),
|
||||
"variables": cfg.get("variables", []),
|
||||
"max_length": cfg.get("max_length", 4000),
|
||||
"required_context": cfg.get("required_context", []),
|
||||
"validation_rules": cfg.get("validation_rules", [])
|
||||
}
|
||||
try:
|
||||
TemplateRegistry.register(PromptTemplate(**template_config))
|
||||
except ValueError as e:
|
||||
logging.warning(f"Failed to register template {template_id}: {e}")
|
||||
|
||||
|
||||
# Auto-register default templates on module import
|
||||
register_default_templates()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user