1175 lines
45 KiB
Python
1175 lines
45 KiB
Python
"""Backtest engine skeleton for daily bar simulation."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass, field
|
||
from datetime import date
|
||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||
|
||
from app.agents.base import AgentAction, AgentContext
|
||
from app.agents.departments import DepartmentManager
|
||
from app.agents.game import Decision, decide
|
||
from app.llm.metrics import record_decision as metrics_record_decision
|
||
from app.agents.registry import default_agents
|
||
from app.data.schema import initialize_database
|
||
from app.utils.data_access import DataBroker
|
||
from app.utils.config import get_config
|
||
from app.utils.db import db_session
|
||
from app.utils.logging import get_logger
|
||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||
|
||
|
||
LOGGER = get_logger(__name__)
|
||
LOG_EXTRA = {"stage": "backtest"}
|
||
|
||
|
||
def _valuation_score(value: object, scale: float) -> float:
|
||
try:
|
||
numeric = float(value)
|
||
except (TypeError, ValueError):
|
||
return 0.0
|
||
if numeric <= 0:
|
||
return 0.0
|
||
score = scale / (scale + numeric)
|
||
return max(0.0, min(1.0, score))
|
||
|
||
|
||
def _volume_ratio_score(value: object) -> float:
|
||
try:
|
||
numeric = float(value)
|
||
except (TypeError, ValueError):
|
||
return 0.0
|
||
if numeric < 0:
|
||
numeric = 0.0
|
||
return max(0.0, min(1.0, numeric / 10.0))
|
||
|
||
|
||
|
||
@dataclass
|
||
class BtConfig:
|
||
id: str
|
||
name: str
|
||
start_date: date
|
||
end_date: date
|
||
universe: List[str]
|
||
params: Dict[str, float]
|
||
method: str = "nash"
|
||
|
||
|
||
@dataclass
|
||
class PortfolioState:
|
||
cash: float = 1_000_000.0
|
||
holdings: Dict[str, float] = field(default_factory=dict)
|
||
cost_basis: Dict[str, float] = field(default_factory=dict)
|
||
opened_dates: Dict[str, str] = field(default_factory=dict)
|
||
realized_pnl: float = 0.0
|
||
|
||
|
||
@dataclass
|
||
class BacktestResult:
|
||
nav_series: List[Dict[str, float]] = field(default_factory=list)
|
||
trades: List[Dict[str, str]] = field(default_factory=list)
|
||
risk_events: List[Dict[str, object]] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class BacktestSession:
|
||
"""Holds the mutable state for incremental backtest execution."""
|
||
|
||
state: PortfolioState
|
||
result: BacktestResult
|
||
current_date: date
|
||
|
||
|
||
class BacktestEngine:
|
||
"""Runs the multi-agent game inside a daily event-driven loop."""
|
||
|
||
def __init__(self, cfg: BtConfig) -> None:
|
||
self.cfg = cfg
|
||
self.agents = default_agents()
|
||
app_cfg = get_config()
|
||
weight_config = app_cfg.agent_weights.as_dict() if app_cfg.agent_weights else {}
|
||
if weight_config:
|
||
self.weights = weight_config
|
||
else:
|
||
self.weights = {agent.name: 1.0 for agent in self.agents}
|
||
initialize_database()
|
||
self.department_manager = (
|
||
DepartmentManager(app_cfg) if app_cfg.departments else None
|
||
)
|
||
self.data_broker = DataBroker()
|
||
params = cfg.params or {}
|
||
self.risk_params = {
|
||
"max_position_weight": float(params.get("max_position_weight", 0.2)),
|
||
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
|
||
"fee_rate": float(params.get("fee_rate", 0.0005)),
|
||
"slippage_bps": float(params.get("slippage_bps", 10.0)),
|
||
}
|
||
self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
|
||
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
|
||
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
|
||
self._buy_actions = {
|
||
AgentAction.BUY_S,
|
||
AgentAction.BUY_M,
|
||
AgentAction.BUY_L,
|
||
}
|
||
self._sell_actions = {AgentAction.SELL}
|
||
department_scope: set[str] = set()
|
||
for settings in app_cfg.departments.values():
|
||
department_scope.update(settings.data_scope)
|
||
base_scope = {
|
||
"daily.close",
|
||
"daily.open",
|
||
"daily.high",
|
||
"daily.low",
|
||
"daily.pct_chg",
|
||
"daily.vol",
|
||
"daily.amount",
|
||
"daily_basic.turnover_rate",
|
||
"daily_basic.turnover_rate_f",
|
||
"daily_basic.volume_ratio",
|
||
"stk_limit.up_limit",
|
||
"stk_limit.down_limit",
|
||
"factors.mom_20",
|
||
"factors.mom_60",
|
||
"factors.volat_20",
|
||
"factors.turn_20",
|
||
}
|
||
self.required_fields = sorted(base_scope | department_scope)
|
||
|
||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]:
|
||
"""Load per-stock feature vectors and context slices for the trade date."""
|
||
|
||
trade_date_str = trade_date.strftime("%Y%m%d")
|
||
feature_map: Dict[str, Dict[str, Any]] = {}
|
||
universe = self.cfg.universe or []
|
||
for ts_code in universe:
|
||
scope_values = self.data_broker.fetch_latest(
|
||
ts_code,
|
||
trade_date_str,
|
||
self.required_fields,
|
||
auto_refresh=False # 避免回测时触发自动补数
|
||
)
|
||
|
||
closes = self.data_broker.fetch_series(
|
||
"daily",
|
||
"close",
|
||
ts_code,
|
||
trade_date_str,
|
||
window=60,
|
||
auto_refresh=False, # 避免在回测中触发自动补数
|
||
)
|
||
close_values = [value for _date, value in closes if value is not None]
|
||
|
||
mom5 = scope_values.get("factors.mom_5")
|
||
if mom5 is None and len(close_values) >= 5:
|
||
mom5 = momentum(close_values, 5)
|
||
mom20 = scope_values.get("factors.mom_20")
|
||
if mom20 is None and len(close_values) >= 20:
|
||
mom20 = momentum(close_values, 20)
|
||
|
||
mom60 = scope_values.get("factors.mom_60")
|
||
if mom60 is None and len(close_values) >= 60:
|
||
mom60 = momentum(close_values, 60)
|
||
|
||
volat20 = scope_values.get("factors.volat_20")
|
||
if volat20 is None and len(close_values) >= 2:
|
||
volat20 = volatility(close_values, 20)
|
||
|
||
turnover_series = self.data_broker.fetch_series(
|
||
"daily_basic",
|
||
"turnover_rate",
|
||
ts_code,
|
||
trade_date_str,
|
||
window=20,
|
||
auto_refresh=False, # 避免在回测中触发自动补数
|
||
)
|
||
turnover_values = [value for _date, value in turnover_series if value is not None]
|
||
|
||
turn20 = scope_values.get("factors.turn_20")
|
||
if turn20 is None and turnover_values:
|
||
turn20 = rolling_mean(turnover_values, 20)
|
||
turn5 = scope_values.get("factors.turn_5")
|
||
if turn5 is None and len(turnover_values) >= 5:
|
||
turn5 = rolling_mean(turnover_values, 5)
|
||
|
||
if mom20 is None:
|
||
mom20 = 0.0
|
||
if mom60 is None:
|
||
mom60 = 0.0
|
||
if volat20 is None:
|
||
volat20 = 0.0
|
||
if turn20 is None:
|
||
turn20 = 0.0
|
||
if mom5 is None:
|
||
mom5 = 0.0
|
||
if turn5 is None:
|
||
turn5 = 0.0
|
||
|
||
liquidity_score = normalize(turn20, factor=20.0)
|
||
cost_penalty = normalize(
|
||
scope_values.get("daily_basic.volume_ratio", 0.0),
|
||
factor=50.0,
|
||
)
|
||
|
||
val_pe = scope_values.get("factors.val_pe_score")
|
||
if val_pe is None:
|
||
val_pe = _valuation_score(scope_values.get("daily_basic.pe"), scale=12.0)
|
||
|
||
val_pb = scope_values.get("factors.val_pb_score")
|
||
if val_pb is None:
|
||
val_pb = _valuation_score(scope_values.get("daily_basic.pb"), scale=2.5)
|
||
|
||
volume_ratio_score = scope_values.get("factors.volume_ratio_score")
|
||
if volume_ratio_score is None:
|
||
volume_ratio_score = _volume_ratio_score(scope_values.get("daily_basic.volume_ratio"))
|
||
|
||
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
|
||
heat_score = scope_values.get("news.heat_score", 0.0)
|
||
scope_values.setdefault("news.sentiment_index", sentiment_index)
|
||
scope_values.setdefault("news.heat_score", heat_score)
|
||
|
||
scope_values.setdefault("factors.mom_5", mom5)
|
||
scope_values.setdefault("factors.mom_20", mom20)
|
||
scope_values.setdefault("factors.mom_60", mom60)
|
||
scope_values.setdefault("factors.volat_20", volat20)
|
||
scope_values.setdefault("factors.turn_20", turn20)
|
||
scope_values.setdefault("factors.turn_5", turn5)
|
||
scope_values.setdefault("factors.val_pe_score", val_pe)
|
||
scope_values.setdefault("factors.val_pb_score", val_pb)
|
||
scope_values.setdefault("factors.volume_ratio_score", volume_ratio_score)
|
||
if scope_values.get("macro.industry_heat") is None:
|
||
scope_values["macro.industry_heat"] = 0.5
|
||
if scope_values.get("macro.relative_strength") is None:
|
||
peer_strength = scope_values.get("index.performance_peers")
|
||
if peer_strength is None:
|
||
peer_strength = 0.5
|
||
scope_values["macro.relative_strength"] = peer_strength
|
||
scope_values.setdefault(
|
||
"index.performance_peers",
|
||
scope_values.get("macro.relative_strength", 0.5),
|
||
)
|
||
|
||
latest_close = scope_values.get("daily.close", 0.0)
|
||
latest_pct = scope_values.get("daily.pct_chg", 0.0)
|
||
latest_turnover = scope_values.get("daily_basic.turnover_rate", 0.0)
|
||
|
||
up_limit = scope_values.get("stk_limit.up_limit")
|
||
limit_up = False
|
||
if up_limit and latest_close:
|
||
limit_up = latest_close >= up_limit * 0.999
|
||
|
||
down_limit = scope_values.get("stk_limit.down_limit")
|
||
limit_down = False
|
||
if down_limit and latest_close:
|
||
limit_down = latest_close <= down_limit * 1.001
|
||
|
||
is_suspended = self.data_broker.fetch_flags(
|
||
"suspend",
|
||
ts_code,
|
||
trade_date,
|
||
"ts_code = ?",
|
||
[ts_code],
|
||
auto_refresh=False, # 避免在回测中触发自动补数
|
||
)
|
||
|
||
features = {
|
||
"mom_5": mom5,
|
||
"mom_20": mom20,
|
||
"mom_60": mom60,
|
||
"volat_20": volat20,
|
||
"turn_20": turn20,
|
||
"turn_5": turn5,
|
||
"liquidity_score": liquidity_score,
|
||
"cost_penalty": cost_penalty,
|
||
"news_heat": heat_score,
|
||
"news_sentiment": sentiment_index,
|
||
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
|
||
"industry_relative_mom": scope_values.get(
|
||
"macro.relative_strength",
|
||
scope_values.get("index.performance_peers", 0.0),
|
||
),
|
||
"risk_penalty": min(1.0, volat20 * 5.0),
|
||
"valuation_pe_score": val_pe,
|
||
"valuation_pb_score": val_pb,
|
||
"volume_ratio_score": volume_ratio_score,
|
||
"is_suspended": is_suspended,
|
||
"limit_up": limit_up,
|
||
"limit_down": limit_down,
|
||
"position_limit": False,
|
||
}
|
||
|
||
market_snapshot = {
|
||
"close": latest_close,
|
||
"pct_chg": latest_pct,
|
||
"turnover_rate": latest_turnover,
|
||
"volume": scope_values.get("daily.vol", 0.0),
|
||
"amount": scope_values.get("daily.amount", 0.0),
|
||
"up_limit": up_limit,
|
||
"down_limit": down_limit,
|
||
}
|
||
|
||
raw_payload = {
|
||
"scope_values": scope_values,
|
||
"close_series": closes,
|
||
"turnover_series": turnover_series,
|
||
"required_fields": self.required_fields,
|
||
}
|
||
|
||
feature_map[ts_code] = {
|
||
"features": features,
|
||
"market_snapshot": market_snapshot,
|
||
"raw": raw_payload,
|
||
}
|
||
|
||
return feature_map
|
||
|
||
def simulate_day(
|
||
self,
|
||
trade_date: date,
|
||
state: PortfolioState,
|
||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||
) -> List[tuple[str, AgentContext, Decision]]:
|
||
feature_map = self.load_market_data(trade_date)
|
||
records: List[tuple[str, AgentContext, Decision]] = []
|
||
for ts_code, payload in feature_map.items():
|
||
features = payload.get("features", {})
|
||
market_snapshot = payload.get("market_snapshot", {})
|
||
raw = payload.get("raw", {})
|
||
context = AgentContext(
|
||
ts_code=ts_code,
|
||
trade_date=trade_date.isoformat(),
|
||
features=features,
|
||
market_snapshot=market_snapshot,
|
||
raw=raw,
|
||
)
|
||
decision = decide(
|
||
context,
|
||
self.agents,
|
||
self.weights,
|
||
method=self.cfg.method,
|
||
department_manager=self.department_manager,
|
||
)
|
||
try:
|
||
metrics_record_decision(
|
||
ts_code=ts_code,
|
||
trade_date=context.trade_date,
|
||
action=decision.action.value,
|
||
confidence=decision.confidence,
|
||
summary=_extract_summary(decision),
|
||
source="backtest",
|
||
departments={
|
||
code: dept.to_dict()
|
||
for code, dept in decision.department_decisions.items()
|
||
},
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
|
||
records.append((ts_code, context, decision))
|
||
self.record_agent_state(context, decision)
|
||
if decision_callback:
|
||
try:
|
||
decision_callback(ts_code, trade_date, context, decision)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
|
||
return records
|
||
|
||
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
|
||
payload = {
|
||
"trade_date": context.trade_date,
|
||
"ts_code": context.ts_code,
|
||
"action": decision.action.value,
|
||
"confidence": decision.confidence,
|
||
"department_votes": decision.department_votes,
|
||
"requires_review": decision.requires_review,
|
||
"departments": {
|
||
code: dept.to_dict()
|
||
for code, dept in decision.department_decisions.items()
|
||
},
|
||
}
|
||
combined_weights = dict(self.weights)
|
||
if self.department_manager:
|
||
for code, agent in self.department_manager.agents.items():
|
||
key = f"dept_{code}"
|
||
combined_weights[key] = agent.settings.weight
|
||
|
||
feasible_json = json.dumps(
|
||
[action.value for action in decision.feasible_actions],
|
||
ensure_ascii=False,
|
||
)
|
||
rows = []
|
||
for agent_name, weight in combined_weights.items():
|
||
action_scores = {
|
||
action.value: float(decision.utilities.get(action, {}).get(agent_name, 0.0))
|
||
for action in decision.utilities.keys()
|
||
}
|
||
best_action = decision.action.value
|
||
if action_scores:
|
||
best_action = max(action_scores.items(), key=lambda item: item[1])[0]
|
||
metadata: Dict[str, object] = {}
|
||
if agent_name.startswith("dept_"):
|
||
dept_code = agent_name.split("dept_", 1)[-1]
|
||
dept_decision = decision.department_decisions.get(dept_code)
|
||
if dept_decision:
|
||
metadata = {
|
||
"_summary": dept_decision.summary,
|
||
"_signals": dept_decision.signals,
|
||
"_risks": dept_decision.risks,
|
||
"_confidence": dept_decision.confidence,
|
||
}
|
||
if dept_decision.supplements:
|
||
metadata["_supplements"] = dept_decision.supplements
|
||
if dept_decision.dialogue:
|
||
metadata["_dialogue"] = dept_decision.dialogue
|
||
if dept_decision.telemetry:
|
||
metadata["_telemetry"] = dept_decision.telemetry
|
||
payload_json = {**action_scores, **metadata}
|
||
rows.append(
|
||
(
|
||
context.trade_date,
|
||
context.ts_code,
|
||
agent_name,
|
||
best_action,
|
||
json.dumps(payload_json, ensure_ascii=False),
|
||
feasible_json,
|
||
float(weight),
|
||
)
|
||
)
|
||
|
||
global_payload = {
|
||
"_confidence": decision.confidence,
|
||
"_target_weight": decision.target_weight,
|
||
"_department_votes": decision.department_votes,
|
||
"_requires_review": decision.requires_review,
|
||
"_scope_values": context.raw.get("scope_values", {}),
|
||
"_close_series": context.raw.get("close_series", []),
|
||
"_turnover_series": context.raw.get("turnover_series", []),
|
||
"_department_supplements": {
|
||
code: dept.supplements
|
||
for code, dept in decision.department_decisions.items()
|
||
if dept.supplements
|
||
},
|
||
"_department_dialogue": {
|
||
code: dept.dialogue
|
||
for code, dept in decision.department_decisions.items()
|
||
if dept.dialogue
|
||
},
|
||
"_department_telemetry": {
|
||
code: dept.telemetry
|
||
for code, dept in decision.department_decisions.items()
|
||
if dept.telemetry
|
||
},
|
||
}
|
||
rows.append(
|
||
(
|
||
context.trade_date,
|
||
context.ts_code,
|
||
"global",
|
||
decision.action.value,
|
||
json.dumps(global_payload, ensure_ascii=False),
|
||
feasible_json,
|
||
1.0,
|
||
)
|
||
)
|
||
|
||
try:
|
||
with db_session() as conn:
|
||
conn.executemany(
|
||
"""
|
||
INSERT OR REPLACE INTO agent_utils
|
||
(trade_date, ts_code, agent, action, utils, feasible, weight)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
rows,
|
||
)
|
||
except Exception:
|
||
LOGGER.exception("写入 agent_utils 失败", extra=LOG_EXTRA)
|
||
_ = payload
|
||
# TODO: persist payload into bt_trades / audit tables when schema is ready.
|
||
|
||
try:
|
||
self._record_investment_candidate(context, decision)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("写入 investment_pool 失败", extra=LOG_EXTRA)
|
||
|
||
def _apply_portfolio_updates(
|
||
self,
|
||
trade_date: date,
|
||
state: PortfolioState,
|
||
records: List[tuple[str, AgentContext, Decision]],
|
||
result: BacktestResult,
|
||
) -> None:
|
||
trade_date_str = trade_date.isoformat()
|
||
price_map: Dict[str, float] = {}
|
||
decisions_map: Dict[str, Decision] = {}
|
||
feature_cache: Dict[str, Mapping[str, Any]] = {}
|
||
for ts_code, context, decision in records:
|
||
features = context.features or {}
|
||
if not isinstance(features, Mapping):
|
||
features = {}
|
||
feature_cache[ts_code] = features
|
||
scope_values = context.raw.get("scope_values") if context.raw else {}
|
||
if not isinstance(scope_values, Mapping):
|
||
scope_values = {}
|
||
price = scope_values.get("daily.close") or scope_values.get("close")
|
||
if price is None:
|
||
continue
|
||
try:
|
||
price = float(price)
|
||
except (TypeError, ValueError):
|
||
continue
|
||
price_map[ts_code] = price
|
||
decisions_map[ts_code] = decision
|
||
|
||
if not price_map and state.holdings:
|
||
trade_date_compact = trade_date.strftime("%Y%m%d")
|
||
for ts_code in list(state.holdings.keys()):
|
||
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"], auto_refresh=False)
|
||
price = fetched.get("daily.close")
|
||
if price:
|
||
price_map[ts_code] = float(price)
|
||
|
||
portfolio_value_before = state.cash
|
||
for ts_code, qty in state.holdings.items():
|
||
price = price_map.get(ts_code)
|
||
if price is None:
|
||
continue
|
||
portfolio_value_before += qty * price
|
||
|
||
if portfolio_value_before <= 0:
|
||
portfolio_value_before = state.cash or 1.0
|
||
|
||
daily_turnover = 0.0
|
||
executed_trades: List[Dict[str, Any]] = []
|
||
risk_events: List[Dict[str, Any]] = []
|
||
|
||
def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None:
|
||
payload = {
|
||
"trade_date": trade_date_str,
|
||
"ts_code": ts_code,
|
||
"action": decision.action.value,
|
||
"target_weight": decision.target_weight,
|
||
"confidence": decision.confidence,
|
||
"reason": reason,
|
||
}
|
||
if extra:
|
||
payload.update(extra)
|
||
risk_events.append(payload)
|
||
|
||
for ts_code, decision in decisions_map.items():
|
||
price = price_map.get(ts_code)
|
||
if price is None or price <= 0:
|
||
continue
|
||
features = feature_cache.get(ts_code, {})
|
||
current_qty = state.holdings.get(ts_code, 0.0)
|
||
liquidity_score = float(features.get("liquidity_score") or 0.0)
|
||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||
is_suspended = bool(features.get("is_suspended"))
|
||
limit_up = bool(features.get("limit_up"))
|
||
limit_down = bool(features.get("limit_down"))
|
||
position_limit = bool(features.get("position_limit"))
|
||
|
||
if is_suspended:
|
||
_record_risk(ts_code, "suspended", decision)
|
||
continue
|
||
if decision.action in self._buy_actions:
|
||
if limit_up:
|
||
_record_risk(ts_code, "limit_up", decision)
|
||
continue
|
||
if position_limit:
|
||
_record_risk(ts_code, "position_limit", decision)
|
||
continue
|
||
if risk_penalty >= 0.95:
|
||
_record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty})
|
||
continue
|
||
if decision.action in self._sell_actions and limit_down:
|
||
_record_risk(ts_code, "limit_down", decision)
|
||
continue
|
||
|
||
effective_weight = max(decision.target_weight, 0.0)
|
||
if decision.action in self._buy_actions:
|
||
capped_weight = min(effective_weight, self.risk_params["max_position_weight"])
|
||
effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty)
|
||
elif decision.action in self._sell_actions:
|
||
effective_weight = 0.0
|
||
|
||
desired_qty = current_qty
|
||
if decision.action in self._sell_actions:
|
||
desired_qty = 0.0
|
||
elif decision.action in self._buy_actions or effective_weight >= 0.0:
|
||
desired_value = max(effective_weight, 0.0) * portfolio_value_before
|
||
desired_qty = desired_value / price if price > 0 else current_qty
|
||
|
||
delta = desired_qty - current_qty
|
||
if abs(delta) < 1e-6:
|
||
continue
|
||
|
||
if delta > 0 and self._turnover_cap > 0:
|
||
liquidity_scalar = max(liquidity_score, 0.1)
|
||
max_trade_value = self._turnover_cap * portfolio_value_before * liquidity_scalar
|
||
if max_trade_value > 0 and delta * price > max_trade_value:
|
||
delta = max_trade_value / price
|
||
desired_qty = current_qty + delta
|
||
|
||
if delta > 0:
|
||
trade_price = price * (1.0 + self._slippage_rate)
|
||
per_share_cost = trade_price * (1.0 + self._fee_rate)
|
||
if per_share_cost <= 0:
|
||
_record_risk(ts_code, "invalid_price", decision)
|
||
continue
|
||
max_affordable = state.cash / per_share_cost if per_share_cost > 0 else 0.0
|
||
if delta > max_affordable:
|
||
if max_affordable <= 1e-6:
|
||
_record_risk(ts_code, "insufficient_cash", decision)
|
||
continue
|
||
delta = max_affordable
|
||
desired_qty = current_qty + delta
|
||
|
||
trade_value = delta * trade_price
|
||
fee = trade_value * self._fee_rate
|
||
total_cash_needed = trade_value + fee
|
||
if total_cash_needed <= 0:
|
||
_record_risk(ts_code, "invalid_trade", decision)
|
||
continue
|
||
|
||
previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty
|
||
new_qty = current_qty + delta
|
||
state.cost_basis[ts_code] = (
|
||
(previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
|
||
)
|
||
state.cash -= total_cash_needed
|
||
state.holdings[ts_code] = new_qty
|
||
state.opened_dates.setdefault(ts_code, trade_date_str)
|
||
daily_turnover += trade_value
|
||
executed_trades.append(
|
||
{
|
||
"trade_date": trade_date_str,
|
||
"ts_code": ts_code,
|
||
"action": "buy",
|
||
"quantity": float(delta),
|
||
"price": trade_price,
|
||
"base_price": price,
|
||
"value": trade_value,
|
||
"fee": fee,
|
||
"slippage": trade_price - price,
|
||
"confidence": decision.confidence,
|
||
"target_weight": decision.target_weight,
|
||
"effective_weight": effective_weight,
|
||
"risk_penalty": risk_penalty,
|
||
"liquidity_score": liquidity_score,
|
||
"status": "executed",
|
||
}
|
||
)
|
||
else:
|
||
sell_qty = min(abs(delta), current_qty)
|
||
if sell_qty <= 1e-6:
|
||
continue
|
||
trade_price = price * (1.0 - self._slippage_rate)
|
||
trade_price = max(trade_price, 0.0)
|
||
gross_value = sell_qty * trade_price
|
||
fee = gross_value * self._fee_rate
|
||
proceeds = gross_value - fee
|
||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||
realized = (trade_price - cost_basis) * sell_qty - fee
|
||
state.cash += proceeds
|
||
state.realized_pnl += realized
|
||
new_qty = current_qty - sell_qty
|
||
if new_qty <= 1e-6:
|
||
state.holdings.pop(ts_code, None)
|
||
state.cost_basis.pop(ts_code, None)
|
||
state.opened_dates.pop(ts_code, None)
|
||
else:
|
||
state.holdings[ts_code] = new_qty
|
||
daily_turnover += gross_value
|
||
executed_trades.append(
|
||
{
|
||
"trade_date": trade_date_str,
|
||
"ts_code": ts_code,
|
||
"action": "sell",
|
||
"quantity": float(sell_qty),
|
||
"price": trade_price,
|
||
"base_price": price,
|
||
"value": gross_value,
|
||
"fee": fee,
|
||
"slippage": price - trade_price,
|
||
"confidence": decision.confidence,
|
||
"target_weight": decision.target_weight,
|
||
"effective_weight": effective_weight,
|
||
"risk_penalty": risk_penalty,
|
||
"liquidity_score": liquidity_score,
|
||
"realized_pnl": realized,
|
||
"status": "executed",
|
||
}
|
||
)
|
||
|
||
market_value = 0.0
|
||
unrealized_pnl = 0.0
|
||
for ts_code, qty in state.holdings.items():
|
||
price = price_map.get(ts_code)
|
||
if price is None:
|
||
continue
|
||
market_value += qty * price
|
||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||
unrealized_pnl += (price - cost_basis) * qty
|
||
|
||
nav = state.cash + market_value
|
||
turnover_ratio = daily_turnover / nav if nav else 0.0
|
||
result.nav_series.append(
|
||
{
|
||
"trade_date": trade_date_str,
|
||
"nav": nav,
|
||
"cash": state.cash,
|
||
"market_value": market_value,
|
||
"realized_pnl": state.realized_pnl,
|
||
"unrealized_pnl": unrealized_pnl,
|
||
"turnover": daily_turnover,
|
||
"turnover_ratio": turnover_ratio,
|
||
}
|
||
)
|
||
if executed_trades:
|
||
result.trades.extend(executed_trades)
|
||
if risk_events:
|
||
result.risk_events.extend(risk_events)
|
||
|
||
try:
|
||
self._persist_portfolio(
|
||
trade_date_str,
|
||
state,
|
||
market_value,
|
||
unrealized_pnl,
|
||
executed_trades,
|
||
price_map,
|
||
decisions_map,
|
||
daily_turnover,
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("持仓数据写入失败", extra=LOG_EXTRA)
|
||
|
||
def _record_investment_candidate(
|
||
self, context: AgentContext, decision: Decision
|
||
) -> None:
|
||
status = _candidate_status(decision.action, decision.requires_review)
|
||
summary = _extract_summary(decision)
|
||
if not summary:
|
||
collected_signals: List[str] = []
|
||
for dept in decision.department_decisions.values():
|
||
collected_signals.extend(dept.signals)
|
||
summary = ";".join(str(sig) for sig in collected_signals[:3])
|
||
|
||
metadata = {
|
||
"target_weight": decision.target_weight,
|
||
"feasible_actions": [action.value for action in decision.feasible_actions],
|
||
"department_votes": decision.department_votes,
|
||
"requires_review": decision.requires_review,
|
||
"confidence": decision.confidence,
|
||
}
|
||
if decision.department_decisions:
|
||
metadata["departments"] = {
|
||
code: dept.to_dict()
|
||
for code, dept in decision.department_decisions.items()
|
||
}
|
||
|
||
with db_session() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO investment_pool
|
||
(trade_date, ts_code, score, status, rationale, tags, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
context.trade_date,
|
||
context.ts_code,
|
||
float(decision.confidence or 0.0),
|
||
status,
|
||
summary or None,
|
||
json.dumps(_department_tags(decision), ensure_ascii=False),
|
||
json.dumps(metadata, ensure_ascii=False),
|
||
),
|
||
)
|
||
|
||
def _persist_portfolio(
|
||
self,
|
||
trade_date: str,
|
||
state: PortfolioState,
|
||
market_value: float,
|
||
unrealized_pnl: float,
|
||
trades: List[Dict[str, Any]],
|
||
price_map: Dict[str, float],
|
||
decisions_map: Dict[str, Decision],
|
||
daily_turnover: float,
|
||
) -> None:
|
||
holdings_rows: List[tuple] = []
|
||
for ts_code, qty in state.holdings.items():
|
||
price = price_map.get(ts_code)
|
||
market_val = qty * price if price is not None else None
|
||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||
unrealized = (price - cost_basis) * qty if price is not None else None
|
||
decision = decisions_map.get(ts_code)
|
||
target_weight = decision.target_weight if decision else None
|
||
metadata = {
|
||
"last_action": decision.action.value if decision else None,
|
||
"confidence": decision.confidence if decision else None,
|
||
}
|
||
holdings_rows.append(
|
||
(
|
||
ts_code,
|
||
state.opened_dates.get(ts_code, trade_date),
|
||
None,
|
||
qty,
|
||
cost_basis,
|
||
price,
|
||
market_val,
|
||
state.realized_pnl,
|
||
unrealized,
|
||
target_weight,
|
||
"open",
|
||
None,
|
||
json.dumps(metadata, ensure_ascii=False),
|
||
)
|
||
)
|
||
|
||
total_value = market_value + state.cash
|
||
turnover_ratio = daily_turnover / total_value if total_value else 0.0
|
||
snapshot_metadata = {
|
||
"holdings": len(state.holdings),
|
||
"turnover_value": daily_turnover,
|
||
"turnover_ratio": turnover_ratio,
|
||
"trade_count": len(trades),
|
||
}
|
||
|
||
exposure = (market_value / total_value) if total_value else 0.0
|
||
net_flow = 0.0
|
||
for trade in trades:
|
||
value = float(trade.get("value", 0.0) or 0.0)
|
||
fee = float(trade.get("fee", 0.0) or 0.0)
|
||
action = str(trade.get("action", "")).lower()
|
||
if action.startswith("buy"):
|
||
net_flow -= value + fee
|
||
elif action.startswith("sell"):
|
||
net_flow += value - fee
|
||
|
||
with db_session() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO portfolio_snapshots
|
||
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
trade_date,
|
||
market_value + state.cash,
|
||
state.cash,
|
||
market_value,
|
||
unrealized_pnl,
|
||
state.realized_pnl,
|
||
net_flow,
|
||
exposure,
|
||
None,
|
||
json.dumps(snapshot_metadata, ensure_ascii=False),
|
||
),
|
||
)
|
||
|
||
conn.execute("DELETE FROM portfolio_positions")
|
||
if holdings_rows:
|
||
conn.executemany(
|
||
"""
|
||
INSERT INTO portfolio_positions
|
||
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
holdings_rows,
|
||
)
|
||
|
||
if trades:
|
||
conn.executemany(
|
||
"""
|
||
INSERT INTO portfolio_trades
|
||
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?)
|
||
""",
|
||
[
|
||
(
|
||
trade["trade_date"],
|
||
trade["ts_code"],
|
||
trade["action"],
|
||
trade["quantity"],
|
||
trade["price"],
|
||
trade.get("fee", 0.0),
|
||
json.dumps(trade, ensure_ascii=False),
|
||
)
|
||
for trade in trades
|
||
],
|
||
)
|
||
|
||
def start_session(self) -> BacktestSession:
|
||
"""Initialise a new incremental backtest session."""
|
||
|
||
return BacktestSession(
|
||
state=PortfolioState(),
|
||
result=BacktestResult(),
|
||
current_date=self.cfg.start_date,
|
||
)
|
||
|
||
def step_session(
|
||
self,
|
||
session: BacktestSession,
|
||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||
) -> Tuple[Iterable[Dict[str, Any]], bool]:
|
||
"""Advance the session by a single trade date.
|
||
|
||
Returns ``(records, done)`` where ``records`` is the raw output of
|
||
:meth:`simulate_day` and ``done`` indicates whether the session
|
||
reached the end date after this step.
|
||
"""
|
||
|
||
if session.current_date > self.cfg.end_date:
|
||
return [], True
|
||
|
||
trade_date = session.current_date
|
||
records = self.simulate_day(trade_date, session.state, decision_callback)
|
||
self._apply_portfolio_updates(trade_date, session.state, records, session.result)
|
||
session.current_date = date.fromordinal(trade_date.toordinal() + 1)
|
||
done = session.current_date > self.cfg.end_date
|
||
return records, done
|
||
|
||
def run(
|
||
self,
|
||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||
) -> BacktestResult:
|
||
session = self.start_session()
|
||
if session.current_date > self.cfg.end_date:
|
||
return session.result
|
||
|
||
while session.current_date <= self.cfg.end_date:
|
||
self.step_session(session, decision_callback)
|
||
return session.result
|
||
|
||
|
||
def run_backtest(
|
||
cfg: BtConfig,
|
||
*,
|
||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||
) -> BacktestResult:
|
||
engine = BacktestEngine(cfg)
|
||
result = engine.run(decision_callback=decision_callback)
|
||
_persist_backtest_results(cfg, result)
|
||
return result
|
||
|
||
|
||
def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||
"""Persist backtest configuration, NAV path, trades and summary metrics."""
|
||
|
||
nav_rows: List[tuple] = []
|
||
trade_rows: List[tuple] = []
|
||
risk_rows: List[tuple] = []
|
||
summary_payload: Dict[str, object] = {}
|
||
turnover_sum = 0.0
|
||
|
||
if result.nav_series:
|
||
first_nav = float(result.nav_series[0].get("nav", 0.0) or 0.0)
|
||
peak_nav = first_nav
|
||
prev_nav: Optional[float] = None
|
||
max_drawdown = 0.0
|
||
for entry in result.nav_series:
|
||
trade_date = str(entry.get("trade_date", ""))
|
||
nav_val = float(entry.get("nav", 0.0) or 0.0)
|
||
cash = float(entry.get("cash", 0.0) or 0.0)
|
||
market_value = float(entry.get("market_value", 0.0) or 0.0)
|
||
realized = float(entry.get("realized_pnl", 0.0) or 0.0)
|
||
unrealized = float(entry.get("unrealized_pnl", 0.0) or 0.0)
|
||
turnover = float(entry.get("turnover", 0.0) or 0.0)
|
||
|
||
if nav_val > peak_nav:
|
||
peak_nav = nav_val
|
||
drawdown = (peak_nav - nav_val) / peak_nav if peak_nav else 0.0
|
||
max_drawdown = max(max_drawdown, drawdown)
|
||
|
||
if prev_nav is None or prev_nav == 0.0:
|
||
ret_val = 0.0
|
||
else:
|
||
ret_val = (nav_val / prev_nav) - 1.0
|
||
prev_nav = nav_val
|
||
|
||
info_payload = {
|
||
"cash": cash,
|
||
"market_value": market_value,
|
||
"realized_pnl": realized,
|
||
"unrealized_pnl": unrealized,
|
||
"turnover": turnover,
|
||
}
|
||
turnover_sum += turnover
|
||
nav_rows.append(
|
||
(
|
||
cfg.id,
|
||
trade_date,
|
||
nav_val,
|
||
float(ret_val),
|
||
None,
|
||
None,
|
||
float(drawdown),
|
||
json.dumps(info_payload, ensure_ascii=False),
|
||
)
|
||
)
|
||
|
||
last_nav = float(result.nav_series[-1].get("nav", 0.0) or 0.0)
|
||
total_return = (last_nav / first_nav - 1.0) if first_nav else 0.0
|
||
summary_payload.update(
|
||
{
|
||
"start_nav": first_nav,
|
||
"end_nav": last_nav,
|
||
"total_return": total_return,
|
||
"max_drawdown": max_drawdown,
|
||
"days": len(result.nav_series),
|
||
}
|
||
)
|
||
if turnover_sum:
|
||
summary_payload["total_turnover"] = turnover_sum
|
||
summary_payload["avg_turnover"] = turnover_sum / max(len(result.nav_series), 1)
|
||
|
||
if result.trades:
|
||
for trade in result.trades:
|
||
trade_date = str(trade.get("trade_date", ""))
|
||
ts_code = str(trade.get("ts_code", ""))
|
||
side = str(trade.get("action", "")).lower()
|
||
price = float(trade.get("price", 0.0) or 0.0)
|
||
qty = float(trade.get("quantity", 0.0) or 0.0)
|
||
reason_payload = {
|
||
"confidence": trade.get("confidence"),
|
||
"target_weight": trade.get("target_weight"),
|
||
"value": trade.get("value"),
|
||
"fee": trade.get("fee"),
|
||
"slippage": trade.get("slippage"),
|
||
"risk_penalty": trade.get("risk_penalty"),
|
||
"liquidity_score": trade.get("liquidity_score"),
|
||
}
|
||
trade_rows.append(
|
||
(
|
||
cfg.id,
|
||
ts_code,
|
||
trade_date,
|
||
side,
|
||
price,
|
||
qty,
|
||
json.dumps(reason_payload, ensure_ascii=False),
|
||
)
|
||
)
|
||
summary_payload["trade_count"] = len(trade_rows)
|
||
|
||
if result.risk_events:
|
||
summary_payload["risk_events"] = len(result.risk_events)
|
||
breakdown: Dict[str, int] = {}
|
||
for event in result.risk_events:
|
||
reason = str(event.get("reason") or "unknown")
|
||
breakdown[reason] = breakdown.get(reason, 0) + 1
|
||
risk_rows.append(
|
||
(
|
||
cfg.id,
|
||
str(event.get("trade_date", "")),
|
||
str(event.get("ts_code", "")),
|
||
reason,
|
||
str(event.get("action", "")),
|
||
float(event.get("target_weight", 0.0) or 0.0),
|
||
float(event.get("confidence", 0.0) or 0.0),
|
||
json.dumps(event, ensure_ascii=False),
|
||
)
|
||
)
|
||
summary_payload["risk_breakdown"] = breakdown
|
||
|
||
cfg_payload = {
|
||
"id": cfg.id,
|
||
"name": cfg.name,
|
||
"start_date": cfg.start_date.isoformat(),
|
||
"end_date": cfg.end_date.isoformat(),
|
||
"universe": cfg.universe,
|
||
"params": cfg.params,
|
||
"method": cfg.method,
|
||
}
|
||
|
||
with db_session() as conn:
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO bt_config (id, name, start_date, end_date, universe, params)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
cfg.id,
|
||
cfg.name,
|
||
cfg.start_date.isoformat(),
|
||
cfg.end_date.isoformat(),
|
||
",".join(cfg.universe),
|
||
json.dumps(cfg.params, ensure_ascii=False),
|
||
),
|
||
)
|
||
|
||
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
|
||
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
|
||
conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,))
|
||
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
|
||
|
||
if nav_rows:
|
||
conn.executemany(
|
||
"""
|
||
INSERT INTO bt_nav (cfg_id, trade_date, nav, ret, pos_count, turnover, dd, info)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
nav_rows,
|
||
)
|
||
|
||
if trade_rows:
|
||
conn.executemany(
|
||
"""
|
||
INSERT INTO bt_trades (cfg_id, ts_code, trade_date, side, price, qty, reason)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
trade_rows,
|
||
)
|
||
|
||
if risk_rows:
|
||
conn.executemany(
|
||
"""
|
||
INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
risk_rows,
|
||
)
|
||
|
||
summary_payload.setdefault("universe", cfg.universe)
|
||
summary_payload.setdefault("method", cfg.method)
|
||
conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO bt_report (cfg_id, summary)
|
||
VALUES (?, ?)
|
||
""",
|
||
(cfg.id, json.dumps(summary_payload, ensure_ascii=False, default=str)),
|
||
)
|
||
|
||
|
||
def _candidate_status(action: AgentAction, requires_review: bool) -> str:
|
||
mapping = {
|
||
AgentAction.SELL: "exit",
|
||
AgentAction.HOLD: "watch",
|
||
AgentAction.BUY_S: "buy_s",
|
||
AgentAction.BUY_M: "buy_m",
|
||
AgentAction.BUY_L: "buy_l",
|
||
}
|
||
base = mapping.get(action, "candidate")
|
||
if requires_review:
|
||
return f"{base}_review"
|
||
return base
|
||
def _extract_summary(decision: Decision) -> str:
|
||
for dept_decision in decision.department_decisions.values():
|
||
summary = getattr(dept_decision, "summary", "")
|
||
if summary:
|
||
return str(summary)
|
||
return ""
|
||
|
||
|
||
def _department_tags(decision: Decision) -> List[str]:
|
||
tags: List[str] = []
|
||
for code, dept in decision.department_decisions.items():
|
||
action = getattr(dept, "action", None)
|
||
if action is None:
|
||
continue
|
||
tags.append(f"{code}:{action.value}")
|
||
return sorted(set(tags))
|