"""Application configuration models and helpers.""" from __future__ import annotations from dataclasses import dataclass, field import json import os from pathlib import Path from typing import Dict, Iterable, List, Mapping, Optional def _default_root() -> Path: return Path(__file__).resolve().parents[2] / "app" / "data" @dataclass class DataPaths: """Holds filesystem locations for persistent artifacts.""" root: Path = field(default_factory=_default_root) database: Path = field(init=False) backups: Path = field(init=False) config_file: Path = field(init=False) def __post_init__(self) -> None: self.root.mkdir(parents=True, exist_ok=True) self.database = self.root / "llm_quant.db" self.backups = self.root / "backups" self.backups.mkdir(parents=True, exist_ok=True) self.config_file = self.root / "config.json" @dataclass class AgentWeights: """Default weighting for decision agents.""" momentum: float = 0.30 value: float = 0.20 news: float = 0.20 liquidity: float = 0.15 macro: float = 0.15 def as_dict(self) -> Dict[str, float]: return { "A_mom": self.momentum, "A_val": self.value, "A_news": self.news, "A_liq": self.liquidity, "A_macro": self.macro, } def update_from_dict(self, data: Mapping[str, float]) -> None: mapping = { "A_mom": "momentum", "momentum": "momentum", "A_val": "value", "value": "value", "A_news": "news", "news": "news", "A_liq": "liquidity", "liquidity": "liquidity", "A_macro": "macro", "macro": "macro", } for key, attr in mapping.items(): if key in data and data[key] is not None: try: setattr(self, attr, float(data[key])) except (TypeError, ValueError): continue @classmethod def from_dict(cls, data: Mapping[str, float]) -> "AgentWeights": inst = cls() inst.update_from_dict(data) return inst DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = { "ollama": { "models": ["llama3", "phi3", "qwen2"], "base_url": "http://localhost:11434", "temperature": 0.2, "timeout": 30.0, }, "openai": { "models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"], "base_url": "https://api.openai.com", "temperature": 0.2, "timeout": 30.0, }, "deepseek": { "models": ["deepseek-chat", "deepseek-coder"], "base_url": "https://api.deepseek.com", "temperature": 0.2, "timeout": 45.0, }, "wenxin": { "models": ["ERNIE-Speed", "ERNIE-Bot"], "base_url": "https://aip.baidubce.com", "temperature": 0.2, "timeout": 60.0, }, } DEFAULT_LLM_MODELS: Dict[str, str] = { provider: info["models"][0] for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() } DEFAULT_LLM_BASE_URLS: Dict[str, str] = { provider: info["base_url"] for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() } DEFAULT_LLM_TEMPERATURES: Dict[str, float] = { provider: float(info.get("temperature", 0.2)) for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() } DEFAULT_LLM_TIMEOUTS: Dict[str, float] = { provider: float(info.get("timeout", 30.0)) for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() } ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"} LLM_STRATEGY_ALIASES = {"leader-follower": "leader"} @dataclass class LLMProvider: """Provider level configuration shared across profiles and routes.""" key: str title: str = "" base_url: str = "" api_key: Optional[str] = None models: List[str] = field(default_factory=list) default_model: Optional[str] = None default_temperature: float = 0.2 default_timeout: float = 30.0 prompt_template: str = "" enabled: bool = True mode: str = "openai" # openai 或 ollama def to_dict(self) -> Dict[str, object]: return { "title": self.title, "base_url": self.base_url, "api_key": self.api_key, "models": list(self.models), "default_model": self.default_model, "default_temperature": self.default_temperature, "default_timeout": self.default_timeout, "prompt_template": self.prompt_template, "enabled": self.enabled, "mode": self.mode, } @dataclass class LLMEndpoint: """Resolved endpoint payload used for actual LLM calls.""" provider: str = "ollama" model: Optional[str] = None base_url: Optional[str] = None api_key: Optional[str] = None temperature: Optional[float] = None timeout: Optional[float] = None prompt_template: Optional[str] = None def __post_init__(self) -> None: self.provider = (self.provider or "ollama").lower() if self.temperature is not None: self.temperature = float(self.temperature) if self.timeout is not None: self.timeout = float(self.timeout) @dataclass class LLMConfig: """LLM configuration allowing single or ensemble strategies.""" primary: LLMEndpoint = field(default_factory=LLMEndpoint) ensemble: List[LLMEndpoint] = field(default_factory=list) strategy: str = "single" # Options: single, majority, leader majority_threshold: int = 3 def _default_llm_providers() -> Dict[str, LLMProvider]: providers: Dict[str, LLMProvider] = {} for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items(): models = list(meta.get("models", [])) mode = "ollama" if provider == "ollama" else "openai" providers[provider] = LLMProvider( key=provider, title=f"默认 {provider}", base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""), models=models, default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider), default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))), default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))), mode=mode, ) return providers @dataclass class DepartmentSettings: """Configuration for a single decision department.""" code: str title: str description: str = "" weight: float = 1.0 data_scope: List[str] = field(default_factory=list) prompt: str = "" llm: LLMConfig = field(default_factory=LLMConfig) def _default_departments() -> Dict[str, DepartmentSettings]: presets = [ { "code": "momentum", "title": "动量策略部门", "description": "跟踪价格动量与量价共振,评估短线趋势延续的概率。", "data_scope": [ "daily.close", "daily.open", "daily_basic.turnover_rate", "factors.mom_20", "factors.mom_60", ], "prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。", }, { "code": "value", "title": "价值评估部门", "description": "衡量估值水平与盈利质量,为中期配置提供性价比判断。", "data_scope": [ "daily_basic.pe", "daily_basic.pb", "daily_basic.roe", "fundamental.growth", ], "prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。", }, { "code": "news", "title": "新闻情绪部门", "description": "监控舆情热度与事件影响,识别情绪驱动的短期风险与机会。", "data_scope": [ "news.sentiment_index", "news.heat_score", "events.latest_headlines", ], "prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。", }, { "code": "liquidity", "title": "流动性评估部门", "description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。", "data_scope": [ "daily_basic.volume_ratio", "daily_basic.turnover_rate_f", "market.spread_estimate", ], "prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。", }, { "code": "macro", "title": "宏观研究部门", "description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。", "data_scope": [ "macro.industry_heat", "macro.liquidity_cycle", "index.performance_peers", ], "prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。", }, { "code": "risk", "title": "风险控制部门", "description": "监控极端风险、合规与交易限制,必要时行使否决。", "data_scope": [ "market.limit_flags", "portfolio.position", "risk.alerts", ], "prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。", }, ] return { item["code"]: DepartmentSettings( code=item["code"], title=item["title"], description=item.get("description", ""), data_scope=list(item.get("data_scope", [])), prompt=item.get("prompt", ""), ) for item in presets } def _normalize_data_scope(raw: object) -> List[str]: if isinstance(raw, str): tokens = raw.replace(";", "\n").replace(",", "\n").splitlines() return [token.strip() for token in tokens if token.strip()] if isinstance(raw, Iterable) and not isinstance(raw, (bytes, bytearray, str)): return [str(item).strip() for item in raw if str(item).strip()] return [] @dataclass class AppConfig: """User configurable settings persisted in a simple structure.""" tushare_token: Optional[str] = None rss_sources: Dict[str, bool] = field(default_factory=dict) decision_method: str = "nash" data_paths: DataPaths = field(default_factory=DataPaths) agent_weights: AgentWeights = field(default_factory=AgentWeights) force_refresh: bool = False llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm: LLMConfig = field(default_factory=LLMConfig) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: return self.llm def sync_runtime_llm(self) -> None: self.llm = self.resolve_llm() CONFIG = AppConfig() def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]: return { "provider": endpoint.provider, "model": endpoint.model, "base_url": endpoint.base_url, "api_key": endpoint.api_key, "temperature": endpoint.temperature, "timeout": endpoint.timeout, "prompt_template": endpoint.prompt_template, } def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint: payload = { key: data.get(key) for key in ( "provider", "model", "base_url", "api_key", "temperature", "timeout", "prompt_template", ) if data.get(key) is not None } return LLMEndpoint(**payload) def _load_from_file(cfg: AppConfig) -> None: path = cfg.data_paths.config_file if not path.exists(): return try: with path.open("r", encoding="utf-8") as fh: payload = json.load(fh) except (json.JSONDecodeError, OSError): return if not isinstance(payload, dict): return if "tushare_token" in payload: cfg.tushare_token = payload.get("tushare_token") or None if "force_refresh" in payload: cfg.force_refresh = bool(payload.get("force_refresh")) if "decision_method" in payload: cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) weights_payload = payload.get("agent_weights") if isinstance(weights_payload, dict): cfg.agent_weights.update_from_dict(weights_payload) legacy_profiles: Dict[str, Dict[str, object]] = {} legacy_routes: Dict[str, Dict[str, object]] = {} providers_payload = payload.get("llm_providers") if isinstance(providers_payload, dict): providers: Dict[str, LLMProvider] = {} for key, data in providers_payload.items(): if not isinstance(data, dict): continue models_raw = data.get("models") if isinstance(models_raw, str): models = [item.strip() for item in models_raw.split(',') if item.strip()] elif isinstance(models_raw, list): models = [str(item).strip() for item in models_raw if str(item).strip()] else: models = [] provider = LLMProvider( key=str(key).lower(), title=str(data.get("title") or ""), base_url=str(data.get("base_url") or ""), api_key=data.get("api_key"), models=models, default_model=data.get("default_model") or (models[0] if models else None), default_temperature=float(data.get("default_temperature", 0.2)), default_timeout=float(data.get("default_timeout", 30.0)), prompt_template=str(data.get("prompt_template") or ""), enabled=bool(data.get("enabled", True)), mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")), ) providers[provider.key] = provider if providers: cfg.llm_providers = providers profiles_payload = payload.get("llm_profiles") if isinstance(profiles_payload, dict): for key, data in profiles_payload.items(): if isinstance(data, dict): legacy_profiles[str(key)] = data routes_payload = payload.get("llm_routes") if isinstance(routes_payload, dict): for name, data in routes_payload.items(): if isinstance(data, dict): legacy_routes[str(name)] = data def _endpoint_from_payload(item: object) -> LLMEndpoint: if isinstance(item, dict): return _dict_to_endpoint(item) if isinstance(item, str): profile_data = legacy_profiles.get(item) if isinstance(profile_data, dict): return _dict_to_endpoint(profile_data) return LLMEndpoint(provider=item) return LLMEndpoint() def _resolve_route(route_name: str) -> Optional[LLMConfig]: route_data = legacy_routes.get(route_name) if not route_data: return None strategy_raw = str(route_data.get("strategy") or "single").lower() strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) primary_ref = route_data.get("primary") primary_ep = _endpoint_from_payload(primary_ref) ensemble_refs = route_data.get("ensemble", []) ensemble_eps = [ _endpoint_from_payload(ref) for ref in ensemble_refs if isinstance(ref, (dict, str)) ] cfg_obj = LLMConfig( primary=primary_ep, ensemble=ensemble_eps, strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single", majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)), ) return cfg_obj llm_payload = payload.get("llm") if isinstance(llm_payload, dict): route_value = llm_payload.get("route") resolved_cfg = None if isinstance(route_value, str) and route_value: resolved_cfg = _resolve_route(route_value) if resolved_cfg is None: resolved_cfg = LLMConfig() primary_data = llm_payload.get("primary") if isinstance(primary_data, dict): resolved_cfg.primary = _dict_to_endpoint(primary_data) ensemble_data = llm_payload.get("ensemble") if isinstance(ensemble_data, list): resolved_cfg.ensemble = [ _dict_to_endpoint(item) for item in ensemble_data if isinstance(item, dict) ] strategy_raw = llm_payload.get("strategy") if isinstance(strategy_raw, str): normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) if normalized in ALLOWED_LLM_STRATEGIES: resolved_cfg.strategy = normalized majority = llm_payload.get("majority_threshold") if isinstance(majority, int) and majority > 0: resolved_cfg.majority_threshold = majority cfg.llm = resolved_cfg departments_payload = payload.get("departments") if isinstance(departments_payload, dict): new_departments: Dict[str, DepartmentSettings] = {} for code, data in departments_payload.items(): if not isinstance(data, dict): continue title = data.get("title") or code description = data.get("description") or "" weight = float(data.get("weight", 1.0)) prompt_text = str(data.get("prompt") or "") data_scope = _normalize_data_scope(data.get("data_scope")) llm_cfg = LLMConfig() route_name = data.get("llm_route") resolved_cfg = None if isinstance(route_name, str) and route_name: resolved_cfg = _resolve_route(route_name) if resolved_cfg is None: llm_data = data.get("llm") if isinstance(llm_data, dict): primary_data = llm_data.get("primary") if isinstance(primary_data, dict): llm_cfg.primary = _dict_to_endpoint(primary_data) ensemble_data = llm_data.get("ensemble") if isinstance(ensemble_data, list): llm_cfg.ensemble = [ _dict_to_endpoint(item) for item in ensemble_data if isinstance(item, dict) ] strategy_raw = llm_data.get("strategy") if isinstance(strategy_raw, str): normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) if normalized in ALLOWED_LLM_STRATEGIES: llm_cfg.strategy = normalized majority_raw = llm_data.get("majority_threshold") if isinstance(majority_raw, int) and majority_raw > 0: llm_cfg.majority_threshold = majority_raw resolved_cfg = llm_cfg new_departments[code] = DepartmentSettings( code=code, title=title, description=description, weight=weight, data_scope=data_scope, prompt=prompt_text, llm=resolved_cfg, ) if new_departments: cfg.departments = new_departments cfg.sync_runtime_llm() def save_config(cfg: AppConfig | None = None) -> None: cfg = cfg or CONFIG cfg.sync_runtime_llm() path = cfg.data_paths.config_file payload = { "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, "decision_method": cfg.decision_method, "agent_weights": cfg.agent_weights.as_dict(), "llm": { "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, "primary": _endpoint_to_dict(cfg.llm.primary), "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], }, "llm_providers": { key: provider.to_dict() for key, provider in cfg.llm_providers.items() }, "departments": { code: { "title": dept.title, "description": dept.description, "weight": dept.weight, "data_scope": list(dept.data_scope), "prompt": dept.prompt, "llm": { "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": dept.llm.majority_threshold, "primary": _endpoint_to_dict(dept.llm.primary), "ensemble": [_endpoint_to_dict(ep) for ep in dept.llm.ensemble], }, } for code, dept in cfg.departments.items() }, } try: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as fh: json.dump(payload, fh, ensure_ascii=False, indent=2) except OSError: pass def _load_env_defaults(cfg: AppConfig) -> None: """Populate sensitive fields from environment variables if present.""" token = os.getenv("TUSHARE_TOKEN") if token: cfg.tushare_token = token.strip() api_key = os.getenv("LLM_API_KEY") if api_key: sanitized = api_key.strip() cfg.llm.primary.api_key = sanitized provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider) if provider_cfg: provider_cfg.api_key = sanitized cfg.sync_runtime_llm() _load_from_file(CONFIG) _load_env_defaults(CONFIG) def get_config() -> AppConfig: """Return a mutable global configuration instance.""" return CONFIG