llm-quant/app/utils/config.py
2025-09-27 21:03:04 +08:00

251 lines
7.5 KiB
Python

"""Application configuration models and helpers."""
from __future__ import annotations
from dataclasses import dataclass, field
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
def _default_root() -> Path:
return Path(__file__).resolve().parents[2] / "app" / "data"
@dataclass
class DataPaths:
"""Holds filesystem locations for persistent artifacts."""
root: Path = field(default_factory=_default_root)
database: Path = field(init=False)
backups: Path = field(init=False)
config_file: Path = field(init=False)
def __post_init__(self) -> None:
self.root.mkdir(parents=True, exist_ok=True)
self.database = self.root / "llm_quant.db"
self.backups = self.root / "backups"
self.backups.mkdir(parents=True, exist_ok=True)
self.config_file = self.root / "config.json"
@dataclass
class AgentWeights:
"""Default weighting for decision agents."""
momentum: float = 0.30
value: float = 0.20
news: float = 0.20
liquidity: float = 0.15
macro: float = 0.15
def as_dict(self) -> Dict[str, float]:
return {
"A_mom": self.momentum,
"A_val": self.value,
"A_news": self.news,
"A_liq": self.liquidity,
"A_macro": self.macro,
}
DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = {
"ollama": {
"models": ["llama3", "phi3", "qwen2"],
"base_url": "http://localhost:11434",
"temperature": 0.2,
"timeout": 30.0,
},
"openai": {
"models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"],
"base_url": "https://api.openai.com",
"temperature": 0.2,
"timeout": 30.0,
},
"deepseek": {
"models": ["deepseek-chat", "deepseek-coder"],
"base_url": "https://api.deepseek.com",
"temperature": 0.2,
"timeout": 45.0,
},
"wenxin": {
"models": ["ERNIE-Speed", "ERNIE-Bot"],
"base_url": "https://aip.baidubce.com",
"temperature": 0.2,
"timeout": 60.0,
},
}
DEFAULT_LLM_MODELS: Dict[str, str] = {
provider: info["models"][0]
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
DEFAULT_LLM_BASE_URLS: Dict[str, str] = {
provider: info["base_url"]
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
DEFAULT_LLM_TEMPERATURES: Dict[str, float] = {
provider: float(info.get("temperature", 0.2))
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
DEFAULT_LLM_TIMEOUTS: Dict[str, float] = {
provider: float(info.get("timeout", 30.0))
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
@dataclass
class LLMEndpoint:
"""Single LLM endpoint configuration."""
provider: str = "ollama"
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.2
timeout: float = 30.0
def __post_init__(self) -> None:
self.provider = (self.provider or "ollama").lower()
if not self.model:
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
if not self.base_url:
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider)
if self.temperature == 0.2 or self.temperature is None:
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
if self.timeout == 30.0 or self.timeout is None:
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
@dataclass
class LLMConfig:
"""LLM configuration allowing single or ensemble strategies."""
primary: LLMEndpoint = field(default_factory=LLMEndpoint)
ensemble: List[LLMEndpoint] = field(default_factory=list)
strategy: str = "single" # Options: single, majority
majority_threshold: int = 3
@dataclass
class AppConfig:
"""User configurable settings persisted in a simple structure."""
tushare_token: Optional[str] = None
rss_sources: Dict[str, bool] = field(default_factory=dict)
decision_method: str = "nash"
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
llm: LLMConfig = field(default_factory=LLMConfig)
CONFIG = AppConfig()
def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
return {
"provider": endpoint.provider,
"model": endpoint.model,
"base_url": endpoint.base_url,
"api_key": endpoint.api_key,
"temperature": endpoint.temperature,
"timeout": endpoint.timeout,
}
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
payload = {
key: data.get(key)
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")
if data.get(key) is not None
}
return LLMEndpoint(**payload)
def _load_from_file(cfg: AppConfig) -> None:
path = cfg.data_paths.config_file
if not path.exists():
return
try:
with path.open("r", encoding="utf-8") as fh:
payload = json.load(fh)
except (json.JSONDecodeError, OSError):
return
if isinstance(payload, dict):
if "tushare_token" in payload:
cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
cfg.force_refresh = bool(payload.get("force_refresh"))
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
llm_payload = payload.get("llm")
if isinstance(llm_payload, dict):
primary_data = llm_payload.get("primary")
if isinstance(primary_data, dict):
cfg.llm.primary = _dict_to_endpoint(primary_data)
ensemble_data = llm_payload.get("ensemble")
if isinstance(ensemble_data, list):
cfg.llm.ensemble = [
_dict_to_endpoint(item)
for item in ensemble_data
if isinstance(item, dict)
]
strategy = llm_payload.get("strategy")
if strategy in {"single", "majority"}:
cfg.llm.strategy = strategy
majority = llm_payload.get("majority_threshold")
if isinstance(majority, int) and majority > 0:
cfg.llm.majority_threshold = majority
def save_config(cfg: AppConfig | None = None) -> None:
cfg = cfg or CONFIG
path = cfg.data_paths.config_file
payload = {
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method,
"llm": {
"strategy": cfg.llm.strategy,
"majority_threshold": cfg.llm.majority_threshold,
"primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
},
}
try:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(payload, fh, ensure_ascii=False, indent=2)
except OSError:
pass
def _load_env_defaults(cfg: AppConfig) -> None:
"""Populate sensitive fields from environment variables if present."""
token = os.getenv("TUSHARE_TOKEN")
if token:
cfg.tushare_token = token.strip()
api_key = os.getenv("LLM_API_KEY")
if api_key:
cfg.llm.primary.api_key = api_key.strip()
_load_from_file(CONFIG)
_load_env_defaults(CONFIG)
def get_config() -> AppConfig:
"""Return a mutable global configuration instance."""
return CONFIG