add suspended stock filtering and reward chart visualization
This commit is contained in:
parent
2779d21d97
commit
d85efae082
@ -12,6 +12,7 @@ from datetime import date
|
||||
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
|
||||
from app.agents.registry import weight_map
|
||||
from app.utils.db import db_session
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -83,6 +84,7 @@ class DecisionEnv:
|
||||
self._session: Optional[BacktestSession] = None
|
||||
self._cumulative_reward = 0.0
|
||||
self._day_index = 0
|
||||
self._data_broker = DataBroker()
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
@ -101,6 +103,9 @@ class DecisionEnv:
|
||||
self._day_index = 0
|
||||
|
||||
cfg = replace(self._template_cfg)
|
||||
filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date)
|
||||
if filtered_universe:
|
||||
cfg = replace(cfg, universe=filtered_universe)
|
||||
self._engine = BacktestEngine(cfg)
|
||||
self._engine.weights = weight_map(self._baseline_weights)
|
||||
if self._disable_departments:
|
||||
@ -145,7 +150,8 @@ class DecisionEnv:
|
||||
if engine is None or session is None:
|
||||
raise RuntimeError("environment not initialised; call reset() before step()")
|
||||
|
||||
engine.weights = weight_map(weights)
|
||||
normalized_weights = weight_map(weights)
|
||||
engine.weights = normalized_weights
|
||||
if self._disable_departments:
|
||||
applied_controls = {}
|
||||
engine.department_manager = None
|
||||
@ -165,7 +171,7 @@ class DecisionEnv:
|
||||
observation["failure"] = 1.0
|
||||
info = {
|
||||
"error": str(exc),
|
||||
"weights": weights,
|
||||
"weights": normalized_weights,
|
||||
"department_controls": applied_controls,
|
||||
"nav_series": failure_metrics.nav_series,
|
||||
"trades": failure_metrics.trades,
|
||||
@ -192,7 +198,7 @@ class DecisionEnv:
|
||||
info = {
|
||||
"nav_series": metrics.nav_series,
|
||||
"trades": metrics.trades,
|
||||
"weights": weights,
|
||||
"weights": normalized_weights,
|
||||
"risk_breakdown": metrics.risk_breakdown,
|
||||
"risk_events": getattr(session.result, "risk_events", []),
|
||||
"portfolio_snapshots": snapshots,
|
||||
@ -585,6 +591,64 @@ class DecisionEnv:
|
||||
|
||||
return snapshots, trades
|
||||
|
||||
def _filter_active_universe(
|
||||
self,
|
||||
universe: Sequence[str],
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> List[str]:
|
||||
if not universe:
|
||||
return list(universe)
|
||||
|
||||
broker = self._data_broker
|
||||
start_key = start_date.strftime("%Y%m%d")
|
||||
end_key = end_date.strftime("%Y%m%d")
|
||||
active: List[str] = []
|
||||
filtered: List[str] = []
|
||||
for ts_code in universe:
|
||||
try:
|
||||
suspended_start = broker.fetch_flags(
|
||||
"suspend",
|
||||
ts_code,
|
||||
start_key,
|
||||
"",
|
||||
[],
|
||||
auto_refresh=False,
|
||||
)
|
||||
suspended_end = broker.fetch_flags(
|
||||
"suspend",
|
||||
ts_code,
|
||||
end_key,
|
||||
"",
|
||||
[],
|
||||
auto_refresh=False,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"检测停牌状态失败 ts_code=%s start=%s end=%s",
|
||||
ts_code,
|
||||
start_key,
|
||||
end_key,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
active.append(ts_code)
|
||||
continue
|
||||
|
||||
if suspended_start and suspended_end:
|
||||
filtered.append(ts_code)
|
||||
continue
|
||||
active.append(ts_code)
|
||||
|
||||
if filtered:
|
||||
LOGGER.info(
|
||||
"过滤停牌标的 %s/%s:%s",
|
||||
len(filtered),
|
||||
len(universe),
|
||||
filtered[:10],
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return active or list(universe)
|
||||
|
||||
@staticmethod
|
||||
def _loads(payload: Any, default: Any) -> Any:
|
||||
if not payload:
|
||||
|
||||
@ -87,7 +87,8 @@ def _render_bandit_summary(
|
||||
if weights_payload:
|
||||
st.write("对应代理权重:")
|
||||
st.json(weights_payload)
|
||||
if st.button("将最佳权重写入默认配置", key="save_decision_env_bandit_weights"):
|
||||
button_key = f"save_decision_env_bandit_weights_{bandit_state.get('experiment_id','current')}"
|
||||
if st.button("将最佳权重写入默认配置", key=button_key):
|
||||
try:
|
||||
app_cfg.agent_weights.update_from_dict(weights_payload)
|
||||
save_config(app_cfg)
|
||||
@ -107,6 +108,19 @@ def _render_bandit_summary(
|
||||
|
||||
st.caption("完整的 RL/BOHB 日志请切换到“RL/BOHB 日志”标签查看。")
|
||||
|
||||
episodes = bandit_state.get("episodes") or []
|
||||
if episodes:
|
||||
df_rewards = pd.DataFrame(episodes)
|
||||
reward_columns = [col for col in df_rewards.columns if "奖励" in col]
|
||||
index_column = next((col for col in df_rewards.columns if "序号" in col), None)
|
||||
if reward_columns and index_column:
|
||||
chart_df = (
|
||||
df_rewards[[index_column, reward_columns[0]]]
|
||||
.rename(columns={index_column: "迭代序号", reward_columns[0]: "奖励"})
|
||||
.set_index("迭代序号")
|
||||
)
|
||||
st.line_chart(chart_df, height=200)
|
||||
|
||||
|
||||
def _render_bandit_logs(bandit_state: Optional[Dict[str, object]]) -> None:
|
||||
"""Render the detailed BOHB/Bandit episode logs."""
|
||||
@ -554,7 +568,7 @@ def _render_experiment_management(
|
||||
selected_agents = st.multiselect(
|
||||
"选择调参的代理权重",
|
||||
agent_names,
|
||||
default=agent_names[:2],
|
||||
default=agent_names,
|
||||
key="decision_env_agents",
|
||||
)
|
||||
|
||||
@ -614,7 +628,7 @@ def _render_experiment_management(
|
||||
selected_departments = st.multiselect(
|
||||
"选择需要调整的部门",
|
||||
dept_codes,
|
||||
default=[],
|
||||
default=dept_codes,
|
||||
key="decision_env_departments",
|
||||
)
|
||||
tool_policy_values = ["auto", "none", "required"]
|
||||
|
||||
@ -770,6 +770,8 @@ class DataBroker:
|
||||
query = (
|
||||
"SELECT 1 FROM suspend "
|
||||
"WHERE ts_code = ? "
|
||||
"AND suspend_date IS NOT NULL "
|
||||
"AND suspend_date <> '' "
|
||||
"AND suspend_date <= ? "
|
||||
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
|
||||
"LIMIT 1"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user