This commit is contained in:
sam 2025-09-30 18:07:47 +08:00
parent 30007cc056
commit 8f820e441e
3 changed files with 273 additions and 44 deletions

View File

@ -47,6 +47,7 @@ class PortfolioState:
class BacktestResult: class BacktestResult:
nav_series: List[Dict[str, float]] = field(default_factory=list) nav_series: List[Dict[str, float]] = field(default_factory=list)
trades: List[Dict[str, str]] = field(default_factory=list) trades: List[Dict[str, str]] = field(default_factory=list)
risk_events: List[Dict[str, object]] = field(default_factory=list)
class BacktestEngine: class BacktestEngine:
@ -65,6 +66,22 @@ class BacktestEngine:
DepartmentManager(app_cfg) if app_cfg.departments else None DepartmentManager(app_cfg) if app_cfg.departments else None
) )
self.data_broker = DataBroker() self.data_broker = DataBroker()
params = cfg.params or {}
self.risk_params = {
"max_position_weight": float(params.get("max_position_weight", 0.2)),
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
"fee_rate": float(params.get("fee_rate", 0.0005)),
"slippage_bps": float(params.get("slippage_bps", 10.0)),
}
self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
self._buy_actions = {
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
}
self._sell_actions = {AgentAction.SELL}
department_scope: set[str] = set() department_scope: set[str] = set()
for settings in app_cfg.departments.values(): for settings in app_cfg.departments.values():
department_scope.update(settings.data_scope) department_scope.update(settings.data_scope)
@ -389,7 +406,12 @@ class BacktestEngine:
trade_date_str = trade_date.isoformat() trade_date_str = trade_date.isoformat()
price_map: Dict[str, float] = {} price_map: Dict[str, float] = {}
decisions_map: Dict[str, Decision] = {} decisions_map: Dict[str, Decision] = {}
feature_cache: Dict[str, Mapping[str, Any]] = {}
for ts_code, context, decision in records: for ts_code, context, decision in records:
features = context.features or {}
if not isinstance(features, Mapping):
features = {}
feature_cache[ts_code] = features
scope_values = context.raw.get("scope_values") if context.raw else {} scope_values = context.raw.get("scope_values") if context.raw else {}
if not isinstance(scope_values, Mapping): if not isinstance(scope_values, Mapping):
scope_values = {} scope_values = {}
@ -405,7 +427,7 @@ class BacktestEngine:
if not price_map and state.holdings: if not price_map and state.holdings:
trade_date_compact = trade_date.strftime("%Y%m%d") trade_date_compact = trade_date.strftime("%Y%m%d")
for ts_code in state.holdings.keys(): for ts_code in list(state.holdings.keys()):
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"]) fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"])
price = fetched.get("daily.close") price = fetched.get("daily.close")
if price: if price:
@ -421,84 +443,166 @@ class BacktestEngine:
if portfolio_value_before <= 0: if portfolio_value_before <= 0:
portfolio_value_before = state.cash or 1.0 portfolio_value_before = state.cash or 1.0
trades_records: List[Dict[str, Any]] = [] daily_turnover = 0.0
executed_trades: List[Dict[str, Any]] = []
risk_events: List[Dict[str, Any]] = []
def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None:
payload = {
"trade_date": trade_date_str,
"ts_code": ts_code,
"action": decision.action.value,
"target_weight": decision.target_weight,
"confidence": decision.confidence,
"reason": reason,
}
if extra:
payload.update(extra)
risk_events.append(payload)
for ts_code, decision in decisions_map.items(): for ts_code, decision in decisions_map.items():
price = price_map.get(ts_code) price = price_map.get(ts_code)
if price is None or price <= 0: if price is None or price <= 0:
continue continue
features = feature_cache.get(ts_code, {})
current_qty = state.holdings.get(ts_code, 0.0) current_qty = state.holdings.get(ts_code, 0.0)
liquidity_score = float(features.get("liquidity_score") or 0.0)
risk_penalty = float(features.get("risk_penalty") or 0.0)
is_suspended = bool(features.get("is_suspended"))
limit_up = bool(features.get("limit_up"))
limit_down = bool(features.get("limit_down"))
position_limit = bool(features.get("position_limit"))
if is_suspended:
_record_risk(ts_code, "suspended", decision)
continue
if decision.action in self._buy_actions:
if limit_up:
_record_risk(ts_code, "limit_up", decision)
continue
if position_limit:
_record_risk(ts_code, "position_limit", decision)
continue
if risk_penalty >= 0.95:
_record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty})
continue
if decision.action in self._sell_actions and limit_down:
_record_risk(ts_code, "limit_down", decision)
continue
effective_weight = max(decision.target_weight, 0.0)
if decision.action in self._buy_actions:
capped_weight = min(effective_weight, self.risk_params["max_position_weight"])
effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty)
elif decision.action in self._sell_actions:
effective_weight = 0.0
desired_qty = current_qty desired_qty = current_qty
if decision.action is AgentAction.SELL: if decision.action in self._sell_actions:
desired_qty = 0.0 desired_qty = 0.0
elif decision.action is AgentAction.HOLD: elif decision.action in self._buy_actions or effective_weight >= 0.0:
desired_qty = current_qty desired_value = max(effective_weight, 0.0) * portfolio_value_before
else: desired_qty = desired_value / price if price > 0 else current_qty
target_weight = max(decision.target_weight, 0.0)
desired_value = target_weight * portfolio_value_before
if desired_value > 0:
desired_qty = desired_value / price
else:
desired_qty = current_qty
delta = desired_qty - current_qty delta = desired_qty - current_qty
if abs(delta) < 1e-6: if abs(delta) < 1e-6:
continue continue
if delta > 0: if delta > 0 and self._turnover_cap > 0:
cost = delta * price liquidity_scalar = max(liquidity_score, 0.1)
if cost > state.cash: max_trade_value = self._turnover_cap * portfolio_value_before * liquidity_scalar
affordable_qty = state.cash / price if price > 0 else 0.0 if max_trade_value > 0 and delta * price > max_trade_value:
delta = max(0.0, affordable_qty) delta = max_trade_value / price
cost = delta * price
desired_qty = current_qty + delta desired_qty = current_qty + delta
if delta <= 0:
if delta > 0:
trade_price = price * (1.0 + self._slippage_rate)
per_share_cost = trade_price * (1.0 + self._fee_rate)
if per_share_cost <= 0:
_record_risk(ts_code, "invalid_price", decision)
continue continue
total_cost = state.cost_basis.get(ts_code, 0.0) * current_qty + cost max_affordable = state.cash / per_share_cost if per_share_cost > 0 else 0.0
if delta > max_affordable:
if max_affordable <= 1e-6:
_record_risk(ts_code, "insufficient_cash", decision)
continue
delta = max_affordable
desired_qty = current_qty + delta
trade_value = delta * trade_price
fee = trade_value * self._fee_rate
total_cash_needed = trade_value + fee
if total_cash_needed <= 0:
_record_risk(ts_code, "invalid_trade", decision)
continue
previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty
new_qty = current_qty + delta new_qty = current_qty + delta
state.cost_basis[ts_code] = total_cost / new_qty if new_qty > 0 else 0.0 state.cost_basis[ts_code] = (
state.cash -= cost (previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
)
state.cash -= total_cash_needed
state.holdings[ts_code] = new_qty state.holdings[ts_code] = new_qty
state.opened_dates.setdefault(ts_code, trade_date_str) state.opened_dates.setdefault(ts_code, trade_date_str)
trades_records.append( daily_turnover += trade_value
executed_trades.append(
{ {
"trade_date": trade_date_str, "trade_date": trade_date_str,
"ts_code": ts_code, "ts_code": ts_code,
"action": "buy", "action": "buy",
"quantity": float(delta), "quantity": float(delta),
"price": price, "price": trade_price,
"value": cost, "base_price": price,
"value": trade_value,
"fee": fee,
"slippage": trade_price - price,
"confidence": decision.confidence, "confidence": decision.confidence,
"target_weight": decision.target_weight, "target_weight": decision.target_weight,
"effective_weight": effective_weight,
"risk_penalty": risk_penalty,
"liquidity_score": liquidity_score,
"status": "executed",
} }
) )
else: else:
sell_qty = abs(delta) sell_qty = min(abs(delta), current_qty)
if sell_qty > current_qty: if sell_qty <= 1e-6:
sell_qty = current_qty continue
delta = -sell_qty trade_price = price * (1.0 - self._slippage_rate)
proceeds = sell_qty * price trade_price = max(trade_price, 0.0)
gross_value = sell_qty * trade_price
fee = gross_value * self._fee_rate
proceeds = gross_value - fee
cost_basis = state.cost_basis.get(ts_code, 0.0) cost_basis = state.cost_basis.get(ts_code, 0.0)
realized = (price - cost_basis) * sell_qty realized = (trade_price - cost_basis) * sell_qty - fee
state.cash += proceeds state.cash += proceeds
state.realized_pnl += realized state.realized_pnl += realized
new_qty = current_qty + delta new_qty = current_qty - sell_qty
if new_qty <= 1e-6: if new_qty <= 1e-6:
state.holdings.pop(ts_code, None) state.holdings.pop(ts_code, None)
state.cost_basis.pop(ts_code, None) state.cost_basis.pop(ts_code, None)
state.opened_dates.pop(ts_code, None) state.opened_dates.pop(ts_code, None)
else: else:
state.holdings[ts_code] = new_qty state.holdings[ts_code] = new_qty
trades_records.append( daily_turnover += gross_value
executed_trades.append(
{ {
"trade_date": trade_date_str, "trade_date": trade_date_str,
"ts_code": ts_code, "ts_code": ts_code,
"action": "sell", "action": "sell",
"quantity": float(sell_qty), "quantity": float(sell_qty),
"price": price, "price": trade_price,
"value": proceeds, "base_price": price,
"value": gross_value,
"fee": fee,
"slippage": price - trade_price,
"confidence": decision.confidence, "confidence": decision.confidence,
"target_weight": decision.target_weight, "target_weight": decision.target_weight,
"effective_weight": effective_weight,
"risk_penalty": risk_penalty,
"liquidity_score": liquidity_score,
"realized_pnl": realized, "realized_pnl": realized,
"status": "executed",
} }
) )
@ -521,10 +625,13 @@ class BacktestEngine:
"market_value": market_value, "market_value": market_value,
"realized_pnl": state.realized_pnl, "realized_pnl": state.realized_pnl,
"unrealized_pnl": unrealized_pnl, "unrealized_pnl": unrealized_pnl,
"turnover": daily_turnover,
} }
) )
if trades_records: if executed_trades:
result.trades.extend(trades_records) result.trades.extend(executed_trades)
if risk_events:
result.risk_events.extend(risk_events)
try: try:
self._persist_portfolio( self._persist_portfolio(
@ -532,9 +639,10 @@ class BacktestEngine:
state, state,
market_value, market_value,
unrealized_pnl, unrealized_pnl,
trades_records, executed_trades,
price_map, price_map,
decisions_map, decisions_map,
daily_turnover,
) )
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("持仓数据写入失败", extra=LOG_EXTRA) LOGGER.exception("持仓数据写入失败", extra=LOG_EXTRA)
@ -590,6 +698,7 @@ class BacktestEngine:
trades: List[Dict[str, Any]], trades: List[Dict[str, Any]],
price_map: Dict[str, float], price_map: Dict[str, float],
decisions_map: Dict[str, Decision], decisions_map: Dict[str, Decision],
daily_turnover: float,
) -> None: ) -> None:
holdings_rows: List[tuple] = [] holdings_rows: List[tuple] = []
for ts_code, qty in state.holdings.items(): for ts_code, qty in state.holdings.items():
@ -623,6 +732,7 @@ class BacktestEngine:
snapshot_metadata = { snapshot_metadata = {
"holdings": len(state.holdings), "holdings": len(state.holdings),
"turnover_value": daily_turnover,
} }
with db_session() as conn: with db_session() as conn:
@ -662,7 +772,7 @@ class BacktestEngine:
""" """
INSERT INTO portfolio_trades INSERT INTO portfolio_trades
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata) (trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
VALUES (?, ?, ?, ?, ?, 0, NULL, 'backtest', NULL, ?) VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?)
""", """,
[ [
( (
@ -671,6 +781,7 @@ class BacktestEngine:
trade["action"], trade["action"],
trade["quantity"], trade["quantity"],
trade["price"], trade["price"],
trade.get("fee", 0.0),
json.dumps(trade, ensure_ascii=False), json.dumps(trade, ensure_ascii=False),
) )
for trade in trades for trade in trades
@ -708,6 +819,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
nav_rows: List[tuple] = [] nav_rows: List[tuple] = []
trade_rows: List[tuple] = [] trade_rows: List[tuple] = []
summary_payload: Dict[str, object] = {} summary_payload: Dict[str, object] = {}
turnover_sum = 0.0
if result.nav_series: if result.nav_series:
first_nav = float(result.nav_series[0].get("nav", 0.0) or 0.0) first_nav = float(result.nav_series[0].get("nav", 0.0) or 0.0)
@ -721,6 +833,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
market_value = float(entry.get("market_value", 0.0) or 0.0) market_value = float(entry.get("market_value", 0.0) or 0.0)
realized = float(entry.get("realized_pnl", 0.0) or 0.0) realized = float(entry.get("realized_pnl", 0.0) or 0.0)
unrealized = float(entry.get("unrealized_pnl", 0.0) or 0.0) unrealized = float(entry.get("unrealized_pnl", 0.0) or 0.0)
turnover = float(entry.get("turnover", 0.0) or 0.0)
if nav_val > peak_nav: if nav_val > peak_nav:
peak_nav = nav_val peak_nav = nav_val
@ -738,7 +851,9 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
"market_value": market_value, "market_value": market_value,
"realized_pnl": realized, "realized_pnl": realized,
"unrealized_pnl": unrealized, "unrealized_pnl": unrealized,
"turnover": turnover,
} }
turnover_sum += turnover
nav_rows.append( nav_rows.append(
( (
cfg.id, cfg.id,
@ -763,6 +878,9 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
"days": len(result.nav_series), "days": len(result.nav_series),
} }
) )
if turnover_sum:
summary_payload["total_turnover"] = turnover_sum
summary_payload["avg_turnover"] = turnover_sum / max(len(result.nav_series), 1)
if result.trades: if result.trades:
for trade in result.trades: for trade in result.trades:
@ -789,6 +907,14 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
) )
summary_payload["trade_count"] = len(trade_rows) summary_payload["trade_count"] = len(trade_rows)
if result.risk_events:
summary_payload["risk_events"] = len(result.risk_events)
breakdown: Dict[str, int] = {}
for event in result.risk_events:
reason = str(event.get("reason") or "unknown")
breakdown[reason] = breakdown.get(reason, 0) + 1
summary_payload["risk_breakdown"] = breakdown
cfg_payload = { cfg_payload = {
"id": cfg.id, "id": cfg.id,
"name": cfg.name, "name": cfg.name,

View File

@ -21,6 +21,7 @@ from app.utils.config import get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.utils.logging import get_logger from app.utils.logging import get_logger
from app.features.factors import compute_factor_range
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -1616,4 +1617,20 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
raise raise
else: else:
alerts.clear_warnings("TuShare") alerts.clear_warnings("TuShare")
if job.granularity == "daily":
try:
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
compute_factor_range(
job.start,
job.end,
ts_codes=job.ts_codes,
skip_existing=False,
)
except Exception as exc:
alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc))
LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA)
raise
else:
alerts.clear_warnings("Factors")
LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA)
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)

