refactor portfolio tables with cfg_id and optimize record handling

This commit is contained in:
sam 2025-10-16 14:57:55 +08:00
parent 2220b5084e
commit 74d98bf4e0
4 changed files with 206 additions and 37 deletions

View File

@ -508,25 +508,19 @@ class DecisionEnv:
return self._last_action return self._last_action
def _clear_portfolio_records(self) -> None: def _clear_portfolio_records(self) -> None:
start = self._template_cfg.start_date.isoformat() cfg_id = self._template_cfg.id or "decision_env"
end = self._template_cfg.end_date.isoformat()
try: try:
with db_session() as conn: with db_session() as conn:
conn.execute("DELETE FROM portfolio_positions") conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
conn.execute( conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?", conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
(start, end),
)
conn.execute(
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
(start, end),
)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA) LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
start = self._template_cfg.start_date.isoformat() start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat() end = self._template_cfg.end_date.isoformat()
cfg_id = self._template_cfg.id or "decision_env"
snapshots: List[Dict[str, Any]] = [] snapshots: List[Dict[str, Any]] = []
trades: List[Dict[str, Any]] = [] trades: List[Dict[str, Any]] = []
try: try:
@ -535,20 +529,20 @@ class DecisionEnv:
""" """
SELECT trade_date, total_value, cash, invested_value, SELECT trade_date, total_value, cash, invested_value,
unrealized_pnl, realized_pnl, net_flow, exposure, metadata unrealized_pnl, realized_pnl, net_flow, exposure, metadata
FROM portfolio_snapshots FROM bt_portfolio_snapshots
WHERE trade_date BETWEEN ? AND ? WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date ORDER BY trade_date
""", """,
(start, end), (cfg_id, start, end),
).fetchall() ).fetchall()
trade_rows = conn.execute( trade_rows = conn.execute(
""" """
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
FROM portfolio_trades FROM bt_portfolio_trades
WHERE trade_date BETWEEN ? AND ? WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date, id ORDER BY trade_date, id
""", """,
(start, end), (cfg_id, start, end),
).fetchall() ).fetchall()
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA) LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)

View File

