update
This commit is contained in:
parent
8befd80cb7
commit
07e5bb1b68
@ -36,6 +36,10 @@ class EpisodeMetrics:
|
|||||||
volatility: float
|
volatility: float
|
||||||
nav_series: List[Dict[str, float]]
|
nav_series: List[Dict[str, float]]
|
||||||
trades: List[Dict[str, object]]
|
trades: List[Dict[str, object]]
|
||||||
|
turnover: float
|
||||||
|
trade_count: int
|
||||||
|
risk_count: int
|
||||||
|
risk_breakdown: Dict[str, int]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sharpe_like(self) -> float:
|
def sharpe_like(self) -> float:
|
||||||
@ -109,11 +113,16 @@ class DecisionEnv:
|
|||||||
"max_drawdown": metrics.max_drawdown,
|
"max_drawdown": metrics.max_drawdown,
|
||||||
"volatility": metrics.volatility,
|
"volatility": metrics.volatility,
|
||||||
"sharpe_like": metrics.sharpe_like,
|
"sharpe_like": metrics.sharpe_like,
|
||||||
|
"turnover": metrics.turnover,
|
||||||
|
"trade_count": float(metrics.trade_count),
|
||||||
|
"risk_count": float(metrics.risk_count),
|
||||||
}
|
}
|
||||||
info = {
|
info = {
|
||||||
"nav_series": metrics.nav_series,
|
"nav_series": metrics.nav_series,
|
||||||
"trades": metrics.trades,
|
"trades": metrics.trades,
|
||||||
"weights": weights,
|
"weights": weights,
|
||||||
|
"risk_breakdown": metrics.risk_breakdown,
|
||||||
|
"risk_events": getattr(result, "risk_events", []),
|
||||||
}
|
}
|
||||||
return observation, reward, True, info
|
return observation, reward, True, info
|
||||||
|
|
||||||
@ -131,7 +140,21 @@ class DecisionEnv:
|
|||||||
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
||||||
nav_series = result.nav_series or []
|
nav_series = result.nav_series or []
|
||||||
if not nav_series:
|
if not nav_series:
|
||||||
return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades)
|
risk_breakdown: Dict[str, int] = {}
|
||||||
|
for event in getattr(result, "risk_events", []) or []:
|
||||||
|
reason = str(event.get("reason") or "unknown")
|
||||||
|
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||||||
|
return EpisodeMetrics(
|
||||||
|
total_return=0.0,
|
||||||
|
max_drawdown=0.0,
|
||||||
|
volatility=0.0,
|
||||||
|
nav_series=[],
|
||||||
|
trades=result.trades,
|
||||||
|
turnover=0.0,
|
||||||
|
trade_count=len(result.trades or []),
|
||||||
|
risk_count=len(getattr(result, "risk_events", []) or []),
|
||||||
|
risk_breakdown=risk_breakdown,
|
||||||
|
)
|
||||||
|
|
||||||
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
||||||
if not nav_values or nav_values[0] == 0:
|
if not nav_values or nav_values[0] == 0:
|
||||||
@ -158,17 +181,30 @@ class DecisionEnv:
|
|||||||
else:
|
else:
|
||||||
volatility = 0.0
|
volatility = 0.0
|
||||||
|
|
||||||
|
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
|
||||||
|
risk_events = getattr(result, "risk_events", []) or []
|
||||||
|
risk_breakdown: Dict[str, int] = {}
|
||||||
|
for event in risk_events:
|
||||||
|
reason = str(event.get("reason") or "unknown")
|
||||||
|
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||||||
|
|
||||||
return EpisodeMetrics(
|
return EpisodeMetrics(
|
||||||
total_return=float(total_return),
|
total_return=float(total_return),
|
||||||
max_drawdown=float(max_drawdown),
|
max_drawdown=float(max_drawdown),
|
||||||
volatility=volatility,
|
volatility=volatility,
|
||||||
nav_series=nav_series,
|
nav_series=nav_series,
|
||||||
trades=result.trades,
|
trades=result.trades,
|
||||||
|
turnover=float(turnover),
|
||||||
|
trade_count=len(result.trades or []),
|
||||||
|
risk_count=len(risk_events),
|
||||||
|
risk_breakdown=risk_breakdown,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||||
penalty = 0.5 * metrics.max_drawdown
|
risk_penalty = 0.05 * metrics.risk_count
|
||||||
|
turnover_penalty = 0.00001 * metrics.turnover
|
||||||
|
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||||
return metrics.total_return - penalty
|
return metrics.total_return - penalty
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -98,6 +98,12 @@ class BacktestEngine:
|
|||||||
"daily_basic.volume_ratio",
|
"daily_basic.volume_ratio",
|
||||||
"stk_limit.up_limit",
|
"stk_limit.up_limit",
|
||||||
"stk_limit.down_limit",
|
"stk_limit.down_limit",
|
||||||
|
"factors.mom_20",
|
||||||
|
"factors.mom_60",
|
||||||
|
"factors.volat_20",
|
||||||
|
"factors.turn_20",
|
||||||
|
"news.sentiment_index",
|
||||||
|
"news.heat_score",
|
||||||
}
|
}
|
||||||
self.required_fields = sorted(base_scope | department_scope)
|
self.required_fields = sorted(base_scope | department_scope)
|
||||||
|
|
||||||
@ -121,10 +127,19 @@ class BacktestEngine:
|
|||||||
trade_date_str,
|
trade_date_str,
|
||||||
window=60,
|
window=60,
|
||||||
)
|
)
|
||||||
close_values = [value for _date, value in closes]
|
close_values = [value for _date, value in closes if value is not None]
|
||||||
mom20 = momentum(close_values, 20)
|
|
||||||
mom60 = momentum(close_values, 60)
|
mom20 = scope_values.get("factors.mom_20")
|
||||||
volat20 = volatility(close_values, 20)
|
if mom20 is None and len(close_values) >= 20:
|
||||||
|
mom20 = momentum(close_values, 20)
|
||||||
|
|
||||||
|
mom60 = scope_values.get("factors.mom_60")
|
||||||
|
if mom60 is None and len(close_values) >= 60:
|
||||||
|
mom60 = momentum(close_values, 60)
|
||||||
|
|
||||||
|
volat20 = scope_values.get("factors.volat_20")
|
||||||
|
if volat20 is None and len(close_values) >= 2:
|
||||||
|
volat20 = volatility(close_values, 20)
|
||||||
|
|
||||||
turnover_series = self.data_broker.fetch_series(
|
turnover_series = self.data_broker.fetch_series(
|
||||||
"daily_basic",
|
"daily_basic",
|
||||||
@ -133,8 +148,20 @@ class BacktestEngine:
|
|||||||
trade_date_str,
|
trade_date_str,
|
||||||
window=20,
|
window=20,
|
||||||
)
|
)
|
||||||
turnover_values = [value for _date, value in turnover_series]
|
turnover_values = [value for _date, value in turnover_series if value is not None]
|
||||||
turn20 = rolling_mean(turnover_values, 20)
|
|
||||||
|
turn20 = scope_values.get("factors.turn_20")
|
||||||
|
if turn20 is None and turnover_values:
|
||||||
|
turn20 = rolling_mean(turnover_values, 20)
|
||||||
|
|
||||||
|
if mom20 is None:
|
||||||
|
mom20 = 0.0
|
||||||
|
if mom60 is None:
|
||||||
|
mom60 = 0.0
|
||||||
|
if volat20 is None:
|
||||||
|
volat20 = 0.0
|
||||||
|
if turn20 is None:
|
||||||
|
turn20 = 0.0
|
||||||
|
|
||||||
liquidity_score = normalize(turn20, factor=20.0)
|
liquidity_score = normalize(turn20, factor=20.0)
|
||||||
cost_penalty = normalize(
|
cost_penalty = normalize(
|
||||||
@ -142,12 +169,15 @@ class BacktestEngine:
|
|||||||
factor=50.0,
|
factor=50.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
|
||||||
|
heat_score = scope_values.get("news.heat_score", 0.0)
|
||||||
|
scope_values.setdefault("news.sentiment_index", sentiment_index)
|
||||||
|
scope_values.setdefault("news.heat_score", heat_score)
|
||||||
|
|
||||||
scope_values.setdefault("factors.mom_20", mom20)
|
scope_values.setdefault("factors.mom_20", mom20)
|
||||||
scope_values.setdefault("factors.mom_60", mom60)
|
scope_values.setdefault("factors.mom_60", mom60)
|
||||||
scope_values.setdefault("factors.volat_20", volat20)
|
scope_values.setdefault("factors.volat_20", volat20)
|
||||||
scope_values.setdefault("factors.turn_20", turn20)
|
scope_values.setdefault("factors.turn_20", turn20)
|
||||||
scope_values.setdefault("news.sentiment_index", 0.0)
|
|
||||||
scope_values.setdefault("news.heat_score", 0.0)
|
|
||||||
if scope_values.get("macro.industry_heat") is None:
|
if scope_values.get("macro.industry_heat") is None:
|
||||||
scope_values["macro.industry_heat"] = 0.5
|
scope_values["macro.industry_heat"] = 0.5
|
||||||
if scope_values.get("macro.relative_strength") is None:
|
if scope_values.get("macro.relative_strength") is None:
|
||||||
@ -189,8 +219,8 @@ class BacktestEngine:
|
|||||||
"turn_20": turn20,
|
"turn_20": turn20,
|
||||||
"liquidity_score": liquidity_score,
|
"liquidity_score": liquidity_score,
|
||||||
"cost_penalty": cost_penalty,
|
"cost_penalty": cost_penalty,
|
||||||
"news_heat": scope_values.get("news.heat_score", 0.0),
|
"news_heat": heat_score,
|
||||||
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
|
"news_sentiment": sentiment_index,
|
||||||
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
|
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
|
||||||
"industry_relative_mom": scope_values.get(
|
"industry_relative_mom": scope_values.get(
|
||||||
"macro.relative_strength",
|
"macro.relative_strength",
|
||||||
@ -818,6 +848,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
|||||||
|
|
||||||
nav_rows: List[tuple] = []
|
nav_rows: List[tuple] = []
|
||||||
trade_rows: List[tuple] = []
|
trade_rows: List[tuple] = []
|
||||||
|
risk_rows: List[tuple] = []
|
||||||
summary_payload: Dict[str, object] = {}
|
summary_payload: Dict[str, object] = {}
|
||||||
turnover_sum = 0.0
|
turnover_sum = 0.0
|
||||||
|
|
||||||
@ -893,6 +924,10 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
|||||||
"confidence": trade.get("confidence"),
|
"confidence": trade.get("confidence"),
|
||||||
"target_weight": trade.get("target_weight"),
|
"target_weight": trade.get("target_weight"),
|
||||||
"value": trade.get("value"),
|
"value": trade.get("value"),
|
||||||
|
"fee": trade.get("fee"),
|
||||||
|
"slippage": trade.get("slippage"),
|
||||||
|
"risk_penalty": trade.get("risk_penalty"),
|
||||||
|
"liquidity_score": trade.get("liquidity_score"),
|
||||||
}
|
}
|
||||||
trade_rows.append(
|
trade_rows.append(
|
||||||
(
|
(
|
||||||
@ -913,6 +948,18 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
|||||||
for event in result.risk_events:
|
for event in result.risk_events:
|
||||||
reason = str(event.get("reason") or "unknown")
|
reason = str(event.get("reason") or "unknown")
|
||||||
breakdown[reason] = breakdown.get(reason, 0) + 1
|
breakdown[reason] = breakdown.get(reason, 0) + 1
|
||||||
|
risk_rows.append(
|
||||||
|
(
|
||||||
|
cfg.id,
|
||||||
|
str(event.get("trade_date", "")),
|
||||||
|
str(event.get("ts_code", "")),
|
||||||
|
reason,
|
||||||
|
str(event.get("action", "")),
|
||||||
|
float(event.get("target_weight", 0.0) or 0.0),
|
||||||
|
float(event.get("confidence", 0.0) or 0.0),
|
||||||
|
json.dumps(event, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
)
|
||||||
summary_payload["risk_breakdown"] = breakdown
|
summary_payload["risk_breakdown"] = breakdown
|
||||||
|
|
||||||
cfg_payload = {
|
cfg_payload = {
|
||||||
@ -943,6 +990,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
|||||||
|
|
||||||
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
|
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
|
||||||
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
|
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
|
||||||
|
conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,))
|
||||||
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
|
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
|
||||||
|
|
||||||
if nav_rows:
|
if nav_rows:
|
||||||
@ -963,6 +1011,15 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
|||||||
trade_rows,
|
trade_rows,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if risk_rows:
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
risk_rows,
|
||||||
|
)
|
||||||
|
|
||||||
summary_payload.setdefault("universe", cfg.universe)
|
summary_payload.setdefault("universe", cfg.universe)
|
||||||
summary_payload.setdefault("method", cfg.method)
|
summary_payload.setdefault("method", cfg.method)
|
||||||
conn.execute(
|
conn.execute(
|
||||||
|
|||||||
139
app/backtest/optimizer.py
Normal file
139
app/backtest/optimizer.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""Optimization utilities for DecisionEnv-based parameter tuning."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Iterable, List, Sequence, Tuple
|
||||||
|
|
||||||
|
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
|
||||||
|
from app.backtest.decision_env import ParameterSpec
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
from app.utils.tuning import log_tuning_result
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "decision_bandit"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BanditConfig:
|
||||||
|
"""Configuration for epsilon-greedy bandit optimization."""
|
||||||
|
|
||||||
|
experiment_id: str
|
||||||
|
strategy: str = "epsilon_greedy"
|
||||||
|
episodes: int = 20
|
||||||
|
epsilon: float = 0.2
|
||||||
|
seed: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BanditEpisode:
|
||||||
|
action: Dict[str, float]
|
||||||
|
reward: float
|
||||||
|
metrics: EpisodeMetrics
|
||||||
|
observation: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BanditSummary:
|
||||||
|
episodes: List[BanditEpisode] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def best_episode(self) -> BanditEpisode | None:
|
||||||
|
if not self.episodes:
|
||||||
|
return None
|
||||||
|
return max(self.episodes, key=lambda item: item.reward)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def average_reward(self) -> float:
|
||||||
|
if not self.episodes:
|
||||||
|
return 0.0
|
||||||
|
return sum(item.reward for item in self.episodes) / len(self.episodes)
|
||||||
|
|
||||||
|
|
||||||
|
class EpsilonGreedyBandit:
|
||||||
|
"""Simple epsilon-greedy tuner using DecisionEnv as the reward oracle."""
|
||||||
|
|
||||||
|
def __init__(self, env: DecisionEnv, config: BanditConfig) -> None:
|
||||||
|
self.env = env
|
||||||
|
self.config = config
|
||||||
|
self._random = random.Random(config.seed)
|
||||||
|
self._specs: List[ParameterSpec] = list(getattr(env, "_specs", []))
|
||||||
|
if not self._specs:
|
||||||
|
raise ValueError("DecisionEnv does not expose parameter specs")
|
||||||
|
self._value_estimates: Dict[Tuple[float, ...], float] = {}
|
||||||
|
self._counts: Dict[Tuple[float, ...], int] = {}
|
||||||
|
self._history = BanditSummary()
|
||||||
|
|
||||||
|
def run(self) -> BanditSummary:
|
||||||
|
for episode in range(1, self.config.episodes + 1):
|
||||||
|
action = self._select_action()
|
||||||
|
self.env.reset()
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
metrics = self.env.last_metrics
|
||||||
|
if metrics is None:
|
||||||
|
raise RuntimeError("DecisionEnv did not populate last_metrics")
|
||||||
|
key = tuple(action)
|
||||||
|
old_estimate = self._value_estimates.get(key, 0.0)
|
||||||
|
count = self._counts.get(key, 0) + 1
|
||||||
|
self._counts[key] = count
|
||||||
|
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
|
||||||
|
|
||||||
|
action_payload = self._action_to_mapping(action)
|
||||||
|
metrics_payload = _metrics_to_dict(metrics)
|
||||||
|
try:
|
||||||
|
log_tuning_result(
|
||||||
|
experiment_id=self.config.experiment_id,
|
||||||
|
strategy=self.config.strategy,
|
||||||
|
action=action_payload,
|
||||||
|
reward=reward,
|
||||||
|
metrics=metrics_payload,
|
||||||
|
weights=info.get("weights"),
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("failed to log tuning result", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
episode_record = BanditEpisode(
|
||||||
|
action=action_payload,
|
||||||
|
reward=reward,
|
||||||
|
metrics=metrics,
|
||||||
|
observation=obs,
|
||||||
|
)
|
||||||
|
self._history.episodes.append(episode_record)
|
||||||
|
LOGGER.info(
|
||||||
|
"Bandit episode=%s reward=%.4f action=%s",
|
||||||
|
episode,
|
||||||
|
reward,
|
||||||
|
action_payload,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return self._history
|
||||||
|
|
||||||
|
def _select_action(self) -> List[float]:
|
||||||
|
if self._value_estimates and self._random.random() > self.config.epsilon:
|
||||||
|
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
|
||||||
|
return list(best)
|
||||||
|
return [
|
||||||
|
self._random.uniform(spec.minimum, spec.maximum)
|
||||||
|
for spec in self._specs
|
||||||
|
]
|
||||||
|
|
||||||
|
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]:
|
||||||
|
return {
|
||||||
|
spec.name: float(value)
|
||||||
|
for spec, value in zip(self._specs, action, strict=True)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
|
||||||
|
payload: Dict[str, float | Dict[str, int]] = {
|
||||||
|
"total_return": metrics.total_return,
|
||||||
|
"max_drawdown": metrics.max_drawdown,
|
||||||
|
"volatility": metrics.volatility,
|
||||||
|
"sharpe_like": metrics.sharpe_like,
|
||||||
|
"turnover": metrics.turnover,
|
||||||
|
"trade_count": float(metrics.trade_count),
|
||||||
|
"risk_count": float(metrics.risk_count),
|
||||||
|
}
|
||||||
|
if metrics.risk_breakdown:
|
||||||
|
payload["risk_breakdown"] = dict(metrics.risk_breakdown)
|
||||||
|
return payload
|
||||||
@ -327,6 +327,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
|||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS bt_risk_events (
|
||||||
|
cfg_id TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
ts_code TEXT,
|
||||||
|
reason TEXT,
|
||||||
|
action TEXT,
|
||||||
|
target_weight REAL,
|
||||||
|
confidence REAL,
|
||||||
|
metadata TEXT
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS bt_nav (
|
CREATE TABLE IF NOT EXISTS bt_nav (
|
||||||
cfg_id TEXT,
|
cfg_id TEXT,
|
||||||
trade_date TEXT,
|
trade_date TEXT,
|
||||||
@ -472,6 +484,7 @@ REQUIRED_TABLES = (
|
|||||||
"heat_daily",
|
"heat_daily",
|
||||||
"bt_config",
|
"bt_config",
|
||||||
"bt_trades",
|
"bt_trades",
|
||||||
|
"bt_risk_events",
|
||||||
"bt_nav",
|
"bt_nav",
|
||||||
"bt_report",
|
"bt_report",
|
||||||
"run_log",
|
"run_log",
|
||||||
|
|||||||
@ -118,13 +118,12 @@ class DataBroker:
|
|||||||
if cached is not None:
|
if cached is not None:
|
||||||
return deepcopy(cached)
|
return deepcopy(cached)
|
||||||
|
|
||||||
grouped: Dict[str, List[str]] = {}
|
grouped: Dict[str, List[Tuple[str, str]]] = {}
|
||||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
|
||||||
derived_cache: Dict[str, Any] = {}
|
derived_cache: Dict[str, Any] = {}
|
||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
for field_name in field_list:
|
for field_name in field_list:
|
||||||
resolved = self.resolve_field(field_name)
|
parsed = parse_field_path(field_name)
|
||||||
if not resolved:
|
if not parsed:
|
||||||
derived = self._resolve_derived_field(
|
derived = self._resolve_derived_field(
|
||||||
ts_code,
|
ts_code,
|
||||||
trade_date,
|
trade_date,
|
||||||
@ -134,11 +133,8 @@ class DataBroker:
|
|||||||
if derived is not None:
|
if derived is not None:
|
||||||
results[field_name] = derived
|
results[field_name] = derived
|
||||||
continue
|
continue
|
||||||
table, column = resolved
|
table, column = parsed
|
||||||
grouped.setdefault(table, [])
|
grouped.setdefault(table, []).append((column, field_name))
|
||||||
if column not in grouped[table]:
|
|
||||||
grouped[table].append(column)
|
|
||||||
field_map.setdefault((table, column), []).append(field_name)
|
|
||||||
|
|
||||||
if not grouped:
|
if not grouped:
|
||||||
if cache_key is not None and results:
|
if cache_key is not None and results:
|
||||||
@ -152,10 +148,9 @@ class DataBroker:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
for table, columns in grouped.items():
|
for table, items in grouped.items():
|
||||||
joined_cols = ", ".join(columns)
|
|
||||||
query = (
|
query = (
|
||||||
f"SELECT trade_date, {joined_cols} FROM {table} "
|
f"SELECT * FROM {table} "
|
||||||
"WHERE ts_code = ? AND trade_date <= ? "
|
"WHERE ts_code = ? AND trade_date <= ? "
|
||||||
"ORDER BY trade_date DESC LIMIT 1"
|
"ORDER BY trade_date DESC LIMIT 1"
|
||||||
)
|
)
|
||||||
@ -165,22 +160,25 @@ class DataBroker:
|
|||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"查询失败 table=%s fields=%s err=%s",
|
"查询失败 table=%s fields=%s err=%s",
|
||||||
table,
|
table,
|
||||||
columns,
|
[column for column, _field in items],
|
||||||
exc,
|
exc,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
if not row:
|
if not row:
|
||||||
continue
|
continue
|
||||||
for column in columns:
|
available = row.keys()
|
||||||
value = row[column]
|
for column, original in items:
|
||||||
|
resolved_column = self._resolve_column_in_row(table, column, available)
|
||||||
|
if resolved_column is None:
|
||||||
|
continue
|
||||||
|
value = row[resolved_column]
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
try:
|
||||||
try:
|
results[original] = float(value)
|
||||||
results[original] = float(value)
|
except (TypeError, ValueError):
|
||||||
except (TypeError, ValueError):
|
results[original] = value
|
||||||
results[original] = value
|
|
||||||
except sqlite3.OperationalError as exc:
|
except sqlite3.OperationalError as exc:
|
||||||
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
|
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
@ -698,6 +696,22 @@ class DataBroker:
|
|||||||
while len(cache) > limit:
|
while len(cache) > limit:
|
||||||
cache.popitem(last=False)
|
cache.popitem(last=False)
|
||||||
|
|
||||||
|
def _resolve_column_in_row(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
column: str,
|
||||||
|
available: Sequence[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
alias_map = self.FIELD_ALIASES.get(table, {})
|
||||||
|
candidate = alias_map.get(column, column)
|
||||||
|
if candidate in available:
|
||||||
|
return candidate
|
||||||
|
lowered = candidate.lower()
|
||||||
|
for name in available:
|
||||||
|
if name.lower() == lowered:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||||
columns = self._get_table_columns(table)
|
columns = self._get_table_columns(table)
|
||||||
if columns is None:
|
if columns is None:
|
||||||
|
|||||||
@ -2,7 +2,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@ -40,3 +41,96 @@ def log_tuning_result(
|
|||||||
)
|
)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
def select_best_tuning_result(
|
||||||
|
experiment_id: str,
|
||||||
|
*,
|
||||||
|
metric: str = "reward",
|
||||||
|
descending: bool = True,
|
||||||
|
require_weights: bool = False,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return the best tuning result for the given experiment.
|
||||||
|
|
||||||
|
``metric`` may refer to ``reward`` (default) or any key inside the
|
||||||
|
persisted metrics payload. When ``require_weights`` is True, rows lacking
|
||||||
|
weight definitions are ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, action, weights, reward, metrics, created_at
|
||||||
|
FROM tuning_results
|
||||||
|
WHERE experiment_id = ?
|
||||||
|
""",
|
||||||
|
(experiment_id,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_row: Optional[Mapping[str, Any]] = None
|
||||||
|
best_metrics: Dict[str, Any] = {}
|
||||||
|
best_action: Dict[str, float] = {}
|
||||||
|
best_weights: Dict[str, float] = {}
|
||||||
|
best_score: Optional[float] = None
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
action = _decode_json(row["action"])
|
||||||
|
weights = _decode_json(row["weights"])
|
||||||
|
metrics_payload = _decode_json(row["metrics"])
|
||||||
|
reward_value = float(row["reward"] or 0.0)
|
||||||
|
|
||||||
|
if require_weights and not weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if metric == "reward":
|
||||||
|
score = reward_value
|
||||||
|
else:
|
||||||
|
score_raw = metrics_payload.get(metric)
|
||||||
|
if score_raw is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
score = float(score_raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if best_score is None:
|
||||||
|
choose = True
|
||||||
|
else:
|
||||||
|
choose = score > best_score if descending else score < best_score
|
||||||
|
|
||||||
|
if choose:
|
||||||
|
best_score = score
|
||||||
|
best_row = row
|
||||||
|
best_metrics = metrics_payload
|
||||||
|
best_action = action
|
||||||
|
best_weights = weights
|
||||||
|
|
||||||
|
if best_row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": best_row["id"],
|
||||||
|
"reward": float(best_row["reward"] or 0.0),
|
||||||
|
"score": best_score,
|
||||||
|
"metric": metric,
|
||||||
|
"action": best_action,
|
||||||
|
"weights": best_weights,
|
||||||
|
"metrics": best_metrics,
|
||||||
|
"created_at": best_row["created_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_json(payload: Any) -> Dict[str, Any]:
|
||||||
|
if not payload:
|
||||||
|
return {}
|
||||||
|
if isinstance(payload, Mapping):
|
||||||
|
return dict(payload)
|
||||||
|
if isinstance(payload, str):
|
||||||
|
try:
|
||||||
|
return json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
return {}
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
## 2. 数据与特征层
|
## 2. 数据与特征层
|
||||||
- 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。
|
- 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。
|
||||||
|
- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。
|
||||||
- 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。
|
- 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。
|
||||||
- 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。
|
- 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。
|
||||||
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
||||||
|
|||||||
83
scripts/apply_best_weights.py
Normal file
83
scripts/apply_best_weights.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""Apply or display the best tuning result for an experiment."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from app.utils.config import get_config, save_config
|
||||||
|
from app.utils.tuning import select_best_tuning_result
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description="Apply best tuning weights")
|
||||||
|
parser.add_argument("experiment_id", help="Experiment identifier")
|
||||||
|
parser.add_argument(
|
||||||
|
"--metric",
|
||||||
|
default="reward",
|
||||||
|
help="Metric name for ranking (default: reward)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ascending",
|
||||||
|
action="store_true",
|
||||||
|
help="Sort metric ascending instead of descending",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--require-weights",
|
||||||
|
action="store_true",
|
||||||
|
help="Ignore records without weight payload",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply-config",
|
||||||
|
action="store_true",
|
||||||
|
help="Update agent_weights in config with best result weights (fallback to action)",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def run_cli(argv: Iterable[str] | None = None) -> int:
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||||
|
|
||||||
|
best = select_best_tuning_result(
|
||||||
|
args.experiment_id,
|
||||||
|
metric=args.metric,
|
||||||
|
descending=not args.ascending,
|
||||||
|
require_weights=args.require_weights,
|
||||||
|
)
|
||||||
|
if not best:
|
||||||
|
LOGGER.error("未找到实验结果 experiment_id=%s", args.experiment_id)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(json.dumps(best, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
if args.apply_config:
|
||||||
|
weights = best.get("weights") or best.get("action")
|
||||||
|
if not weights:
|
||||||
|
LOGGER.error("最佳结果缺少权重信息,无法更新配置")
|
||||||
|
return 2
|
||||||
|
cfg = get_config()
|
||||||
|
if not cfg.agent_weights:
|
||||||
|
LOGGER.warning("配置缺少 agent_weights,初始化默认值")
|
||||||
|
cfg.agent_weights.update_from_dict(weights)
|
||||||
|
save_config(cfg)
|
||||||
|
LOGGER.info("已写入新的 agent_weights 至配置")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
raise SystemExit(run_cli())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
scripts/run_bandit_optimization.py
Normal file
124
scripts/run_bandit_optimization.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""Run epsilon-greedy bandit tuning on DecisionEnv."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, date
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
from app.agents.registry import default_agents
|
||||||
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||||
|
from app.backtest.engine import BtConfig
|
||||||
|
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||||
|
from app.utils.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(value: str) -> date:
|
||||||
|
return datetime.strptime(value, "%Y%m%d").date()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_param(text: str) -> ParameterSpec:
|
||||||
|
parts = text.split(":")
|
||||||
|
if len(parts) not in {3, 4}:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
"parameter format must be name:target:min[:max]"
|
||||||
|
)
|
||||||
|
name, target, minimum = parts[:3]
|
||||||
|
maximum = parts[3] if len(parts) == 4 else "1.0"
|
||||||
|
return ParameterSpec(
|
||||||
|
name=name,
|
||||||
|
target=target,
|
||||||
|
minimum=float(minimum),
|
||||||
|
maximum=float(maximum),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_baseline_weights() -> dict:
|
||||||
|
cfg = get_config()
|
||||||
|
if cfg.agent_weights:
|
||||||
|
return cfg.agent_weights.as_dict()
|
||||||
|
return {agent.name: 1.0 for agent in default_agents()}
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
|
||||||
|
parser.add_argument("experiment_id", help="Experiment identifier to log results")
|
||||||
|
parser.add_argument("name", help="Backtest config name")
|
||||||
|
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
|
||||||
|
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
|
||||||
|
parser.add_argument(
|
||||||
|
"--universe",
|
||||||
|
required=True,
|
||||||
|
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--param",
|
||||||
|
action="append",
|
||||||
|
required=True,
|
||||||
|
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--episodes", type=int, default=20)
|
||||||
|
parser.add_argument("--epsilon", type=float, default=0.2)
|
||||||
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def run_cli(argv: Iterable[str] | None = None) -> int:
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||||
|
|
||||||
|
if args.end < args.start:
|
||||||
|
parser.error("end date must not precede start date")
|
||||||
|
|
||||||
|
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
|
||||||
|
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
|
||||||
|
bt_cfg = BtConfig(
|
||||||
|
id=args.experiment_id,
|
||||||
|
name=args.name,
|
||||||
|
start_date=args.start,
|
||||||
|
end_date=args.end,
|
||||||
|
universe=universe,
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
|
||||||
|
env = DecisionEnv(
|
||||||
|
bt_config=bt_cfg,
|
||||||
|
parameter_specs=specs,
|
||||||
|
baseline_weights=_resolve_baseline_weights(),
|
||||||
|
)
|
||||||
|
optimizer = EpsilonGreedyBandit(
|
||||||
|
env,
|
||||||
|
BanditConfig(
|
||||||
|
experiment_id=args.experiment_id,
|
||||||
|
episodes=args.episodes,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
seed=args.seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
summary = optimizer.run()
|
||||||
|
best = summary.best_episode
|
||||||
|
output = {
|
||||||
|
"episodes": len(summary.episodes),
|
||||||
|
"average_reward": summary.average_reward,
|
||||||
|
"best": {
|
||||||
|
"reward": best.reward if best else None,
|
||||||
|
"action": best.action if best else None,
|
||||||
|
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
raise SystemExit(run_cli())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
57
tests/test_backtest_engine_factors.py
Normal file
57
tests/test_backtest_engine_factors.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Verify BacktestEngine consumes persisted factor fields."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.backtest.engine import BacktestEngine, BtConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def engine(monkeypatch):
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="test",
|
||||||
|
name="factor",
|
||||||
|
start_date=date(2025, 1, 10),
|
||||||
|
end_date=date(2025, 1, 10),
|
||||||
|
universe=["000001.SZ"],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
engine = BacktestEngine(cfg)
|
||||||
|
|
||||||
|
def fake_fetch_latest(ts_code, trade_date, fields): # noqa: D401
|
||||||
|
assert "factors.mom_20" in fields
|
||||||
|
return {
|
||||||
|
"daily.close": 10.0,
|
||||||
|
"daily.pct_chg": 0.02,
|
||||||
|
"daily_basic.turnover_rate": 5.0,
|
||||||
|
"daily_basic.volume_ratio": 15.0,
|
||||||
|
"factors.mom_20": 0.12,
|
||||||
|
"factors.mom_60": 0.25,
|
||||||
|
"factors.volat_20": 0.05,
|
||||||
|
"factors.turn_20": 3.0,
|
||||||
|
"news.sentiment_index": 0.3,
|
||||||
|
"news.heat_score": 0.4,
|
||||||
|
"macro.industry_heat": 0.6,
|
||||||
|
"macro.relative_strength": 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(engine.data_broker, "fetch_latest", fake_fetch_latest)
|
||||||
|
monkeypatch.setattr(engine.data_broker, "fetch_series", lambda *args, **kwargs: [])
|
||||||
|
monkeypatch.setattr(engine.data_broker, "fetch_flags", lambda *args, **kwargs: False)
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_market_data_prefers_factors(engine):
|
||||||
|
data = engine.load_market_data(date(2025, 1, 10))
|
||||||
|
record = data["000001.SZ"]
|
||||||
|
features = record["features"]
|
||||||
|
assert features["mom_20"] == pytest.approx(0.12)
|
||||||
|
assert features["mom_60"] == pytest.approx(0.25)
|
||||||
|
assert features["volat_20"] == pytest.approx(0.05)
|
||||||
|
assert features["turn_20"] == pytest.approx(3.0)
|
||||||
|
assert features["news_sentiment"] == pytest.approx(0.3)
|
||||||
|
assert features["news_heat"] == pytest.approx(0.4)
|
||||||
|
assert features["risk_penalty"] == pytest.approx(min(1.0, 0.05 * 5.0))
|
||||||
@ -5,9 +5,20 @@ from datetime import date
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from app.agents.base import AgentAction, AgentContext
|
from app.agents.base import AgentAction, AgentContext
|
||||||
from app.agents.game import Decision
|
from app.agents.game import Decision
|
||||||
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
|
from app.backtest.engine import (
|
||||||
|
BacktestEngine,
|
||||||
|
BacktestResult,
|
||||||
|
BtConfig,
|
||||||
|
PortfolioState,
|
||||||
|
_persist_backtest_results,
|
||||||
|
)
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.utils.config import DataPaths, get_config
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
def _make_context(price: float, features: dict | None = None) -> AgentContext:
|
def _make_context(price: float, features: dict | None = None) -> AgentContext:
|
||||||
@ -43,6 +54,20 @@ def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
|
|||||||
return BacktestEngine(cfg)
|
return BacktestEngine(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def isolated_db(tmp_path):
|
||||||
|
cfg = get_config()
|
||||||
|
original_paths = cfg.data_paths
|
||||||
|
tmp_root = tmp_path / "data"
|
||||||
|
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
cfg.data_paths = DataPaths(root=tmp_root)
|
||||||
|
initialize_database()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cfg.data_paths = original_paths
|
||||||
|
|
||||||
|
|
||||||
def test_buy_respects_risk_caps():
|
def test_buy_respects_risk_caps():
|
||||||
engine = _engine_with_params(
|
engine = _engine_with_params(
|
||||||
{
|
{
|
||||||
@ -130,3 +155,56 @@ def test_sell_applies_slippage_and_fee():
|
|||||||
assert not state.holdings
|
assert not state.holdings
|
||||||
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
|
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
|
||||||
assert not result.risk_events
|
assert not result.risk_events
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_backtest_results_saves_risk_events(isolated_db):
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="risk_cfg",
|
||||||
|
name="risk",
|
||||||
|
start_date=date(2025, 1, 10),
|
||||||
|
end_date=date(2025, 1, 10),
|
||||||
|
universe=["000001.SZ"],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
result = BacktestResult()
|
||||||
|
result.nav_series = [
|
||||||
|
{
|
||||||
|
"trade_date": "2025-01-10",
|
||||||
|
"nav": 100.0,
|
||||||
|
"cash": 100.0,
|
||||||
|
"market_value": 0.0,
|
||||||
|
"realized_pnl": 0.0,
|
||||||
|
"unrealized_pnl": 0.0,
|
||||||
|
"turnover": 0.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result.risk_events = [
|
||||||
|
{
|
||||||
|
"trade_date": "2025-01-10",
|
||||||
|
"ts_code": "000001.SZ",
|
||||||
|
"reason": "limit_up",
|
||||||
|
"action": "buy_l",
|
||||||
|
"target_weight": 0.3,
|
||||||
|
"confidence": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
_persist_backtest_results(cfg, result)
|
||||||
|
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
risk_row = conn.execute(
|
||||||
|
"SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?",
|
||||||
|
(cfg.id,),
|
||||||
|
).fetchone()
|
||||||
|
assert risk_row is not None
|
||||||
|
assert risk_row["reason"] == "limit_up"
|
||||||
|
metadata = json.loads(risk_row["metadata"])
|
||||||
|
assert metadata["action"] == "buy_l"
|
||||||
|
|
||||||
|
summary_row = conn.execute(
|
||||||
|
"SELECT summary FROM bt_report WHERE cfg_id = ?",
|
||||||
|
(cfg.id,),
|
||||||
|
).fetchone()
|
||||||
|
summary = json.loads(summary_row["summary"])
|
||||||
|
assert summary["risk_events"] == 1
|
||||||
|
assert summary["risk_breakdown"]["limit_up"] == 1
|
||||||
|
|||||||
92
tests/test_bandit_optimizer.py
Normal file
92
tests/test_bandit_optimizer.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for epsilon-greedy bandit optimizer."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.backtest.decision_env import EpisodeMetrics, ParameterSpec
|
||||||
|
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||||
|
from app.utils import tuning
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEnv:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._specs = [
|
||||||
|
ParameterSpec(name="w1", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)
|
||||||
|
]
|
||||||
|
self._last_metrics: EpisodeMetrics | None = None
|
||||||
|
self._episode = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_dim(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_metrics(self) -> EpisodeMetrics | None:
|
||||||
|
return self._last_metrics
|
||||||
|
|
||||||
|
def reset(self) -> dict:
|
||||||
|
self._episode += 1
|
||||||
|
return {"episode": float(self._episode)}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
value = float(action[0])
|
||||||
|
reward = 1.0 - abs(value - 0.7)
|
||||||
|
metrics = EpisodeMetrics(
|
||||||
|
total_return=reward,
|
||||||
|
max_drawdown=0.1,
|
||||||
|
volatility=0.05,
|
||||||
|
nav_series=[],
|
||||||
|
trades=[],
|
||||||
|
turnover=100.0,
|
||||||
|
trade_count=0,
|
||||||
|
risk_count=1,
|
||||||
|
risk_breakdown={"test": 1},
|
||||||
|
)
|
||||||
|
self._last_metrics = metrics
|
||||||
|
obs = {
|
||||||
|
"total_return": reward,
|
||||||
|
"max_drawdown": 0.1,
|
||||||
|
"volatility": 0.05,
|
||||||
|
"sharpe_like": reward / 0.05,
|
||||||
|
"turnover": 100.0,
|
||||||
|
"trade_count": 0.0,
|
||||||
|
"risk_count": 1.0,
|
||||||
|
}
|
||||||
|
info = {
|
||||||
|
"nav_series": [],
|
||||||
|
"trades": [],
|
||||||
|
"weights": {"A_mom": value},
|
||||||
|
"risk_breakdown": metrics.risk_breakdown,
|
||||||
|
"risk_events": [],
|
||||||
|
}
|
||||||
|
return obs, reward, True, info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_logging(monkeypatch):
|
||||||
|
records = []
|
||||||
|
|
||||||
|
def fake_log_tuning_result(**kwargs):
|
||||||
|
records.append(kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(tuning, "log_tuning_result", fake_log_tuning_result)
|
||||||
|
from app.backtest import optimizer as optimizer_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(optimizer_module, "log_tuning_result", fake_log_tuning_result)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
def test_bandit_optimizer_runs_and_logs(patch_logging):
|
||||||
|
env = DummyEnv()
|
||||||
|
optimizer = EpsilonGreedyBandit(
|
||||||
|
env,
|
||||||
|
BanditConfig(experiment_id="exp", episodes=5, epsilon=0.5, seed=42),
|
||||||
|
)
|
||||||
|
summary = optimizer.run()
|
||||||
|
|
||||||
|
assert len(summary.episodes) == 5
|
||||||
|
assert summary.best_episode is not None
|
||||||
|
assert patch_logging and len(patch_logging) == 5
|
||||||
|
payload = patch_logging[0]["metrics"]
|
||||||
|
assert isinstance(payload, dict)
|
||||||
|
assert "risk_breakdown" in payload
|
||||||
92
tests/test_decision_env.py
Normal file
92
tests/test_decision_env.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for DecisionEnv risk-aware reward and info outputs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
||||||
|
from app.backtest.engine import BacktestResult, BtConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
|
||||||
|
self.cfg = cfg
|
||||||
|
self.weights = {}
|
||||||
|
self.department_manager = None
|
||||||
|
|
||||||
|
def run(self) -> BacktestResult:
|
||||||
|
result = BacktestResult()
|
||||||
|
result.nav_series = [
|
||||||
|
{
|
||||||
|
"trade_date": "2025-01-10",
|
||||||
|
"nav": 102.0,
|
||||||
|
"cash": 50.0,
|
||||||
|
"market_value": 52.0,
|
||||||
|
"realized_pnl": 1.0,
|
||||||
|
"unrealized_pnl": 1.0,
|
||||||
|
"turnover": 20000.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result.trades = [
|
||||||
|
{
|
||||||
|
"trade_date": "2025-01-10",
|
||||||
|
"ts_code": "000001.SZ",
|
||||||
|
"action": "buy",
|
||||||
|
"quantity": 100.0,
|
||||||
|
"price": 100.0,
|
||||||
|
"value": 10000.0,
|
||||||
|
"fee": 5.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result.risk_events = [
|
||||||
|
{
|
||||||
|
"trade_date": "2025-01-10",
|
||||||
|
"ts_code": "000002.SZ",
|
||||||
|
"reason": "limit_up",
|
||||||
|
"action": "buy_l",
|
||||||
|
"confidence": 0.7,
|
||||||
|
"target_weight": 0.2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="stub",
|
||||||
|
name="stub",
|
||||||
|
start_date=date(2025, 1, 10),
|
||||||
|
end_date=date(2025, 1, 10),
|
||||||
|
universe=["000001.SZ"],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)]
|
||||||
|
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||||
|
|
||||||
|
obs, reward, done, info = env.step([0.8])
|
||||||
|
|
||||||
|
assert done is True
|
||||||
|
assert "risk_count" in obs and obs["risk_count"] == 1.0
|
||||||
|
assert obs["turnover"] == pytest.approx(20000.0)
|
||||||
|
assert info["risk_events"][0]["reason"] == "limit_up"
|
||||||
|
assert info["risk_breakdown"]["limit_up"] == 1
|
||||||
|
assert reward < obs["total_return"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_reward_penalizes_metrics():
|
||||||
|
metrics = EpisodeMetrics(
|
||||||
|
total_return=0.1,
|
||||||
|
max_drawdown=0.2,
|
||||||
|
volatility=0.05,
|
||||||
|
nav_series=[],
|
||||||
|
trades=[],
|
||||||
|
turnover=1000.0,
|
||||||
|
trade_count=0,
|
||||||
|
risk_count=2,
|
||||||
|
risk_breakdown={"foo": 2},
|
||||||
|
)
|
||||||
|
reward = DecisionEnv._default_reward(metrics)
|
||||||
|
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))
|
||||||
81
tests/test_tuning_utils.py
Normal file
81
tests/test_tuning_utils.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for tuning result selection and CLI application."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.utils.config import DataPaths, get_config
|
||||||
|
from app.utils.db import db_session
|
||||||
|
from app.utils.tuning import select_best_tuning_result
|
||||||
|
|
||||||
|
import scripts.apply_best_weights as apply_best_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def isolated_env(tmp_path):
|
||||||
|
cfg = get_config()
|
||||||
|
original_paths = cfg.data_paths
|
||||||
|
tmp_root = tmp_path / "data"
|
||||||
|
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
cfg.data_paths = DataPaths(root=tmp_root)
|
||||||
|
initialize_database()
|
||||||
|
try:
|
||||||
|
yield cfg
|
||||||
|
finally:
|
||||||
|
cfg.data_paths = original_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_result(experiment: str, reward: float, metrics: dict, weights: dict | None = None, action: dict | None = None) -> None:
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
experiment,
|
||||||
|
"test",
|
||||||
|
json.dumps(action or {}, ensure_ascii=False),
|
||||||
|
json.dumps(weights or {}, ensure_ascii=False),
|
||||||
|
reward,
|
||||||
|
json.dumps(metrics, ensure_ascii=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_best_by_reward(isolated_env):
|
||||||
|
_insert_result("exp", 0.1, {"risk_count": 2}, {"A_mom": 0.3})
|
||||||
|
_insert_result("exp", 0.25, {"risk_count": 4}, {"A_mom": 0.6})
|
||||||
|
|
||||||
|
best = select_best_tuning_result("exp")
|
||||||
|
assert best is not None
|
||||||
|
assert best["reward"] == pytest.approx(0.25)
|
||||||
|
assert best["weights"]["A_mom"] == pytest.approx(0.6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_best_by_metric(isolated_env):
|
||||||
|
_insert_result("exp_metric", 0.2, {"risk_count": 5}, {"A_mom": 0.4})
|
||||||
|
_insert_result("exp_metric", 0.1, {"risk_count": 2}, {"A_mom": 0.7})
|
||||||
|
|
||||||
|
best = select_best_tuning_result("exp_metric", metric="risk_count", descending=False)
|
||||||
|
assert best is not None
|
||||||
|
assert best["weights"]["A_mom"] == pytest.approx(0.7)
|
||||||
|
assert best["metrics"]["risk_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_best_weights_cli_updates_config(isolated_env, capsys):
|
||||||
|
cfg = isolated_env
|
||||||
|
_insert_result("exp_cli", 0.3, {"risk_count": 1}, {"A_mom": 0.65, "A_val": 0.2})
|
||||||
|
exit_code = apply_best_weights.run_cli([
|
||||||
|
"exp_cli",
|
||||||
|
"--apply-config",
|
||||||
|
])
|
||||||
|
assert exit_code == 0
|
||||||
|
output = capsys.readouterr().out
|
||||||
|
payload = json.loads(output)
|
||||||
|
assert payload["metric"] == "reward"
|
||||||
|
updated = cfg.agent_weights.as_dict()
|
||||||
|
assert updated["A_mom"] == pytest.approx(0.65)
|
||||||
|
assert updated["A_val"] == pytest.approx(0.2)
|
||||||
Loading…
Reference in New Issue
Block a user