From 74d98bf4e010b3ea59f16c2901f68f8e6fb02489 Mon Sep 17 00:00:00 2001 From: sam Date: Thu, 16 Oct 2025 14:57:55 +0800 Subject: [PATCH] refactor portfolio tables with cfg_id and optimize record handling --- app/backtest/decision_env.py | 28 ++++----- app/backtest/engine.py | 108 ++++++++++++++++++++++++++++------- app/data/schema.py | 52 +++++++++++++++++ app/utils/portfolio_init.py | 55 ++++++++++++++++++ 4 files changed, 206 insertions(+), 37 deletions(-) diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index 5a74995..702afbb 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -508,25 +508,19 @@ class DecisionEnv: return self._last_action def _clear_portfolio_records(self) -> None: - start = self._template_cfg.start_date.isoformat() - end = self._template_cfg.end_date.isoformat() + cfg_id = self._template_cfg.id or "decision_env" try: with db_session() as conn: - conn.execute("DELETE FROM portfolio_positions") - conn.execute( - "DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?", - (start, end), - ) - conn.execute( - "DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?", - (start, end), - ) + conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,)) + conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,)) + conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,)) except Exception: # noqa: BLE001 LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA) def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: start = self._template_cfg.start_date.isoformat() end = self._template_cfg.end_date.isoformat() + cfg_id = self._template_cfg.id or "decision_env" snapshots: List[Dict[str, Any]] = [] trades: List[Dict[str, Any]] = [] try: @@ -535,20 +529,20 @@ class DecisionEnv: """ SELECT trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata - FROM portfolio_snapshots - WHERE trade_date BETWEEN ? AND ? + FROM bt_portfolio_snapshots + WHERE cfg_id = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date """, - (start, end), + (cfg_id, start, end), ).fetchall() trade_rows = conn.execute( """ SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata - FROM portfolio_trades - WHERE trade_date BETWEEN ? AND ? + FROM bt_portfolio_trades + WHERE cfg_id = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date, id """, - (start, end), + (cfg_id, start, end), ).fetchall() except Exception: # noqa: BLE001 LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA) diff --git a/app/backtest/engine.py b/app/backtest/engine.py index f4cd47a..18f68bf 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -70,6 +70,7 @@ class PortfolioState: cost_basis: Dict[str, float] = field(default_factory=dict) opened_dates: Dict[str, str] = field(default_factory=dict) realized_pnl: float = 0.0 + realized_pnl_by_symbol: Dict[str, float] = field(default_factory=dict) @dataclass @@ -618,13 +619,48 @@ class BacktestEngine: price_map[ts_code] = price decisions_map[ts_code] = decision - if not price_map and state.holdings: - trade_date_compact = trade_date.strftime("%Y%m%d") - for ts_code in list(state.holdings.keys()): - fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"], auto_refresh=False) - price = fetched.get("daily.close") - if price: - price_map[ts_code] = float(price) + trade_date_compact = trade_date.strftime("%Y%m%d") + missing_prices: set[str] = { + code for code in decisions_map.keys() if code not in price_map + } + missing_prices.update(code for code in state.holdings.keys() if code not in price_map) + if missing_prices: + for ts_code in sorted(missing_prices): + try: + fetched = self.data_broker.fetch_latest( + ts_code, + trade_date_compact, + ["daily.close"], + auto_refresh=False, + ) + except Exception: # noqa: BLE001 + LOGGER.debug( + "回补价格失败 ts_code=%s date=%s", + ts_code, + trade_date_compact, + extra=LOG_EXTRA, + ) + continue + fallback_price = fetched.get("daily.close") + if fallback_price is None: + continue + try: + price_map[ts_code] = float(fallback_price) + except (TypeError, ValueError): + LOGGER.debug( + "价格解析失败 ts_code=%s raw=%s", + ts_code, + fallback_price, + extra=LOG_EXTRA, + ) + unresolved = [code for code in missing_prices if code not in price_map] + if unresolved: + LOGGER.warning( + "缺少收盘价,回测将跳过估值:codes=%s date=%s", + unresolved, + trade_date_compact, + extra=LOG_EXTRA, + ) portfolio_value_before = state.cash for ts_code, qty in state.holdings.items(): @@ -883,6 +919,9 @@ class BacktestEngine: realized = (trade_price - current_cost_basis) * sell_qty - fee state.cash += proceeds state.realized_pnl += realized + state.realized_pnl_by_symbol[ts_code] = ( + state.realized_pnl_by_symbol.get(ts_code, 0.0) + realized + ) new_qty = current_qty - sell_qty if new_qty <= 1e-6: state.holdings.pop(ts_code, None) @@ -953,6 +992,7 @@ class BacktestEngine: try: self._persist_portfolio( + self.cfg.id, trade_date_str, state, market_value, @@ -1047,8 +1087,21 @@ class BacktestEngine: except sqlite3.Error: pass + def _reset_bt_portfolio_records(self) -> None: + cfg_id = self.cfg.id + if not cfg_id: + return + try: + with db_session() as conn: + conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,)) + conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,)) + conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,)) + except Exception: # noqa: BLE001 + LOGGER.exception("清理回测投资组合数据失败 cfg_id=%s", cfg_id, extra=LOG_EXTRA) + def _persist_portfolio( self, + cfg_id: str, trade_date: str, state: PortfolioState, market_value: float, @@ -1070,16 +1123,21 @@ class BacktestEngine: "last_action": decision.action.value if decision else None, "confidence": decision.confidence if decision else None, } + opened_date = state.opened_dates.get(ts_code, trade_date) + if hasattr(opened_date, "isoformat"): + opened_date = opened_date.isoformat() # type: ignore[attr-defined] holdings_rows.append( ( + cfg_id, + trade_date, ts_code, - state.opened_dates.get(ts_code, trade_date), + opened_date, None, qty, cost_basis, price, market_val, - state.realized_pnl, + state.realized_pnl_by_symbol.get(ts_code, 0.0), unrealized, target_weight, "open", @@ -1111,11 +1169,12 @@ class BacktestEngine: with db_session() as conn: conn.execute( """ - INSERT OR REPLACE INTO portfolio_snapshots - (trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata) + INSERT OR REPLACE INTO bt_portfolio_snapshots + (cfg_id, trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( + cfg_id, trade_date, market_value + state.cash, state.cash, @@ -1124,37 +1183,45 @@ class BacktestEngine: state.realized_pnl, net_flow, exposure, - None, json.dumps(snapshot_metadata, ensure_ascii=False), ), ) - conn.execute("DELETE FROM portfolio_positions") + conn.execute( + "DELETE FROM bt_portfolio_positions WHERE cfg_id = ? AND trade_date = ?", + (cfg_id, trade_date), + ) if holdings_rows: conn.executemany( """ - INSERT INTO portfolio_positions - (ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO bt_portfolio_positions + (cfg_id, trade_date, ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, holdings_rows, ) if trades: + conn.execute( + "DELETE FROM bt_portfolio_trades WHERE cfg_id = ? AND trade_date = ?", + (cfg_id, trade_date), + ) conn.executemany( """ - INSERT INTO portfolio_trades - (trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata) - VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?) + INSERT INTO bt_portfolio_trades + (cfg_id, trade_date, ts_code, action, quantity, price, fee, source, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ ( + cfg_id, trade["trade_date"], trade["ts_code"], trade["action"], trade["quantity"], trade["price"], - trade.get("fee", 0.0), + float(trade.get("fee", 0.0)), + str(trade.get("source") or "backtest"), json.dumps(trade, ensure_ascii=False), ) for trade in trades @@ -1196,6 +1263,7 @@ class BacktestEngine: self, decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, ) -> BacktestResult: + self._reset_bt_portfolio_records() session = self.start_session() if session.current_date > self.cfg.end_date: return session.result diff --git a/app/data/schema.py b/app/data/schema.py index 3ee2953..86ce3f8 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -436,6 +436,55 @@ SCHEMA_STATEMENTS: Iterable[str] = ( ); """, """ + CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots ( + cfg_id TEXT, + trade_date TEXT, + total_value REAL, + cash REAL, + invested_value REAL, + unrealized_pnl REAL, + realized_pnl REAL, + net_flow REAL, + exposure REAL, + metadata TEXT, + PRIMARY KEY (cfg_id, trade_date) + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_portfolio_positions ( + cfg_id TEXT, + trade_date TEXT, + ts_code TEXT, + opened_date TEXT, + closed_date TEXT, + quantity REAL, + cost_price REAL, + market_price REAL, + market_value REAL, + realized_pnl REAL, + unrealized_pnl REAL, + target_weight REAL, + status TEXT, + notes TEXT, + metadata TEXT, + PRIMARY KEY (cfg_id, trade_date, ts_code) + ); + """, + """ + CREATE TABLE IF NOT EXISTS bt_portfolio_trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + cfg_id TEXT NOT NULL, + trade_date TEXT NOT NULL, + ts_code TEXT NOT NULL, + action TEXT NOT NULL, + quantity REAL NOT NULL, + price REAL NOT NULL, + fee REAL DEFAULT 0, + source TEXT, + metadata TEXT + ); + """, + """ CREATE TABLE IF NOT EXISTS run_log ( ts TEXT PRIMARY KEY, stage TEXT, @@ -569,6 +618,9 @@ REQUIRED_TABLES = ( "bt_risk_events", "bt_nav", "bt_report", + "bt_portfolio_snapshots", + "bt_portfolio_positions", + "bt_portfolio_trades", "run_log", "agent_utils", "alloc_log", diff --git a/app/utils/portfolio_init.py b/app/utils/portfolio_init.py index e89df15..aa4423d 100644 --- a/app/utils/portfolio_init.py +++ b/app/utils/portfolio_init.py @@ -130,6 +130,61 @@ SCHEMA_STATEMENTS = [ metadata TEXT -- JSON object ); """, + + # 回测组合快照表 + """ + CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots ( + cfg_id TEXT, + trade_date TEXT, + total_value REAL, + cash REAL, + invested_value REAL, + unrealized_pnl REAL, + realized_pnl REAL, + net_flow REAL, + exposure REAL, + metadata TEXT, + PRIMARY KEY (cfg_id, trade_date) + ); + """, + + # 回测持仓表 + """ + CREATE TABLE IF NOT EXISTS bt_portfolio_positions ( + cfg_id TEXT, + trade_date TEXT, + ts_code TEXT, + opened_date TEXT, + closed_date TEXT, + quantity REAL, + cost_price REAL, + market_price REAL, + market_value REAL, + realized_pnl REAL, + unrealized_pnl REAL, + target_weight REAL, + status TEXT, + notes TEXT, + metadata TEXT, + PRIMARY KEY (cfg_id, trade_date, ts_code) + ); + """, + + # 回测交易表 + """ + CREATE TABLE IF NOT EXISTS bt_portfolio_trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + cfg_id TEXT NOT NULL, + trade_date TEXT NOT NULL, + ts_code TEXT NOT NULL, + action TEXT NOT NULL, + quantity REAL NOT NULL, + price REAL NOT NULL, + fee REAL DEFAULT 0, + source TEXT, + metadata TEXT + ); + """, ]