This commit is contained in:
sam 2025-09-29 15:41:35 +08:00
parent 1773929431
commit a6564cdced
2 changed files with 49 additions and 12 deletions

View File

@ -914,9 +914,14 @@ def render_backtest() -> None:
"保存这些权重为默认配置", "保存这些权重为默认配置",
key="save_decision_env_weights_single", key="save_decision_env_weights_single",
): ):
cfg.agent_weights.update_from_dict(weights_dict) try:
save_config(cfg) cfg.agent_weights.update_from_dict(weights_dict)
st.success("代理权重已写入 config.json") save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
nav_series = info.get("nav_series") nav_series = info.get("nav_series")
if nav_series: if nav_series:
@ -1085,11 +1090,16 @@ def render_backtest() -> None:
"保存所选权重为默认配置", "保存所选权重为默认配置",
key="save_decision_env_weights_batch", key="save_decision_env_weights_batch",
): ):
cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) try:
save_config(cfg) cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
st.success( save_config(cfg)
f"已将序号 {selected_row['序号']} 的权重写入 config.json" except Exception as exc: # noqa: BLE001
) LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else: else:
st.caption("暂无成功的结果可供保存。") st.caption("暂无成功的结果可供保存。")

View File

@ -3,11 +3,15 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
import json import json
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional from typing import Dict, Iterable, List, Mapping, Optional
LOGGER = logging.getLogger(__name__)
def _default_root() -> Path: def _default_root() -> Path:
return Path(__file__).resolve().parents[2] / "app" / "data" return Path(__file__).resolve().parents[2] / "app" / "data"
@ -26,7 +30,13 @@ class DataPaths:
self.database = self.root / "llm_quant.db" self.database = self.root / "llm_quant.db"
self.backups = self.root / "backups" self.backups = self.root / "backups"
self.backups.mkdir(parents=True, exist_ok=True) self.backups.mkdir(parents=True, exist_ok=True)
self.config_file = self.root / "config.json" config_override = os.getenv("LLM_QUANT_CONFIG_PATH")
if config_override:
config_path = Path(config_override).expanduser()
config_path.parent.mkdir(parents=True, exist_ok=True)
else:
config_path = self.root / "config.json"
self.config_file = config_path
@dataclass @dataclass
@ -38,6 +48,7 @@ class AgentWeights:
news: float = 0.20 news: float = 0.20
liquidity: float = 0.15 liquidity: float = 0.15
macro: float = 0.15 macro: float = 0.15
risk: float = 1.0
def as_dict(self) -> Dict[str, float]: def as_dict(self) -> Dict[str, float]:
return { return {
@ -46,6 +57,7 @@ class AgentWeights:
"A_news": self.news, "A_news": self.news,
"A_liq": self.liquidity, "A_liq": self.liquidity,
"A_macro": self.macro, "A_macro": self.macro,
"A_risk": self.risk,
} }
def update_from_dict(self, data: Mapping[str, float]) -> None: def update_from_dict(self, data: Mapping[str, float]) -> None:
@ -60,6 +72,8 @@ class AgentWeights:
"liquidity": "liquidity", "liquidity": "liquidity",
"A_macro": "macro", "A_macro": "macro",
"macro": "macro", "macro": "macro",
"A_risk": "risk",
"risk": "risk",
} }
for key, attr in mapping.items(): for key, attr in mapping.items():
if key in data and data[key] is not None: if key in data and data[key] is not None:
@ -581,12 +595,25 @@ def save_config(cfg: AppConfig | None = None) -> None:
for code, dept in cfg.departments.items() for code, dept in cfg.departments.items()
}, },
} }
serialized = json.dumps(payload, ensure_ascii=False, indent=2)
try:
existing = path.read_text(encoding="utf-8")
except OSError:
existing = None
if existing == serialized:
LOGGER.info("配置未变更,跳过写入:%s", path)
return
try: try:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh: tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp")
json.dump(payload, fh, ensure_ascii=False, indent=2) tmp_path.write_text(serialized, encoding="utf-8")
tmp_path.replace(path)
LOGGER.info("配置已写入:%s", path)
except OSError: except OSError:
pass LOGGER.exception("配置写入失败:%s", path)
def _load_env_defaults(cfg: AppConfig) -> None: def _load_env_defaults(cfg: AppConfig) -> None: