703 lines
27 KiB
Python
703 lines
27 KiB
Python
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import math
|
||
import copy
|
||
from dataclasses import dataclass, replace
|
||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||
|
||
from datetime import date
|
||
|
||
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
|
||
from app.agents.registry import weight_map
|
||
from app.utils.db import db_session
|
||
from app.utils.data_access import DataBroker
|
||
from app.utils.logging import get_logger
|
||
|
||
LOGGER = get_logger(__name__)
|
||
LOG_EXTRA = {"stage": "decision_env"}
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ParameterSpec:
|
||
"""Defines how an action dimension maps to strategy parameters or behaviors."""
|
||
|
||
name: str
|
||
target: str
|
||
minimum: float = 0.0
|
||
maximum: float = 1.0
|
||
values: Optional[Sequence[Any]] = None
|
||
|
||
def clamp(self, value: float) -> float:
|
||
clipped = max(0.0, min(1.0, float(value)))
|
||
return self.minimum + clipped * (self.maximum - self.minimum)
|
||
|
||
def resolve(self, value: float) -> Any:
|
||
if self.values is not None:
|
||
if not self.values:
|
||
raise ValueError(f"ParameterSpec {self.name} configured with empty values list")
|
||
clipped = max(0.0, min(1.0, float(value)))
|
||
index = int(round(clipped * (len(self.values) - 1)))
|
||
return self.values[index]
|
||
return self.clamp(value)
|
||
|
||
|
||
@dataclass
|
||
class EpisodeMetrics:
|
||
total_return: float
|
||
max_drawdown: float
|
||
volatility: float
|
||
sharpe_like: float
|
||
calmar_like: float
|
||
nav_series: List[Dict[str, float]]
|
||
trades: List[Dict[str, object]]
|
||
turnover: float
|
||
turnover_value: float
|
||
trade_count: int
|
||
risk_count: int
|
||
risk_breakdown: Dict[str, int]
|
||
|
||
|
||
class DecisionEnv:
|
||
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
bt_config: BtConfig,
|
||
parameter_specs: Sequence[ParameterSpec],
|
||
baseline_weights: Mapping[str, float],
|
||
reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None,
|
||
disable_departments: bool = False,
|
||
) -> None:
|
||
self._template_cfg = bt_config
|
||
self._specs = list(parameter_specs)
|
||
self._baseline_weights = dict(baseline_weights)
|
||
self._reward_fn = reward_fn or self._default_reward
|
||
self._last_metrics: Optional[EpisodeMetrics] = None
|
||
self._last_action: Optional[Tuple[float, ...]] = None
|
||
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
|
||
self._episode = 0
|
||
self._disable_departments = bool(disable_departments)
|
||
self._engine: Optional[BacktestEngine] = None
|
||
self._session: Optional[BacktestSession] = None
|
||
self._cumulative_reward = 0.0
|
||
self._day_index = 0
|
||
self._data_broker = DataBroker()
|
||
|
||
@property
|
||
def action_dim(self) -> int:
|
||
return len(self._specs)
|
||
|
||
@property
|
||
def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||
return self._last_department_controls
|
||
|
||
def reset(self) -> Dict[str, float]:
|
||
self._episode += 1
|
||
self._last_metrics = None
|
||
self._last_action = None
|
||
self._last_department_controls = None
|
||
self._cumulative_reward = 0.0
|
||
self._day_index = 0
|
||
|
||
cfg = replace(self._template_cfg)
|
||
filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date)
|
||
if filtered_universe:
|
||
cfg = replace(cfg, universe=filtered_universe)
|
||
self._engine = BacktestEngine(cfg)
|
||
self._engine.weights = weight_map(self._baseline_weights)
|
||
if self._disable_departments:
|
||
self._engine.department_manager = None
|
||
|
||
self._clear_portfolio_records()
|
||
|
||
self._session = self._engine.start_session()
|
||
return {
|
||
"episode": float(self._episode),
|
||
"day_index": 0.0,
|
||
"date_ord": float(self._template_cfg.start_date.toordinal()),
|
||
"nav": float(self._session.state.cash),
|
||
"total_return": 0.0,
|
||
"max_drawdown": 0.0,
|
||
"volatility": 0.0,
|
||
"turnover": 0.0,
|
||
"sharpe_like": 0.0,
|
||
"calmar_like": 0.0,
|
||
"trade_count": 0.0,
|
||
"risk_count": 0.0,
|
||
}
|
||
|
||
def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]:
|
||
if len(action) != self.action_dim:
|
||
raise ValueError(f"expected action length {self.action_dim}, got {len(action)}")
|
||
action_array = [float(val) for val in action]
|
||
self._last_action = tuple(action_array)
|
||
|
||
weights, department_controls = self._prepare_actions(action_array)
|
||
LOGGER.info(
|
||
"episode=%s action=%s weights=%s controls=%s",
|
||
self._episode,
|
||
action_array,
|
||
weights,
|
||
department_controls,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
engine = self._engine
|
||
session = self._session
|
||
if engine is None or session is None:
|
||
raise RuntimeError("environment not initialised; call reset() before step()")
|
||
|
||
normalized_weights = weight_map(weights)
|
||
engine.weights = normalized_weights
|
||
if self._disable_departments:
|
||
applied_controls = {}
|
||
engine.department_manager = None
|
||
else:
|
||
applied_controls = self._apply_department_controls(engine, department_controls)
|
||
|
||
records_list: List[Dict[str, Any]] = []
|
||
try:
|
||
records, done = engine.step_session(session)
|
||
records_list = list(records) if records is not None else []
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)})
|
||
failure_metrics = self._empty_metrics(getattr(session, "result", None))
|
||
self._last_metrics = failure_metrics
|
||
self._last_department_controls = applied_controls
|
||
observation = self._build_observation(failure_metrics, records_list, True)
|
||
observation["failure"] = 1.0
|
||
info = {
|
||
"error": str(exc),
|
||
"weights": normalized_weights,
|
||
"department_controls": applied_controls,
|
||
"nav_series": failure_metrics.nav_series,
|
||
"trades": failure_metrics.trades,
|
||
"risk_breakdown": failure_metrics.risk_breakdown,
|
||
"session_done": True,
|
||
"raw_records": records_list,
|
||
}
|
||
return observation, -1.0, True, info
|
||
|
||
snapshots, trades_override = self._fetch_portfolio_records()
|
||
metrics = self._compute_metrics(
|
||
session.result,
|
||
nav_override=snapshots if snapshots else None,
|
||
trades_override=trades_override if trades_override else None,
|
||
)
|
||
|
||
total_reward = float(self._reward_fn(metrics))
|
||
reward = total_reward - self._cumulative_reward
|
||
self._cumulative_reward = total_reward
|
||
self._last_metrics = metrics
|
||
|
||
observation = self._build_observation(metrics, records_list, done)
|
||
observation["turnover_value"] = metrics.turnover_value
|
||
info = {
|
||
"nav_series": metrics.nav_series,
|
||
"trades": metrics.trades,
|
||
"weights": normalized_weights,
|
||
"risk_breakdown": metrics.risk_breakdown,
|
||
"risk_events": getattr(session.result, "risk_events", []),
|
||
"portfolio_snapshots": snapshots,
|
||
"portfolio_trades": trades_override,
|
||
"department_controls": applied_controls,
|
||
"session_done": done,
|
||
"raw_records": records_list,
|
||
}
|
||
self._last_department_controls = applied_controls
|
||
self._day_index += 1
|
||
return observation, reward, done, info
|
||
|
||
def _prepare_actions(
|
||
self,
|
||
action: Sequence[float],
|
||
) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]:
|
||
weights = dict(self._baseline_weights)
|
||
department_controls: Dict[str, Dict[str, Any]] = {}
|
||
for idx, spec in enumerate(self._specs):
|
||
try:
|
||
resolved = spec.resolve(action[idx])
|
||
except ValueError as exc:
|
||
LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA)
|
||
continue
|
||
if spec.target.startswith("agent_weights."):
|
||
agent_name = spec.target.split(".", 1)[1]
|
||
try:
|
||
weights[agent_name] = float(resolved)
|
||
except (TypeError, ValueError):
|
||
LOGGER.debug(
|
||
"spec %s produced non-numeric weight %s; skipping",
|
||
spec.name,
|
||
resolved,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
if spec.target.startswith("department."):
|
||
target_path = spec.target.split(".")[1:]
|
||
if len(target_path) < 2:
|
||
LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA)
|
||
continue
|
||
dept_code = target_path[0]
|
||
field = ".".join(target_path[1:])
|
||
dept_controls = department_controls.setdefault(dept_code, {})
|
||
dept_controls[field] = resolved
|
||
continue
|
||
else:
|
||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||
return weights, department_controls
|
||
|
||
def _apply_department_controls(
|
||
self,
|
||
engine: BacktestEngine,
|
||
controls: Mapping[str, Mapping[str, Any]],
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
manager = getattr(engine, "department_manager", None)
|
||
if not manager or not getattr(manager, "agents", None):
|
||
return {}
|
||
|
||
applied: Dict[str, Dict[str, Any]] = {}
|
||
for dept_code, payload in controls.items():
|
||
agent = manager.agents.get(dept_code)
|
||
if not agent or not isinstance(payload, Mapping):
|
||
continue
|
||
|
||
applied_fields: Dict[str, Any] = {}
|
||
|
||
# Ensure mutable settings clone to avoid global side-effects
|
||
try:
|
||
original_settings = agent.settings
|
||
cloned_settings = replace(original_settings)
|
||
cloned_settings.llm = copy.deepcopy(original_settings.llm)
|
||
agent.settings = cloned_settings
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.warning(
|
||
"复制部门 %s 配置失败:%s",
|
||
dept_code,
|
||
exc,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
|
||
for raw_field, value in payload.items():
|
||
field = raw_field.lower()
|
||
if field == "function_policy":
|
||
field = "tool_choice"
|
||
if field in {"prompt", "instruction"}:
|
||
agent.settings.prompt = str(value)
|
||
applied_fields[field] = agent.settings.prompt
|
||
continue
|
||
if field == "description":
|
||
agent.settings.description = str(value)
|
||
applied_fields[field] = agent.settings.description
|
||
continue
|
||
if field in {"prompt_template_id", "prompt_template"}:
|
||
agent.settings.prompt_template_id = str(value)
|
||
applied_fields["prompt_template_id"] = agent.settings.prompt_template_id
|
||
continue
|
||
if field == "prompt_template_version":
|
||
agent.settings.prompt_template_version = str(value)
|
||
applied_fields["prompt_template_version"] = agent.settings.prompt_template_version
|
||
continue
|
||
if field in {"temperature", "llm.temperature"}:
|
||
try:
|
||
temperature = max(0.0, min(2.0, float(value)))
|
||
agent.settings.llm.primary.temperature = temperature
|
||
applied_fields["temperature"] = temperature
|
||
except (TypeError, ValueError):
|
||
LOGGER.debug(
|
||
"无效的温度值 %s for %s",
|
||
value,
|
||
dept_code,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
if field in {"tool_choice", "tool_strategy"}:
|
||
try:
|
||
agent.tool_choice = value
|
||
applied_fields["tool_choice"] = agent.tool_choice
|
||
except ValueError:
|
||
LOGGER.debug(
|
||
"部门 %s 工具策略 %s 无效",
|
||
dept_code,
|
||
value,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
if field == "max_rounds":
|
||
try:
|
||
agent.max_rounds = value
|
||
applied_fields["max_rounds"] = agent.max_rounds
|
||
except ValueError:
|
||
LOGGER.debug(
|
||
"部门 %s max_rounds %s 无效",
|
||
dept_code,
|
||
value,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
if field == "prompt_template_override":
|
||
agent.settings.prompt = str(value)
|
||
applied_fields["prompt"] = agent.settings.prompt
|
||
continue
|
||
LOGGER.debug(
|
||
"部门 %s 未识别的控制项 %s",
|
||
dept_code,
|
||
raw_field,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
if applied_fields:
|
||
applied[dept_code] = applied_fields
|
||
|
||
return applied
|
||
|
||
def _compute_metrics(
|
||
self,
|
||
result: BacktestResult,
|
||
*,
|
||
nav_override: Optional[List[Dict[str, Any]]] = None,
|
||
trades_override: Optional[List[Dict[str, Any]]] = None,
|
||
) -> EpisodeMetrics:
|
||
nav_series = nav_override if nav_override is not None else result.nav_series or []
|
||
trades = trades_override if trades_override is not None else result.trades
|
||
|
||
if not nav_series:
|
||
risk_breakdown: Dict[str, int] = {}
|
||
for event in getattr(result, "risk_events", []) or []:
|
||
reason = str(event.get("reason") or "unknown")
|
||
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||
return EpisodeMetrics(
|
||
total_return=0.0,
|
||
max_drawdown=0.0,
|
||
volatility=0.0,
|
||
sharpe_like=0.0,
|
||
calmar_like=0.0,
|
||
nav_series=[],
|
||
trades=trades or [],
|
||
turnover=0.0,
|
||
turnover_value=0.0,
|
||
trade_count=len(trades or []),
|
||
risk_count=len(getattr(result, "risk_events", []) or []),
|
||
risk_breakdown=risk_breakdown,
|
||
)
|
||
|
||
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
||
if not nav_values or nav_values[0] == 0:
|
||
base_nav = nav_values[0] if nav_values else 1.0
|
||
else:
|
||
base_nav = nav_values[0]
|
||
|
||
returns = [(nav / base_nav) - 1.0 for nav in nav_values]
|
||
total_return = returns[-1]
|
||
|
||
peak = nav_values[0]
|
||
max_drawdown = 0.0
|
||
for nav in nav_values:
|
||
if nav > peak:
|
||
peak = nav
|
||
drawdown = (peak - nav) / peak if peak else 0.0
|
||
max_drawdown = max(max_drawdown, drawdown)
|
||
|
||
diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))]
|
||
if diffs:
|
||
mean_diff = sum(diffs) / len(diffs)
|
||
variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs)
|
||
volatility = math.sqrt(variance) / base_nav
|
||
else:
|
||
volatility = 0.0
|
||
sharpe_like = total_return / volatility if abs(volatility) > 1e-9 else 0.0
|
||
calmar_like = total_return / max_drawdown if max_drawdown > 1e-6 else total_return
|
||
|
||
turnover_value = 0.0
|
||
turnover_ratios: List[float] = []
|
||
for row in nav_series:
|
||
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
|
||
turnover_value += turnover_raw
|
||
ratio = row.get("turnover_ratio")
|
||
if ratio is not None:
|
||
try:
|
||
turnover_ratios.append(float(ratio))
|
||
continue
|
||
except (TypeError, ValueError):
|
||
turnover_ratios.append(0.0)
|
||
continue
|
||
nav_val = float(row.get("nav", 0.0) or 0.0)
|
||
if nav_val > 0:
|
||
turnover_ratios.append(turnover_raw / nav_val)
|
||
else:
|
||
turnover_ratios.append(0.0)
|
||
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
|
||
risk_events = getattr(result, "risk_events", []) or []
|
||
risk_breakdown: Dict[str, int] = {}
|
||
for event in risk_events:
|
||
reason = str(event.get("reason") or "unknown")
|
||
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||
|
||
return EpisodeMetrics(
|
||
total_return=float(total_return),
|
||
max_drawdown=float(max_drawdown),
|
||
volatility=volatility,
|
||
sharpe_like=float(sharpe_like),
|
||
calmar_like=float(calmar_like),
|
||
nav_series=nav_series,
|
||
trades=trades or [],
|
||
turnover=float(avg_turnover_ratio),
|
||
turnover_value=float(turnover_value),
|
||
trade_count=len(trades or []),
|
||
risk_count=len(risk_events),
|
||
risk_breakdown=risk_breakdown,
|
||
)
|
||
|
||
@staticmethod
|
||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||
risk_penalty = 0.05 * metrics.risk_count
|
||
turnover_penalty = 0.1 * metrics.turnover
|
||
drawdown_penalty = 0.5 * metrics.max_drawdown
|
||
bonus = 0.1 * metrics.sharpe_like + 0.05 * metrics.calmar_like
|
||
return metrics.total_return + bonus - (drawdown_penalty + risk_penalty + turnover_penalty)
|
||
|
||
def _build_observation(
|
||
self,
|
||
metrics: EpisodeMetrics,
|
||
records: Sequence[Dict[str, Any]] | None,
|
||
done: bool,
|
||
) -> Dict[str, float]:
|
||
observation: Dict[str, float] = {
|
||
"day_index": float(self._day_index + 1),
|
||
"total_return": metrics.total_return,
|
||
"max_drawdown": metrics.max_drawdown,
|
||
"volatility": metrics.volatility,
|
||
"sharpe_like": metrics.sharpe_like,
|
||
"calmar_like": metrics.calmar_like,
|
||
"turnover": metrics.turnover,
|
||
"trade_count": float(metrics.trade_count),
|
||
"risk_count": float(metrics.risk_count),
|
||
"done": 1.0 if done else 0.0,
|
||
}
|
||
|
||
latest_snapshot = metrics.nav_series[-1] if metrics.nav_series else None
|
||
if latest_snapshot:
|
||
observation["nav"] = float(latest_snapshot.get("nav", 0.0) or 0.0)
|
||
observation["cash"] = float(latest_snapshot.get("cash", 0.0) or 0.0)
|
||
observation["market_value"] = float(latest_snapshot.get("market_value", 0.0) or 0.0)
|
||
trade_date = latest_snapshot.get("trade_date")
|
||
if isinstance(trade_date, date):
|
||
observation["date_ord"] = float(trade_date.toordinal())
|
||
elif isinstance(trade_date, str):
|
||
try:
|
||
parsed = date.fromisoformat(trade_date)
|
||
except ValueError:
|
||
parsed = None
|
||
if parsed:
|
||
observation["date_ord"] = float(parsed.toordinal())
|
||
if "turnover_ratio" in latest_snapshot and latest_snapshot["turnover_ratio"] is not None:
|
||
try:
|
||
observation["turnover_ratio"] = float(latest_snapshot["turnover_ratio"])
|
||
except (TypeError, ValueError):
|
||
observation["turnover_ratio"] = 0.0
|
||
|
||
# Include a simple proxy for action effect size when available
|
||
if records:
|
||
observation["record_count"] = float(len(records))
|
||
|
||
return observation
|
||
|
||
@property
|
||
def last_metrics(self) -> Optional[EpisodeMetrics]:
|
||
return self._last_metrics
|
||
|
||
@property
|
||
def last_action(self) -> Optional[Tuple[float, ...]]:
|
||
return self._last_action
|
||
|
||
def _clear_portfolio_records(self) -> None:
|
||
cfg_id = self._template_cfg.id or "decision_env"
|
||
try:
|
||
with db_session() as conn:
|
||
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
|
||
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
|
||
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
|
||
|
||
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||
start = self._template_cfg.start_date.isoformat()
|
||
end = self._template_cfg.end_date.isoformat()
|
||
cfg_id = self._template_cfg.id or "decision_env"
|
||
snapshots: List[Dict[str, Any]] = []
|
||
trades: List[Dict[str, Any]] = []
|
||
try:
|
||
with db_session(read_only=True) as conn:
|
||
snapshot_rows = conn.execute(
|
||
"""
|
||
SELECT trade_date, total_value, cash, invested_value,
|
||
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
|
||
FROM bt_portfolio_snapshots
|
||
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
|
||
ORDER BY trade_date
|
||
""",
|
||
(cfg_id, start, end),
|
||
).fetchall()
|
||
trade_rows = conn.execute(
|
||
"""
|
||
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
|
||
FROM bt_portfolio_trades
|
||
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
|
||
ORDER BY trade_date, id
|
||
""",
|
||
(cfg_id, start, end),
|
||
).fetchall()
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
|
||
return snapshots, trades
|
||
|
||
for row in snapshot_rows:
|
||
metadata = self._loads(row["metadata"], {})
|
||
snapshots.append(
|
||
{
|
||
"trade_date": row["trade_date"],
|
||
"nav": float(row["total_value"] or 0.0),
|
||
"cash": float(row["cash"] or 0.0),
|
||
"market_value": float(row["invested_value"] or 0.0),
|
||
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
|
||
"realized_pnl": float(row["realized_pnl"] or 0.0),
|
||
"net_flow": float(row["net_flow"] or 0.0),
|
||
"exposure": float(row["exposure"] or 0.0),
|
||
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
|
||
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
|
||
"holdings": metadata.get("holdings"),
|
||
"trade_count": metadata.get("trade_count"),
|
||
}
|
||
)
|
||
|
||
for row in trade_rows:
|
||
metadata = self._loads(row["metadata"], {})
|
||
trades.append(
|
||
{
|
||
"id": row["id"],
|
||
"trade_date": row["trade_date"],
|
||
"ts_code": row["ts_code"],
|
||
"action": row["action"],
|
||
"quantity": float(row["quantity"] or 0.0),
|
||
"price": float(row["price"] or 0.0),
|
||
"fee": float(row["fee"] or 0.0),
|
||
"source": row["source"],
|
||
"metadata": metadata,
|
||
}
|
||
)
|
||
|
||
return snapshots, trades
|
||
|
||
def _filter_active_universe(
|
||
self,
|
||
universe: Sequence[str],
|
||
start_date: date,
|
||
end_date: date,
|
||
) -> List[str]:
|
||
if not universe:
|
||
return list(universe)
|
||
|
||
broker = self._data_broker
|
||
start_key = start_date.strftime("%Y%m%d")
|
||
end_key = end_date.strftime("%Y%m%d")
|
||
active: List[str] = []
|
||
filtered: List[str] = []
|
||
for ts_code in universe:
|
||
try:
|
||
suspended_start = broker.fetch_flags(
|
||
"suspend",
|
||
ts_code,
|
||
start_key,
|
||
"",
|
||
[],
|
||
auto_refresh=False,
|
||
)
|
||
suspended_end = broker.fetch_flags(
|
||
"suspend",
|
||
ts_code,
|
||
end_key,
|
||
"",
|
||
[],
|
||
auto_refresh=False,
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.debug(
|
||
"检测停牌状态失败 ts_code=%s start=%s end=%s",
|
||
ts_code,
|
||
start_key,
|
||
end_key,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
active.append(ts_code)
|
||
continue
|
||
|
||
if suspended_start and suspended_end:
|
||
filtered.append(ts_code)
|
||
continue
|
||
active.append(ts_code)
|
||
|
||
if filtered:
|
||
LOGGER.info(
|
||
"过滤停牌标的 %s/%s:%s",
|
||
len(filtered),
|
||
len(universe),
|
||
filtered[:10],
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return active or list(universe)
|
||
|
||
@staticmethod
|
||
def _loads(payload: Any, default: Any) -> Any:
|
||
if not payload:
|
||
return default
|
||
if isinstance(payload, (dict, list)):
|
||
return payload
|
||
if isinstance(payload, str):
|
||
try:
|
||
return json.loads(payload)
|
||
except json.JSONDecodeError:
|
||
return default
|
||
return default
|
||
|
||
def _empty_metrics(self, result: Optional[BacktestResult]) -> EpisodeMetrics:
|
||
nav_series: List[Dict[str, Any]] = []
|
||
trades: List[Dict[str, Any]] = []
|
||
risk_events: List[Dict[str, Any]] = []
|
||
|
||
if result is not None:
|
||
try:
|
||
nav_series = list(result.nav_series or [])
|
||
except Exception: # noqa: BLE001
|
||
nav_series = []
|
||
try:
|
||
trades = list(result.trades or [])
|
||
except Exception: # noqa: BLE001
|
||
trades = []
|
||
try:
|
||
risk_events = list(getattr(result, "risk_events", []) or [])
|
||
except Exception: # noqa: BLE001
|
||
risk_events = []
|
||
|
||
risk_breakdown: Dict[str, int] = {}
|
||
for event in risk_events:
|
||
reason = str(event.get("reason") or "unknown")
|
||
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||
|
||
return EpisodeMetrics(
|
||
total_return=0.0,
|
||
max_drawdown=0.0,
|
||
volatility=0.0,
|
||
sharpe_like=0.0,
|
||
calmar_like=0.0,
|
||
nav_series=nav_series,
|
||
trades=trades,
|
||
turnover=0.0,
|
||
turnover_value=0.0,
|
||
trade_count=len(trades),
|
||
risk_count=len(risk_events),
|
||
risk_breakdown=risk_breakdown,
|
||
)
|