refactor portfolio tables with cfg_id and optimize record handling
This commit is contained in:
parent
2220b5084e
commit
74d98bf4e0
@ -508,25 +508,19 @@ class DecisionEnv:
|
||||
return self._last_action
|
||||
|
||||
def _clear_portfolio_records(self) -> None:
|
||||
start = self._template_cfg.start_date.isoformat()
|
||||
end = self._template_cfg.end_date.isoformat()
|
||||
cfg_id = self._template_cfg.id or "decision_env"
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute("DELETE FROM portfolio_positions")
|
||||
conn.execute(
|
||||
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
|
||||
(start, end),
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
|
||||
(start, end),
|
||||
)
|
||||
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
|
||||
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
|
||||
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
|
||||
|
||||
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
start = self._template_cfg.start_date.isoformat()
|
||||
end = self._template_cfg.end_date.isoformat()
|
||||
cfg_id = self._template_cfg.id or "decision_env"
|
||||
snapshots: List[Dict[str, Any]] = []
|
||||
trades: List[Dict[str, Any]] = []
|
||||
try:
|
||||
@ -535,20 +529,20 @@ class DecisionEnv:
|
||||
"""
|
||||
SELECT trade_date, total_value, cash, invested_value,
|
||||
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
|
||||
FROM portfolio_snapshots
|
||||
WHERE trade_date BETWEEN ? AND ?
|
||||
FROM bt_portfolio_snapshots
|
||||
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date
|
||||
""",
|
||||
(start, end),
|
||||
(cfg_id, start, end),
|
||||
).fetchall()
|
||||
trade_rows = conn.execute(
|
||||
"""
|
||||
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
|
||||
FROM portfolio_trades
|
||||
WHERE trade_date BETWEEN ? AND ?
|
||||
FROM bt_portfolio_trades
|
||||
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date, id
|
||||
""",
|
||||
(start, end),
|
||||
(cfg_id, start, end),
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
|
||||
|
||||
@ -70,6 +70,7 @@ class PortfolioState:
|
||||
cost_basis: Dict[str, float] = field(default_factory=dict)
|
||||
opened_dates: Dict[str, str] = field(default_factory=dict)
|
||||
realized_pnl: float = 0.0
|
||||
realized_pnl_by_symbol: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -618,13 +619,48 @@ class BacktestEngine:
|
||||
price_map[ts_code] = price
|
||||
decisions_map[ts_code] = decision
|
||||
|
||||
if not price_map and state.holdings:
|
||||
trade_date_compact = trade_date.strftime("%Y%m%d")
|
||||
for ts_code in list(state.holdings.keys()):
|
||||
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"], auto_refresh=False)
|
||||
price = fetched.get("daily.close")
|
||||
if price:
|
||||
price_map[ts_code] = float(price)
|
||||
missing_prices: set[str] = {
|
||||
code for code in decisions_map.keys() if code not in price_map
|
||||
}
|
||||
missing_prices.update(code for code in state.holdings.keys() if code not in price_map)
|
||||
if missing_prices:
|
||||
for ts_code in sorted(missing_prices):
|
||||
try:
|
||||
fetched = self.data_broker.fetch_latest(
|
||||
ts_code,
|
||||
trade_date_compact,
|
||||
["daily.close"],
|
||||
auto_refresh=False,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"回补价格失败 ts_code=%s date=%s",
|
||||
ts_code,
|
||||
trade_date_compact,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
fallback_price = fetched.get("daily.close")
|
||||
if fallback_price is None:
|
||||
continue
|
||||
try:
|
||||
price_map[ts_code] = float(fallback_price)
|
||||
except (TypeError, ValueError):
|
||||
LOGGER.debug(
|
||||
"价格解析失败 ts_code=%s raw=%s",
|
||||
ts_code,
|
||||
fallback_price,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
unresolved = [code for code in missing_prices if code not in price_map]
|
||||
if unresolved:
|
||||
LOGGER.warning(
|
||||
"缺少收盘价,回测将跳过估值:codes=%s date=%s",
|
||||
unresolved,
|
||||
trade_date_compact,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
portfolio_value_before = state.cash
|
||||
for ts_code, qty in state.holdings.items():
|
||||
@ -883,6 +919,9 @@ class BacktestEngine:
|
||||
realized = (trade_price - current_cost_basis) * sell_qty - fee
|
||||
state.cash += proceeds
|
||||
state.realized_pnl += realized
|
||||
state.realized_pnl_by_symbol[ts_code] = (
|
||||
state.realized_pnl_by_symbol.get(ts_code, 0.0) + realized
|
||||
)
|
||||
new_qty = current_qty - sell_qty
|
||||
if new_qty <= 1e-6:
|
||||
state.holdings.pop(ts_code, None)
|
||||
@ -953,6 +992,7 @@ class BacktestEngine:
|
||||
|
||||
try:
|
||||
self._persist_portfolio(
|
||||
self.cfg.id,
|
||||
trade_date_str,
|
||||
state,
|
||||
market_value,
|
||||
@ -1047,8 +1087,21 @@ class BacktestEngine:
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
|
||||
def _reset_bt_portfolio_records(self) -> None:
|
||||
cfg_id = self.cfg.id
|
||||
if not cfg_id:
|
||||
return
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
|
||||
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
|
||||
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("清理回测投资组合数据失败 cfg_id=%s", cfg_id, extra=LOG_EXTRA)
|
||||
|
||||
def _persist_portfolio(
|
||||
self,
|
||||
cfg_id: str,
|
||||
trade_date: str,
|
||||
state: PortfolioState,
|
||||
market_value: float,
|
||||
@ -1070,16 +1123,21 @@ class BacktestEngine:
|
||||
"last_action": decision.action.value if decision else None,
|
||||
"confidence": decision.confidence if decision else None,
|
||||
}
|
||||
opened_date = state.opened_dates.get(ts_code, trade_date)
|
||||
if hasattr(opened_date, "isoformat"):
|
||||
opened_date = opened_date.isoformat() # type: ignore[attr-defined]
|
||||
holdings_rows.append(
|
||||
(
|
||||
cfg_id,
|
||||
trade_date,
|
||||
ts_code,
|
||||
state.opened_dates.get(ts_code, trade_date),
|
||||
opened_date,
|
||||
None,
|
||||
qty,
|
||||
cost_basis,
|
||||
price,
|
||||
market_val,
|
||||
state.realized_pnl,
|
||||
state.realized_pnl_by_symbol.get(ts_code, 0.0),
|
||||
unrealized,
|
||||
target_weight,
|
||||
"open",
|
||||
@ -1111,11 +1169,12 @@ class BacktestEngine:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO portfolio_snapshots
|
||||
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
|
||||
INSERT OR REPLACE INTO bt_portfolio_snapshots
|
||||
(cfg_id, trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
cfg_id,
|
||||
trade_date,
|
||||
market_value + state.cash,
|
||||
state.cash,
|
||||
@ -1124,37 +1183,45 @@ class BacktestEngine:
|
||||
state.realized_pnl,
|
||||
net_flow,
|
||||
exposure,
|
||||
None,
|
||||
json.dumps(snapshot_metadata, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
conn.execute("DELETE FROM portfolio_positions")
|
||||
conn.execute(
|
||||
"DELETE FROM bt_portfolio_positions WHERE cfg_id = ? AND trade_date = ?",
|
||||
(cfg_id, trade_date),
|
||||
)
|
||||
if holdings_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO portfolio_positions
|
||||
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO bt_portfolio_positions
|
||||
(cfg_id, trade_date, ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
holdings_rows,
|
||||
)
|
||||
|
||||
if trades:
|
||||
conn.execute(
|
||||
"DELETE FROM bt_portfolio_trades WHERE cfg_id = ? AND trade_date = ?",
|
||||
(cfg_id, trade_date),
|
||||
)
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO portfolio_trades
|
||||
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?)
|
||||
INSERT INTO bt_portfolio_trades
|
||||
(cfg_id, trade_date, ts_code, action, quantity, price, fee, source, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
cfg_id,
|
||||
trade["trade_date"],
|
||||
trade["ts_code"],
|
||||
trade["action"],
|
||||
trade["quantity"],
|
||||
trade["price"],
|
||||
trade.get("fee", 0.0),
|
||||
float(trade.get("fee", 0.0)),
|
||||
str(trade.get("source") or "backtest"),
|
||||
json.dumps(trade, ensure_ascii=False),
|
||||
)
|
||||
for trade in trades
|
||||
@ -1196,6 +1263,7 @@ class BacktestEngine:
|
||||
self,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> BacktestResult:
|
||||
self._reset_bt_portfolio_records()
|
||||
session = self.start_session()
|
||||
if session.current_date > self.cfg.end_date:
|
||||
return session.result
|
||||
|
||||
@ -436,6 +436,55 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
total_value REAL,
|
||||
cash REAL,
|
||||
invested_value REAL,
|
||||
unrealized_pnl REAL,
|
||||
realized_pnl REAL,
|
||||
net_flow REAL,
|
||||
exposure REAL,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (cfg_id, trade_date)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
opened_date TEXT,
|
||||
closed_date TEXT,
|
||||
quantity REAL,
|
||||
cost_price REAL,
|
||||
market_price REAL,
|
||||
market_value REAL,
|
||||
realized_pnl REAL,
|
||||
unrealized_pnl REAL,
|
||||
target_weight REAL,
|
||||
status TEXT,
|
||||
notes TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (cfg_id, trade_date, ts_code)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
cfg_id TEXT NOT NULL,
|
||||
trade_date TEXT NOT NULL,
|
||||
ts_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
fee REAL DEFAULT 0,
|
||||
source TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS run_log (
|
||||
ts TEXT PRIMARY KEY,
|
||||
stage TEXT,
|
||||
@ -569,6 +618,9 @@ REQUIRED_TABLES = (
|
||||
"bt_risk_events",
|
||||
"bt_nav",
|
||||
"bt_report",
|
||||
"bt_portfolio_snapshots",
|
||||
"bt_portfolio_positions",
|
||||
"bt_portfolio_trades",
|
||||
"run_log",
|
||||
"agent_utils",
|
||||
"alloc_log",
|
||||
|
||||
@ -130,6 +130,61 @@ SCHEMA_STATEMENTS = [
|
||||
metadata TEXT -- JSON object
|
||||
);
|
||||
""",
|
||||
|
||||
# 回测组合快照表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_snapshots (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
total_value REAL,
|
||||
cash REAL,
|
||||
invested_value REAL,
|
||||
unrealized_pnl REAL,
|
||||
realized_pnl REAL,
|
||||
net_flow REAL,
|
||||
exposure REAL,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (cfg_id, trade_date)
|
||||
);
|
||||
""",
|
||||
|
||||
# 回测持仓表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_positions (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
opened_date TEXT,
|
||||
closed_date TEXT,
|
||||
quantity REAL,
|
||||
cost_price REAL,
|
||||
market_price REAL,
|
||||
market_value REAL,
|
||||
realized_pnl REAL,
|
||||
unrealized_pnl REAL,
|
||||
target_weight REAL,
|
||||
status TEXT,
|
||||
notes TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (cfg_id, trade_date, ts_code)
|
||||
);
|
||||
""",
|
||||
|
||||
# 回测交易表
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_portfolio_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
cfg_id TEXT NOT NULL,
|
||||
trade_date TEXT NOT NULL,
|
||||
ts_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
fee REAL DEFAULT 0,
|
||||
source TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user