update
This commit is contained in:
parent
b4bd9fc9c5
commit
f29bb99b68
@ -1,14 +1,15 @@
|
|||||||
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, replace
|
import json
|
||||||
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from .engine import BacktestEngine, BacktestResult, BtConfig
|
from .engine import BacktestEngine, BacktestResult, BtConfig
|
||||||
from app.agents.game import Decision
|
from app.agents.game import Decision
|
||||||
from app.agents.registry import weight_map
|
from app.agents.registry import weight_map
|
||||||
|
from app.utils.db import db_session
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -37,6 +38,7 @@ class EpisodeMetrics:
|
|||||||
nav_series: List[Dict[str, float]]
|
nav_series: List[Dict[str, float]]
|
||||||
trades: List[Dict[str, object]]
|
trades: List[Dict[str, object]]
|
||||||
turnover: float
|
turnover: float
|
||||||
|
turnover_value: float
|
||||||
trade_count: int
|
trade_count: int
|
||||||
risk_count: int
|
risk_count: int
|
||||||
risk_breakdown: Dict[str, int]
|
risk_breakdown: Dict[str, int]
|
||||||
@ -97,6 +99,8 @@ class DecisionEnv:
|
|||||||
if self._disable_departments:
|
if self._disable_departments:
|
||||||
engine.department_manager = None
|
engine.department_manager = None
|
||||||
|
|
||||||
|
self._clear_portfolio_records()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = engine.run()
|
result = engine.run()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
@ -104,7 +108,12 @@ class DecisionEnv:
|
|||||||
info = {"error": str(exc)}
|
info = {"error": str(exc)}
|
||||||
return {"failure": 1.0}, -1.0, True, info
|
return {"failure": 1.0}, -1.0, True, info
|
||||||
|
|
||||||
metrics = self._compute_metrics(result)
|
snapshots, trades_override = self._fetch_portfolio_records()
|
||||||
|
metrics = self._compute_metrics(
|
||||||
|
result,
|
||||||
|
nav_override=snapshots if snapshots else None,
|
||||||
|
trades_override=trades_override if trades_override else None,
|
||||||
|
)
|
||||||
reward = float(self._reward_fn(metrics))
|
reward = float(self._reward_fn(metrics))
|
||||||
self._last_metrics = metrics
|
self._last_metrics = metrics
|
||||||
|
|
||||||
@ -114,6 +123,7 @@ class DecisionEnv:
|
|||||||
"volatility": metrics.volatility,
|
"volatility": metrics.volatility,
|
||||||
"sharpe_like": metrics.sharpe_like,
|
"sharpe_like": metrics.sharpe_like,
|
||||||
"turnover": metrics.turnover,
|
"turnover": metrics.turnover,
|
||||||
|
"turnover_value": metrics.turnover_value,
|
||||||
"trade_count": float(metrics.trade_count),
|
"trade_count": float(metrics.trade_count),
|
||||||
"risk_count": float(metrics.risk_count),
|
"risk_count": float(metrics.risk_count),
|
||||||
}
|
}
|
||||||
@ -123,6 +133,8 @@ class DecisionEnv:
|
|||||||
"weights": weights,
|
"weights": weights,
|
||||||
"risk_breakdown": metrics.risk_breakdown,
|
"risk_breakdown": metrics.risk_breakdown,
|
||||||
"risk_events": getattr(result, "risk_events", []),
|
"risk_events": getattr(result, "risk_events", []),
|
||||||
|
"portfolio_snapshots": snapshots,
|
||||||
|
"portfolio_trades": trades_override,
|
||||||
}
|
}
|
||||||
return observation, reward, True, info
|
return observation, reward, True, info
|
||||||
|
|
||||||
@ -137,8 +149,16 @@ class DecisionEnv:
|
|||||||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
def _compute_metrics(
|
||||||
nav_series = result.nav_series or []
|
self,
|
||||||
|
result: BacktestResult,
|
||||||
|
*,
|
||||||
|
nav_override: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
trades_override: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> EpisodeMetrics:
|
||||||
|
nav_series = nav_override if nav_override is not None else result.nav_series or []
|
||||||
|
trades = trades_override if trades_override is not None else result.trades
|
||||||
|
|
||||||
if not nav_series:
|
if not nav_series:
|
||||||
risk_breakdown: Dict[str, int] = {}
|
risk_breakdown: Dict[str, int] = {}
|
||||||
for event in getattr(result, "risk_events", []) or []:
|
for event in getattr(result, "risk_events", []) or []:
|
||||||
@ -149,9 +169,10 @@ class DecisionEnv:
|
|||||||
max_drawdown=0.0,
|
max_drawdown=0.0,
|
||||||
volatility=0.0,
|
volatility=0.0,
|
||||||
nav_series=[],
|
nav_series=[],
|
||||||
trades=result.trades,
|
trades=trades or [],
|
||||||
turnover=0.0,
|
turnover=0.0,
|
||||||
trade_count=len(result.trades or []),
|
turnover_value=0.0,
|
||||||
|
trade_count=len(trades or []),
|
||||||
risk_count=len(getattr(result, "risk_events", []) or []),
|
risk_count=len(getattr(result, "risk_events", []) or []),
|
||||||
risk_breakdown=risk_breakdown,
|
risk_breakdown=risk_breakdown,
|
||||||
)
|
)
|
||||||
@ -181,7 +202,25 @@ class DecisionEnv:
|
|||||||
else:
|
else:
|
||||||
volatility = 0.0
|
volatility = 0.0
|
||||||
|
|
||||||
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
|
turnover_value = 0.0
|
||||||
|
turnover_ratios: List[float] = []
|
||||||
|
for row in nav_series:
|
||||||
|
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
|
||||||
|
turnover_value += turnover_raw
|
||||||
|
ratio = row.get("turnover_ratio")
|
||||||
|
if ratio is not None:
|
||||||
|
try:
|
||||||
|
turnover_ratios.append(float(ratio))
|
||||||
|
continue
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
turnover_ratios.append(0.0)
|
||||||
|
continue
|
||||||
|
nav_val = float(row.get("nav", 0.0) or 0.0)
|
||||||
|
if nav_val > 0:
|
||||||
|
turnover_ratios.append(turnover_raw / nav_val)
|
||||||
|
else:
|
||||||
|
turnover_ratios.append(0.0)
|
||||||
|
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
|
||||||
risk_events = getattr(result, "risk_events", []) or []
|
risk_events = getattr(result, "risk_events", []) or []
|
||||||
risk_breakdown: Dict[str, int] = {}
|
risk_breakdown: Dict[str, int] = {}
|
||||||
for event in risk_events:
|
for event in risk_events:
|
||||||
@ -193,9 +232,10 @@ class DecisionEnv:
|
|||||||
max_drawdown=float(max_drawdown),
|
max_drawdown=float(max_drawdown),
|
||||||
volatility=volatility,
|
volatility=volatility,
|
||||||
nav_series=nav_series,
|
nav_series=nav_series,
|
||||||
trades=result.trades,
|
trades=trades or [],
|
||||||
turnover=float(turnover),
|
turnover=float(avg_turnover_ratio),
|
||||||
trade_count=len(result.trades or []),
|
turnover_value=float(turnover_value),
|
||||||
|
trade_count=len(trades or []),
|
||||||
risk_count=len(risk_events),
|
risk_count=len(risk_events),
|
||||||
risk_breakdown=risk_breakdown,
|
risk_breakdown=risk_breakdown,
|
||||||
)
|
)
|
||||||
@ -203,7 +243,7 @@ class DecisionEnv:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||||
risk_penalty = 0.05 * metrics.risk_count
|
risk_penalty = 0.05 * metrics.risk_count
|
||||||
turnover_penalty = 0.00001 * metrics.turnover
|
turnover_penalty = 0.1 * metrics.turnover
|
||||||
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||||
return metrics.total_return - penalty
|
return metrics.total_return - penalty
|
||||||
|
|
||||||
@ -214,3 +254,100 @@ class DecisionEnv:
|
|||||||
@property
|
@property
|
||||||
def last_action(self) -> Optional[Tuple[float, ...]]:
|
def last_action(self) -> Optional[Tuple[float, ...]]:
|
||||||
return self._last_action
|
return self._last_action
|
||||||
|
|
||||||
|
def _clear_portfolio_records(self) -> None:
|
||||||
|
start = self._template_cfg.start_date.isoformat()
|
||||||
|
end = self._template_cfg.end_date.isoformat()
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.execute("DELETE FROM portfolio_positions")
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
|
||||||
|
(start, end),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
|
||||||
|
(start, end),
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
start = self._template_cfg.start_date.isoformat()
|
||||||
|
end = self._template_cfg.end_date.isoformat()
|
||||||
|
snapshots: List[Dict[str, Any]] = []
|
||||||
|
trades: List[Dict[str, Any]] = []
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
snapshot_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT trade_date, total_value, cash, invested_value,
|
||||||
|
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
|
||||||
|
FROM portfolio_snapshots
|
||||||
|
WHERE trade_date BETWEEN ? AND ?
|
||||||
|
ORDER BY trade_date
|
||||||
|
""",
|
||||||
|
(start, end),
|
||||||
|
).fetchall()
|
||||||
|
trade_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
|
||||||
|
FROM portfolio_trades
|
||||||
|
WHERE trade_date BETWEEN ? AND ?
|
||||||
|
ORDER BY trade_date, id
|
||||||
|
""",
|
||||||
|
(start, end),
|
||||||
|
).fetchall()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
|
||||||
|
return snapshots, trades
|
||||||
|
|
||||||
|
for row in snapshot_rows:
|
||||||
|
metadata = self._loads(row["metadata"], {})
|
||||||
|
snapshots.append(
|
||||||
|
{
|
||||||
|
"trade_date": row["trade_date"],
|
||||||
|
"nav": float(row["total_value"] or 0.0),
|
||||||
|
"cash": float(row["cash"] or 0.0),
|
||||||
|
"market_value": float(row["invested_value"] or 0.0),
|
||||||
|
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
|
||||||
|
"realized_pnl": float(row["realized_pnl"] or 0.0),
|
||||||
|
"net_flow": float(row["net_flow"] or 0.0),
|
||||||
|
"exposure": float(row["exposure"] or 0.0),
|
||||||
|
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
|
||||||
|
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
|
||||||
|
"holdings": metadata.get("holdings"),
|
||||||
|
"trade_count": metadata.get("trade_count"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in trade_rows:
|
||||||
|
metadata = self._loads(row["metadata"], {})
|
||||||
|
trades.append(
|
||||||
|
{
|
||||||
|
"id": row["id"],
|
||||||
|
"trade_date": row["trade_date"],
|
||||||
|
"ts_code": row["ts_code"],
|
||||||
|
"action": row["action"],
|
||||||
|
"quantity": float(row["quantity"] or 0.0),
|
||||||
|
"price": float(row["price"] or 0.0),
|
||||||
|
"fee": float(row["fee"] or 0.0),
|
||||||
|
"source": row["source"],
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return snapshots, trades
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _loads(payload: Any, default: Any) -> Any:
|
||||||
|
if not payload:
|
||||||
|
return default
|
||||||
|
if isinstance(payload, (dict, list)):
|
||||||
|
return payload
|
||||||
|
if isinstance(payload, str):
|
||||||
|
try:
|
||||||
|
return json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return default
|
||||||
|
return default
|
||||||
|
|||||||
@ -704,6 +704,7 @@ class BacktestEngine:
|
|||||||
unrealized_pnl += (price - cost_basis) * qty
|
unrealized_pnl += (price - cost_basis) * qty
|
||||||
|
|
||||||
nav = state.cash + market_value
|
nav = state.cash + market_value
|
||||||
|
turnover_ratio = daily_turnover / nav if nav else 0.0
|
||||||
result.nav_series.append(
|
result.nav_series.append(
|
||||||
{
|
{
|
||||||
"trade_date": trade_date_str,
|
"trade_date": trade_date_str,
|
||||||
@ -713,6 +714,7 @@ class BacktestEngine:
|
|||||||
"realized_pnl": state.realized_pnl,
|
"realized_pnl": state.realized_pnl,
|
||||||
"unrealized_pnl": unrealized_pnl,
|
"unrealized_pnl": unrealized_pnl,
|
||||||
"turnover": daily_turnover,
|
"turnover": daily_turnover,
|
||||||
|
"turnover_ratio": turnover_ratio,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if executed_trades:
|
if executed_trades:
|
||||||
@ -817,11 +819,26 @@ class BacktestEngine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
total_value = market_value + state.cash
|
||||||
|
turnover_ratio = daily_turnover / total_value if total_value else 0.0
|
||||||
snapshot_metadata = {
|
snapshot_metadata = {
|
||||||
"holdings": len(state.holdings),
|
"holdings": len(state.holdings),
|
||||||
"turnover_value": daily_turnover,
|
"turnover_value": daily_turnover,
|
||||||
|
"turnover_ratio": turnover_ratio,
|
||||||
|
"trade_count": len(trades),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
exposure = (market_value / total_value) if total_value else 0.0
|
||||||
|
net_flow = 0.0
|
||||||
|
for trade in trades:
|
||||||
|
value = float(trade.get("value", 0.0) or 0.0)
|
||||||
|
fee = float(trade.get("fee", 0.0) or 0.0)
|
||||||
|
action = str(trade.get("action", "")).lower()
|
||||||
|
if action.startswith("buy"):
|
||||||
|
net_flow -= value + fee
|
||||||
|
elif action.startswith("sell"):
|
||||||
|
net_flow += value - fee
|
||||||
|
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
@ -836,8 +853,8 @@ class BacktestEngine:
|
|||||||
market_value,
|
market_value,
|
||||||
unrealized_pnl,
|
unrealized_pnl,
|
||||||
state.realized_pnl,
|
state.realized_pnl,
|
||||||
None,
|
net_flow,
|
||||||
None,
|
exposure,
|
||||||
None,
|
None,
|
||||||
json.dumps(snapshot_metadata, ensure_ascii=False),
|
json.dumps(snapshot_metadata, ensure_ascii=False),
|
||||||
),
|
),
|
||||||
|
|||||||
@ -131,6 +131,7 @@ def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int
|
|||||||
"volatility": metrics.volatility,
|
"volatility": metrics.volatility,
|
||||||
"sharpe_like": metrics.sharpe_like,
|
"sharpe_like": metrics.sharpe_like,
|
||||||
"turnover": metrics.turnover,
|
"turnover": metrics.turnover,
|
||||||
|
"turnover_value": metrics.turnover_value,
|
||||||
"trade_count": float(metrics.trade_count),
|
"trade_count": float(metrics.trade_count),
|
||||||
"risk_count": float(metrics.risk_count),
|
"risk_count": float(metrics.risk_count),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1563,6 +1563,9 @@ def render_log_viewer() -> None:
|
|||||||
"weights": info.get("weights", {}),
|
"weights": info.get("weights", {}),
|
||||||
"nav_series": info.get("nav_series"),
|
"nav_series": info.get("nav_series"),
|
||||||
"trades": info.get("trades"),
|
"trades": info.get("trades"),
|
||||||
|
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||||
|
"portfolio_trades": info.get("portfolio_trades"),
|
||||||
|
"risk_breakdown": info.get("risk_breakdown"),
|
||||||
"selected_agents": list(selected_agents),
|
"selected_agents": list(selected_agents),
|
||||||
"action_values": list(action_values),
|
"action_values": list(action_values),
|
||||||
"experiment_id": resolved_experiment_id,
|
"experiment_id": resolved_experiment_id,
|
||||||
@ -1587,6 +1590,14 @@ def render_log_viewer() -> None:
|
|||||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||||
|
|
||||||
|
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
||||||
|
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
||||||
|
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
||||||
|
col_metrics_extra = st.columns(3)
|
||||||
|
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
||||||
|
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
||||||
|
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
||||||
|
|
||||||
weights_dict = single_result.get("weights") or {}
|
weights_dict = single_result.get("weights") or {}
|
||||||
if weights_dict:
|
if weights_dict:
|
||||||
st.write("调参后权重:")
|
st.write("调参后权重:")
|
||||||
@ -1620,6 +1631,21 @@ def render_log_viewer() -> None:
|
|||||||
st.write("成交记录:")
|
st.write("成交记录:")
|
||||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||||
|
|
||||||
|
snapshots = single_result.get("portfolio_snapshots") or []
|
||||||
|
if snapshots:
|
||||||
|
with st.expander("投资组合快照", expanded=False):
|
||||||
|
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
||||||
|
|
||||||
|
portfolio_trades = single_result.get("portfolio_trades") or []
|
||||||
|
if portfolio_trades:
|
||||||
|
with st.expander("组合成交明细", expanded=False):
|
||||||
|
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
||||||
|
|
||||||
|
risk_breakdown = single_result.get("risk_breakdown") or {}
|
||||||
|
if risk_breakdown:
|
||||||
|
with st.expander("风险事件统计", expanded=False):
|
||||||
|
st.json(risk_breakdown)
|
||||||
|
|
||||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
st.success("已清除单次调参结果缓存。")
|
st.success("已清除单次调参结果缓存。")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""Initialize portfolio database tables."""
|
"""Initialize portfolio database tables."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@ -78,7 +77,7 @@ SCHEMA_STATEMENTS = [
|
|||||||
tags TEXT, -- JSON array
|
tags TEXT, -- JSON array
|
||||||
metadata TEXT, -- JSON object
|
metadata TEXT, -- JSON object
|
||||||
PRIMARY KEY (trade_date, ts_code)
|
PRIMARY KEY (trade_date, ts_code)
|
||||||
)
|
);
|
||||||
""",
|
""",
|
||||||
|
|
||||||
# 数据获取任务表
|
# 数据获取任务表
|
||||||
@ -91,7 +90,6 @@ SCHEMA_STATEMENTS = [
|
|||||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
error_msg TEXT,
|
error_msg TEXT,
|
||||||
metadata TEXT -- JSON object for additional info
|
metadata TEXT -- JSON object for additional info
|
||||||
)
|
|
||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,8 @@ class DummyEnv:
|
|||||||
volatility=0.05,
|
volatility=0.05,
|
||||||
nav_series=[],
|
nav_series=[],
|
||||||
trades=[],
|
trades=[],
|
||||||
turnover=100.0,
|
turnover=0.1,
|
||||||
|
turnover_value=1000.0,
|
||||||
trade_count=0,
|
trade_count=0,
|
||||||
risk_count=1,
|
risk_count=1,
|
||||||
risk_breakdown={"test": 1},
|
risk_breakdown={"test": 1},
|
||||||
@ -48,7 +49,8 @@ class DummyEnv:
|
|||||||
"max_drawdown": 0.1,
|
"max_drawdown": 0.1,
|
||||||
"volatility": 0.05,
|
"volatility": 0.05,
|
||||||
"sharpe_like": reward / 0.05,
|
"sharpe_like": reward / 0.05,
|
||||||
"turnover": 100.0,
|
"turnover": 0.1,
|
||||||
|
"turnover_value": 1000.0,
|
||||||
"trade_count": 0.0,
|
"trade_count": 0.0,
|
||||||
"risk_count": 1.0,
|
"risk_count": 1.0,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class _StubEngine:
|
|||||||
"realized_pnl": 1.0,
|
"realized_pnl": 1.0,
|
||||||
"unrealized_pnl": 1.0,
|
"unrealized_pnl": 1.0,
|
||||||
"turnover": 20000.0,
|
"turnover": 20000.0,
|
||||||
|
"turnover_ratio": 0.2,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
result.trades = [
|
result.trades = [
|
||||||
@ -65,14 +66,18 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
|
|||||||
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||||
|
|
||||||
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||||
|
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||||
|
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||||
|
|
||||||
obs, reward, done, info = env.step([0.8])
|
obs, reward, done, info = env.step([0.8])
|
||||||
|
|
||||||
assert done is True
|
assert done is True
|
||||||
assert "risk_count" in obs and obs["risk_count"] == 1.0
|
assert "risk_count" in obs and obs["risk_count"] == 1.0
|
||||||
assert obs["turnover"] == pytest.approx(20000.0)
|
assert obs["turnover"] == pytest.approx(0.2)
|
||||||
|
assert obs["turnover_value"] == pytest.approx(20000.0)
|
||||||
assert info["risk_events"][0]["reason"] == "limit_up"
|
assert info["risk_events"][0]["reason"] == "limit_up"
|
||||||
assert info["risk_breakdown"]["limit_up"] == 1
|
assert info["risk_breakdown"]["limit_up"] == 1
|
||||||
|
assert info["nav_series"][0]["turnover_ratio"] == pytest.approx(0.2)
|
||||||
assert reward < obs["total_return"]
|
assert reward < obs["total_return"]
|
||||||
|
|
||||||
|
|
||||||
@ -83,10 +88,11 @@ def test_default_reward_penalizes_metrics():
|
|||||||
volatility=0.05,
|
volatility=0.05,
|
||||||
nav_series=[],
|
nav_series=[],
|
||||||
trades=[],
|
trades=[],
|
||||||
turnover=1000.0,
|
turnover=0.3,
|
||||||
|
turnover_value=5000.0,
|
||||||
trade_count=0,
|
trade_count=0,
|
||||||
risk_count=2,
|
risk_count=2,
|
||||||
risk_breakdown={"foo": 2},
|
risk_breakdown={"foo": 2},
|
||||||
)
|
)
|
||||||
reward = DecisionEnv._default_reward(metrics)
|
reward = DecisionEnv._default_reward(metrics)
|
||||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))
|
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user