update
This commit is contained in:
parent
8befd80cb7
commit
07e5bb1b68
@ -36,6 +36,10 @@ class EpisodeMetrics:
|
||||
volatility: float
|
||||
nav_series: List[Dict[str, float]]
|
||||
trades: List[Dict[str, object]]
|
||||
turnover: float
|
||||
trade_count: int
|
||||
risk_count: int
|
||||
risk_breakdown: Dict[str, int]
|
||||
|
||||
@property
|
||||
def sharpe_like(self) -> float:
|
||||
@ -109,11 +113,16 @@ class DecisionEnv:
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
}
|
||||
info = {
|
||||
"nav_series": metrics.nav_series,
|
||||
"trades": metrics.trades,
|
||||
"weights": weights,
|
||||
"risk_breakdown": metrics.risk_breakdown,
|
||||
"risk_events": getattr(result, "risk_events", []),
|
||||
}
|
||||
return observation, reward, True, info
|
||||
|
||||
@ -131,7 +140,21 @@ class DecisionEnv:
|
||||
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
||||
nav_series = result.nav_series or []
|
||||
if not nav_series:
|
||||
return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades)
|
||||
risk_breakdown: Dict[str, int] = {}
|
||||
for event in getattr(result, "risk_events", []) or []:
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||||
return EpisodeMetrics(
|
||||
total_return=0.0,
|
||||
max_drawdown=0.0,
|
||||
volatility=0.0,
|
||||
nav_series=[],
|
||||
trades=result.trades,
|
||||
turnover=0.0,
|
||||
trade_count=len(result.trades or []),
|
||||
risk_count=len(getattr(result, "risk_events", []) or []),
|
||||
risk_breakdown=risk_breakdown,
|
||||
)
|
||||
|
||||
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
||||
if not nav_values or nav_values[0] == 0:
|
||||
@ -158,17 +181,30 @@ class DecisionEnv:
|
||||
else:
|
||||
volatility = 0.0
|
||||
|
||||
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
|
||||
risk_events = getattr(result, "risk_events", []) or []
|
||||
risk_breakdown: Dict[str, int] = {}
|
||||
for event in risk_events:
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
||||
|
||||
return EpisodeMetrics(
|
||||
total_return=float(total_return),
|
||||
max_drawdown=float(max_drawdown),
|
||||
volatility=volatility,
|
||||
nav_series=nav_series,
|
||||
trades=result.trades,
|
||||
turnover=float(turnover),
|
||||
trade_count=len(result.trades or []),
|
||||
risk_count=len(risk_events),
|
||||
risk_breakdown=risk_breakdown,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||
penalty = 0.5 * metrics.max_drawdown
|
||||
risk_penalty = 0.05 * metrics.risk_count
|
||||
turnover_penalty = 0.00001 * metrics.turnover
|
||||
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||
return metrics.total_return - penalty
|
||||
|
||||
@property
|
||||
|
||||
@ -98,6 +98,12 @@ class BacktestEngine:
|
||||
"daily_basic.volume_ratio",
|
||||
"stk_limit.up_limit",
|
||||
"stk_limit.down_limit",
|
||||
"factors.mom_20",
|
||||
"factors.mom_60",
|
||||
"factors.volat_20",
|
||||
"factors.turn_20",
|
||||
"news.sentiment_index",
|
||||
"news.heat_score",
|
||||
}
|
||||
self.required_fields = sorted(base_scope | department_scope)
|
||||
|
||||
@ -121,9 +127,18 @@ class BacktestEngine:
|
||||
trade_date_str,
|
||||
window=60,
|
||||
)
|
||||
close_values = [value for _date, value in closes]
|
||||
close_values = [value for _date, value in closes if value is not None]
|
||||
|
||||
mom20 = scope_values.get("factors.mom_20")
|
||||
if mom20 is None and len(close_values) >= 20:
|
||||
mom20 = momentum(close_values, 20)
|
||||
|
||||
mom60 = scope_values.get("factors.mom_60")
|
||||
if mom60 is None and len(close_values) >= 60:
|
||||
mom60 = momentum(close_values, 60)
|
||||
|
||||
volat20 = scope_values.get("factors.volat_20")
|
||||
if volat20 is None and len(close_values) >= 2:
|
||||
volat20 = volatility(close_values, 20)
|
||||
|
||||
turnover_series = self.data_broker.fetch_series(
|
||||
@ -133,21 +148,36 @@ class BacktestEngine:
|
||||
trade_date_str,
|
||||
window=20,
|
||||
)
|
||||
turnover_values = [value for _date, value in turnover_series]
|
||||
turnover_values = [value for _date, value in turnover_series if value is not None]
|
||||
|
||||
turn20 = scope_values.get("factors.turn_20")
|
||||
if turn20 is None and turnover_values:
|
||||
turn20 = rolling_mean(turnover_values, 20)
|
||||
|
||||
if mom20 is None:
|
||||
mom20 = 0.0
|
||||
if mom60 is None:
|
||||
mom60 = 0.0
|
||||
if volat20 is None:
|
||||
volat20 = 0.0
|
||||
if turn20 is None:
|
||||
turn20 = 0.0
|
||||
|
||||
liquidity_score = normalize(turn20, factor=20.0)
|
||||
cost_penalty = normalize(
|
||||
scope_values.get("daily_basic.volume_ratio", 0.0),
|
||||
factor=50.0,
|
||||
)
|
||||
|
||||
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
|
||||
heat_score = scope_values.get("news.heat_score", 0.0)
|
||||
scope_values.setdefault("news.sentiment_index", sentiment_index)
|
||||
scope_values.setdefault("news.heat_score", heat_score)
|
||||
|
||||
scope_values.setdefault("factors.mom_20", mom20)
|
||||
scope_values.setdefault("factors.mom_60", mom60)
|
||||
scope_values.setdefault("factors.volat_20", volat20)
|
||||
scope_values.setdefault("factors.turn_20", turn20)
|
||||
scope_values.setdefault("news.sentiment_index", 0.0)
|
||||
scope_values.setdefault("news.heat_score", 0.0)
|
||||
if scope_values.get("macro.industry_heat") is None:
|
||||
scope_values["macro.industry_heat"] = 0.5
|
||||
if scope_values.get("macro.relative_strength") is None:
|
||||
@ -189,8 +219,8 @@ class BacktestEngine:
|
||||
"turn_20": turn20,
|
||||
"liquidity_score": liquidity_score,
|
||||
"cost_penalty": cost_penalty,
|
||||
"news_heat": scope_values.get("news.heat_score", 0.0),
|
||||
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
|
||||
"news_heat": heat_score,
|
||||
"news_sentiment": sentiment_index,
|
||||
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
|
||||
"industry_relative_mom": scope_values.get(
|
||||
"macro.relative_strength",
|
||||
@ -818,6 +848,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||||
|
||||
nav_rows: List[tuple] = []
|
||||
trade_rows: List[tuple] = []
|
||||
risk_rows: List[tuple] = []
|
||||
summary_payload: Dict[str, object] = {}
|
||||
turnover_sum = 0.0
|
||||
|
||||
@ -893,6 +924,10 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||||
"confidence": trade.get("confidence"),
|
||||
"target_weight": trade.get("target_weight"),
|
||||
"value": trade.get("value"),
|
||||
"fee": trade.get("fee"),
|
||||
"slippage": trade.get("slippage"),
|
||||
"risk_penalty": trade.get("risk_penalty"),
|
||||
"liquidity_score": trade.get("liquidity_score"),
|
||||
}
|
||||
trade_rows.append(
|
||||
(
|
||||
@ -913,6 +948,18 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||||
for event in result.risk_events:
|
||||
reason = str(event.get("reason") or "unknown")
|
||||
breakdown[reason] = breakdown.get(reason, 0) + 1
|
||||
risk_rows.append(
|
||||
(
|
||||
cfg.id,
|
||||
str(event.get("trade_date", "")),
|
||||
str(event.get("ts_code", "")),
|
||||
reason,
|
||||
str(event.get("action", "")),
|
||||
float(event.get("target_weight", 0.0) or 0.0),
|
||||
float(event.get("confidence", 0.0) or 0.0),
|
||||
json.dumps(event, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
summary_payload["risk_breakdown"] = breakdown
|
||||
|
||||
cfg_payload = {
|
||||
@ -943,6 +990,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||||
|
||||
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
|
||||
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
|
||||
conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,))
|
||||
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
|
||||
|
||||
if nav_rows:
|
||||
@ -963,6 +1011,15 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
|
||||
trade_rows,
|
||||
)
|
||||
|
||||
if risk_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
risk_rows,
|
||||
)
|
||||
|
||||
summary_payload.setdefault("universe", cfg.universe)
|
||||
summary_payload.setdefault("method", cfg.method)
|
||||
conn.execute(
|
||||
|
||||
139
app/backtest/optimizer.py
Normal file
139
app/backtest/optimizer.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""Optimization utilities for DecisionEnv-based parameter tuning."""
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Sequence, Tuple
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
|
||||
from app.backtest.decision_env import ParameterSpec
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils.tuning import log_tuning_result
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_bandit"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanditConfig:
|
||||
"""Configuration for epsilon-greedy bandit optimization."""
|
||||
|
||||
experiment_id: str
|
||||
strategy: str = "epsilon_greedy"
|
||||
episodes: int = 20
|
||||
epsilon: float = 0.2
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanditEpisode:
|
||||
action: Dict[str, float]
|
||||
reward: float
|
||||
metrics: EpisodeMetrics
|
||||
observation: Dict[str, float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanditSummary:
|
||||
episodes: List[BanditEpisode] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def best_episode(self) -> BanditEpisode | None:
|
||||
if not self.episodes:
|
||||
return None
|
||||
return max(self.episodes, key=lambda item: item.reward)
|
||||
|
||||
@property
|
||||
def average_reward(self) -> float:
|
||||
if not self.episodes:
|
||||
return 0.0
|
||||
return sum(item.reward for item in self.episodes) / len(self.episodes)
|
||||
|
||||
|
||||
class EpsilonGreedyBandit:
|
||||
"""Simple epsilon-greedy tuner using DecisionEnv as the reward oracle."""
|
||||
|
||||
def __init__(self, env: DecisionEnv, config: BanditConfig) -> None:
|
||||
self.env = env
|
||||
self.config = config
|
||||
self._random = random.Random(config.seed)
|
||||
self._specs: List[ParameterSpec] = list(getattr(env, "_specs", []))
|
||||
if not self._specs:
|
||||
raise ValueError("DecisionEnv does not expose parameter specs")
|
||||
self._value_estimates: Dict[Tuple[float, ...], float] = {}
|
||||
self._counts: Dict[Tuple[float, ...], int] = {}
|
||||
self._history = BanditSummary()
|
||||
|
||||
def run(self) -> BanditSummary:
|
||||
for episode in range(1, self.config.episodes + 1):
|
||||
action = self._select_action()
|
||||
self.env.reset()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
metrics = self.env.last_metrics
|
||||
if metrics is None:
|
||||
raise RuntimeError("DecisionEnv did not populate last_metrics")
|
||||
key = tuple(action)
|
||||
old_estimate = self._value_estimates.get(key, 0.0)
|
||||
count = self._counts.get(key, 0) + 1
|
||||
self._counts[key] = count
|
||||
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
|
||||
|
||||
action_payload = self._action_to_mapping(action)
|
||||
metrics_payload = _metrics_to_dict(metrics)
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=self.config.experiment_id,
|
||||
strategy=self.config.strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=info.get("weights"),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("failed to log tuning result", extra=LOG_EXTRA)
|
||||
|
||||
episode_record = BanditEpisode(
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics,
|
||||
observation=obs,
|
||||
)
|
||||
self._history.episodes.append(episode_record)
|
||||
LOGGER.info(
|
||||
"Bandit episode=%s reward=%.4f action=%s",
|
||||
episode,
|
||||
reward,
|
||||
action_payload,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self._history
|
||||
|
||||
def _select_action(self) -> List[float]:
|
||||
if self._value_estimates and self._random.random() > self.config.epsilon:
|
||||
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
|
||||
return list(best)
|
||||
return [
|
||||
self._random.uniform(spec.minimum, spec.maximum)
|
||||
for spec in self._specs
|
||||
]
|
||||
|
||||
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]:
|
||||
return {
|
||||
spec.name: float(value)
|
||||
for spec, value in zip(self._specs, action, strict=True)
|
||||
}
|
||||
|
||||
|
||||
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
|
||||
payload: Dict[str, float | Dict[str, int]] = {
|
||||
"total_return": metrics.total_return,
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
}
|
||||
if metrics.risk_breakdown:
|
||||
payload["risk_breakdown"] = dict(metrics.risk_breakdown)
|
||||
return payload
|
||||
@ -327,6 +327,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_risk_events (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
reason TEXT,
|
||||
action TEXT,
|
||||
target_weight REAL,
|
||||
confidence REAL,
|
||||
metadata TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS bt_nav (
|
||||
cfg_id TEXT,
|
||||
trade_date TEXT,
|
||||
@ -472,6 +484,7 @@ REQUIRED_TABLES = (
|
||||
"heat_daily",
|
||||
"bt_config",
|
||||
"bt_trades",
|
||||
"bt_risk_events",
|
||||
"bt_nav",
|
||||
"bt_report",
|
||||
"run_log",
|
||||
|
||||
@ -118,13 +118,12 @@ class DataBroker:
|
||||
if cached is not None:
|
||||
return deepcopy(cached)
|
||||
|
||||
grouped: Dict[str, List[str]] = {}
|
||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||
grouped: Dict[str, List[Tuple[str, str]]] = {}
|
||||
derived_cache: Dict[str, Any] = {}
|
||||
results: Dict[str, Any] = {}
|
||||
for field_name in field_list:
|
||||
resolved = self.resolve_field(field_name)
|
||||
if not resolved:
|
||||
parsed = parse_field_path(field_name)
|
||||
if not parsed:
|
||||
derived = self._resolve_derived_field(
|
||||
ts_code,
|
||||
trade_date,
|
||||
@ -134,11 +133,8 @@ class DataBroker:
|
||||
if derived is not None:
|
||||
results[field_name] = derived
|
||||
continue
|
||||
table, column = resolved
|
||||
grouped.setdefault(table, [])
|
||||
if column not in grouped[table]:
|
||||
grouped[table].append(column)
|
||||
field_map.setdefault((table, column), []).append(field_name)
|
||||
table, column = parsed
|
||||
grouped.setdefault(table, []).append((column, field_name))
|
||||
|
||||
if not grouped:
|
||||
if cache_key is not None and results:
|
||||
@ -152,10 +148,9 @@ class DataBroker:
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
for table, columns in grouped.items():
|
||||
joined_cols = ", ".join(columns)
|
||||
for table, items in grouped.items():
|
||||
query = (
|
||||
f"SELECT trade_date, {joined_cols} FROM {table} "
|
||||
f"SELECT * FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT 1"
|
||||
)
|
||||
@ -165,18 +160,21 @@ class DataBroker:
|
||||
LOGGER.debug(
|
||||
"查询失败 table=%s fields=%s err=%s",
|
||||
table,
|
||||
columns,
|
||||
[column for column, _field in items],
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if not row:
|
||||
continue
|
||||
for column in columns:
|
||||
value = row[column]
|
||||
available = row.keys()
|
||||
for column, original in items:
|
||||
resolved_column = self._resolve_column_in_row(table, column, available)
|
||||
if resolved_column is None:
|
||||
continue
|
||||
value = row[resolved_column]
|
||||
if value is None:
|
||||
continue
|
||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||
try:
|
||||
results[original] = float(value)
|
||||
except (TypeError, ValueError):
|
||||
@ -698,6 +696,22 @@ class DataBroker:
|
||||
while len(cache) > limit:
|
||||
cache.popitem(last=False)
|
||||
|
||||
def _resolve_column_in_row(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
available: Sequence[str],
|
||||
) -> Optional[str]:
|
||||
alias_map = self.FIELD_ALIASES.get(table, {})
|
||||
candidate = alias_map.get(column, column)
|
||||
if candidate in available:
|
||||
return candidate
|
||||
lowered = candidate.lower()
|
||||
for name in available:
|
||||
if name.lower() == lowered:
|
||||
return name
|
||||
return None
|
||||
|
||||
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||
columns = self._get_table_columns(table)
|
||||
if columns is None:
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
@ -40,3 +41,96 @@ def log_tuning_result(
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
|
||||
|
||||
def select_best_tuning_result(
|
||||
experiment_id: str,
|
||||
*,
|
||||
metric: str = "reward",
|
||||
descending: bool = True,
|
||||
require_weights: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Return the best tuning result for the given experiment.
|
||||
|
||||
``metric`` may refer to ``reward`` (default) or any key inside the
|
||||
persisted metrics payload. When ``require_weights`` is True, rows lacking
|
||||
weight definitions are ignored.
|
||||
"""
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT id, action, weights, reward, metrics, created_at
|
||||
FROM tuning_results
|
||||
WHERE experiment_id = ?
|
||||
""",
|
||||
(experiment_id,),
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
best_row: Optional[Mapping[str, Any]] = None
|
||||
best_metrics: Dict[str, Any] = {}
|
||||
best_action: Dict[str, float] = {}
|
||||
best_weights: Dict[str, float] = {}
|
||||
best_score: Optional[float] = None
|
||||
|
||||
for row in rows:
|
||||
action = _decode_json(row["action"])
|
||||
weights = _decode_json(row["weights"])
|
||||
metrics_payload = _decode_json(row["metrics"])
|
||||
reward_value = float(row["reward"] or 0.0)
|
||||
|
||||
if require_weights and not weights:
|
||||
continue
|
||||
|
||||
if metric == "reward":
|
||||
score = reward_value
|
||||
else:
|
||||
score_raw = metrics_payload.get(metric)
|
||||
if score_raw is None:
|
||||
continue
|
||||
try:
|
||||
score = float(score_raw)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
if best_score is None:
|
||||
choose = True
|
||||
else:
|
||||
choose = score > best_score if descending else score < best_score
|
||||
|
||||
if choose:
|
||||
best_score = score
|
||||
best_row = row
|
||||
best_metrics = metrics_payload
|
||||
best_action = action
|
||||
best_weights = weights
|
||||
|
||||
if best_row is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": best_row["id"],
|
||||
"reward": float(best_row["reward"] or 0.0),
|
||||
"score": best_score,
|
||||
"metric": metric,
|
||||
"action": best_action,
|
||||
"weights": best_weights,
|
||||
"metrics": best_metrics,
|
||||
"created_at": best_row["created_at"],
|
||||
}
|
||||
|
||||
|
||||
def _decode_json(payload: Any) -> Dict[str, Any]:
|
||||
if not payload:
|
||||
return {}
|
||||
if isinstance(payload, Mapping):
|
||||
return dict(payload)
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
|
||||
## 2. 数据与特征层
|
||||
- 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。
|
||||
- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。
|
||||
- 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。
|
||||
- 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。
|
||||
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
||||
|
||||
83
scripts/apply_best_weights.py
Normal file
83
scripts/apply_best_weights.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Apply or display the best tuning result for an experiment."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.utils.config import get_config, save_config
|
||||
from app.utils.tuning import select_best_tuning_result
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Apply best tuning weights")
|
||||
parser.add_argument("experiment_id", help="Experiment identifier")
|
||||
parser.add_argument(
|
||||
"--metric",
|
||||
default="reward",
|
||||
help="Metric name for ranking (default: reward)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ascending",
|
||||
action="store_true",
|
||||
help="Sort metric ascending instead of descending",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--require-weights",
|
||||
action="store_true",
|
||||
help="Ignore records without weight payload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-config",
|
||||
action="store_true",
|
||||
help="Update agent_weights in config with best result weights (fallback to action)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_cli(argv: Iterable[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||
|
||||
best = select_best_tuning_result(
|
||||
args.experiment_id,
|
||||
metric=args.metric,
|
||||
descending=not args.ascending,
|
||||
require_weights=args.require_weights,
|
||||
)
|
||||
if not best:
|
||||
LOGGER.error("未找到实验结果 experiment_id=%s", args.experiment_id)
|
||||
return 1
|
||||
|
||||
print(json.dumps(best, ensure_ascii=False, indent=2))
|
||||
|
||||
if args.apply_config:
|
||||
weights = best.get("weights") or best.get("action")
|
||||
if not weights:
|
||||
LOGGER.error("最佳结果缺少权重信息,无法更新配置")
|
||||
return 2
|
||||
cfg = get_config()
|
||||
if not cfg.agent_weights:
|
||||
LOGGER.warning("配置缺少 agent_weights,初始化默认值")
|
||||
cfg.agent_weights.update_from_dict(weights)
|
||||
save_config(cfg)
|
||||
LOGGER.info("已写入新的 agent_weights 至配置")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> None:
|
||||
raise SystemExit(run_cli())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
scripts/run_bandit_optimization.py
Normal file
124
scripts/run_bandit_optimization.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""Run epsilon-greedy bandit tuning on DecisionEnv."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.agents.registry import default_agents
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.backtest.engine import BtConfig
|
||||
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||
from app.utils.config import get_config
|
||||
|
||||
|
||||
def _parse_date(value: str) -> date:
|
||||
return datetime.strptime(value, "%Y%m%d").date()
|
||||
|
||||
|
||||
def _parse_param(text: str) -> ParameterSpec:
|
||||
parts = text.split(":")
|
||||
if len(parts) not in {3, 4}:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"parameter format must be name:target:min[:max]"
|
||||
)
|
||||
name, target, minimum = parts[:3]
|
||||
maximum = parts[3] if len(parts) == 4 else "1.0"
|
||||
return ParameterSpec(
|
||||
name=name,
|
||||
target=target,
|
||||
minimum=float(minimum),
|
||||
maximum=float(maximum),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_baseline_weights() -> dict:
|
||||
cfg = get_config()
|
||||
if cfg.agent_weights:
|
||||
return cfg.agent_weights.as_dict()
|
||||
return {agent.name: 1.0 for agent in default_agents()}
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
|
||||
parser.add_argument("experiment_id", help="Experiment identifier to log results")
|
||||
parser.add_argument("name", help="Backtest config name")
|
||||
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
|
||||
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
|
||||
parser.add_argument(
|
||||
"--universe",
|
||||
required=True,
|
||||
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--param",
|
||||
action="append",
|
||||
required=True,
|
||||
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
|
||||
)
|
||||
parser.add_argument("--episodes", type=int, default=20)
|
||||
parser.add_argument("--epsilon", type=float, default=0.2)
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
return parser
|
||||
|
||||
|
||||
def run_cli(argv: Iterable[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||
|
||||
if args.end < args.start:
|
||||
parser.error("end date must not precede start date")
|
||||
|
||||
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
|
||||
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
|
||||
bt_cfg = BtConfig(
|
||||
id=args.experiment_id,
|
||||
name=args.name,
|
||||
start_date=args.start,
|
||||
end_date=args.end,
|
||||
universe=universe,
|
||||
params={},
|
||||
)
|
||||
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=_resolve_baseline_weights(),
|
||||
)
|
||||
optimizer = EpsilonGreedyBandit(
|
||||
env,
|
||||
BanditConfig(
|
||||
experiment_id=args.experiment_id,
|
||||
episodes=args.episodes,
|
||||
epsilon=args.epsilon,
|
||||
seed=args.seed,
|
||||
),
|
||||
)
|
||||
summary = optimizer.run()
|
||||
best = summary.best_episode
|
||||
output = {
|
||||
"episodes": len(summary.episodes),
|
||||
"average_reward": summary.average_reward,
|
||||
"best": {
|
||||
"reward": best.reward if best else None,
|
||||
"action": best.action if best else None,
|
||||
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
|
||||
},
|
||||
}
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> None:
|
||||
raise SystemExit(run_cli())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
tests/test_backtest_engine_factors.py
Normal file
57
tests/test_backtest_engine_factors.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Verify BacktestEngine consumes persisted factor fields."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from app.backtest.engine import BacktestEngine, BtConfig
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="test",
|
||||
name="factor",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
|
||||
def fake_fetch_latest(ts_code, trade_date, fields): # noqa: D401
|
||||
assert "factors.mom_20" in fields
|
||||
return {
|
||||
"daily.close": 10.0,
|
||||
"daily.pct_chg": 0.02,
|
||||
"daily_basic.turnover_rate": 5.0,
|
||||
"daily_basic.volume_ratio": 15.0,
|
||||
"factors.mom_20": 0.12,
|
||||
"factors.mom_60": 0.25,
|
||||
"factors.volat_20": 0.05,
|
||||
"factors.turn_20": 3.0,
|
||||
"news.sentiment_index": 0.3,
|
||||
"news.heat_score": 0.4,
|
||||
"macro.industry_heat": 0.6,
|
||||
"macro.relative_strength": 0.7,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(engine.data_broker, "fetch_latest", fake_fetch_latest)
|
||||
monkeypatch.setattr(engine.data_broker, "fetch_series", lambda *args, **kwargs: [])
|
||||
monkeypatch.setattr(engine.data_broker, "fetch_flags", lambda *args, **kwargs: False)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def test_load_market_data_prefers_factors(engine):
|
||||
data = engine.load_market_data(date(2025, 1, 10))
|
||||
record = data["000001.SZ"]
|
||||
features = record["features"]
|
||||
assert features["mom_20"] == pytest.approx(0.12)
|
||||
assert features["mom_60"] == pytest.approx(0.25)
|
||||
assert features["volat_20"] == pytest.approx(0.05)
|
||||
assert features["turn_20"] == pytest.approx(3.0)
|
||||
assert features["news_sentiment"] == pytest.approx(0.3)
|
||||
assert features["news_heat"] == pytest.approx(0.4)
|
||||
assert features["risk_penalty"] == pytest.approx(min(1.0, 0.05 * 5.0))
|
||||
@ -5,9 +5,20 @@ from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
import json
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
|
||||
from app.backtest.engine import (
|
||||
BacktestEngine,
|
||||
BacktestResult,
|
||||
BtConfig,
|
||||
PortfolioState,
|
||||
_persist_backtest_results,
|
||||
)
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.config import DataPaths, get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
def _make_context(price: float, features: dict | None = None) -> AgentContext:
|
||||
@ -43,6 +54,20 @@ def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
|
||||
return BacktestEngine(cfg)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def isolated_db(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
tmp_root = tmp_path / "data"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
cfg.data_paths = DataPaths(root=tmp_root)
|
||||
initialize_database()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
|
||||
def test_buy_respects_risk_caps():
|
||||
engine = _engine_with_params(
|
||||
{
|
||||
@ -130,3 +155,56 @@ def test_sell_applies_slippage_and_fee():
|
||||
assert not state.holdings
|
||||
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
|
||||
assert not result.risk_events
|
||||
|
||||
|
||||
def test_persist_backtest_results_saves_risk_events(isolated_db):
|
||||
cfg = BtConfig(
|
||||
id="risk_cfg",
|
||||
name="risk",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
result = BacktestResult()
|
||||
result.nav_series = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
"nav": 100.0,
|
||||
"cash": 100.0,
|
||||
"market_value": 0.0,
|
||||
"realized_pnl": 0.0,
|
||||
"unrealized_pnl": 0.0,
|
||||
"turnover": 0.0,
|
||||
}
|
||||
]
|
||||
result.risk_events = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
"ts_code": "000001.SZ",
|
||||
"reason": "limit_up",
|
||||
"action": "buy_l",
|
||||
"target_weight": 0.3,
|
||||
"confidence": 0.8,
|
||||
}
|
||||
]
|
||||
|
||||
_persist_backtest_results(cfg, result)
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
risk_row = conn.execute(
|
||||
"SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?",
|
||||
(cfg.id,),
|
||||
).fetchone()
|
||||
assert risk_row is not None
|
||||
assert risk_row["reason"] == "limit_up"
|
||||
metadata = json.loads(risk_row["metadata"])
|
||||
assert metadata["action"] == "buy_l"
|
||||
|
||||
summary_row = conn.execute(
|
||||
"SELECT summary FROM bt_report WHERE cfg_id = ?",
|
||||
(cfg.id,),
|
||||
).fetchone()
|
||||
summary = json.loads(summary_row["summary"])
|
||||
assert summary["risk_events"] == 1
|
||||
assert summary["risk_breakdown"]["limit_up"] == 1
|
||||
|
||||
92
tests/test_bandit_optimizer.py
Normal file
92
tests/test_bandit_optimizer.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Tests for epsilon-greedy bandit optimizer."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.backtest.decision_env import EpisodeMetrics, ParameterSpec
|
||||
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||
from app.utils import tuning
|
||||
|
||||
|
||||
class DummyEnv:
|
||||
def __init__(self) -> None:
|
||||
self._specs = [
|
||||
ParameterSpec(name="w1", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)
|
||||
]
|
||||
self._last_metrics: EpisodeMetrics | None = None
|
||||
self._episode = 0
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def last_metrics(self) -> EpisodeMetrics | None:
|
||||
return self._last_metrics
|
||||
|
||||
def reset(self) -> dict:
|
||||
self._episode += 1
|
||||
return {"episode": float(self._episode)}
|
||||
|
||||
def step(self, action):
|
||||
value = float(action[0])
|
||||
reward = 1.0 - abs(value - 0.7)
|
||||
metrics = EpisodeMetrics(
|
||||
total_return=reward,
|
||||
max_drawdown=0.1,
|
||||
volatility=0.05,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=100.0,
|
||||
trade_count=0,
|
||||
risk_count=1,
|
||||
risk_breakdown={"test": 1},
|
||||
)
|
||||
self._last_metrics = metrics
|
||||
obs = {
|
||||
"total_return": reward,
|
||||
"max_drawdown": 0.1,
|
||||
"volatility": 0.05,
|
||||
"sharpe_like": reward / 0.05,
|
||||
"turnover": 100.0,
|
||||
"trade_count": 0.0,
|
||||
"risk_count": 1.0,
|
||||
}
|
||||
info = {
|
||||
"nav_series": [],
|
||||
"trades": [],
|
||||
"weights": {"A_mom": value},
|
||||
"risk_breakdown": metrics.risk_breakdown,
|
||||
"risk_events": [],
|
||||
}
|
||||
return obs, reward, True, info
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_logging(monkeypatch):
|
||||
records = []
|
||||
|
||||
def fake_log_tuning_result(**kwargs):
|
||||
records.append(kwargs)
|
||||
|
||||
monkeypatch.setattr(tuning, "log_tuning_result", fake_log_tuning_result)
|
||||
from app.backtest import optimizer as optimizer_module
|
||||
|
||||
monkeypatch.setattr(optimizer_module, "log_tuning_result", fake_log_tuning_result)
|
||||
return records
|
||||
|
||||
|
||||
def test_bandit_optimizer_runs_and_logs(patch_logging):
|
||||
env = DummyEnv()
|
||||
optimizer = EpsilonGreedyBandit(
|
||||
env,
|
||||
BanditConfig(experiment_id="exp", episodes=5, epsilon=0.5, seed=42),
|
||||
)
|
||||
summary = optimizer.run()
|
||||
|
||||
assert len(summary.episodes) == 5
|
||||
assert summary.best_episode is not None
|
||||
assert patch_logging and len(patch_logging) == 5
|
||||
payload = patch_logging[0]["metrics"]
|
||||
assert isinstance(payload, dict)
|
||||
assert "risk_breakdown" in payload
|
||||
92
tests/test_decision_env.py
Normal file
92
tests/test_decision_env.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Tests for DecisionEnv risk-aware reward and info outputs."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
||||
from app.backtest.engine import BacktestResult, BtConfig
|
||||
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
|
||||
self.cfg = cfg
|
||||
self.weights = {}
|
||||
self.department_manager = None
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
result = BacktestResult()
|
||||
result.nav_series = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
"nav": 102.0,
|
||||
"cash": 50.0,
|
||||
"market_value": 52.0,
|
||||
"realized_pnl": 1.0,
|
||||
"unrealized_pnl": 1.0,
|
||||
"turnover": 20000.0,
|
||||
}
|
||||
]
|
||||
result.trades = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
"ts_code": "000001.SZ",
|
||||
"action": "buy",
|
||||
"quantity": 100.0,
|
||||
"price": 100.0,
|
||||
"value": 10000.0,
|
||||
"fee": 5.0,
|
||||
}
|
||||
]
|
||||
result.risk_events = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
"ts_code": "000002.SZ",
|
||||
"reason": "limit_up",
|
||||
"action": "buy_l",
|
||||
"confidence": 0.7,
|
||||
"target_weight": 0.2,
|
||||
}
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="stub",
|
||||
name="stub",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)]
|
||||
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||
|
||||
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||
|
||||
obs, reward, done, info = env.step([0.8])
|
||||
|
||||
assert done is True
|
||||
assert "risk_count" in obs and obs["risk_count"] == 1.0
|
||||
assert obs["turnover"] == pytest.approx(20000.0)
|
||||
assert info["risk_events"][0]["reason"] == "limit_up"
|
||||
assert info["risk_breakdown"]["limit_up"] == 1
|
||||
assert reward < obs["total_return"]
|
||||
|
||||
|
||||
def test_default_reward_penalizes_metrics():
|
||||
metrics = EpisodeMetrics(
|
||||
total_return=0.1,
|
||||
max_drawdown=0.2,
|
||||
volatility=0.05,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=1000.0,
|
||||
trade_count=0,
|
||||
risk_count=2,
|
||||
risk_breakdown={"foo": 2},
|
||||
)
|
||||
reward = DecisionEnv._default_reward(metrics)
|
||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))
|
||||
81
tests/test_tuning_utils.py
Normal file
81
tests/test_tuning_utils.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Tests for tuning result selection and CLI application."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.config import DataPaths, get_config
|
||||
from app.utils.db import db_session
|
||||
from app.utils.tuning import select_best_tuning_result
|
||||
|
||||
import scripts.apply_best_weights as apply_best_weights
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def isolated_env(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
tmp_root = tmp_path / "data"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
cfg.data_paths = DataPaths(root=tmp_root)
|
||||
initialize_database()
|
||||
try:
|
||||
yield cfg
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
|
||||
def _insert_result(experiment: str, reward: float, metrics: dict, weights: dict | None = None, action: dict | None = None) -> None:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
experiment,
|
||||
"test",
|
||||
json.dumps(action or {}, ensure_ascii=False),
|
||||
json.dumps(weights or {}, ensure_ascii=False),
|
||||
reward,
|
||||
json.dumps(metrics, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_select_best_by_reward(isolated_env):
|
||||
_insert_result("exp", 0.1, {"risk_count": 2}, {"A_mom": 0.3})
|
||||
_insert_result("exp", 0.25, {"risk_count": 4}, {"A_mom": 0.6})
|
||||
|
||||
best = select_best_tuning_result("exp")
|
||||
assert best is not None
|
||||
assert best["reward"] == pytest.approx(0.25)
|
||||
assert best["weights"]["A_mom"] == pytest.approx(0.6)
|
||||
|
||||
|
||||
def test_select_best_by_metric(isolated_env):
|
||||
_insert_result("exp_metric", 0.2, {"risk_count": 5}, {"A_mom": 0.4})
|
||||
_insert_result("exp_metric", 0.1, {"risk_count": 2}, {"A_mom": 0.7})
|
||||
|
||||
best = select_best_tuning_result("exp_metric", metric="risk_count", descending=False)
|
||||
assert best is not None
|
||||
assert best["weights"]["A_mom"] == pytest.approx(0.7)
|
||||
assert best["metrics"]["risk_count"] == 2
|
||||
|
||||
|
||||
def test_apply_best_weights_cli_updates_config(isolated_env, capsys):
|
||||
cfg = isolated_env
|
||||
_insert_result("exp_cli", 0.3, {"risk_count": 1}, {"A_mom": 0.65, "A_val": 0.2})
|
||||
exit_code = apply_best_weights.run_cli([
|
||||
"exp_cli",
|
||||
"--apply-config",
|
||||
])
|
||||
assert exit_code == 0
|
||||
output = capsys.readouterr().out
|
||||
payload = json.loads(output)
|
||||
assert payload["metric"] == "reward"
|
||||
updated = cfg.agent_weights.as_dict()
|
||||
assert updated["A_mom"] == pytest.approx(0.65)
|
||||
assert updated["A_val"] == pytest.approx(0.2)
|
||||
Loading…
Reference in New Issue
Block a user