@ -70,6 +70,7 @@ class PortfolioState:
cost_basis: Dict[str, float] = field(default_factory=dict) cost_basis: Dict[str, float] = field(default_factory=dict)
opened_dates: Dict[str, str] = field(default_factory=dict) opened_dates: Dict[str, str] = field(default_factory=dict)
realized_pnl: float = 0.0 realized_pnl: float = 0.0
realized_pnl_by_symbol: Dict[str, float] = field(default_factory=dict)
@dataclass @dataclass
@ -618,13 +619,48 @@ class BacktestEngine:
price_map[ts_code] = price price_map[ts_code] = price
decisions_map[ts_code] = decision decisions_map[ts_code] = decision
if not price_map and state.holdings: trade_date_compact = trade_date.strftime("%Y%m%d")
trade_date_compact = trade_date.strftime("%Y%m%d") missing_prices: set[str] = {
for ts_code in list(state.holdings.keys()): code for code in decisions_map.keys() if code not in price_map
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"], auto_refresh=False) }
price = fetched.get("daily.close") missing_prices.update(code for code in state.holdings.keys() if code not in price_map)
if price: if missing_prices:
price_map[ts_code] = float(price) for ts_code in sorted(missing_prices):
try:
fetched = self.data_broker.fetch_latest(
ts_code,
trade_date_compact,
["daily.close"],
auto_refresh=False,
)
except Exception: # noqa: BLE001
LOGGER.debug(
"回补价格失败 ts_code=%s date=%s",
ts_code,
trade_date_compact,
extra=LOG_EXTRA,
)
continue
fallback_price = fetched.get("daily.close")
if fallback_price is None:
continue
try:
price_map[ts_code] = float(fallback_price)
except (TypeError, ValueError):
LOGGER.debug(
"价格解析失败 ts_code=%s raw=%s",
ts_code,
fallback_price,
extra=LOG_EXTRA,
)
unresolved = [code for code in missing_prices if code not in price_map]
if unresolved:
LOGGER.warning(
"缺少收盘价回测将跳过估值codes=%s date=%s",
unresolved,
trade_date_compact,
extra=LOG_EXTRA,
)
portfolio_value_before = state.cash portfolio_value_before = state.cash
for ts_code, qty in state.holdings.items(): for ts_code, qty in state.holdings.items():
@ -883,6 +919,9 @@ class BacktestEngine:
realized = (trade_price - current_cost_basis) * sell_qty - fee realized = (trade_price - current_cost_basis) * sell_qty - fee
state.cash += proceeds state.cash += proceeds
state.realized_pnl += realized state.realized_pnl += realized
state.realized_pnl_by_symbol[ts_code] = (
state.realized_pnl_by_symbol.get(ts_code, 0.0) + realized
)
new_qty = current_qty - sell_qty new_qty = current_qty - sell_qty
if new_qty <= 1e-6: if new_qty <= 1e-6:
state.holdings.pop(ts_code, None) state.holdings.pop(ts_code, None)
@ -953,6 +992,7 @@ class BacktestEngine:
try: try:
self._persist_portfolio( self._persist_portfolio(
self.cfg.id,
trade_date_str, trade_date_str,
state, state,
market_value, market_value,
@ -1047,8 +1087,21 @@ class BacktestEngine:
except sqlite3.Error: except sqlite3.Error:
pass pass
def _reset_bt_portfolio_records(self) -> None:
cfg_id = self.cfg.id
if not cfg_id:
return
try:
with db_session() as conn:
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
except Exception: # noqa: BLE001
LOGGER.exception("清理回测投资组合数据失败 cfg_id=%s", cfg_id, extra=LOG_EXTRA)
def _persist_portfolio( def _persist_portfolio(
self, self,
cfg_id: str,
trade_date: str, trade_date: str,
state: PortfolioState, state: PortfolioState,
market_value: float, market_value: float,
@ -1070,16 +1123,21 @@ class BacktestEngine:
"last_action": decision.action.value if decision else None, "last_action": decision.action.value if decision else None,
"confidence": decision.confidence if decision else None, "confidence": decision.confidence if decision else None,
} }
opened_date = state.opened_dates.get(ts_code, trade_date)
if hasattr(opened_date, "isoformat"):
opened_date = opened_date.isoformat() # type: ignore[attr-defined]
holdings_rows.append( holdings_rows.append(
( (
cfg_id,
trade_date,
ts_code, ts_code,
state.opened_dates.get(ts_code, trade_date), opened_date,
None, None,
qty, qty,
cost_basis, cost_basis,
price, price,
market_val, market_val,
state.realized_pnl, state.realized_pnl_by_symbol.get(ts_code, 0.0),
unrealized, unrealized,
target_weight, target_weight,
"open", "open",
@ -1111,11 +1169,12 @@ class BacktestEngine:
with db_session() as conn: with db_session() as conn:
conn.execute( conn.execute(
""" """
INSERT OR REPLACE INTO portfolio_snapshots INSERT OR REPLACE INTO bt_portfolio_snapshots
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata) (cfg_id, trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
cfg_id,
trade_date, trade_date,
market_value + state.cash, market_value + state.cash,
state.cash, state.cash,
@ -1124,37 +1183,45 @@ class BacktestEngine:
state.realized_pnl, state.realized_pnl,
net_flow, net_flow,
exposure, exposure,
None,
json.dumps(snapshot_metadata, ensure_ascii=False), json.dumps(snapshot_metadata, ensure_ascii=False),
), ),
) )
conn.execute("DELETE FROM portfolio_positions") conn.execute(
"DELETE FROM bt_portfolio_positions WHERE cfg_id = ? AND trade_date = ?",
(cfg_id, trade_date),
)
if holdings_rows: if holdings_rows:
conn.executemany( conn.executemany(
""" """
INSERT INTO portfolio_positions INSERT INTO bt_portfolio_positions
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata) (cfg_id, trade_date, ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
holdings_rows, holdings_rows,
) )
if trades: if trades:
conn.execute(
"DELETE FROM bt_portfolio_trades WHERE cfg_id = ? AND trade_date = ?",
(cfg_id, trade_date),
)
conn.executemany( conn.executemany(
""" """
INSERT INTO portfolio_trades INSERT INTO bt_portfolio_trades
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata) (cfg_id, trade_date, ts_code, action, quantity, price, fee, source, metadata)
VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
[ [
( (
cfg_id,
trade["trade_date"], trade["trade_date"],
trade["ts_code"], trade["ts_code"],
trade["action"], trade["action"],
trade["quantity"], trade["quantity"],
trade["price"], trade["price"],
trade.get("fee", 0.0), float(trade.get("fee", 0.0)),
str(trade.get("source") or "backtest"),
json.dumps(trade, ensure_ascii=False), json.dumps(trade, ensure_ascii=False),
) )
for trade in trades for trade in trades
@ -1196,6 +1263,7 @@ class BacktestEngine:
self, self,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult: ) -> BacktestResult:
self._reset_bt_portfolio_records()
session = self.start_session() session = self.start_session()
if session.current_date > self.cfg.end_date: if session.current_date > self.cfg.end_date:
return session.result return session.result

View File

@ -436,6 +436,55 @@ SCHEMA_STATEMENTS: Iterable[str] = (
); );
""", """,
""" """
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
cfg_id TEXT,
trade_date TEXT,
total_value REAL,
cash REAL,
invested_value REAL,
unrealized_pnl REAL,
realized_pnl REAL,
net_flow REAL,
exposure REAL,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
opened_date TEXT,
closed_date TEXT,
quantity REAL,
cost_price REAL,
market_price REAL,
market_value REAL,
realized_pnl REAL,
unrealized_pnl REAL,
target_weight REAL,
status TEXT,
notes TEXT,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date, ts_code)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cfg_id TEXT NOT NULL,
trade_date TEXT NOT NULL,
ts_code TEXT NOT NULL,
action TEXT NOT NULL,
quantity REAL NOT NULL,
price REAL NOT NULL,
fee REAL DEFAULT 0,
source TEXT,
metadata TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS run_log ( CREATE TABLE IF NOT EXISTS run_log (
ts TEXT PRIMARY KEY, ts TEXT PRIMARY KEY,
stage TEXT, stage TEXT,
@ -569,6 +618,9 @@ REQUIRED_TABLES = (
"bt_risk_events", "bt_risk_events",
"bt_nav", "bt_nav",
"bt_report", "bt_report",
"bt_portfolio_snapshots",
"bt_portfolio_positions",
"bt_portfolio_trades",
"run_log", "run_log",
"agent_utils", "agent_utils",
"alloc_log", "alloc_log",

View File

@ -130,6 +130,61 @@ SCHEMA_STATEMENTS = [
metadata TEXT -- JSON object metadata TEXT -- JSON object
); );
""", """,
# 回测组合快照表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
cfg_id TEXT,
trade_date TEXT,
total_value REAL,
cash REAL,
invested_value REAL,
unrealized_pnl REAL,
realized_pnl REAL,
net_flow REAL,
exposure REAL,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date)
);
""",
# 回测持仓表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
opened_date TEXT,
closed_date TEXT,
quantity REAL,
cost_price REAL,
market_price REAL,
market_value REAL,
realized_pnl REAL,
unrealized_pnl REAL,
target_weight REAL,
status TEXT,
notes TEXT,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date, ts_code)
);
""",
# 回测交易表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cfg_id TEXT NOT NULL,
trade_date TEXT NOT NULL,
ts_code TEXT NOT NULL,
action TEXT NOT NULL,
quantity REAL NOT NULL,
price REAL NOT NULL,
fee REAL DEFAULT 0,
source TEXT,
metadata TEXT
);
""",
] ]