add suspended stock filtering and reward chart visualization

This commit is contained in:
sam 2025-10-20 08:22:35 +08:00
parent 2779d21d97
commit d85efae082
3 changed files with 86 additions and 6 deletions

View File

@ -12,6 +12,7 @@ from datetime import date
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
from app.agents.registry import weight_map
from app.utils.db import db_session
from app.utils.data_access import DataBroker
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
@ -83,6 +84,7 @@ class DecisionEnv:
self._session: Optional[BacktestSession] = None
self._cumulative_reward = 0.0
self._day_index = 0
self._data_broker = DataBroker()
@property
def action_dim(self) -> int:
@ -101,6 +103,9 @@ class DecisionEnv:
self._day_index = 0
cfg = replace(self._template_cfg)
filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date)
if filtered_universe:
cfg = replace(cfg, universe=filtered_universe)
self._engine = BacktestEngine(cfg)
self._engine.weights = weight_map(self._baseline_weights)
if self._disable_departments:
@ -145,7 +150,8 @@ class DecisionEnv:
if engine is None or session is None:
raise RuntimeError("environment not initialised; call reset() before step()")
engine.weights = weight_map(weights)
normalized_weights = weight_map(weights)
engine.weights = normalized_weights
if self._disable_departments:
applied_controls = {}
engine.department_manager = None
@ -165,7 +171,7 @@ class DecisionEnv:
observation["failure"] = 1.0
info = {
"error": str(exc),
"weights": weights,
"weights": normalized_weights,
"department_controls": applied_controls,
"nav_series": failure_metrics.nav_series,
"trades": failure_metrics.trades,
@ -192,7 +198,7 @@ class DecisionEnv:
info = {
"nav_series": metrics.nav_series,
"trades": metrics.trades,
"weights": weights,
"weights": normalized_weights,
"risk_breakdown": metrics.risk_breakdown,
"risk_events": getattr(session.result, "risk_events", []),
"portfolio_snapshots": snapshots,
@ -585,6 +591,64 @@ class DecisionEnv:
return snapshots, trades
def _filter_active_universe(
self,
universe: Sequence[str],
start_date: date,
end_date: date,
) -> List[str]:
if not universe:
return list(universe)
broker = self._data_broker
start_key = start_date.strftime("%Y%m%d")
end_key = end_date.strftime("%Y%m%d")
active: List[str] = []
filtered: List[str] = []
for ts_code in universe:
try:
suspended_start = broker.fetch_flags(
"suspend",
ts_code,
start_key,
"",
[],
auto_refresh=False,
)
suspended_end = broker.fetch_flags(
"suspend",
ts_code,
end_key,
"",
[],
auto_refresh=False,
)
except Exception: # noqa: BLE001
LOGGER.debug(
"检测停牌状态失败 ts_code=%s start=%s end=%s",
ts_code,
start_key,
end_key,
extra=LOG_EXTRA,
)
active.append(ts_code)
continue
if suspended_start and suspended_end:
filtered.append(ts_code)
continue
active.append(ts_code)
if filtered:
LOGGER.info(
"过滤停牌标的 %s/%s%s",
len(filtered),
len(universe),
filtered[:10],
extra=LOG_EXTRA,
)
return active or list(universe)
@staticmethod
def _loads(payload: Any, default: Any) -> Any:
if not payload:

View File

@ -87,7 +87,8 @@ def _render_bandit_summary(
if weights_payload:
st.write("对应代理权重:")
st.json(weights_payload)
if st.button("将最佳权重写入默认配置", key="save_decision_env_bandit_weights"):
button_key = f"save_decision_env_bandit_weights_{bandit_state.get('experiment_id','current')}"
if st.button("将最佳权重写入默认配置", key=button_key):
try:
app_cfg.agent_weights.update_from_dict(weights_payload)
save_config(app_cfg)
@ -107,6 +108,19 @@ def _render_bandit_summary(
st.caption("完整的 RL/BOHB 日志请切换到“RL/BOHB 日志”标签查看。")
episodes = bandit_state.get("episodes") or []
if episodes:
df_rewards = pd.DataFrame(episodes)
reward_columns = [col for col in df_rewards.columns if "奖励" in col]
index_column = next((col for col in df_rewards.columns if "序号" in col), None)
if reward_columns and index_column:
chart_df = (
df_rewards[[index_column, reward_columns[0]]]
.rename(columns={index_column: "迭代序号", reward_columns[0]: "奖励"})
.set_index("迭代序号")
)
st.line_chart(chart_df, height=200)
def _render_bandit_logs(bandit_state: Optional[Dict[str, object]]) -> None:
"""Render the detailed BOHB/Bandit episode logs."""
@ -554,7 +568,7 @@ def _render_experiment_management(
selected_agents = st.multiselect(
"选择调参的代理权重",
agent_names,
default=agent_names[:2],
default=agent_names,
key="decision_env_agents",
)
@ -614,7 +628,7 @@ def _render_experiment_management(
selected_departments = st.multiselect(
"选择需要调整的部门",
dept_codes,
default=[],
default=dept_codes,
key="decision_env_departments",
)
tool_policy_values = ["auto", "none", "required"]

View File

@ -770,6 +770,8 @@ class DataBroker:
query = (
"SELECT 1 FROM suspend "
"WHERE ts_code = ? "
"AND suspend_date IS NOT NULL "
"AND suspend_date <> '' "
"AND suspend_date <= ? "
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
"LIMIT 1"