View File

@ -3,7 +3,9 @@ from __future__ import annotations
import re import re
import sqlite3 import sqlite3
from dataclasses import dataclass from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
@ -91,6 +93,16 @@ class DataBroker:
MAX_WINDOW: ClassVar[int] = 120 MAX_WINDOW: ClassVar[int] = 120
BENCHMARK_INDEX: ClassVar[str] = "000300.SH" BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
enable_cache: bool = True
latest_cache_size: int = 256
series_cache_size: int = 512
_latest_cache: OrderedDict = field(init=False, repr=False)
_series_cache: OrderedDict = field(init=False, repr=False)
def __post_init__(self) -> None:
self._latest_cache = OrderedDict()
self._series_cache = OrderedDict()
def fetch_latest( def fetch_latest(
self, self,
ts_code: str, ts_code: str,
@ -98,15 +110,19 @@ class DataBroker:
fields: Iterable[str], fields: Iterable[str],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Fetch the latest value (<= trade_date) for each requested field.""" """Fetch the latest value (<= trade_date) for each requested field."""
field_list = [str(item) for item in fields if item]
cache_key: Optional[Tuple[Any, ...]] = None
if self.enable_cache and field_list:
cache_key = (ts_code, trade_date, tuple(sorted(field_list)))
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
return deepcopy(cached)
grouped: Dict[str, List[str]] = {} grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {}
derived_cache: Dict[str, Any] = {} derived_cache: Dict[str, Any] = {}
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
for item in fields: for field_name in field_list:
if not item:
continue
field_name = str(item)
resolved = self.resolve_field(field_name) resolved = self.resolve_field(field_name)
if not resolved: if not resolved:
derived = self._resolve_derived_field( derived = self._resolve_derived_field(
@ -125,6 +141,13 @@ class DataBroker:
field_map.setdefault((table, column), []).append(field_name) field_map.setdefault((table, column), []).append(field_name)
if not grouped: if not grouped:
if cache_key is not None and results:
self._cache_store(
self._latest_cache,
cache_key,
deepcopy(results),
self.latest_cache_size,
)
return results return results
try: try:
@ -160,6 +183,23 @@ class DataBroker:
results[original] = value results[original] = value
except sqlite3.OperationalError as exc: except sqlite3.OperationalError as exc:
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
if cache_key is not None:
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
LOGGER.debug(
"使用缓存结果 ts_code=%s trade_date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA,
)
return deepcopy(cached)
if cache_key is not None and results:
self._cache_store(
self._latest_cache,
cache_key,
deepcopy(results),
self.latest_cache_size,
)
return results return results
def fetch_series( def fetch_series(
@ -185,6 +225,14 @@ class DataBroker:
) )
return [] return []
table, resolved = resolved_field table, resolved = resolved_field
cache_key: Optional[Tuple[Any, ...]] = None
if self.enable_cache:
cache_key = (table, resolved, ts_code, end_date, window)
cached = self._cache_lookup(self._series_cache, cache_key)
if cached is not None:
return [tuple(item) for item in cached]
query = ( query = (
f"SELECT trade_date, {resolved} FROM {table} " f"SELECT trade_date, {resolved} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? " "WHERE ts_code = ? AND trade_date <= ? "
@ -211,6 +259,17 @@ class DataBroker:
exc, exc,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
if cache_key is not None:
cached = self._cache_lookup(self._series_cache, cache_key)
if cached is not None:
LOGGER.debug(
"使用缓存时间序列 table=%s column=%s ts_code=%s",
table,
resolved,
ts_code,
extra=LOG_EXTRA,
)
return [tuple(item) for item in cached]
return [] return []
series: List[Tuple[str, float]] = [] series: List[Tuple[str, float]] = []
for row in rows: for row in rows:
@ -218,6 +277,13 @@ class DataBroker:
if value is None: if value is None:
continue continue
series.append((row["trade_date"], float(value))) series.append((row["trade_date"], float(value)))
if cache_key is not None and series:
self._cache_store(
self._series_cache,
cache_key,
tuple(series),
self.series_cache_size,
)
return series return series
def fetch_flags( def fetch_flags(
@ -612,6 +678,26 @@ class DataBroker:
cache[table] = columns cache[table] = columns
return columns return columns
def _cache_lookup(self, cache: OrderedDict, key: Tuple[Any, ...]) -> Optional[Any]:
if key in cache:
cache.move_to_end(key)
return cache[key]
return None
def _cache_store(
self,
cache: OrderedDict,
key: Tuple[Any, ...],
value: Any,
limit: int,
) -> None:
if not self.enable_cache or limit <= 0:
return
cache[key] = value
cache.move_to_end(key)
while len(cache) > limit:
cache.popitem(last=False)
def _resolve_column(self, table: str, column: str) -> Optional[str]: def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table) columns = self._get_table_columns(table)
if columns is None: if columns is None: