From f29bb99b68192cea1cb7e5fe6e06dbcfcc4bc349 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 5 Oct 2025 18:19:28 +0800 Subject: [PATCH] update --- app/backtest/decision_env.py | 163 ++++++++++++++++++++++++++++++--- app/backtest/engine.py | 21 ++++- app/backtest/optimizer.py | 1 + app/ui/streamlit_app.py | 26 ++++++ app/utils/portfolio_init.py | 4 +- tests/test_bandit_optimizer.py | 6 +- tests/test_decision_env.py | 12 ++- 7 files changed, 210 insertions(+), 23 deletions(-) diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index 394838b..06c3a21 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -1,14 +1,15 @@ """Reinforcement-learning style environment wrapping the backtest engine.""" from __future__ import annotations -from dataclasses import dataclass, replace -from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple - +import json import math +from dataclasses import dataclass, replace +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple from .engine import BacktestEngine, BacktestResult, BtConfig from app.agents.game import Decision from app.agents.registry import weight_map +from app.utils.db import db_session from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -37,6 +38,7 @@ class EpisodeMetrics: nav_series: List[Dict[str, float]] trades: List[Dict[str, object]] turnover: float + turnover_value: float trade_count: int risk_count: int risk_breakdown: Dict[str, int] @@ -97,6 +99,8 @@ class DecisionEnv: if self._disable_departments: engine.department_manager = None + self._clear_portfolio_records() + try: result = engine.run() except Exception as exc: # noqa: BLE001 @@ -104,7 +108,12 @@ class DecisionEnv: info = {"error": str(exc)} return {"failure": 1.0}, -1.0, True, info - metrics = self._compute_metrics(result) + snapshots, trades_override = self._fetch_portfolio_records() + metrics = self._compute_metrics( + result, + nav_override=snapshots if snapshots else None, + trades_override=trades_override if trades_override else None, + ) reward = float(self._reward_fn(metrics)) self._last_metrics = metrics @@ -114,6 +123,7 @@ class DecisionEnv: "volatility": metrics.volatility, "sharpe_like": metrics.sharpe_like, "turnover": metrics.turnover, + "turnover_value": metrics.turnover_value, "trade_count": float(metrics.trade_count), "risk_count": float(metrics.risk_count), } @@ -123,6 +133,8 @@ class DecisionEnv: "weights": weights, "risk_breakdown": metrics.risk_breakdown, "risk_events": getattr(result, "risk_events", []), + "portfolio_snapshots": snapshots, + "portfolio_trades": trades_override, } return observation, reward, True, info @@ -137,8 +149,16 @@ class DecisionEnv: LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA) return weights - def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics: - nav_series = result.nav_series or [] + def _compute_metrics( + self, + result: BacktestResult, + *, + nav_override: Optional[List[Dict[str, Any]]] = None, + trades_override: Optional[List[Dict[str, Any]]] = None, + ) -> EpisodeMetrics: + nav_series = nav_override if nav_override is not None else result.nav_series or [] + trades = trades_override if trades_override is not None else result.trades + if not nav_series: risk_breakdown: Dict[str, int] = {} for event in getattr(result, "risk_events", []) or []: @@ -149,9 +169,10 @@ class DecisionEnv: max_drawdown=0.0, volatility=0.0, nav_series=[], - trades=result.trades, + trades=trades or [], turnover=0.0, - trade_count=len(result.trades or []), + turnover_value=0.0, + trade_count=len(trades or []), risk_count=len(getattr(result, "risk_events", []) or []), risk_breakdown=risk_breakdown, ) @@ -181,7 +202,25 @@ class DecisionEnv: else: volatility = 0.0 - turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series) + turnover_value = 0.0 + turnover_ratios: List[float] = [] + for row in nav_series: + turnover_raw = float(row.get("turnover", 0.0) or 0.0) + turnover_value += turnover_raw + ratio = row.get("turnover_ratio") + if ratio is not None: + try: + turnover_ratios.append(float(ratio)) + continue + except (TypeError, ValueError): + turnover_ratios.append(0.0) + continue + nav_val = float(row.get("nav", 0.0) or 0.0) + if nav_val > 0: + turnover_ratios.append(turnover_raw / nav_val) + else: + turnover_ratios.append(0.0) + avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0 risk_events = getattr(result, "risk_events", []) or [] risk_breakdown: Dict[str, int] = {} for event in risk_events: @@ -193,9 +232,10 @@ class DecisionEnv: max_drawdown=float(max_drawdown), volatility=volatility, nav_series=nav_series, - trades=result.trades, - turnover=float(turnover), - trade_count=len(result.trades or []), + trades=trades or [], + turnover=float(avg_turnover_ratio), + turnover_value=float(turnover_value), + trade_count=len(trades or []), risk_count=len(risk_events), risk_breakdown=risk_breakdown, ) @@ -203,7 +243,7 @@ class DecisionEnv: @staticmethod def _default_reward(metrics: EpisodeMetrics) -> float: risk_penalty = 0.05 * metrics.risk_count - turnover_penalty = 0.00001 * metrics.turnover + turnover_penalty = 0.1 * metrics.turnover penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty return metrics.total_return - penalty @@ -214,3 +254,100 @@ class DecisionEnv: @property def last_action(self) -> Optional[Tuple[float, ...]]: return self._last_action + + def _clear_portfolio_records(self) -> None: + start = self._template_cfg.start_date.isoformat() + end = self._template_cfg.end_date.isoformat() + try: + with db_session() as conn: + conn.execute("DELETE FROM portfolio_positions") + conn.execute( + "DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?", + (start, end), + ) + conn.execute( + "DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?", + (start, end), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA) + + def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + start = self._template_cfg.start_date.isoformat() + end = self._template_cfg.end_date.isoformat() + snapshots: List[Dict[str, Any]] = [] + trades: List[Dict[str, Any]] = [] + try: + with db_session(read_only=True) as conn: + snapshot_rows = conn.execute( + """ + SELECT trade_date, total_value, cash, invested_value, + unrealized_pnl, realized_pnl, net_flow, exposure, metadata + FROM portfolio_snapshots + WHERE trade_date BETWEEN ? AND ? + ORDER BY trade_date + """, + (start, end), + ).fetchall() + trade_rows = conn.execute( + """ + SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata + FROM portfolio_trades + WHERE trade_date BETWEEN ? AND ? + ORDER BY trade_date, id + """, + (start, end), + ).fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA) + return snapshots, trades + + for row in snapshot_rows: + metadata = self._loads(row["metadata"], {}) + snapshots.append( + { + "trade_date": row["trade_date"], + "nav": float(row["total_value"] or 0.0), + "cash": float(row["cash"] or 0.0), + "market_value": float(row["invested_value"] or 0.0), + "unrealized_pnl": float(row["unrealized_pnl"] or 0.0), + "realized_pnl": float(row["realized_pnl"] or 0.0), + "net_flow": float(row["net_flow"] or 0.0), + "exposure": float(row["exposure"] or 0.0), + "turnover": float(metadata.get("turnover_value", 0.0) or 0.0), + "turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0), + "holdings": metadata.get("holdings"), + "trade_count": metadata.get("trade_count"), + } + ) + + for row in trade_rows: + metadata = self._loads(row["metadata"], {}) + trades.append( + { + "id": row["id"], + "trade_date": row["trade_date"], + "ts_code": row["ts_code"], + "action": row["action"], + "quantity": float(row["quantity"] or 0.0), + "price": float(row["price"] or 0.0), + "fee": float(row["fee"] or 0.0), + "source": row["source"], + "metadata": metadata, + } + ) + + return snapshots, trades + + @staticmethod + def _loads(payload: Any, default: Any) -> Any: + if not payload: + return default + if isinstance(payload, (dict, list)): + return payload + if isinstance(payload, str): + try: + return json.loads(payload) + except json.JSONDecodeError: + return default + return default diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 4e8006c..52cba53 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -704,6 +704,7 @@ class BacktestEngine: unrealized_pnl += (price - cost_basis) * qty nav = state.cash + market_value + turnover_ratio = daily_turnover / nav if nav else 0.0 result.nav_series.append( { "trade_date": trade_date_str, @@ -713,6 +714,7 @@ class BacktestEngine: "realized_pnl": state.realized_pnl, "unrealized_pnl": unrealized_pnl, "turnover": daily_turnover, + "turnover_ratio": turnover_ratio, } ) if executed_trades: @@ -817,11 +819,26 @@ class BacktestEngine: ) ) + total_value = market_value + state.cash + turnover_ratio = daily_turnover / total_value if total_value else 0.0 snapshot_metadata = { "holdings": len(state.holdings), "turnover_value": daily_turnover, + "turnover_ratio": turnover_ratio, + "trade_count": len(trades), } + exposure = (market_value / total_value) if total_value else 0.0 + net_flow = 0.0 + for trade in trades: + value = float(trade.get("value", 0.0) or 0.0) + fee = float(trade.get("fee", 0.0) or 0.0) + action = str(trade.get("action", "")).lower() + if action.startswith("buy"): + net_flow -= value + fee + elif action.startswith("sell"): + net_flow += value - fee + with db_session() as conn: conn.execute( """ @@ -836,8 +853,8 @@ class BacktestEngine: market_value, unrealized_pnl, state.realized_pnl, - None, - None, + net_flow, + exposure, None, json.dumps(snapshot_metadata, ensure_ascii=False), ), diff --git a/app/backtest/optimizer.py b/app/backtest/optimizer.py index 8e630a9..a7a5f1b 100644 --- a/app/backtest/optimizer.py +++ b/app/backtest/optimizer.py @@ -131,6 +131,7 @@ def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int "volatility": metrics.volatility, "sharpe_like": metrics.sharpe_like, "turnover": metrics.turnover, + "turnover_value": metrics.turnover_value, "trade_count": float(metrics.trade_count), "risk_count": float(metrics.risk_count), } diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 56304db..ecd2b4d 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -1563,6 +1563,9 @@ def render_log_viewer() -> None: "weights": info.get("weights", {}), "nav_series": info.get("nav_series"), "trades": info.get("trades"), + "portfolio_snapshots": info.get("portfolio_snapshots"), + "portfolio_trades": info.get("portfolio_trades"), + "risk_breakdown": info.get("risk_breakdown"), "selected_agents": list(selected_agents), "action_values": list(action_values), "experiment_id": resolved_experiment_id, @@ -1587,6 +1590,14 @@ def render_log_viewer() -> None: col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") col_metrics[3].metric("奖励", f"{reward:+.4f}") + turnover_ratio = float(observation.get("turnover", 0.0) or 0.0) + turnover_value = float(observation.get("turnover_value", 0.0) or 0.0) + risk_count = float(observation.get("risk_count", 0.0) or 0.0) + col_metrics_extra = st.columns(3) + col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}") + col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}") + col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}") + weights_dict = single_result.get("weights") or {} if weights_dict: st.write("调参后权重:") @@ -1620,6 +1631,21 @@ def render_log_viewer() -> None: st.write("成交记录:") st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') + snapshots = single_result.get("portfolio_snapshots") or [] + if snapshots: + with st.expander("投资组合快照", expanded=False): + st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch') + + portfolio_trades = single_result.get("portfolio_trades") or [] + if portfolio_trades: + with st.expander("组合成交明细", expanded=False): + st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch') + + risk_breakdown = single_result.get("risk_breakdown") or {} + if risk_breakdown: + with st.expander("风险事件统计", expanded=False): + st.json(risk_breakdown) + if st.button("清除单次调参结果", key="clear_decision_env_single"): st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) st.success("已清除单次调参结果缓存。") diff --git a/app/utils/portfolio_init.py b/app/utils/portfolio_init.py index 72c6b34..6995e01 100644 --- a/app/utils/portfolio_init.py +++ b/app/utils/portfolio_init.py @@ -1,7 +1,6 @@ """Initialize portfolio database tables.""" from __future__ import annotations -import json from typing import Any from .logging import get_logger @@ -78,7 +77,7 @@ SCHEMA_STATEMENTS = [ tags TEXT, -- JSON array metadata TEXT, -- JSON object PRIMARY KEY (trade_date, ts_code) - ) + ); """, # 数据获取任务表 @@ -91,7 +90,6 @@ SCHEMA_STATEMENTS = [ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, error_msg TEXT, metadata TEXT -- JSON object for additional info - ) ); """, diff --git a/tests/test_bandit_optimizer.py b/tests/test_bandit_optimizer.py index 9ebae46..f82d5c8 100644 --- a/tests/test_bandit_optimizer.py +++ b/tests/test_bandit_optimizer.py @@ -37,7 +37,8 @@ class DummyEnv: volatility=0.05, nav_series=[], trades=[], - turnover=100.0, + turnover=0.1, + turnover_value=1000.0, trade_count=0, risk_count=1, risk_breakdown={"test": 1}, @@ -48,7 +49,8 @@ class DummyEnv: "max_drawdown": 0.1, "volatility": 0.05, "sharpe_like": reward / 0.05, - "turnover": 100.0, + "turnover": 0.1, + "turnover_value": 1000.0, "trade_count": 0.0, "risk_count": 1.0, } diff --git a/tests/test_decision_env.py b/tests/test_decision_env.py index c3e1d41..18e5724 100644 --- a/tests/test_decision_env.py +++ b/tests/test_decision_env.py @@ -26,6 +26,7 @@ class _StubEngine: "realized_pnl": 1.0, "unrealized_pnl": 1.0, "turnover": 20000.0, + "turnover_ratio": 0.2, } ] result.trades = [ @@ -65,14 +66,18 @@ def test_decision_env_returns_risk_metrics(monkeypatch): env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5}) monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine) + monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None) + monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], [])) obs, reward, done, info = env.step([0.8]) assert done is True assert "risk_count" in obs and obs["risk_count"] == 1.0 - assert obs["turnover"] == pytest.approx(20000.0) + assert obs["turnover"] == pytest.approx(0.2) + assert obs["turnover_value"] == pytest.approx(20000.0) assert info["risk_events"][0]["reason"] == "limit_up" assert info["risk_breakdown"]["limit_up"] == 1 + assert info["nav_series"][0]["turnover_ratio"] == pytest.approx(0.2) assert reward < obs["total_return"] @@ -83,10 +88,11 @@ def test_default_reward_penalizes_metrics(): volatility=0.05, nav_series=[], trades=[], - turnover=1000.0, + turnover=0.3, + turnover_value=5000.0, trade_count=0, risk_count=2, risk_breakdown={"foo": 2}, ) reward = DecisionEnv._default_reward(metrics) - assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0)) + assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))