This commit is contained in:
sam 2025-09-29 15:41:35 +08:00
parent 1773929431
commit a6564cdced
2 changed files with 49 additions and 12 deletions

View File

@ -914,8 +914,13 @@ def render_backtest() -> None:
"保存这些权重为默认配置",
key="save_decision_env_weights_single",
):
try:
cfg.agent_weights.update_from_dict(weights_dict)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
nav_series = info.get("nav_series")
@ -1085,8 +1090,13 @@ def render_backtest() -> None:
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)

View File

@ -3,11 +3,15 @@ from __future__ import annotations
from dataclasses import dataclass, field
import json
import logging
import os
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional
LOGGER = logging.getLogger(__name__)
def _default_root() -> Path:
return Path(__file__).resolve().parents[2] / "app" / "data"
@ -26,7 +30,13 @@ class DataPaths:
self.database = self.root / "llm_quant.db"
self.backups = self.root / "backups"
self.backups.mkdir(parents=True, exist_ok=True)
self.config_file = self.root / "config.json"
config_override = os.getenv("LLM_QUANT_CONFIG_PATH")
if config_override:
config_path = Path(config_override).expanduser()
config_path.parent.mkdir(parents=True, exist_ok=True)
else:
config_path = self.root / "config.json"
self.config_file = config_path
@dataclass
@ -38,6 +48,7 @@ class AgentWeights:
news: float = 0.20
liquidity: float = 0.15
macro: float = 0.15
risk: float = 1.0
def as_dict(self) -> Dict[str, float]:
return {
@ -46,6 +57,7 @@ class AgentWeights:
"A_news": self.news,
"A_liq": self.liquidity,
"A_macro": self.macro,
"A_risk": self.risk,
}
def update_from_dict(self, data: Mapping[str, float]) -> None:
@ -60,6 +72,8 @@ class AgentWeights:
"liquidity": "liquidity",
"A_macro": "macro",
"macro": "macro",
"A_risk": "risk",
"risk": "risk",
}
for key, attr in mapping.items():
if key in data and data[key] is not None:
@ -581,12 +595,25 @@ def save_config(cfg: AppConfig | None = None) -> None:
for code, dept in cfg.departments.items()
},
}
serialized = json.dumps(payload, ensure_ascii=False, indent=2)
try:
existing = path.read_text(encoding="utf-8")
except OSError:
existing = None
if existing == serialized:
LOGGER.info("配置未变更,跳过写入:%s", path)
return
try:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(payload, fh, ensure_ascii=False, indent=2)
tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp")
tmp_path.write_text(serialized, encoding="utf-8")
tmp_path.replace(path)
LOGGER.info("配置已写入:%s", path)
except OSError:
pass
LOGGER.exception("配置写入失败:%s", path)
def _load_env_defaults(cfg: AppConfig) -> None: