update
This commit is contained in:
parent
b3f2f5b4fc
commit
37fd7f80ce
@ -2,7 +2,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -62,58 +61,42 @@ from app.utils.tuning import log_tuning_result
|
|||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "ui"}
|
LOG_EXTRA = {"stage": "ui"}
|
||||||
_SIDEBAR_THROTTLE_SECONDS = 0.75
|
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||||||
|
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||||
|
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
|
||||||
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
|
||||||
_update_dashboard_sidebar(metrics, throttled=True)
|
|
||||||
|
|
||||||
|
|
||||||
def render_global_dashboard() -> None:
|
def render_global_dashboard() -> None:
|
||||||
"""Render a persistent sidebar with realtime LLM stats and recent decisions."""
|
"""Render a persistent sidebar with realtime LLM stats and recent decisions."""
|
||||||
|
|
||||||
|
global _DASHBOARD_CONTAINERS
|
||||||
|
global _DASHBOARD_ELEMENTS
|
||||||
|
|
||||||
metrics_container = st.sidebar.container()
|
metrics_container = st.sidebar.container()
|
||||||
decisions_container = st.sidebar.container()
|
decisions_container = st.sidebar.container()
|
||||||
st.session_state["dashboard_containers"] = (metrics_container, decisions_container)
|
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
|
||||||
_ensure_dashboard_elements(metrics_container, decisions_container)
|
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||||
if not st.session_state.get("dashboard_listener_registered"):
|
|
||||||
register_llm_metrics_listener(_sidebar_metrics_listener)
|
|
||||||
st.session_state["dashboard_listener_registered"] = True
|
|
||||||
_update_dashboard_sidebar()
|
_update_dashboard_sidebar()
|
||||||
|
|
||||||
|
|
||||||
def _update_dashboard_sidebar(
|
def _update_dashboard_sidebar(
|
||||||
metrics: Optional[Dict[str, object]] = None,
|
metrics: Optional[Dict[str, object]] = None,
|
||||||
*,
|
|
||||||
throttled: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
containers = st.session_state.get("dashboard_containers")
|
global _DASHBOARD_CONTAINERS
|
||||||
|
global _DASHBOARD_ELEMENTS
|
||||||
|
|
||||||
|
containers = _DASHBOARD_CONTAINERS
|
||||||
if not containers:
|
if not containers:
|
||||||
return
|
return
|
||||||
metrics_container, decisions_container = containers
|
metrics_container, decisions_container = containers
|
||||||
elements = st.session_state.get("dashboard_elements")
|
elements = _DASHBOARD_ELEMENTS
|
||||||
if elements is None:
|
if elements is None:
|
||||||
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
|
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||||
|
_DASHBOARD_ELEMENTS = elements
|
||||||
if throttled:
|
|
||||||
now = time.monotonic()
|
|
||||||
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
|
|
||||||
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
|
|
||||||
if metrics is not None:
|
|
||||||
st.session_state["dashboard_pending_metrics"] = metrics
|
|
||||||
return
|
|
||||||
st.session_state["dashboard_last_update_ts"] = now
|
|
||||||
else:
|
|
||||||
st.session_state["dashboard_last_update_ts"] = time.monotonic()
|
|
||||||
|
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
metrics = st.session_state.pop("dashboard_pending_metrics", None)
|
metrics = snapshot_llm_metrics()
|
||||||
if metrics is None:
|
|
||||||
metrics = snapshot_llm_metrics()
|
|
||||||
else:
|
|
||||||
st.session_state.pop("dashboard_pending_metrics", None)
|
|
||||||
|
|
||||||
metrics = metrics or snapshot_llm_metrics()
|
|
||||||
|
|
||||||
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
|
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
|
||||||
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
||||||
@ -160,10 +143,6 @@ def _update_dashboard_sidebar(
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
|
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
|
||||||
elements = st.session_state.get("dashboard_elements")
|
|
||||||
if elements:
|
|
||||||
return elements
|
|
||||||
|
|
||||||
metrics_container.header("系统监控")
|
metrics_container.header("系统监控")
|
||||||
col_a, col_b, col_c = metrics_container.columns(3)
|
col_a, col_b, col_c = metrics_container.columns(3)
|
||||||
metrics_calls = col_a.empty()
|
metrics_calls = col_a.empty()
|
||||||
@ -184,7 +163,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
|
|||||||
"model_distribution": model_distribution,
|
"model_distribution": model_distribution,
|
||||||
"decisions_list": decisions_list,
|
"decisions_list": decisions_list,
|
||||||
}
|
}
|
||||||
st.session_state["dashboard_elements"] = elements
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||||
@ -835,12 +813,20 @@ def render_backtest() -> None:
|
|||||||
action_values.append(action_val)
|
action_values.append(action_val)
|
||||||
|
|
||||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||||
|
just_finished_single = False
|
||||||
if run_decision_env:
|
if run_decision_env:
|
||||||
if not selected_agents:
|
if not selected_agents:
|
||||||
st.warning("请至少选择一个代理进行调参。")
|
st.warning("请至少选择一个代理进行调参。")
|
||||||
elif not range_valid:
|
elif not range_valid:
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||||
else:
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||||||
|
selected_agents,
|
||||||
|
action_values,
|
||||||
|
disable_departments,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
baseline_weights = cfg.agent_weights.as_dict()
|
baseline_weights = cfg.agent_weights.as_dict()
|
||||||
for agent in agent_objects:
|
for agent in agent_objects:
|
||||||
baseline_weights.setdefault(agent.name, 1.0)
|
baseline_weights.setdefault(agent.name, 1.0)
|
||||||
@ -869,74 +855,126 @@ def render_backtest() -> None:
|
|||||||
disable_departments=disable_departments,
|
disable_departments=disable_departments,
|
||||||
)
|
)
|
||||||
env.reset()
|
env.reset()
|
||||||
|
LOGGER.debug(
|
||||||
|
"离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s",
|
||||||
|
bt_cfg_env,
|
||||||
|
len(specs),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
with st.spinner("正在执行离线调参……"):
|
with st.spinner("正在执行离线调参……"):
|
||||||
try:
|
try:
|
||||||
observation, reward, done, info = env.step(action_values)
|
observation, reward, done, info = env.step(action_values)
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(单次)完成,obs=%s reward=%.4f done=%s",
|
||||||
|
observation,
|
||||||
|
reward,
|
||||||
|
done,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
||||||
st.error(f"离线调参失败:{exc}")
|
st.error(f"离线调参失败:{exc}")
|
||||||
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
else:
|
else:
|
||||||
if observation.get("failure"):
|
if observation.get("failure"):
|
||||||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
||||||
st.json(observation)
|
st.json(observation)
|
||||||
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
else:
|
else:
|
||||||
st.success("离线调参完成")
|
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||||
col_metrics = st.columns(4)
|
resolved_strategy = strategy_label or "DecisionEnv"
|
||||||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
|
||||||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
|
||||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
|
||||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
|
||||||
|
|
||||||
st.write("调参后权重:")
|
|
||||||
weights_dict = info.get("weights", {})
|
|
||||||
st.json(weights_dict)
|
|
||||||
action_payload = {
|
action_payload = {
|
||||||
name: value
|
name: value
|
||||||
for name, value in zip(selected_agents, action_values)
|
for name, value in zip(selected_agents, action_values)
|
||||||
}
|
}
|
||||||
metrics_payload = dict(observation)
|
metrics_payload = dict(observation)
|
||||||
metrics_payload["reward"] = reward
|
metrics_payload["reward"] = reward
|
||||||
|
log_success = False
|
||||||
try:
|
try:
|
||||||
log_tuning_result(
|
log_tuning_result(
|
||||||
experiment_id=experiment_id or str(uuid.uuid4()),
|
experiment_id=resolved_experiment_id,
|
||||||
strategy=strategy_label or "DecisionEnv",
|
strategy=resolved_strategy,
|
||||||
action=action_payload,
|
action=action_payload,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
metrics=metrics_payload,
|
metrics=metrics_payload,
|
||||||
weights=weights_dict,
|
weights=info.get("weights", {}),
|
||||||
)
|
)
|
||||||
st.caption("调参结果已写入 tuning_results 表。")
|
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||||
|
else:
|
||||||
|
log_success = True
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(单次)日志写入成功:experiment=%s strategy=%s",
|
||||||
|
resolved_experiment_id,
|
||||||
|
resolved_strategy,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
|
||||||
|
"observation": dict(observation),
|
||||||
|
"reward": float(reward),
|
||||||
|
"weights": info.get("weights", {}),
|
||||||
|
"nav_series": info.get("nav_series"),
|
||||||
|
"trades": info.get("trades"),
|
||||||
|
"selected_agents": list(selected_agents),
|
||||||
|
"action_values": list(action_values),
|
||||||
|
"experiment_id": resolved_experiment_id,
|
||||||
|
"strategy_label": resolved_strategy,
|
||||||
|
"logged": log_success,
|
||||||
|
}
|
||||||
|
just_finished_single = True
|
||||||
|
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
|
||||||
|
if single_result:
|
||||||
|
if just_finished_single:
|
||||||
|
st.success("离线调参完成")
|
||||||
|
else:
|
||||||
|
st.success("离线调参结果(最近一次运行)")
|
||||||
|
st.caption(
|
||||||
|
f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
|
||||||
|
)
|
||||||
|
observation = single_result.get("observation", {})
|
||||||
|
reward = float(single_result.get("reward", 0.0))
|
||||||
|
col_metrics = st.columns(4)
|
||||||
|
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
||||||
|
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
||||||
|
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||||
|
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||||
|
|
||||||
if weights_dict:
|
weights_dict = single_result.get("weights") or {}
|
||||||
if st.button(
|
if weights_dict:
|
||||||
"保存这些权重为默认配置",
|
st.write("调参后权重:")
|
||||||
key="save_decision_env_weights_single",
|
st.json(weights_dict)
|
||||||
):
|
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
|
||||||
try:
|
try:
|
||||||
cfg.agent_weights.update_from_dict(weights_dict)
|
cfg.agent_weights.update_from_dict(weights_dict)
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||||
st.error(f"写入配置失败:{exc}")
|
st.error(f"写入配置失败:{exc}")
|
||||||
else:
|
else:
|
||||||
st.success("代理权重已写入 config.json")
|
st.success("代理权重已写入 config.json")
|
||||||
|
|
||||||
nav_series = info.get("nav_series")
|
if single_result.get("logged"):
|
||||||
if nav_series:
|
st.caption("调参结果已写入 tuning_results 表。")
|
||||||
try:
|
|
||||||
nav_df = pd.DataFrame(nav_series)
|
nav_series = single_result.get("nav_series") or []
|
||||||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
if nav_series:
|
||||||
nav_df = nav_df.sort_values("trade_date")
|
try:
|
||||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
nav_df = pd.DataFrame(nav_series)
|
||||||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
||||||
except Exception: # noqa: BLE001
|
nav_df = nav_df.sort_values("trade_date")
|
||||||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
||||||
trades = info.get("trades")
|
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
||||||
if trades:
|
except Exception: # noqa: BLE001
|
||||||
st.write("成交记录:")
|
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
||||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
|
||||||
|
trades = single_result.get("trades") or []
|
||||||
|
if trades:
|
||||||
|
st.write("成交记录:")
|
||||||
|
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||||
|
|
||||||
|
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||||
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
|
st.success("已清除单次调参结果缓存。")
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||||||
@ -954,16 +992,28 @@ def render_backtest() -> None:
|
|||||||
key="decision_env_batch_actions",
|
key="decision_env_batch_actions",
|
||||||
)
|
)
|
||||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||||
|
batch_just_ran = False
|
||||||
if run_batch:
|
if run_batch:
|
||||||
if not selected_agents:
|
if not selected_agents:
|
||||||
st.warning("请先选择调参代理。")
|
st.warning("请先选择调参代理。")
|
||||||
elif not range_valid:
|
elif not range_valid:
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||||
else:
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||||||
|
selected_agents,
|
||||||
|
disable_departments,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
||||||
if not lines:
|
if not lines:
|
||||||
st.warning("请在文本框中输入至少一组动作。")
|
st.warning("请在文本框中输入至少一组动作。")
|
||||||
else:
|
else:
|
||||||
|
LOGGER.debug(
|
||||||
|
"离线调参(批量)原始输入=%s",
|
||||||
|
lines,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
parsed_actions: List[List[float]] = []
|
parsed_actions: List[List[float]] = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
try:
|
try:
|
||||||
@ -978,6 +1028,11 @@ def render_backtest() -> None:
|
|||||||
break
|
break
|
||||||
parsed_actions.append(values)
|
parsed_actions.append(values)
|
||||||
if parsed_actions:
|
if parsed_actions:
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(批量)解析动作成功,数量=%s",
|
||||||
|
len(parsed_actions),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
baseline_weights = cfg.agent_weights.as_dict()
|
baseline_weights = cfg.agent_weights.as_dict()
|
||||||
for agent in agent_objects:
|
for agent in agent_objects:
|
||||||
baseline_weights.setdefault(agent.name, 1.0)
|
baseline_weights.setdefault(agent.name, 1.0)
|
||||||
@ -1006,6 +1061,14 @@ def render_backtest() -> None:
|
|||||||
disable_departments=disable_departments,
|
disable_departments=disable_departments,
|
||||||
)
|
)
|
||||||
results: List[Dict[str, object]] = []
|
results: List[Dict[str, object]] = []
|
||||||
|
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||||
|
resolved_strategy = strategy_label or "DecisionEnv"
|
||||||
|
LOGGER.debug(
|
||||||
|
"离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s",
|
||||||
|
bt_cfg_env,
|
||||||
|
len(parsed_actions),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
with st.spinner("正在批量执行调参……"):
|
with st.spinner("正在批量执行调参……"):
|
||||||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
for idx, action_vals in enumerate(parsed_actions, start=1):
|
||||||
env.reset()
|
env.reset()
|
||||||
@ -1032,6 +1095,13 @@ def render_backtest() -> None:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"离线调参(批量)第 %s 组完成,reward=%.4f obs=%s",
|
||||||
|
idx,
|
||||||
|
reward,
|
||||||
|
observation,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
action_payload = {
|
action_payload = {
|
||||||
name: value
|
name: value
|
||||||
for name, value in zip(selected_agents, action_vals)
|
for name, value in zip(selected_agents, action_vals)
|
||||||
@ -1041,8 +1111,8 @@ def render_backtest() -> None:
|
|||||||
weights_payload = info.get("weights", {})
|
weights_payload = info.get("weights", {})
|
||||||
try:
|
try:
|
||||||
log_tuning_result(
|
log_tuning_result(
|
||||||
experiment_id=experiment_id or str(uuid.uuid4()),
|
experiment_id=resolved_experiment_id,
|
||||||
strategy=strategy_label or "DecisionEnv",
|
strategy=resolved_strategy,
|
||||||
action=action_payload,
|
action=action_payload,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
metrics=metrics_payload,
|
metrics=metrics_payload,
|
||||||
@ -1062,46 +1132,74 @@ def render_backtest() -> None:
|
|||||||
"权重": weights_payload,
|
"权重": weights_payload,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if results:
|
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||||||
st.write("批量调参结果:")
|
"results": results,
|
||||||
results_df = pd.DataFrame(results)
|
"selectable": [
|
||||||
st.dataframe(results_df, hide_index=True, width='stretch')
|
|
||||||
selectable = [
|
|
||||||
row
|
row
|
||||||
for row in results
|
for row in results
|
||||||
if row.get("状态") == "ok" and row.get("权重")
|
if row.get("状态") == "ok" and row.get("权重")
|
||||||
]
|
],
|
||||||
if selectable:
|
"experiment_id": resolved_experiment_id,
|
||||||
option_labels = [
|
"strategy_label": resolved_strategy,
|
||||||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
}
|
||||||
for row in selectable
|
batch_just_ran = True
|
||||||
]
|
LOGGER.info(
|
||||||
selected_label = st.selectbox(
|
"离线调参(批量)执行结束,总结果条数=%s",
|
||||||
"选择要保存的记录",
|
len(results),
|
||||||
option_labels,
|
extra=LOG_EXTRA,
|
||||||
key="decision_env_batch_select",
|
)
|
||||||
)
|
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
|
||||||
selected_row = None
|
if batch_state:
|
||||||
for label, row in zip(option_labels, selectable):
|
results = batch_state.get("results") or []
|
||||||
if label == selected_label:
|
if results:
|
||||||
selected_row = row
|
if batch_just_ran:
|
||||||
break
|
st.success("批量调参完成")
|
||||||
if selected_row and st.button(
|
else:
|
||||||
"保存所选权重为默认配置",
|
st.success("批量调参结果(最近一次运行)")
|
||||||
key="save_decision_env_weights_batch",
|
st.caption(
|
||||||
):
|
f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
|
||||||
try:
|
)
|
||||||
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
results_df = pd.DataFrame(results)
|
||||||
save_config(cfg)
|
st.write("批量调参结果:")
|
||||||
except Exception as exc: # noqa: BLE001
|
st.dataframe(results_df, hide_index=True, width='stretch')
|
||||||
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
selectable = batch_state.get("selectable") or []
|
||||||
st.error(f"写入配置失败:{exc}")
|
if selectable:
|
||||||
else:
|
option_labels = [
|
||||||
st.success(
|
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
||||||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
for row in selectable
|
||||||
)
|
]
|
||||||
else:
|
selected_label = st.selectbox(
|
||||||
st.caption("暂无成功的结果可供保存。")
|
"选择要保存的记录",
|
||||||
|
option_labels,
|
||||||
|
key="decision_env_batch_select",
|
||||||
|
)
|
||||||
|
selected_row = None
|
||||||
|
for label, row in zip(option_labels, selectable):
|
||||||
|
if label == selected_label:
|
||||||
|
selected_row = row
|
||||||
|
break
|
||||||
|
if selected_row and st.button(
|
||||||
|
"保存所选权重为默认配置",
|
||||||
|
key="save_decision_env_weights_batch",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
||||||
|
save_config(cfg)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||||
|
st.error(f"写入配置失败:{exc}")
|
||||||
|
else:
|
||||||
|
st.success(
|
||||||
|
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.caption("暂无成功的结果可供保存。")
|
||||||
|
else:
|
||||||
|
st.caption("批量调参在最近一次执行中未产生结果。")
|
||||||
|
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
|
||||||
|
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
|
||||||
|
st.session_state.pop("decision_env_batch_select", None)
|
||||||
|
st.success("已清除批量调参结果缓存。")
|
||||||
|
|
||||||
|
|
||||||
def render_settings() -> None:
|
def render_settings() -> None:
|
||||||
@ -1673,7 +1771,7 @@ def render_tests() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(candle_fig, width='stretch')
|
st.plotly_chart(candle_fig, use_container_width=True)
|
||||||
|
|
||||||
vol_fig = px.bar(
|
vol_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -1683,7 +1781,7 @@ def render_tests() -> None:
|
|||||||
title="成交量",
|
title="成交量",
|
||||||
)
|
)
|
||||||
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(vol_fig, width='stretch')
|
st.plotly_chart(vol_fig, use_container_width=True)
|
||||||
|
|
||||||
amt_fig = px.bar(
|
amt_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -1693,7 +1791,7 @@ def render_tests() -> None:
|
|||||||
title="成交额",
|
title="成交额",
|
||||||
)
|
)
|
||||||
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(amt_fig, width='stretch')
|
st.plotly_chart(amt_fig, use_container_width=True)
|
||||||
|
|
||||||
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
||||||
box_fig = px.box(
|
box_fig = px.box(
|
||||||
@ -1704,7 +1802,7 @@ def render_tests() -> None:
|
|||||||
title="月度收盘价分布",
|
title="月度收盘价分布",
|
||||||
)
|
)
|
||||||
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(box_fig, width='stretch')
|
st.plotly_chart(box_fig, use_container_width=True)
|
||||||
|
|
||||||
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
||||||
st.dataframe(df_reset.tail(20), width='stretch')
|
st.dataframe(df_reset.tail(20), width='stretch')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user