diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 8ada8ad..874a133 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -914,9 +914,14 @@ def render_backtest() -> None: "保存这些权重为默认配置", key="save_decision_env_weights_single", ): - cfg.agent_weights.update_from_dict(weights_dict) - save_config(cfg) - st.success("代理权重已写入 config.json") + try: + cfg.agent_weights.update_from_dict(weights_dict) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success("代理权重已写入 config.json") nav_series = info.get("nav_series") if nav_series: @@ -1085,11 +1090,16 @@ def render_backtest() -> None: "保存所选权重为默认配置", key="save_decision_env_weights_batch", ): - cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) - save_config(cfg) - st.success( - f"已将序号 {selected_row['序号']} 的权重写入 config.json" - ) + try: + cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success( + f"已将序号 {selected_row['序号']} 的权重写入 config.json" + ) else: st.caption("暂无成功的结果可供保存。") diff --git a/app/utils/config.py b/app/utils/config.py index 653cc14..8204eff 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -3,11 +3,15 @@ from __future__ import annotations from dataclasses import dataclass, field import json +import logging import os from pathlib import Path from typing import Dict, Iterable, List, Mapping, Optional +LOGGER = logging.getLogger(__name__) + + def _default_root() -> Path: return Path(__file__).resolve().parents[2] / "app" / "data" @@ -26,7 +30,13 @@ class DataPaths: self.database = self.root / "llm_quant.db" self.backups = self.root / "backups" self.backups.mkdir(parents=True, exist_ok=True) - self.config_file = self.root / "config.json" + config_override = os.getenv("LLM_QUANT_CONFIG_PATH") + if config_override: + config_path = Path(config_override).expanduser() + config_path.parent.mkdir(parents=True, exist_ok=True) + else: + config_path = self.root / "config.json" + self.config_file = config_path @dataclass @@ -38,6 +48,7 @@ class AgentWeights: news: float = 0.20 liquidity: float = 0.15 macro: float = 0.15 + risk: float = 1.0 def as_dict(self) -> Dict[str, float]: return { @@ -46,6 +57,7 @@ class AgentWeights: "A_news": self.news, "A_liq": self.liquidity, "A_macro": self.macro, + "A_risk": self.risk, } def update_from_dict(self, data: Mapping[str, float]) -> None: @@ -60,6 +72,8 @@ class AgentWeights: "liquidity": "liquidity", "A_macro": "macro", "macro": "macro", + "A_risk": "risk", + "risk": "risk", } for key, attr in mapping.items(): if key in data and data[key] is not None: @@ -581,12 +595,25 @@ def save_config(cfg: AppConfig | None = None) -> None: for code, dept in cfg.departments.items() }, } + serialized = json.dumps(payload, ensure_ascii=False, indent=2) + + try: + existing = path.read_text(encoding="utf-8") + except OSError: + existing = None + + if existing == serialized: + LOGGER.info("配置未变更,跳过写入:%s", path) + return + try: path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", encoding="utf-8") as fh: - json.dump(payload, fh, ensure_ascii=False, indent=2) + tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp") + tmp_path.write_text(serialized, encoding="utf-8") + tmp_path.replace(path) + LOGGER.info("配置已写入:%s", path) except OSError: - pass + LOGGER.exception("配置写入失败:%s", path) def _load_env_defaults(cfg: AppConfig) -> None: