llm-quant/app/backtest/decision_env.py

703 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Reinforcement-learning style environment wrapping the backtest engine."""
from __future__ import annotations
import json
import math
import copy
from dataclasses import dataclass, replace
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from datetime import date
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
from app.agents.registry import weight_map
from app.utils.db import db_session
from app.utils.data_access import DataBroker
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_env"}
@dataclass(frozen=True)
class ParameterSpec:
"""Defines how an action dimension maps to strategy parameters or behaviors."""
name: str
target: str
minimum: float = 0.0
maximum: float = 1.0
values: Optional[Sequence[Any]] = None
def clamp(self, value: float) -> float:
clipped = max(0.0, min(1.0, float(value)))
return self.minimum + clipped * (self.maximum - self.minimum)
def resolve(self, value: float) -> Any:
if self.values is not None:
if not self.values:
raise ValueError(f"ParameterSpec {self.name} configured with empty values list")
clipped = max(0.0, min(1.0, float(value)))
index = int(round(clipped * (len(self.values) - 1)))
return self.values[index]
return self.clamp(value)
@dataclass
class EpisodeMetrics:
total_return: float
max_drawdown: float
volatility: float
sharpe_like: float
calmar_like: float
nav_series: List[Dict[str, float]]
trades: List[Dict[str, object]]
turnover: float
turnover_value: float
trade_count: int
risk_count: int
risk_breakdown: Dict[str, int]
class DecisionEnv:
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
def __init__(
self,
*,
bt_config: BtConfig,
parameter_specs: Sequence[ParameterSpec],
baseline_weights: Mapping[str, float],
reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None,
disable_departments: bool = False,
) -> None:
self._template_cfg = bt_config
self._specs = list(parameter_specs)
self._baseline_weights = dict(baseline_weights)
self._reward_fn = reward_fn or self._default_reward
self._last_metrics: Optional[EpisodeMetrics] = None
self._last_action: Optional[Tuple[float, ...]] = None
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
self._episode = 0
self._disable_departments = bool(disable_departments)
self._engine: Optional[BacktestEngine] = None
self._session: Optional[BacktestSession] = None
self._cumulative_reward = 0.0
self._day_index = 0
self._data_broker = DataBroker()
@property
def action_dim(self) -> int:
return len(self._specs)
@property
def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]:
return self._last_department_controls
def reset(self) -> Dict[str, float]:
self._episode += 1
self._last_metrics = None
self._last_action = None
self._last_department_controls = None
self._cumulative_reward = 0.0
self._day_index = 0
cfg = replace(self._template_cfg)
filtered_universe = self._filter_active_universe(cfg.universe, cfg.start_date, cfg.end_date)
if filtered_universe:
cfg = replace(cfg, universe=filtered_universe)
self._engine = BacktestEngine(cfg)
self._engine.weights = weight_map(self._baseline_weights)
if self._disable_departments:
self._engine.department_manager = None
self._clear_portfolio_records()
self._session = self._engine.start_session()
return {
"episode": float(self._episode),
"day_index": 0.0,
"date_ord": float(self._template_cfg.start_date.toordinal()),
"nav": float(self._session.state.cash),
"total_return": 0.0,
"max_drawdown": 0.0,
"volatility": 0.0,
"turnover": 0.0,
"sharpe_like": 0.0,
"calmar_like": 0.0,
"trade_count": 0.0,
"risk_count": 0.0,
}
def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]:
if len(action) != self.action_dim:
raise ValueError(f"expected action length {self.action_dim}, got {len(action)}")
action_array = [float(val) for val in action]
self._last_action = tuple(action_array)
weights, department_controls = self._prepare_actions(action_array)
LOGGER.info(
"episode=%s action=%s weights=%s controls=%s",
self._episode,
action_array,
weights,
department_controls,
extra=LOG_EXTRA,
)
engine = self._engine
session = self._session
if engine is None or session is None:
raise RuntimeError("environment not initialised; call reset() before step()")
normalized_weights = weight_map(weights)
engine.weights = normalized_weights
if self._disable_departments:
applied_controls = {}
engine.department_manager = None
else:
applied_controls = self._apply_department_controls(engine, department_controls)
records_list: List[Dict[str, Any]] = []
try:
records, done = engine.step_session(session)
records_list = list(records) if records is not None else []
except Exception as exc: # noqa: BLE001
LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)})
failure_metrics = self._empty_metrics(getattr(session, "result", None))
self._last_metrics = failure_metrics
self._last_department_controls = applied_controls
observation = self._build_observation(failure_metrics, records_list, True)
observation["failure"] = 1.0
info = {
"error": str(exc),
"weights": normalized_weights,
"department_controls": applied_controls,
"nav_series": failure_metrics.nav_series,
"trades": failure_metrics.trades,
"risk_breakdown": failure_metrics.risk_breakdown,
"session_done": True,
"raw_records": records_list,
}
return observation, -1.0, True, info
snapshots, trades_override = self._fetch_portfolio_records()
metrics = self._compute_metrics(
session.result,
nav_override=snapshots if snapshots else None,
trades_override=trades_override if trades_override else None,
)
total_reward = float(self._reward_fn(metrics))
reward = total_reward - self._cumulative_reward
self._cumulative_reward = total_reward
self._last_metrics = metrics
observation = self._build_observation(metrics, records_list, done)
observation["turnover_value"] = metrics.turnover_value
info = {
"nav_series": metrics.nav_series,
"trades": metrics.trades,
"weights": normalized_weights,
"risk_breakdown": metrics.risk_breakdown,
"risk_events": getattr(session.result, "risk_events", []),
"portfolio_snapshots": snapshots,
"portfolio_trades": trades_override,
"department_controls": applied_controls,
"session_done": done,
"raw_records": records_list,
}
self._last_department_controls = applied_controls
self._day_index += 1
return observation, reward, done, info
def _prepare_actions(
self,
action: Sequence[float],
) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]:
weights = dict(self._baseline_weights)
department_controls: Dict[str, Dict[str, Any]] = {}
for idx, spec in enumerate(self._specs):
try:
resolved = spec.resolve(action[idx])
except ValueError as exc:
LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA)
continue
if spec.target.startswith("agent_weights."):
agent_name = spec.target.split(".", 1)[1]
try:
weights[agent_name] = float(resolved)
except (TypeError, ValueError):
LOGGER.debug(
"spec %s produced non-numeric weight %s; skipping",
spec.name,
resolved,
extra=LOG_EXTRA,
)
continue
if spec.target.startswith("department."):
target_path = spec.target.split(".")[1:]
if len(target_path) < 2:
LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA)
continue
dept_code = target_path[0]
field = ".".join(target_path[1:])
dept_controls = department_controls.setdefault(dept_code, {})
dept_controls[field] = resolved
continue
else:
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
return weights, department_controls
def _apply_department_controls(
self,
engine: BacktestEngine,
controls: Mapping[str, Mapping[str, Any]],
) -> Dict[str, Dict[str, Any]]:
manager = getattr(engine, "department_manager", None)
if not manager or not getattr(manager, "agents", None):
return {}
applied: Dict[str, Dict[str, Any]] = {}
for dept_code, payload in controls.items():
agent = manager.agents.get(dept_code)
if not agent or not isinstance(payload, Mapping):
continue
applied_fields: Dict[str, Any] = {}
# Ensure mutable settings clone to avoid global side-effects
try:
original_settings = agent.settings
cloned_settings = replace(original_settings)
cloned_settings.llm = copy.deepcopy(original_settings.llm)
agent.settings = cloned_settings
except Exception as exc: # noqa: BLE001
LOGGER.warning(
"复制部门 %s 配置失败:%s",
dept_code,
exc,
extra=LOG_EXTRA,
)
continue
for raw_field, value in payload.items():
field = raw_field.lower()
if field == "function_policy":
field = "tool_choice"
if field in {"prompt", "instruction"}:
agent.settings.prompt = str(value)
applied_fields[field] = agent.settings.prompt
continue
if field == "description":
agent.settings.description = str(value)
applied_fields[field] = agent.settings.description
continue
if field in {"prompt_template_id", "prompt_template"}:
agent.settings.prompt_template_id = str(value)
applied_fields["prompt_template_id"] = agent.settings.prompt_template_id
continue
if field == "prompt_template_version":
agent.settings.prompt_template_version = str(value)
applied_fields["prompt_template_version"] = agent.settings.prompt_template_version
continue
if field in {"temperature", "llm.temperature"}:
try:
temperature = max(0.0, min(2.0, float(value)))
agent.settings.llm.primary.temperature = temperature
applied_fields["temperature"] = temperature
except (TypeError, ValueError):
LOGGER.debug(
"无效的温度值 %s for %s",
value,
dept_code,
extra=LOG_EXTRA,
)
continue
if field in {"tool_choice", "tool_strategy"}:
try:
agent.tool_choice = value
applied_fields["tool_choice"] = agent.tool_choice
except ValueError:
LOGGER.debug(
"部门 %s 工具策略 %s 无效",
dept_code,
value,
extra=LOG_EXTRA,
)
continue
if field == "max_rounds":
try:
agent.max_rounds = value
applied_fields["max_rounds"] = agent.max_rounds
except ValueError:
LOGGER.debug(
"部门 %s max_rounds %s 无效",
dept_code,
value,
extra=LOG_EXTRA,
)
continue
if field == "prompt_template_override":
agent.settings.prompt = str(value)
applied_fields["prompt"] = agent.settings.prompt
continue
LOGGER.debug(
"部门 %s 未识别的控制项 %s",
dept_code,
raw_field,
extra=LOG_EXTRA,
)
if applied_fields:
applied[dept_code] = applied_fields
return applied
def _compute_metrics(
self,
result: BacktestResult,
*,
nav_override: Optional[List[Dict[str, Any]]] = None,
trades_override: Optional[List[Dict[str, Any]]] = None,
) -> EpisodeMetrics:
nav_series = nav_override if nav_override is not None else result.nav_series or []
trades = trades_override if trades_override is not None else result.trades
if not nav_series:
risk_breakdown: Dict[str, int] = {}
for event in getattr(result, "risk_events", []) or []:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=0.0,
max_drawdown=0.0,
volatility=0.0,
sharpe_like=0.0,
calmar_like=0.0,
nav_series=[],
trades=trades or [],
turnover=0.0,
turnover_value=0.0,
trade_count=len(trades or []),
risk_count=len(getattr(result, "risk_events", []) or []),
risk_breakdown=risk_breakdown,
)
nav_values = [row.get("nav", 0.0) for row in nav_series]
if not nav_values or nav_values[0] == 0:
base_nav = nav_values[0] if nav_values else 1.0
else:
base_nav = nav_values[0]
returns = [(nav / base_nav) - 1.0 for nav in nav_values]
total_return = returns[-1]
peak = nav_values[0]
max_drawdown = 0.0
for nav in nav_values:
if nav > peak:
peak = nav
drawdown = (peak - nav) / peak if peak else 0.0
max_drawdown = max(max_drawdown, drawdown)
diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))]
if diffs:
mean_diff = sum(diffs) / len(diffs)
variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs)
volatility = math.sqrt(variance) / base_nav
else:
volatility = 0.0
sharpe_like = total_return / volatility if abs(volatility) > 1e-9 else 0.0
calmar_like = total_return / max_drawdown if max_drawdown > 1e-6 else total_return
turnover_value = 0.0
turnover_ratios: List[float] = []
for row in nav_series:
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
turnover_value += turnover_raw
ratio = row.get("turnover_ratio")
if ratio is not None:
try:
turnover_ratios.append(float(ratio))
continue
except (TypeError, ValueError):
turnover_ratios.append(0.0)
continue
nav_val = float(row.get("nav", 0.0) or 0.0)
if nav_val > 0:
turnover_ratios.append(turnover_raw / nav_val)
else:
turnover_ratios.append(0.0)
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
risk_events = getattr(result, "risk_events", []) or []
risk_breakdown: Dict[str, int] = {}
for event in risk_events:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=float(total_return),
max_drawdown=float(max_drawdown),
volatility=volatility,
sharpe_like=float(sharpe_like),
calmar_like=float(calmar_like),
nav_series=nav_series,
trades=trades or [],
turnover=float(avg_turnover_ratio),
turnover_value=float(turnover_value),
trade_count=len(trades or []),
risk_count=len(risk_events),
risk_breakdown=risk_breakdown,
)
@staticmethod
def _default_reward(metrics: EpisodeMetrics) -> float:
risk_penalty = 0.05 * metrics.risk_count
turnover_penalty = 0.1 * metrics.turnover
drawdown_penalty = 0.5 * metrics.max_drawdown
bonus = 0.1 * metrics.sharpe_like + 0.05 * metrics.calmar_like
return metrics.total_return + bonus - (drawdown_penalty + risk_penalty + turnover_penalty)
def _build_observation(
self,
metrics: EpisodeMetrics,
records: Sequence[Dict[str, Any]] | None,
done: bool,
) -> Dict[str, float]:
observation: Dict[str, float] = {
"day_index": float(self._day_index + 1),
"total_return": metrics.total_return,
"max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"calmar_like": metrics.calmar_like,
"turnover": metrics.turnover,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
"done": 1.0 if done else 0.0,
}
latest_snapshot = metrics.nav_series[-1] if metrics.nav_series else None
if latest_snapshot:
observation["nav"] = float(latest_snapshot.get("nav", 0.0) or 0.0)
observation["cash"] = float(latest_snapshot.get("cash", 0.0) or 0.0)
observation["market_value"] = float(latest_snapshot.get("market_value", 0.0) or 0.0)
trade_date = latest_snapshot.get("trade_date")
if isinstance(trade_date, date):
observation["date_ord"] = float(trade_date.toordinal())
elif isinstance(trade_date, str):
try:
parsed = date.fromisoformat(trade_date)
except ValueError:
parsed = None
if parsed:
observation["date_ord"] = float(parsed.toordinal())
if "turnover_ratio" in latest_snapshot and latest_snapshot["turnover_ratio"] is not None:
try:
observation["turnover_ratio"] = float(latest_snapshot["turnover_ratio"])
except (TypeError, ValueError):
observation["turnover_ratio"] = 0.0
# Include a simple proxy for action effect size when available
if records:
observation["record_count"] = float(len(records))
return observation
@property
def last_metrics(self) -> Optional[EpisodeMetrics]:
return self._last_metrics
@property
def last_action(self) -> Optional[Tuple[float, ...]]:
return self._last_action
def _clear_portfolio_records(self) -> None:
cfg_id = self._template_cfg.id or "decision_env"
try:
with db_session() as conn:
conn.execute("DELETE FROM bt_portfolio_positions WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_snapshots WHERE cfg_id = ?", (cfg_id,))
conn.execute("DELETE FROM bt_portfolio_trades WHERE cfg_id = ?", (cfg_id,))
except Exception: # noqa: BLE001
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
start = self._template_cfg.start_date.isoformat()
end = self._template_cfg.end_date.isoformat()
cfg_id = self._template_cfg.id or "decision_env"
snapshots: List[Dict[str, Any]] = []
trades: List[Dict[str, Any]] = []
try:
with db_session(read_only=True) as conn:
snapshot_rows = conn.execute(
"""
SELECT trade_date, total_value, cash, invested_value,
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
FROM bt_portfolio_snapshots
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date
""",
(cfg_id, start, end),
).fetchall()
trade_rows = conn.execute(
"""
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
FROM bt_portfolio_trades
WHERE cfg_id = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date, id
""",
(cfg_id, start, end),
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
return snapshots, trades
for row in snapshot_rows:
metadata = self._loads(row["metadata"], {})
snapshots.append(
{
"trade_date": row["trade_date"],
"nav": float(row["total_value"] or 0.0),
"cash": float(row["cash"] or 0.0),
"market_value": float(row["invested_value"] or 0.0),
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
"realized_pnl": float(row["realized_pnl"] or 0.0),
"net_flow": float(row["net_flow"] or 0.0),
"exposure": float(row["exposure"] or 0.0),
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
"holdings": metadata.get("holdings"),
"trade_count": metadata.get("trade_count"),
}
)
for row in trade_rows:
metadata = self._loads(row["metadata"], {})
trades.append(
{
"id": row["id"],
"trade_date": row["trade_date"],
"ts_code": row["ts_code"],
"action": row["action"],
"quantity": float(row["quantity"] or 0.0),
"price": float(row["price"] or 0.0),
"fee": float(row["fee"] or 0.0),
"source": row["source"],
"metadata": metadata,
}
)
return snapshots, trades
def _filter_active_universe(
self,
universe: Sequence[str],
start_date: date,
end_date: date,
) -> List[str]:
if not universe:
return list(universe)
broker = self._data_broker
start_key = start_date.strftime("%Y%m%d")
end_key = end_date.strftime("%Y%m%d")
active: List[str] = []
filtered: List[str] = []
for ts_code in universe:
try:
suspended_start = broker.fetch_flags(
"suspend",
ts_code,
start_key,
"",
[],
auto_refresh=False,
)
suspended_end = broker.fetch_flags(
"suspend",
ts_code,
end_key,
"",
[],
auto_refresh=False,
)
except Exception: # noqa: BLE001
LOGGER.debug(
"检测停牌状态失败 ts_code=%s start=%s end=%s",
ts_code,
start_key,
end_key,
extra=LOG_EXTRA,
)
active.append(ts_code)
continue
if suspended_start and suspended_end:
filtered.append(ts_code)
continue
active.append(ts_code)
if filtered:
LOGGER.info(
"过滤停牌标的 %s/%s%s",
len(filtered),
len(universe),
filtered[:10],
extra=LOG_EXTRA,
)
return active or list(universe)
@staticmethod
def _loads(payload: Any, default: Any) -> Any:
if not payload:
return default
if isinstance(payload, (dict, list)):
return payload
if isinstance(payload, str):
try:
return json.loads(payload)
except json.JSONDecodeError:
return default
return default
def _empty_metrics(self, result: Optional[BacktestResult]) -> EpisodeMetrics:
nav_series: List[Dict[str, Any]] = []
trades: List[Dict[str, Any]] = []
risk_events: List[Dict[str, Any]] = []
if result is not None:
try:
nav_series = list(result.nav_series or [])
except Exception: # noqa: BLE001
nav_series = []
try:
trades = list(result.trades or [])
except Exception: # noqa: BLE001
trades = []
try:
risk_events = list(getattr(result, "risk_events", []) or [])
except Exception: # noqa: BLE001
risk_events = []
risk_breakdown: Dict[str, int] = {}
for event in risk_events:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=0.0,
max_drawdown=0.0,
volatility=0.0,
sharpe_like=0.0,
calmar_like=0.0,
nav_series=nav_series,
trades=trades,
turnover=0.0,
turnover_value=0.0,
trade_count=len(trades),
risk_count=len(risk_events),
risk_breakdown=risk_breakdown,
)