This commit is contained in:
sam 2025-10-05 18:19:28 +08:00
parent b4bd9fc9c5
commit f29bb99b68
7 changed files with 210 additions and 23 deletions

View File

@ -1,14 +1,15 @@
"""Reinforcement-learning style environment wrapping the backtest engine."""
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
import json
import math
from dataclasses import dataclass, replace
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .engine import BacktestEngine, BacktestResult, BtConfig
from app.agents.game import Decision
from app.agents.registry import weight_map
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
@ -37,6 +38,7 @@ class EpisodeMetrics:
nav_series: List[Dict[str, float]]
trades: List[Dict[str, object]]
turnover: float
turnover_value: float
trade_count: int
risk_count: int
risk_breakdown: Dict[str, int]
@ -97,6 +99,8 @@ class DecisionEnv:
if self._disable_departments:
engine.department_manager = None
self._clear_portfolio_records()
try:
result = engine.run()
except Exception as exc: # noqa: BLE001
@ -104,7 +108,12 @@ class DecisionEnv:
info = {"error": str(exc)}
return {"failure": 1.0}, -1.0, True, info
metrics = self._compute_metrics(result)
snapshots, trades_override = self._fetch_portfolio_records()
metrics = self._compute_metrics(
result,
nav_override=snapshots if snapshots else None,
trades_override=trades_override if trades_override else None,
)
reward = float(self._reward_fn(metrics))
self._last_metrics = metrics
@ -114,6 +123,7 @@ class DecisionEnv:
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"turnover_value": metrics.turnover_value,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
}
@ -123,6 +133,8 @@ class DecisionEnv:
"weights": weights,
"risk_breakdown": metrics.risk_breakdown,
"risk_events": getattr(result, "risk_events", []),
"portfolio_snapshots": snapshots,
"portfolio_trades": trades_override,
}
return observation, reward, True, info
@ -137,8 +149,16 @@ class DecisionEnv:
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
return weights
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
nav_series = result.nav_series or []
def _compute_metrics(
self,
result: BacktestResult,
*,
nav_override: Optional[List[Dict[str, Any]]] = None,
trades_override: Optional[List[Dict[str, Any]]] = None,
) -> EpisodeMetrics:
nav_series = nav_override if nav_override is not None else result.nav_series or []
trades = trades_override if trades_override is not None else result.trades
if not nav_series:
risk_breakdown: Dict[str, int] = {}
for event in getattr(result, "risk_events", []) or []:
@ -149,9 +169,10 @@ class DecisionEnv:
max_drawdown=0.0,
volatility=0.0,
nav_series=[],
trades=result.trades,
trades=trades or [],
turnover=0.0,
trade_count=len(result.trades or []),
turnover_value=0.0,
trade_count=len(trades or []),
risk_count=len(getattr(result, "risk_events", []) or []),
risk_breakdown=risk_breakdown,
)
@ -181,7 +202,25 @@ class DecisionEnv:
else:
volatility = 0.0
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
turnover_value = 0.0
turnover_ratios: List[float] = []
for row in nav_series:
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
turnover_value += turnover_raw
ratio = row.get("turnover_ratio")
if ratio is not None:
try:
turnover_ratios.append(float(ratio))
continue
except (TypeError, ValueError):
turnover_ratios.append(0.0)
continue
nav_val = float(row.get("nav", 0.0) or 0.0)
if nav_val > 0:
turnover_ratios.append(turnover_raw / nav_val)
else:
turnover_ratios.append(0.0)
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
risk_events = getattr(result, "risk_events", []) or []
risk_breakdown: Dict[str, int] = {}
for event in risk_events:
@ -193,9 +232,10 @@ class DecisionEnv:
max_drawdown=float(max_drawdown),
volatility=volatility,
nav_series=nav_series,
trades=result.trades,
turnover=float(turnover),
trade_count=len(result.trades or []),
trades=trades or [],
turnover=float(avg_turnover_ratio),
turnover_value=float(turnover_value),
trade_count=len(trades or []),
risk_count=len(risk_events),
risk_breakdown=risk_breakdown,
)
@ -203,7 +243,7 @@ class DecisionEnv:
@staticmethod
def _default_reward(metrics: EpisodeMetrics) -> float:
risk_penalty = 0.05 * metrics.risk_count
turnover_penalty = 0.00001 * metrics.turnover
turnover_penalty = 0.1 * metrics.turnover
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
return metrics.total_return - penalty
@ -214,3 +254,100 @@ class DecisionEnv:
@property
def last_action(self) -> Optional[Tuple[float, ...]]:
return self._last_action
def _clear_portfolio_records(self) -> None:
start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat()
try:
with db_session() as conn:
conn.execute("DELETE FROM portfolio_positions")
conn.execute(
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
(start, end),
)
conn.execute(
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
(start, end),
)
except Exception: # noqa: BLE001
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat()
snapshots: List[Dict[str, Any]] = []
trades: List[Dict[str, Any]] = []
try:
with db_session(read_only=True) as conn:
snapshot_rows = conn.execute(
"""
SELECT trade_date, total_value, cash, invested_value,
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
FROM portfolio_snapshots
WHERE trade_date BETWEEN ? AND ?
ORDER BY trade_date
""",
(start, end),
).fetchall()
trade_rows = conn.execute(
"""
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
FROM portfolio_trades
WHERE trade_date BETWEEN ? AND ?
ORDER BY trade_date, id
""",
(start, end),
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
return snapshots, trades
for row in snapshot_rows:
metadata = self._loads(row["metadata"], {})
snapshots.append(
{
"trade_date": row["trade_date"],
"nav": float(row["total_value"] or 0.0),
"cash": float(row["cash"] or 0.0),
"market_value": float(row["invested_value"] or 0.0),
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
"realized_pnl": float(row["realized_pnl"] or 0.0),
"net_flow": float(row["net_flow"] or 0.0),
"exposure": float(row["exposure"] or 0.0),
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
"holdings": metadata.get("holdings"),
"trade_count": metadata.get("trade_count"),
}
)
for row in trade_rows:
metadata = self._loads(row["metadata"], {})
trades.append(
{
"id": row["id"],
"trade_date": row["trade_date"],
"ts_code": row["ts_code"],
"action": row["action"],
"quantity": float(row["quantity"] or 0.0),
"price": float(row["price"] or 0.0),
"fee": float(row["fee"] or 0.0),
"source": row["source"],
"metadata": metadata,
}
)
return snapshots, trades
@staticmethod
def _loads(payload: Any, default: Any) -> Any:
if not payload:
return default
if isinstance(payload, (dict, list)):
return payload
if isinstance(payload, str):
try:
return json.loads(payload)
except json.JSONDecodeError:
return default
return default

View File

@ -704,6 +704,7 @@ class BacktestEngine:
unrealized_pnl += (price - cost_basis) * qty
nav = state.cash + market_value
turnover_ratio = daily_turnover / nav if nav else 0.0
result.nav_series.append(
{
"trade_date": trade_date_str,
@ -713,6 +714,7 @@ class BacktestEngine:
"realized_pnl": state.realized_pnl,
"unrealized_pnl": unrealized_pnl,
"turnover": daily_turnover,
"turnover_ratio": turnover_ratio,
}
)
if executed_trades:
@ -817,11 +819,26 @@ class BacktestEngine:
)
)
total_value = market_value + state.cash
turnover_ratio = daily_turnover / total_value if total_value else 0.0
snapshot_metadata = {
"holdings": len(state.holdings),
"turnover_value": daily_turnover,
"turnover_ratio": turnover_ratio,
"trade_count": len(trades),
}
exposure = (market_value / total_value) if total_value else 0.0
net_flow = 0.0
for trade in trades:
value = float(trade.get("value", 0.0) or 0.0)
fee = float(trade.get("fee", 0.0) or 0.0)
action = str(trade.get("action", "")).lower()
if action.startswith("buy"):
net_flow -= value + fee
elif action.startswith("sell"):
net_flow += value - fee
with db_session() as conn:
conn.execute(
"""
@ -836,8 +853,8 @@ class BacktestEngine:
market_value,
unrealized_pnl,
state.realized_pnl,
None,
None,
net_flow,
exposure,
None,
json.dumps(snapshot_metadata, ensure_ascii=False),
),

View File

@ -131,6 +131,7 @@ def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"turnover_value": metrics.turnover_value,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
}

View File

@ -1563,6 +1563,9 @@ def render_log_viewer() -> None:
"weights": info.get("weights", {}),
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"portfolio_snapshots": info.get("portfolio_snapshots"),
"portfolio_trades": info.get("portfolio_trades"),
"risk_breakdown": info.get("risk_breakdown"),
"selected_agents": list(selected_agents),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
@ -1587,6 +1590,14 @@ def render_log_viewer() -> None:
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
col_metrics_extra = st.columns(3)
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
weights_dict = single_result.get("weights") or {}
if weights_dict:
st.write("调参后权重:")
@ -1620,6 +1631,21 @@ def render_log_viewer() -> None:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
snapshots = single_result.get("portfolio_snapshots") or []
if snapshots:
with st.expander("投资组合快照", expanded=False):
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
portfolio_trades = single_result.get("portfolio_trades") or []
if portfolio_trades:
with st.expander("组合成交明细", expanded=False):
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
risk_breakdown = single_result.get("risk_breakdown") or {}
if risk_breakdown:
with st.expander("风险事件统计", expanded=False):
st.json(risk_breakdown)
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")

View File

@ -1,7 +1,6 @@
"""Initialize portfolio database tables."""
from __future__ import annotations
import json
from typing import Any
from .logging import get_logger
@ -78,7 +77,7 @@ SCHEMA_STATEMENTS = [
tags TEXT, -- JSON array
metadata TEXT, -- JSON object
PRIMARY KEY (trade_date, ts_code)
)
);
""",
# 数据获取任务表
@ -91,7 +90,6 @@ SCHEMA_STATEMENTS = [
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
error_msg TEXT,
metadata TEXT -- JSON object for additional info
)
);
""",

View File

@ -37,7 +37,8 @@ class DummyEnv:
volatility=0.05,
nav_series=[],
trades=[],
turnover=100.0,
turnover=0.1,
turnover_value=1000.0,
trade_count=0,
risk_count=1,
risk_breakdown={"test": 1},
@ -48,7 +49,8 @@ class DummyEnv:
"max_drawdown": 0.1,
"volatility": 0.05,
"sharpe_like": reward / 0.05,
"turnover": 100.0,
"turnover": 0.1,
"turnover_value": 1000.0,
"trade_count": 0.0,
"risk_count": 1.0,
}

View File

@ -26,6 +26,7 @@ class _StubEngine:
"realized_pnl": 1.0,
"unrealized_pnl": 1.0,
"turnover": 20000.0,
"turnover_ratio": 0.2,
}
]
result.trades = [
@ -65,14 +66,18 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
obs, reward, done, info = env.step([0.8])
assert done is True
assert "risk_count" in obs and obs["risk_count"] == 1.0
assert obs["turnover"] == pytest.approx(20000.0)
assert obs["turnover"] == pytest.approx(0.2)
assert obs["turnover_value"] == pytest.approx(20000.0)
assert info["risk_events"][0]["reason"] == "limit_up"
assert info["risk_breakdown"]["limit_up"] == 1
assert info["nav_series"][0]["turnover_ratio"] == pytest.approx(0.2)
assert reward < obs["total_return"]
@ -83,10 +88,11 @@ def test_default_reward_penalizes_metrics():
volatility=0.05,
nav_series=[],
trades=[],
turnover=1000.0,
turnover=0.3,
turnover_value=5000.0,
trade_count=0,
risk_count=2,
risk_breakdown={"foo": 2},
)
reward = DecisionEnv._default_reward(metrics)
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))