refactor portfolio tables with cfg_id and optimize record handling

This commit is contained in:
sam 2025-10-16 14:57:55 +08:00
parent 2220b5084e
commit 74d98bf4e0
4 changed files with 206 additions and 37 deletions

View File

@ -508,25 +508,19 @@ class DecisionEnv:
return self._last_action
def _clear_portfolio_records(self) -> None:
start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat()
cfg_id = self._template_cfg.id or "decision_env"
try:
with db_session() as conn:
conn.execute("DELETE FROM portfolio_positions")
conn.execute(
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
(start, end),
)
conn.execute(
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
(start, end),
)
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
except Exception: # noqa: BLE001
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat()
cfg_id = self._template_cfg.id or "decision_env"
snapshots: List[Dict[str, Any]] = []
trades: List[Dict[str, Any]] = []
try:
@ -535,20 +529,20 @@ class DecisionEnv:
"""
SELECT trade_date, total_value, cash, invested_value,
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
FROM portfolio_snapshots
WHERE trade_date BETWEEN ? AND ?
FROM bt_portfolio_snapshots
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date
""",
(start, end),
(cfg_id, start, end),
).fetchall()
trade_rows = conn.execute(
"""
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
FROM portfolio_trades
WHERE trade_date BETWEEN ? AND ?
FROM bt_portfolio_trades
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date, id
""",
(start, end),
(cfg_id, start, end),
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)

View File

