diff --git a/app/llm/__init__.py b/app/llm/__init__.py index e69de29..cfcce48 100644 --- a/app/llm/__init__.py +++ b/app/llm/__init__.py @@ -0,0 +1,20 @@ +"""LLM module exports.""" +from .templates import PromptTemplate, TemplateRegistry, DEFAULT_TEMPLATES +from .context import ( + Context, + ContextConfig, + ContextManager, + DataAccessConfig, + Message, +) + +__all__ = [ + "Context", + "ContextConfig", + "ContextManager", + "DataAccessConfig", + "Message", + "PromptTemplate", + "TemplateRegistry", + "DEFAULT_TEMPLATES", +] diff --git a/app/llm/context.py b/app/llm/context.py index d05ab37..c75d19a 100644 --- a/app/llm/context.py +++ b/app/llm/context.py @@ -28,13 +28,13 @@ class DataAccessConfig: start_ts = time.strptime(start_date, "%Y%m%d") if end_date: end_ts = time.strptime(end_date, "%Y%m%d") - days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600) - if days > self.max_history_days: - errors.append( - f"Date range exceeds max {self.max_history_days} days" - ) - if days < 0: + delta_days = (time.mktime(end_ts) - time.mktime(start_ts)) / (24 * 3600) + if delta_days < 0: errors.append("End date before start date") + elif delta_days > self.max_history_days: + errors.append( + f"Date range ({int(delta_days)} days) exceeds max {self.max_history_days} days" + ) except ValueError: errors.append("Invalid date format (expected YYYYMMDD)") diff --git a/app/llm/templates.py b/app/llm/templates.py index 3a7940e..b4bb536 100644 --- a/app/llm/templates.py +++ b/app/llm/templates.py @@ -2,6 +2,7 @@ from __future__ import annotations import json +import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -56,10 +57,14 @@ class PromptTemplate: except KeyError as e: raise ValueError(f"Missing template variable: {e}") - # Truncate if needed + # Truncate if needed, preserving exact number of characters if len(result) > self.max_length: - result = result[:self.max_length-3] + "..." - + target = self.max_length - 3 # Reserve space for "..." + if target > 0: # Only truncate if we have space for content + result = result[:target] + "..." + else: + result = "..." # If max_length <= 3, just return "..." + return result @@ -116,7 +121,7 @@ class TemplateRegistry: cls._templates.clear() -# Register default templates +# Default template definitions DEFAULT_TEMPLATES = { "department_base": { "name": "部门基础模板", @@ -214,6 +219,25 @@ DEFAULT_TEMPLATES = { } } -# Register default templates -for template_id, cfg in DEFAULT_TEMPLATES.items(): - TemplateRegistry.register(PromptTemplate(**{"id": template_id, **cfg})) + +def register_default_templates() -> None: + """Register all default templates from DEFAULT_TEMPLATES.""" + for template_id, cfg in DEFAULT_TEMPLATES.items(): + template_config = { + "id": template_id, + "name": cfg.get("name", template_id), + "description": cfg.get("description", ""), + "template": cfg.get("template", ""), + "variables": cfg.get("variables", []), + "max_length": cfg.get("max_length", 4000), + "required_context": cfg.get("required_context", []), + "validation_rules": cfg.get("validation_rules", []) + } + try: + TemplateRegistry.register(PromptTemplate(**template_config)) + except ValueError as e: + logging.warning(f"Failed to register template {template_id}: {e}") + + +# Auto-register default templates on module import +register_default_templates()