From 8f820e441e71e98557f39adb71ccc480962a0a9c Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 30 Sep 2025 18:07:47 +0800 Subject: [PATCH] update --- app/backtest/engine.py | 204 +++++++++++++++++++++++++++++++-------- app/ingest/tushare.py | 17 ++++ app/utils/data_access.py | 96 +++++++++++++++++- 3 files changed, 273 insertions(+), 44 deletions(-) diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 0a21ec8..b092ce9 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -47,6 +47,7 @@ class PortfolioState: class BacktestResult: nav_series: List[Dict[str, float]] = field(default_factory=list) trades: List[Dict[str, str]] = field(default_factory=list) + risk_events: List[Dict[str, object]] = field(default_factory=list) class BacktestEngine: @@ -65,6 +66,22 @@ class BacktestEngine: DepartmentManager(app_cfg) if app_cfg.departments else None ) self.data_broker = DataBroker() + params = cfg.params or {} + self.risk_params = { + "max_position_weight": float(params.get("max_position_weight", 0.2)), + "max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)), + "fee_rate": float(params.get("fee_rate", 0.0005)), + "slippage_bps": float(params.get("slippage_bps", 10.0)), + } + self._fee_rate = max(self.risk_params["fee_rate"], 0.0) + self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0 + self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0) + self._buy_actions = { + AgentAction.BUY_S, + AgentAction.BUY_M, + AgentAction.BUY_L, + } + self._sell_actions = {AgentAction.SELL} department_scope: set[str] = set() for settings in app_cfg.departments.values(): department_scope.update(settings.data_scope) @@ -389,7 +406,12 @@ class BacktestEngine: trade_date_str = trade_date.isoformat() price_map: Dict[str, float] = {} decisions_map: Dict[str, Decision] = {} + feature_cache: Dict[str, Mapping[str, Any]] = {} for ts_code, context, decision in records: + features = context.features or {} + if not isinstance(features, Mapping): + features = {} + feature_cache[ts_code] = features scope_values = context.raw.get("scope_values") if context.raw else {} if not isinstance(scope_values, Mapping): scope_values = {} @@ -405,7 +427,7 @@ class BacktestEngine: if not price_map and state.holdings: trade_date_compact = trade_date.strftime("%Y%m%d") - for ts_code in state.holdings.keys(): + for ts_code in list(state.holdings.keys()): fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"]) price = fetched.get("daily.close") if price: @@ -421,84 +443,166 @@ class BacktestEngine: if portfolio_value_before <= 0: portfolio_value_before = state.cash or 1.0 - trades_records: List[Dict[str, Any]] = [] + daily_turnover = 0.0 + executed_trades: List[Dict[str, Any]] = [] + risk_events: List[Dict[str, Any]] = [] + + def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None: + payload = { + "trade_date": trade_date_str, + "ts_code": ts_code, + "action": decision.action.value, + "target_weight": decision.target_weight, + "confidence": decision.confidence, + "reason": reason, + } + if extra: + payload.update(extra) + risk_events.append(payload) + for ts_code, decision in decisions_map.items(): price = price_map.get(ts_code) if price is None or price <= 0: continue + features = feature_cache.get(ts_code, {}) current_qty = state.holdings.get(ts_code, 0.0) + liquidity_score = float(features.get("liquidity_score") or 0.0) + risk_penalty = float(features.get("risk_penalty") or 0.0) + is_suspended = bool(features.get("is_suspended")) + limit_up = bool(features.get("limit_up")) + limit_down = bool(features.get("limit_down")) + position_limit = bool(features.get("position_limit")) + + if is_suspended: + _record_risk(ts_code, "suspended", decision) + continue + if decision.action in self._buy_actions: + if limit_up: + _record_risk(ts_code, "limit_up", decision) + continue + if position_limit: + _record_risk(ts_code, "position_limit", decision) + continue + if risk_penalty >= 0.95: + _record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty}) + continue + if decision.action in self._sell_actions and limit_down: + _record_risk(ts_code, "limit_down", decision) + continue + + effective_weight = max(decision.target_weight, 0.0) + if decision.action in self._buy_actions: + capped_weight = min(effective_weight, self.risk_params["max_position_weight"]) + effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty) + elif decision.action in self._sell_actions: + effective_weight = 0.0 + desired_qty = current_qty - if decision.action is AgentAction.SELL: + if decision.action in self._sell_actions: desired_qty = 0.0 - elif decision.action is AgentAction.HOLD: - desired_qty = current_qty - else: - target_weight = max(decision.target_weight, 0.0) - desired_value = target_weight * portfolio_value_before - if desired_value > 0: - desired_qty = desired_value / price - else: - desired_qty = current_qty + elif decision.action in self._buy_actions or effective_weight >= 0.0: + desired_value = max(effective_weight, 0.0) * portfolio_value_before + desired_qty = desired_value / price if price > 0 else current_qty delta = desired_qty - current_qty if abs(delta) < 1e-6: continue - if delta > 0: - cost = delta * price - if cost > state.cash: - affordable_qty = state.cash / price if price > 0 else 0.0 - delta = max(0.0, affordable_qty) - cost = delta * price + if delta > 0 and self._turnover_cap > 0: + liquidity_scalar = max(liquidity_score, 0.1) + max_trade_value = self._turnover_cap * portfolio_value_before * liquidity_scalar + if max_trade_value > 0 and delta * price > max_trade_value: + delta = max_trade_value / price desired_qty = current_qty + delta - if delta <= 0: + + if delta > 0: + trade_price = price * (1.0 + self._slippage_rate) + per_share_cost = trade_price * (1.0 + self._fee_rate) + if per_share_cost <= 0: + _record_risk(ts_code, "invalid_price", decision) continue - total_cost = state.cost_basis.get(ts_code, 0.0) * current_qty + cost + max_affordable = state.cash / per_share_cost if per_share_cost > 0 else 0.0 + if delta > max_affordable: + if max_affordable <= 1e-6: + _record_risk(ts_code, "insufficient_cash", decision) + continue + delta = max_affordable + desired_qty = current_qty + delta + + trade_value = delta * trade_price + fee = trade_value * self._fee_rate + total_cash_needed = trade_value + fee + if total_cash_needed <= 0: + _record_risk(ts_code, "invalid_trade", decision) + continue + + previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty new_qty = current_qty + delta - state.cost_basis[ts_code] = total_cost / new_qty if new_qty > 0 else 0.0 - state.cash -= cost + state.cost_basis[ts_code] = ( + (previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0 + ) + state.cash -= total_cash_needed state.holdings[ts_code] = new_qty state.opened_dates.setdefault(ts_code, trade_date_str) - trades_records.append( + daily_turnover += trade_value + executed_trades.append( { "trade_date": trade_date_str, "ts_code": ts_code, "action": "buy", "quantity": float(delta), - "price": price, - "value": cost, + "price": trade_price, + "base_price": price, + "value": trade_value, + "fee": fee, + "slippage": trade_price - price, "confidence": decision.confidence, "target_weight": decision.target_weight, + "effective_weight": effective_weight, + "risk_penalty": risk_penalty, + "liquidity_score": liquidity_score, + "status": "executed", } ) else: - sell_qty = abs(delta) - if sell_qty > current_qty: - sell_qty = current_qty - delta = -sell_qty - proceeds = sell_qty * price + sell_qty = min(abs(delta), current_qty) + if sell_qty <= 1e-6: + continue + trade_price = price * (1.0 - self._slippage_rate) + trade_price = max(trade_price, 0.0) + gross_value = sell_qty * trade_price + fee = gross_value * self._fee_rate + proceeds = gross_value - fee cost_basis = state.cost_basis.get(ts_code, 0.0) - realized = (price - cost_basis) * sell_qty + realized = (trade_price - cost_basis) * sell_qty - fee state.cash += proceeds state.realized_pnl += realized - new_qty = current_qty + delta + new_qty = current_qty - sell_qty if new_qty <= 1e-6: state.holdings.pop(ts_code, None) state.cost_basis.pop(ts_code, None) state.opened_dates.pop(ts_code, None) else: state.holdings[ts_code] = new_qty - trades_records.append( + daily_turnover += gross_value + executed_trades.append( { "trade_date": trade_date_str, "ts_code": ts_code, "action": "sell", "quantity": float(sell_qty), - "price": price, - "value": proceeds, + "price": trade_price, + "base_price": price, + "value": gross_value, + "fee": fee, + "slippage": price - trade_price, "confidence": decision.confidence, "target_weight": decision.target_weight, + "effective_weight": effective_weight, + "risk_penalty": risk_penalty, + "liquidity_score": liquidity_score, "realized_pnl": realized, + "status": "executed", } ) @@ -521,10 +625,13 @@ class BacktestEngine: "market_value": market_value, "realized_pnl": state.realized_pnl, "unrealized_pnl": unrealized_pnl, + "turnover": daily_turnover, } ) - if trades_records: - result.trades.extend(trades_records) + if executed_trades: + result.trades.extend(executed_trades) + if risk_events: + result.risk_events.extend(risk_events) try: self._persist_portfolio( @@ -532,9 +639,10 @@ class BacktestEngine: state, market_value, unrealized_pnl, - trades_records, + executed_trades, price_map, decisions_map, + daily_turnover, ) except Exception: # noqa: BLE001 LOGGER.exception("持仓数据写入失败", extra=LOG_EXTRA) @@ -590,6 +698,7 @@ class BacktestEngine: trades: List[Dict[str, Any]], price_map: Dict[str, float], decisions_map: Dict[str, Decision], + daily_turnover: float, ) -> None: holdings_rows: List[tuple] = [] for ts_code, qty in state.holdings.items(): @@ -623,6 +732,7 @@ class BacktestEngine: snapshot_metadata = { "holdings": len(state.holdings), + "turnover_value": daily_turnover, } with db_session() as conn: @@ -662,7 +772,7 @@ class BacktestEngine: """ INSERT INTO portfolio_trades (trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata) - VALUES (?, ?, ?, ?, ?, 0, NULL, 'backtest', NULL, ?) + VALUES (?, ?, ?, ?, ?, ?, NULL, 'backtest', NULL, ?) """, [ ( @@ -671,6 +781,7 @@ class BacktestEngine: trade["action"], trade["quantity"], trade["price"], + trade.get("fee", 0.0), json.dumps(trade, ensure_ascii=False), ) for trade in trades @@ -708,6 +819,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: nav_rows: List[tuple] = [] trade_rows: List[tuple] = [] summary_payload: Dict[str, object] = {} + turnover_sum = 0.0 if result.nav_series: first_nav = float(result.nav_series[0].get("nav", 0.0) or 0.0) @@ -721,6 +833,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: market_value = float(entry.get("market_value", 0.0) or 0.0) realized = float(entry.get("realized_pnl", 0.0) or 0.0) unrealized = float(entry.get("unrealized_pnl", 0.0) or 0.0) + turnover = float(entry.get("turnover", 0.0) or 0.0) if nav_val > peak_nav: peak_nav = nav_val @@ -738,7 +851,9 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: "market_value": market_value, "realized_pnl": realized, "unrealized_pnl": unrealized, + "turnover": turnover, } + turnover_sum += turnover nav_rows.append( ( cfg.id, @@ -763,6 +878,9 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: "days": len(result.nav_series), } ) + if turnover_sum: + summary_payload["total_turnover"] = turnover_sum + summary_payload["avg_turnover"] = turnover_sum / max(len(result.nav_series), 1) if result.trades: for trade in result.trades: @@ -789,6 +907,14 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: ) summary_payload["trade_count"] = len(trade_rows) + if result.risk_events: + summary_payload["risk_events"] = len(result.risk_events) + breakdown: Dict[str, int] = {} + for event in result.risk_events: + reason = str(event.get("reason") or "unknown") + breakdown[reason] = breakdown.get(reason, 0) + 1 + summary_payload["risk_breakdown"] = breakdown + cfg_payload = { "id": cfg.id, "name": cfg.name, diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 816d6a2..8343f4a 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -21,6 +21,7 @@ from app.utils.config import get_config from app.utils.db import db_session from app.data.schema import initialize_database from app.utils.logging import get_logger +from app.features.factors import compute_factor_range LOGGER = get_logger(__name__) @@ -1616,4 +1617,20 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: raise else: alerts.clear_warnings("TuShare") + if job.granularity == "daily": + try: + LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) + compute_factor_range( + job.start, + job.end, + ts_codes=job.ts_codes, + skip_existing=False, + ) + except Exception as exc: + alerts.add_warning("Factors", f"因子计算失败:{job.name}", str(exc)) + LOGGER.exception("因子计算失败 job=%s", job.name, extra=LOG_EXTRA) + raise + else: + alerts.clear_warnings("Factors") + LOGGER.info("因子计算完成:%s", job.name, extra=LOG_EXTRA) LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index 337efb4..f52546f 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -3,7 +3,9 @@ from __future__ import annotations import re import sqlite3 -from dataclasses import dataclass +from collections import OrderedDict +from copy import deepcopy +from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple @@ -91,6 +93,16 @@ class DataBroker: MAX_WINDOW: ClassVar[int] = 120 BENCHMARK_INDEX: ClassVar[str] = "000300.SH" + enable_cache: bool = True + latest_cache_size: int = 256 + series_cache_size: int = 512 + _latest_cache: OrderedDict = field(init=False, repr=False) + _series_cache: OrderedDict = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._latest_cache = OrderedDict() + self._series_cache = OrderedDict() + def fetch_latest( self, ts_code: str, @@ -98,15 +110,19 @@ class DataBroker: fields: Iterable[str], ) -> Dict[str, Any]: """Fetch the latest value (<= trade_date) for each requested field.""" + field_list = [str(item) for item in fields if item] + cache_key: Optional[Tuple[Any, ...]] = None + if self.enable_cache and field_list: + cache_key = (ts_code, trade_date, tuple(sorted(field_list))) + cached = self._cache_lookup(self._latest_cache, cache_key) + if cached is not None: + return deepcopy(cached) grouped: Dict[str, List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {} derived_cache: Dict[str, Any] = {} results: Dict[str, Any] = {} - for item in fields: - if not item: - continue - field_name = str(item) + for field_name in field_list: resolved = self.resolve_field(field_name) if not resolved: derived = self._resolve_derived_field( @@ -125,6 +141,13 @@ class DataBroker: field_map.setdefault((table, column), []).append(field_name) if not grouped: + if cache_key is not None and results: + self._cache_store( + self._latest_cache, + cache_key, + deepcopy(results), + self.latest_cache_size, + ) return results try: @@ -160,6 +183,23 @@ class DataBroker: results[original] = value except sqlite3.OperationalError as exc: LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) + if cache_key is not None: + cached = self._cache_lookup(self._latest_cache, cache_key) + if cached is not None: + LOGGER.debug( + "使用缓存结果 ts_code=%s trade_date=%s", + ts_code, + trade_date, + extra=LOG_EXTRA, + ) + return deepcopy(cached) + if cache_key is not None and results: + self._cache_store( + self._latest_cache, + cache_key, + deepcopy(results), + self.latest_cache_size, + ) return results def fetch_series( @@ -185,6 +225,14 @@ class DataBroker: ) return [] table, resolved = resolved_field + + cache_key: Optional[Tuple[Any, ...]] = None + if self.enable_cache: + cache_key = (table, resolved, ts_code, end_date, window) + cached = self._cache_lookup(self._series_cache, cache_key) + if cached is not None: + return [tuple(item) for item in cached] + query = ( f"SELECT trade_date, {resolved} FROM {table} " "WHERE ts_code = ? AND trade_date <= ? " @@ -211,6 +259,17 @@ class DataBroker: exc, extra=LOG_EXTRA, ) + if cache_key is not None: + cached = self._cache_lookup(self._series_cache, cache_key) + if cached is not None: + LOGGER.debug( + "使用缓存时间序列 table=%s column=%s ts_code=%s", + table, + resolved, + ts_code, + extra=LOG_EXTRA, + ) + return [tuple(item) for item in cached] return [] series: List[Tuple[str, float]] = [] for row in rows: @@ -218,6 +277,13 @@ class DataBroker: if value is None: continue series.append((row["trade_date"], float(value))) + if cache_key is not None and series: + self._cache_store( + self._series_cache, + cache_key, + tuple(series), + self.series_cache_size, + ) return series def fetch_flags( @@ -612,6 +678,26 @@ class DataBroker: cache[table] = columns return columns + def _cache_lookup(self, cache: OrderedDict, key: Tuple[Any, ...]) -> Optional[Any]: + if key in cache: + cache.move_to_end(key) + return cache[key] + return None + + def _cache_store( + self, + cache: OrderedDict, + key: Tuple[Any, ...], + value: Any, + limit: int, + ) -> None: + if not self.enable_cache or limit <= 0: + return + cache[key] = value + cache.move_to_end(key) + while len(cache) > limit: + cache.popitem(last=False) + def _resolve_column(self, table: str, column: str) -> Optional[str]: columns = self._get_table_columns(table) if columns is None: