diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index 702afbb..462f013 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -12,6 +12,7 @@ from datetime import date from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig from app.agents.registry import weight_map from app.utils.db import db_session +from app.utils.data_access import DataBroker from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -83,6 +84,7 @@ class DecisionEnv: self._session: Optional[BacktestSession] = None self._cumulative_reward = 0.0 self._day_index = 0 + self._data_broker = DataBroker() @property def action_dim(self) -> int: @@ -101,6 +103,9 @@ class DecisionEnv: self._day_index = 0 cfg = replace(self._template_cfg) + filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date) + if filtered_universe: + cfg = replace(cfg, universe=filtered_universe) self._engine = BacktestEngine(cfg) self._engine.weights = weight_map(self._baseline_weights) if self._disable_departments: @@ -145,7 +150,8 @@ class DecisionEnv: if engine is None or session is None: raise RuntimeError("environment not initialised; call reset() before step()") - engine.weights = weight_map(weights) + normalized_weights = weight_map(weights) + engine.weights = normalized_weights if self._disable_departments: applied_controls = {} engine.department_manager = None @@ -165,7 +171,7 @@ class DecisionEnv: observation["failure"] = 1.0 info = { "error": str(exc), - "weights": weights, + "weights": normalized_weights, "department_controls": applied_controls, "nav_series": failure_metrics.nav_series, "trades": failure_metrics.trades, @@ -192,7 +198,7 @@ class DecisionEnv: info = { "nav_series": metrics.nav_series, "trades": metrics.trades, - "weights": weights, + "weights": normalized_weights, "risk_breakdown": metrics.risk_breakdown, "risk_events": getattr(session.result, "risk_events", []), "portfolio_snapshots": snapshots, @@ -585,6 +591,64 @@ class DecisionEnv: return snapshots, trades + def _filter_active_universe( + self, + universe: Sequence[str], + start_date: date, + end_date: date, + ) -> List[str]: + if not universe: + return list(universe) + + broker = self._data_broker + start_key = start_date.strftime("%Y%m%d") + end_key = end_date.strftime("%Y%m%d") + active: List[str] = [] + filtered: List[str] = [] + for ts_code in universe: + try: + suspended_start = broker.fetch_flags( + "suspend", + ts_code, + start_key, + "", + [], + auto_refresh=False, + ) + suspended_end = broker.fetch_flags( + "suspend", + ts_code, + end_key, + "", + [], + auto_refresh=False, + ) + except Exception: # noqa: BLE001 + LOGGER.debug( + "检测停牌状态失败 ts_code=%s start=%s end=%s", + ts_code, + start_key, + end_key, + extra=LOG_EXTRA, + ) + active.append(ts_code) + continue + + if suspended_start and suspended_end: + filtered.append(ts_code) + continue + active.append(ts_code) + + if filtered: + LOGGER.info( + "过滤停牌标的 %s/%s:%s", + len(filtered), + len(universe), + filtered[:10], + extra=LOG_EXTRA, + ) + return active or list(universe) + @staticmethod def _loads(payload: Any, default: Any) -> Any: if not payload: diff --git a/app/ui/views/tuning.py b/app/ui/views/tuning.py index 13df9ab..ab90931 100644 --- a/app/ui/views/tuning.py +++ b/app/ui/views/tuning.py @@ -87,7 +87,8 @@ def _render_bandit_summary( if weights_payload: st.write("对应代理权重:") st.json(weights_payload) - if st.button("将最佳权重写入默认配置", key="save_decision_env_bandit_weights"): + button_key = f"save_decision_env_bandit_weights_{bandit_state.get('experiment_id','current')}" + if st.button("将最佳权重写入默认配置", key=button_key): try: app_cfg.agent_weights.update_from_dict(weights_payload) save_config(app_cfg) @@ -107,6 +108,19 @@ def _render_bandit_summary( st.caption("完整的 RL/BOHB 日志请切换到“RL/BOHB 日志”标签查看。") + episodes = bandit_state.get("episodes") or [] + if episodes: + df_rewards = pd.DataFrame(episodes) + reward_columns = [col for col in df_rewards.columns if "奖励" in col] + index_column = next((col for col in df_rewards.columns if "序号" in col), None) + if reward_columns and index_column: + chart_df = ( + df_rewards[[index_column, reward_columns[0]]] + .rename(columns={index_column: "迭代序号", reward_columns[0]: "奖励"}) + .set_index("迭代序号") + ) + st.line_chart(chart_df, height=200) + def _render_bandit_logs(bandit_state: Optional[Dict[str, object]]) -> None: """Render the detailed BOHB/Bandit episode logs.""" @@ -554,7 +568,7 @@ def _render_experiment_management( selected_agents = st.multiselect( "选择调参的代理权重", agent_names, - default=agent_names[:2], + default=agent_names, key="decision_env_agents", ) @@ -614,7 +628,7 @@ def _render_experiment_management( selected_departments = st.multiselect( "选择需要调整的部门", dept_codes, - default=[], + default=dept_codes, key="decision_env_departments", ) tool_policy_values = ["auto", "none", "required"] diff --git a/app/utils/data_access.py b/app/utils/data_access.py index d7a0692..b8f88d4 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -770,6 +770,8 @@ class DataBroker: query = ( "SELECT 1 FROM suspend " "WHERE ts_code = ? " + "AND suspend_date IS NOT NULL " + "AND suspend_date <> '' " "AND suspend_date <= ? " "AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) " "LIMIT 1"