add suspended stock filtering and reward chart visualization
This commit is contained in:
parent
2779d21d97
commit
d85efae082
@ -12,6 +12,7 @@ from datetime import date
|
|||||||
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
|
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
|
||||||
from app.agents.registry import weight_map
|
from app.agents.registry import weight_map
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -83,6 +84,7 @@ class DecisionEnv:
|
|||||||
self._session: Optional[BacktestSession] = None
|
self._session: Optional[BacktestSession] = None
|
||||||
self._cumulative_reward = 0.0
|
self._cumulative_reward = 0.0
|
||||||
self._day_index = 0
|
self._day_index = 0
|
||||||
|
self._data_broker = DataBroker()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_dim(self) -> int:
|
def action_dim(self) -> int:
|
||||||
@ -101,6 +103,9 @@ class DecisionEnv:
|
|||||||
self._day_index = 0
|
self._day_index = 0
|
||||||
|
|
||||||
cfg = replace(self._template_cfg)
|
cfg = replace(self._template_cfg)
|
||||||
|
filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date)
|
||||||
|
if filtered_universe:
|
||||||
|
cfg = replace(cfg, universe=filtered_universe)
|
||||||
self._engine = BacktestEngine(cfg)
|
self._engine = BacktestEngine(cfg)
|
||||||
self._engine.weights = weight_map(self._baseline_weights)
|
self._engine.weights = weight_map(self._baseline_weights)
|
||||||
if self._disable_departments:
|
if self._disable_departments:
|
||||||
@ -145,7 +150,8 @@ class DecisionEnv:
|
|||||||
if engine is None or session is None:
|
if engine is None or session is None:
|
||||||
raise RuntimeError("environment not initialised; call reset() before step()")
|
raise RuntimeError("environment not initialised; call reset() before step()")
|
||||||
|
|
||||||
engine.weights = weight_map(weights)
|
normalized_weights = weight_map(weights)
|
||||||
|
engine.weights = normalized_weights
|
||||||
if self._disable_departments:
|
if self._disable_departments:
|
||||||
applied_controls = {}
|
applied_controls = {}
|
||||||
engine.department_manager = None
|
engine.department_manager = None
|
||||||
@ -165,7 +171,7 @@ class DecisionEnv:
|
|||||||
observation["failure"] = 1.0
|
observation["failure"] = 1.0
|
||||||
info = {
|
info = {
|
||||||
"error": str(exc),
|
"error": str(exc),
|
||||||
"weights": weights,
|
"weights": normalized_weights,
|
||||||
"department_controls": applied_controls,
|
"department_controls": applied_controls,
|
||||||
"nav_series": failure_metrics.nav_series,
|
"nav_series": failure_metrics.nav_series,
|
||||||
"trades": failure_metrics.trades,
|
"trades": failure_metrics.trades,
|
||||||
@ -192,7 +198,7 @@ class DecisionEnv:
|
|||||||
info = {
|
info = {
|
||||||
"nav_series": metrics.nav_series,
|
"nav_series": metrics.nav_series,
|
||||||
"trades": metrics.trades,
|
"trades": metrics.trades,
|
||||||
"weights": weights,
|
"weights": normalized_weights,
|
||||||
"risk_breakdown": metrics.risk_breakdown,
|
"risk_breakdown": metrics.risk_breakdown,
|
||||||
"risk_events": getattr(session.result, "risk_events", []),
|
"risk_events": getattr(session.result, "risk_events", []),
|
||||||
"portfolio_snapshots": snapshots,
|
"portfolio_snapshots": snapshots,
|
||||||
@ -585,6 +591,64 @@ class DecisionEnv:
|
|||||||
|
|
||||||
return snapshots, trades
|
return snapshots, trades
|
||||||
|
|
||||||
|
def _filter_active_universe(
|
||||||
|
self,
|
||||||
|
universe: Sequence[str],
|
||||||
|
start_date: date,
|
||||||
|
end_date: date,
|
||||||
|
) -> List[str]:
|
||||||
|
if not universe:
|
||||||
|
return list(universe)
|
||||||
|
|
||||||
|
broker = self._data_broker
|
||||||
|
start_key = start_date.strftime("%Y%m%d")
|
||||||
|
end_key = end_date.strftime("%Y%m%d")
|
||||||
|
active: List[str] = []
|
||||||
|
filtered: List[str] = []
|
||||||
|
for ts_code in universe:
|
||||||
|
try:
|
||||||
|
suspended_start = broker.fetch_flags(
|
||||||
|
"suspend",
|
||||||
|
ts_code,
|
||||||
|
start_key,
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
auto_refresh=False,
|
||||||
|
)
|
||||||
|
suspended_end = broker.fetch_flags(
|
||||||
|
"suspend",
|
||||||
|
ts_code,
|
||||||
|
end_key,
|
||||||
|
"",
|
||||||
|
[],
|
||||||
|
auto_refresh=False,
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"检测停牌状态失败 ts_code=%s start=%s end=%s",
|
||||||
|
ts_code,
|
||||||
|
start_key,
|
||||||
|
end_key,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
active.append(ts_code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if suspended_start and suspended_end:
|
||||||
|
filtered.append(ts_code)
|
||||||
|
continue
|
||||||
|
active.append(ts_code)
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
LOGGER.info(
|
||||||
|
"过滤停牌标的 %s/%s:%s",
|
||||||
|
len(filtered),
|
||||||
|
len(universe),
|
||||||
|
filtered[:10],
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return active or list(universe)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _loads(payload: Any, default: Any) -> Any:
|
def _loads(payload: Any, default: Any) -> Any:
|
||||||
if not payload:
|
if not payload:
|
||||||
|
|||||||
@ -87,7 +87,8 @@ def _render_bandit_summary(
|
|||||||
if weights_payload:
|
if weights_payload:
|
||||||
st.write("对应代理权重:")
|
st.write("对应代理权重:")
|
||||||
st.json(weights_payload)
|
st.json(weights_payload)
|
||||||
if st.button("将最佳权重写入默认配置", key="save_decision_env_bandit_weights"):
|
button_key = f"save_decision_env_bandit_weights_{bandit_state.get('experiment_id','current')}"
|
||||||
|
if st.button("将最佳权重写入默认配置", key=button_key):
|
||||||
try:
|
try:
|
||||||
app_cfg.agent_weights.update_from_dict(weights_payload)
|
app_cfg.agent_weights.update_from_dict(weights_payload)
|
||||||
save_config(app_cfg)
|
save_config(app_cfg)
|
||||||
@ -107,6 +108,19 @@ def _render_bandit_summary(
|
|||||||
|
|
||||||
st.caption("完整的 RL/BOHB 日志请切换到“RL/BOHB 日志”标签查看。")
|
st.caption("完整的 RL/BOHB 日志请切换到“RL/BOHB 日志”标签查看。")
|
||||||
|
|
||||||
|
episodes = bandit_state.get("episodes") or []
|
||||||
|
if episodes:
|
||||||
|
df_rewards = pd.DataFrame(episodes)
|
||||||
|
reward_columns = [col for col in df_rewards.columns if "奖励" in col]
|
||||||
|
index_column = next((col for col in df_rewards.columns if "序号" in col), None)
|
||||||
|
if reward_columns and index_column:
|
||||||
|
chart_df = (
|
||||||
|
df_rewards[[index_column, reward_columns[0]]]
|
||||||
|
.rename(columns={index_column: "迭代序号", reward_columns[0]: "奖励"})
|
||||||
|
.set_index("迭代序号")
|
||||||
|
)
|
||||||
|
st.line_chart(chart_df, height=200)
|
||||||
|
|
||||||
|
|
||||||
def _render_bandit_logs(bandit_state: Optional[Dict[str, object]]) -> None:
|
def _render_bandit_logs(bandit_state: Optional[Dict[str, object]]) -> None:
|
||||||
"""Render the detailed BOHB/Bandit episode logs."""
|
"""Render the detailed BOHB/Bandit episode logs."""
|
||||||
@ -554,7 +568,7 @@ def _render_experiment_management(
|
|||||||
selected_agents = st.multiselect(
|
selected_agents = st.multiselect(
|
||||||
"选择调参的代理权重",
|
"选择调参的代理权重",
|
||||||
agent_names,
|
agent_names,
|
||||||
default=agent_names[:2],
|
default=agent_names,
|
||||||
key="decision_env_agents",
|
key="decision_env_agents",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -614,7 +628,7 @@ def _render_experiment_management(
|
|||||||
selected_departments = st.multiselect(
|
selected_departments = st.multiselect(
|
||||||
"选择需要调整的部门",
|
"选择需要调整的部门",
|
||||||
dept_codes,
|
dept_codes,
|
||||||
default=[],
|
default=dept_codes,
|
||||||
key="decision_env_departments",
|
key="decision_env_departments",
|
||||||
)
|
)
|
||||||
tool_policy_values = ["auto", "none", "required"]
|
tool_policy_values = ["auto", "none", "required"]
|
||||||
|
|||||||
@ -770,6 +770,8 @@ class DataBroker:
|
|||||||
query = (
|
query = (
|
||||||
"SELECT 1 FROM suspend "
|
"SELECT 1 FROM suspend "
|
||||||
"WHERE ts_code = ? "
|
"WHERE ts_code = ? "
|
||||||
|
"AND suspend_date IS NOT NULL "
|
||||||
|
"AND suspend_date <> '' "
|
||||||
"AND suspend_date <= ? "
|
"AND suspend_date <= ? "
|
||||||
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
|
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
|
||||||
"LIMIT 1"
|
"LIMIT 1"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user