update
This commit is contained in:
parent
b4bd9fc9c5
commit
f29bb99b68
@ -1,14 +1,15 @@
|
||||
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from .engine import BacktestEngine, BacktestResult, BtConfig
|
||||
from app.agents.game import Decision
|
||||
from app.agents.registry import weight_map
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -37,6 +38,7 @@ class EpisodeMetrics:
|
||||
nav_series: List[Dict[str, float]]
|
||||
trades: List[Dict[str, object]]
|
||||
turnover: float
|
||||
turnover_value: float
|
||||
trade_count: int
|
||||
risk_count: int
|
||||
risk_breakdown: Dict[str, int]
|
||||
@ -97,6 +99,8 @@ class DecisionEnv:
|
||||
if self._disable_departments:
|
||||
engine.department_manager = None
|
||||
|
||||
self._clear_portfolio_records()
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@ -104,7 +108,12 @@ class DecisionEnv:
|
||||
info = {"error": str(exc)}
|
||||
return {"failure": 1.0}, -1.0, True, info
|
||||
|
||||
metrics = self._compute_metrics(result)
|
||||
snapshots, trades_override = self._fetch_portfolio_records()
|
||||
metrics = self._compute_metrics(
|
||||
result,
|
||||
nav_override=snapshots if snapshots else None,
|
||||
trades_override=trades_override if trades_override else None,
|
||||
)
|
||||
reward = float(self._reward_fn(metrics))
|
||||
self._last_metrics = metrics
|
||||
|
||||
@ -114,6 +123,7 @@ class DecisionEnv:
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"turnover_value": metrics.turnover_value,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
}
|
||||
@ -123,6 +133,8 @@ class DecisionEnv:
|
||||
"weights": weights,
|
||||
"risk_breakdown": metrics.risk_breakdown,
|
||||
"risk_events": getattr(result, "risk_events", []),
|
||||
"portfolio_snapshots": snapshots,
|
||||
"portfolio_trades": trades_override,
|
||||
}
|
||||
return observation, reward, True, info
|
||||
|
||||
@ -137,8 +149,16 @@ class DecisionEnv:
|
||||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||||
return weights
|
||||
|
||||
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
||||
nav_series = result.nav_series or []
|
||||
def _compute_metrics(
|
||||
self,
|
||||
result: BacktestResult,
|
||||
*,
|
||||
nav_override: Optional[List[Dict[str, Any]]] = None,
|
||||
trades_override: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> EpisodeMetrics:
|
||||
nav_series = nav_override if nav_override is not None else result.nav_series or []
|
||||
trades = trades_override if trades_override is not None else result.trades
|
||||
|
||||
if not nav_series:
|
||||
risk_breakdown: Dict[str, int] = {}
|
||||
for event in getattr(result, "risk_events", []) or []:
|
||||
@ -149,9 +169,10 @@ class DecisionEnv:
|
||||
max_drawdown=0.0,
|
||||
volatility=0.0,
|
||||
nav_series=[],
|
||||
trades=result.trades,
|
||||
trades=trades or [],
|
||||
turnover=0.0,
|
||||
trade_count=len(result.trades or []),
|
||||
turnover_value=0.0,
|
||||
trade_count=len(trades or []),
|
||||
risk_count=len(getattr(result, "risk_events", []) or []),
|
||||
risk_breakdown=risk_breakdown,
|
||||
)
|
||||
@ -181,7 +202,25 @@ class DecisionEnv:
|
||||
else:
|
||||
volatility = 0.0
|
||||
|
||||
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
|
||||
turnover_value = 0.0
|
||||
turnover_ratios: List[float] = []
|
||||
for row in nav_series:
|
||||
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
|
||||
turnover_value += turnover_raw
|
||||
ratio = row.get("turnover_ratio")
|
||||
if ratio is not None:
|
||||
try:
|
||||
turnover_ratios.append(float(ratio))
|
||||
continue
|
||||
except (TypeError, ValueError):
|
||||
turnover_ratios.append(0.0)
|
||||
continue
|
||||
nav_val = float(row.get("nav", 0.0) or 0.0)
|
||||
if nav_val > 0:
|
||||
turnover_ratios.append(turnover_raw / nav_val)
|
||||
else:
|
||||
turnover_ratios.append(0.0)
|
||||
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
|
||||
risk_events = getattr(result, "risk_events", []) or []
|
||||
risk_breakdown: Dict[str, int] = {}
|
||||
for event in risk_events:
|
||||
@ -193,9 +232,10 @@ class DecisionEnv:
|
||||
max_drawdown=float(max_drawdown),
|
||||
volatility=volatility,
|
||||
nav_series=nav_series,
|
||||
trades=result.trades,
|
||||
turnover=float(turnover),
|
||||
trade_count=len(result.trades or []),
|
||||
trades=trades or [],
|
||||
turnover=float(avg_turnover_ratio),
|
||||
turnover_value=float(turnover_value),
|
||||
trade_count=len(trades or []),
|
||||
risk_count=len(risk_events),
|
||||
risk_breakdown=risk_breakdown,
|
||||
)
|
||||
@ -203,7 +243,7 @@ class DecisionEnv:
|
||||
@staticmethod
|
||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||
risk_penalty = 0.05 * metrics.risk_count
|
||||
turnover_penalty = 0.00001 * metrics.turnover
|
||||
turnover_penalty = 0.1 * metrics.turnover
|
||||
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||
return metrics.total_return - penalty
|
||||
|
||||
@ -214,3 +254,100 @@ class DecisionEnv:
|
||||
@property
|
||||
def last_action(self) -> Optional[Tuple[float, ...]]:
|
||||
return self._last_action
|
||||
|
||||
def _clear_portfolio_records(self) -> None:
|
||||
start = self._template_cfg.start_date.isoformat()
|
||||
end = self._template_cfg.end_date.isoformat()
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute("DELETE FROM portfolio_positions")
|
||||
conn.execute(
|
||||
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
|
||||
(start, end),
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
|
||||
(start, end),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
|
||||
|
||||
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
start = self._template_cfg.start_date.isoformat()
|
||||
end = self._template_cfg.end_date.isoformat()
|
||||
snapshots: List[Dict[str, Any]] = []
|
||||
trades: List[Dict[str, Any]] = []
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
snapshot_rows = conn.execute(
|
||||
"""
|
||||
SELECT trade_date, total_value, cash, invested_value,
|
||||
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
|
||||
FROM portfolio_snapshots
|
||||
WHERE trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date
|
||||
""",
|
||||
(start, end),
|
||||
).fetchall()
|
||||
trade_rows = conn.execute(
|
||||
"""
|
||||
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
|
||||
FROM portfolio_trades
|
||||
WHERE trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date, id
|
||||
""",
|
||||
(start, end),
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
|
||||
return snapshots, trades
|
||||
|
||||
for row in snapshot_rows:
|
||||
metadata = self._loads(row["metadata"], {})
|
||||
snapshots.append(
|
||||
{
|
||||
"trade_date": row["trade_date"],
|
||||
"nav": float(row["total_value"] or 0.0),
|
||||
"cash": float(row["cash"] or 0.0),
|
||||
"market_value": float(row["invested_value"] or 0.0),
|
||||
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
|
||||
"realized_pnl": float(row["realized_pnl"] or 0.0),
|
||||
"net_flow": float(row["net_flow"] or 0.0),
|
||||
"exposure": float(row["exposure"] or 0.0),
|
||||
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
|
||||
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
|
||||
"holdings": metadata.get("holdings"),
|
||||
"trade_count": metadata.get("trade_count"),
|
||||
}
|
||||
)
|
||||
|
||||
for row in trade_rows:
|
||||
metadata = self._loads(row["metadata"], {})
|
||||
trades.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"trade_date": row["trade_date"],
|
||||
"ts_code": row["ts_code"],
|
||||
"action": row["action"],
|
||||
"quantity": float(row["quantity"] or 0.0),
|
||||
"price": float(row["price"] or 0.0),
|
||||
"fee": float(row["fee"] or 0.0),
|
||||
"source": row["source"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
return snapshots, trades
|
||||
|
||||
@staticmethod
|
||||
def _loads(payload: Any, default: Any) -> Any:
|
||||
if not payload:
|
||||
return default
|
||||
if isinstance(payload, (dict, list)):
|
||||
return payload
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
return default
|
||||
return default
|
||||
|
||||
@ -704,6 +704,7 @@ class BacktestEngine:
|
||||
unrealized_pnl += (price - cost_basis) * qty
|
||||
|
||||
nav = state.cash + market_value
|
||||
turnover_ratio = daily_turnover / nav if nav else 0.0
|
||||
result.nav_series.append(
|
||||
{
|
||||
"trade_date": trade_date_str,
|
||||
@ -713,6 +714,7 @@ class BacktestEngine:
|
||||
"realized_pnl": state.realized_pnl,
|
||||
"unrealized_pnl": unrealized_pnl,
|
||||
"turnover": daily_turnover,
|
||||
"turnover_ratio": turnover_ratio,
|
||||
}
|
||||
)
|
||||
if executed_trades:
|
||||
@ -817,11 +819,26 @@ class BacktestEngine:
|
||||
)
|
||||
)
|
||||
|
||||
total_value = market_value + state.cash
|
||||
turnover_ratio = daily_turnover / total_value if total_value else 0.0
|
||||
snapshot_metadata = {
|
||||
"holdings": len(state.holdings),
|
||||
"turnover_value": daily_turnover,
|
||||
"turnover_ratio": turnover_ratio,
|
||||
"trade_count": len(trades),
|
||||
}
|
||||
|
||||
exposure = (market_value / total_value) if total_value else 0.0
|
||||
net_flow = 0.0
|
||||
for trade in trades:
|
||||
value = float(trade.get("value", 0.0) or 0.0)
|
||||
fee = float(trade.get("fee", 0.0) or 0.0)
|
||||
action = str(trade.get("action", "")).lower()
|
||||
if action.startswith("buy"):
|
||||
net_flow -= value + fee
|
||||
elif action.startswith("sell"):
|
||||
net_flow += value - fee
|
||||
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
@ -836,8 +853,8 @@ class BacktestEngine:
|
||||
market_value,
|
||||
unrealized_pnl,
|
||||
state.realized_pnl,
|
||||
None,
|
||||
None,
|
||||
net_flow,
|
||||
exposure,
|
||||
None,
|
||||
json.dumps(snapshot_metadata, ensure_ascii=False),
|
||||
),
|
||||
|
||||
@ -131,6 +131,7 @@ def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"turnover_value": metrics.turnover_value,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
}
|
||||
|
||||
@ -1563,6 +1563,9 @@ def render_log_viewer() -> None:
|
||||
"weights": info.get("weights", {}),
|
||||
"nav_series": info.get("nav_series"),
|
||||
"trades": info.get("trades"),
|
||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||
"portfolio_trades": info.get("portfolio_trades"),
|
||||
"risk_breakdown": info.get("risk_breakdown"),
|
||||
"selected_agents": list(selected_agents),
|
||||
"action_values": list(action_values),
|
||||
"experiment_id": resolved_experiment_id,
|
||||
@ -1587,6 +1590,14 @@ def render_log_viewer() -> None:
|
||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||
|
||||
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
||||
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
||||
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
||||
col_metrics_extra = st.columns(3)
|
||||
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
||||
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
||||
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
||||
|
||||
weights_dict = single_result.get("weights") or {}
|
||||
if weights_dict:
|
||||
st.write("调参后权重:")
|
||||
@ -1620,6 +1631,21 @@ def render_log_viewer() -> None:
|
||||
st.write("成交记录:")
|
||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||
|
||||
snapshots = single_result.get("portfolio_snapshots") or []
|
||||
if snapshots:
|
||||
with st.expander("投资组合快照", expanded=False):
|
||||
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
||||
|
||||
portfolio_trades = single_result.get("portfolio_trades") or []
|
||||
if portfolio_trades:
|
||||
with st.expander("组合成交明细", expanded=False):
|
||||
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
||||
|
||||
risk_breakdown = single_result.get("risk_breakdown") or {}
|
||||
if risk_breakdown:
|
||||
with st.expander("风险事件统计", expanded=False):
|
||||
st.json(risk_breakdown)
|
||||
|
||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
st.success("已清除单次调参结果缓存。")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Initialize portfolio database tables."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .logging import get_logger
|
||||
@ -78,7 +77,7 @@ SCHEMA_STATEMENTS = [
|
||||
tags TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
)
|
||||
);
|
||||
""",
|
||||
|
||||
# 数据获取任务表
|
||||
@ -91,7 +90,6 @@ SCHEMA_STATEMENTS = [
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
error_msg TEXT,
|
||||
metadata TEXT -- JSON object for additional info
|
||||
)
|
||||
);
|
||||
""",
|
||||
|
||||
|
||||
@ -37,7 +37,8 @@ class DummyEnv:
|
||||
volatility=0.05,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=100.0,
|
||||
turnover=0.1,
|
||||
turnover_value=1000.0,
|
||||
trade_count=0,
|
||||
risk_count=1,
|
||||
risk_breakdown={"test": 1},
|
||||
@ -48,7 +49,8 @@ class DummyEnv:
|
||||
"max_drawdown": 0.1,
|
||||
"volatility": 0.05,
|
||||
"sharpe_like": reward / 0.05,
|
||||
"turnover": 100.0,
|
||||
"turnover": 0.1,
|
||||
"turnover_value": 1000.0,
|
||||
"trade_count": 0.0,
|
||||
"risk_count": 1.0,
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ class _StubEngine:
|
||||
"realized_pnl": 1.0,
|
||||
"unrealized_pnl": 1.0,
|
||||
"turnover": 20000.0,
|
||||
"turnover_ratio": 0.2,
|
||||
}
|
||||
]
|
||||
result.trades = [
|
||||
@ -65,14 +66,18 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||
|
||||
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||
|
||||
obs, reward, done, info = env.step([0.8])
|
||||
|
||||
assert done is True
|
||||
assert "risk_count" in obs and obs["risk_count"] == 1.0
|
||||
assert obs["turnover"] == pytest.approx(20000.0)
|
||||
assert obs["turnover"] == pytest.approx(0.2)
|
||||
assert obs["turnover_value"] == pytest.approx(20000.0)
|
||||
assert info["risk_events"][0]["reason"] == "limit_up"
|
||||
assert info["risk_breakdown"]["limit_up"] == 1
|
||||
assert info["nav_series"][0]["turnover_ratio"] == pytest.approx(0.2)
|
||||
assert reward < obs["total_return"]
|
||||
|
||||
|
||||
@ -83,10 +88,11 @@ def test_default_reward_penalizes_metrics():
|
||||
volatility=0.05,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=1000.0,
|
||||
turnover=0.3,
|
||||
turnover_value=5000.0,
|
||||
trade_count=0,
|
||||
risk_count=2,
|
||||
risk_breakdown={"foo": 2},
|
||||
)
|
||||
reward = DecisionEnv._default_reward(metrics)
|
||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))
|
||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user