@ -70,6 +70,7 @@ class PortfolioState:
cost_basis: Dict[str, float] = field(default_factory=dict)
opened_dates: Dict[str, str] = field(default_factory=dict)
realized_pnl: float = 0.0
realized_pnl_by_symbol: Dict[str, float] = field(default_factory=dict)
@dataclass
@ -618,13 +619,48 @@ class BacktestEngine:
price_map[ts_code] = price
decisions_map[ts_code] = decision
if not price_map and state.holdings:
trade_date_compact = trade_date.strftime("%Y%m%d")
for ts_code in list(state.holdings.keys()):
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"], auto_refresh=False)
price = fetched.get("daily.close")
if price:
price_map[ts_code] = float(price)
trade_date_compact = trade_date.strftime("%Y%m%d")
missing_prices: set[str] = {
code for code in decisions_map.keys() if code not in price_map
}
missing_prices.update(code for code in state.holdings.keys() if code not in price_map)
if missing_prices:
for ts_code in sorted(missing_prices):
try:
fetched = self.data_broker.fetch_latest(
ts_code,
trade_date_compact,
["daily.close"],
auto_refresh=False,
)
except Exception: # noqa: BLE001
LOGGER.debug(
"回补价格失败 ts_code=%s date=%s",
ts_code,
trade_date_compact,
extra=LOG_EXTRA,
)
continue
fallback_price = fetched.get("daily.close")
if fallback_price is None:
continue
try:
price_map[ts_code] = float(fallback_price)
except (TypeError, ValueError):
LOGGER.debug(
"价格解析失败 ts_code=%s raw=%s",
ts_code,
fallback_price,
extra=LOG_EXTRA,
)
unresolved = [code for code in missing_prices if code not in price_map]
if unresolved:
LOGGER.warning(
"缺少收盘价回测将跳过估值codes=%s date=%s",
unresolved,
trade_date_compact,
extra=LOG_EXTRA,
)
portfolio_value_before = state.cash
for ts_code, qty in state.holdings.items():
@ -883,6 +919,9 @@ class BacktestEngine:
realized = (trade_price - current_cost_basis) * sell_qty - fee
state.cash += proceeds
state.realized_pnl += realized
state.realized_pnl_by_symbol[ts_code] = (
state.realized_pnl_by_symbol.get(ts_code, 0.0) + realized
)
new_qty = current_qty - sell_qty
if new_qty <= 1e-6:
state.holdings.pop(ts_code, None)
@ -953,6 +992,7 @@ class BacktestEngine:
try:
self._persist_portfolio(
self.cfg.id,
trade_date_str,
state,
market_value,
@ -1047,8 +1087,21 @@ class BacktestEngine:
except sqlite3.Error:
pass
def _reset_bt_portfolio_records(self) -> None:
cfg_id = self.cfg.id
if not cfg_id:
return
try:
with db_session() as conn:
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
except Exception: # noqa: BLE001
LOGGER.exception("清理回测投资组合数据失败 cfg_id=%s", cfg_id, extra=LOG_EXTRA)
def _persist_portfolio(
self,
cfg_id: str,
trade_date: str,
state: PortfolioState,
market_value: float,
@ -1070,16 +1123,21 @@ class BacktestEngine:
"last_action": decision.action.value if decision else None,
"confidence": decision.confidence if decision else None,
}
opened_date = state.opened_dates.get(ts_code, trade_date)
if hasattr(opened_date, "isoformat"):
opened_date = opened_date.isoformat() # type: ignore[attr-defined]
holdings_rows.append(
(
cfg_id,
trade_date,
ts_code,
state.opened_dates.get(ts_code, trade_date),
opened_date,
None,
qty,
cost_basis,
price,
market_val,
state.realized_pnl,
state.realized_pnl_by_symbol.get(ts_code, 0.0),
unrealized,
target_weight,
"open",
@ -1111,11 +1169,12 @@ class BacktestEngine:
with db_session() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO portfolio_snapshots
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
INSERT OR REPLACE INTO bt_portfolio_snapshots
(cfg_id, trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
cfg_id,
trade_date,
market_value + state.cash,
state.cash,
@ -1124,37 +1183,45 @@ class BacktestEngine:
state.realized_pnl,
net_flow,
exposure,
None,
json.dumps(snapshot_metadata, ensure_ascii=False),
),
)
conn.execute("DELETE FROM portfolio_positions")
conn.execute(
"DELETE FROM bt_portfolio_positions WHERE cfg_id = ? AND trade_date = ?",
(cfg_id, trade_date),
)
if holdings_rows:
conn.executemany(
"""
INSERT INTO portfolio_positions
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO bt_portfolio_positions
(cfg_id, trade_date, ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
holdings_rows,
)
if trades:
conn.execute(
"DELETE FROM bt_portfolio_trades WHERE cfg_id = ? AND trade_date = ?",
(cfg_id, trade_date),
)
conn.executemany(
"""
INSERT INTO portfolio_trades
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?)
INSERT INTO bt_portfolio_trades
(cfg_id, trade_date, ts_code, action, quantity, price, fee, source, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
cfg_id,
trade["trade_date"],
trade["ts_code"],
trade["action"],
trade["quantity"],
trade["price"],
trade.get("fee", 0.0),
float(trade.get("fee", 0.0)),
str(trade.get("source") or "backtest"),
json.dumps(trade, ensure_ascii=False),
)
for trade in trades
@ -1196,6 +1263,7 @@ class BacktestEngine:
self,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult:
self._reset_bt_portfolio_records()
session = self.start_session()
if session.current_date > self.cfg.end_date:
return session.result

View File

@ -436,6 +436,55 @@ SCHEMA_STATEMENTS: Iterable[str] = (
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
cfg_id TEXT,
trade_date TEXT,
total_value REAL,
cash REAL,
invested_value REAL,
unrealized_pnl REAL,
realized_pnl REAL,
net_flow REAL,
exposure REAL,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
opened_date TEXT,
closed_date TEXT,
quantity REAL,
cost_price REAL,
market_price REAL,
market_value REAL,
realized_pnl REAL,
unrealized_pnl REAL,
target_weight REAL,
status TEXT,
notes TEXT,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date, ts_code)
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cfg_id TEXT NOT NULL,
trade_date TEXT NOT NULL,
ts_code TEXT NOT NULL,
action TEXT NOT NULL,
quantity REAL NOT NULL,
price REAL NOT NULL,
fee REAL DEFAULT 0,
source TEXT,
metadata TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS run_log (
ts TEXT PRIMARY KEY,
stage TEXT,
@ -569,6 +618,9 @@ REQUIRED_TABLES = (
"bt_risk_events",
"bt_nav",
"bt_report",
"bt_portfolio_snapshots",
"bt_portfolio_positions",
"bt_portfolio_trades",
"run_log",
"agent_utils",
"alloc_log",

View File

@ -130,6 +130,61 @@ SCHEMA_STATEMENTS = [
metadata TEXT -- JSON object
);
""",
# 回测组合快照表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
cfg_id TEXT,
trade_date TEXT,
total_value REAL,
cash REAL,
invested_value REAL,
unrealized_pnl REAL,
realized_pnl REAL,
net_flow REAL,
exposure REAL,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date)
);
""",
# 回测持仓表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
opened_date TEXT,
closed_date TEXT,
quantity REAL,
cost_price REAL,
market_price REAL,
market_value REAL,
realized_pnl REAL,
unrealized_pnl REAL,
target_weight REAL,
status TEXT,
notes TEXT,
metadata TEXT,
PRIMARY KEY (cfg_id, trade_date, ts_code)
);
""",
# 回测交易表
"""
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
cfg_id TEXT NOT NULL,
trade_date TEXT NOT NULL,
ts_code TEXT NOT NULL,
action TEXT NOT NULL,
quantity REAL NOT NULL,
price REAL NOT NULL,
fee REAL DEFAULT 0,
source TEXT,
metadata TEXT
);
""",
]