From 07e5bb1b6868d7ef575bb37084a83c4a63d288dd Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 30 Sep 2025 18:34:29 +0800 Subject: [PATCH] update --- app/backtest/decision_env.py | 40 +++++++- app/backtest/engine.py | 77 ++++++++++++-- app/backtest/optimizer.py | 139 ++++++++++++++++++++++++++ app/data/schema.py | 13 +++ app/utils/data_access.py | 54 ++++++---- app/utils/tuning.py | 96 +++++++++++++++++- docs/TODO.md | 1 + scripts/apply_best_weights.py | 83 +++++++++++++++ scripts/run_bandit_optimization.py | 124 +++++++++++++++++++++++ tests/test_backtest_engine_factors.py | 57 +++++++++++ tests/test_backtest_engine_risk.py | 80 ++++++++++++++- tests/test_bandit_optimizer.py | 92 +++++++++++++++++ tests/test_decision_env.py | 92 +++++++++++++++++ tests/test_tuning_utils.py | 81 +++++++++++++++ 14 files changed, 995 insertions(+), 34 deletions(-) create mode 100644 app/backtest/optimizer.py create mode 100644 scripts/apply_best_weights.py create mode 100644 scripts/run_bandit_optimization.py create mode 100644 tests/test_backtest_engine_factors.py create mode 100644 tests/test_bandit_optimizer.py create mode 100644 tests/test_decision_env.py create mode 100644 tests/test_tuning_utils.py diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index 6577e3a..394838b 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -36,6 +36,10 @@ class EpisodeMetrics: volatility: float nav_series: List[Dict[str, float]] trades: List[Dict[str, object]] + turnover: float + trade_count: int + risk_count: int + risk_breakdown: Dict[str, int] @property def sharpe_like(self) -> float: @@ -109,11 +113,16 @@ class DecisionEnv: "max_drawdown": metrics.max_drawdown, "volatility": metrics.volatility, "sharpe_like": metrics.sharpe_like, + "turnover": metrics.turnover, + "trade_count": float(metrics.trade_count), + "risk_count": float(metrics.risk_count), } info = { "nav_series": metrics.nav_series, "trades": metrics.trades, "weights": weights, + "risk_breakdown": metrics.risk_breakdown, + "risk_events": getattr(result, "risk_events", []), } return observation, reward, True, info @@ -131,7 +140,21 @@ class DecisionEnv: def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics: nav_series = result.nav_series or [] if not nav_series: - return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades) + risk_breakdown: Dict[str, int] = {} + for event in getattr(result, "risk_events", []) or []: + reason = str(event.get("reason") or "unknown") + risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1 + return EpisodeMetrics( + total_return=0.0, + max_drawdown=0.0, + volatility=0.0, + nav_series=[], + trades=result.trades, + turnover=0.0, + trade_count=len(result.trades or []), + risk_count=len(getattr(result, "risk_events", []) or []), + risk_breakdown=risk_breakdown, + ) nav_values = [row.get("nav", 0.0) for row in nav_series] if not nav_values or nav_values[0] == 0: @@ -158,17 +181,30 @@ class DecisionEnv: else: volatility = 0.0 + turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series) + risk_events = getattr(result, "risk_events", []) or [] + risk_breakdown: Dict[str, int] = {} + for event in risk_events: + reason = str(event.get("reason") or "unknown") + risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1 + return EpisodeMetrics( total_return=float(total_return), max_drawdown=float(max_drawdown), volatility=volatility, nav_series=nav_series, trades=result.trades, + turnover=float(turnover), + trade_count=len(result.trades or []), + risk_count=len(risk_events), + risk_breakdown=risk_breakdown, ) @staticmethod def _default_reward(metrics: EpisodeMetrics) -> float: - penalty = 0.5 * metrics.max_drawdown + risk_penalty = 0.05 * metrics.risk_count + turnover_penalty = 0.00001 * metrics.turnover + penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty return metrics.total_return - penalty @property diff --git a/app/backtest/engine.py b/app/backtest/engine.py index b092ce9..0c4b4b9 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -98,6 +98,12 @@ class BacktestEngine: "daily_basic.volume_ratio", "stk_limit.up_limit", "stk_limit.down_limit", + "factors.mom_20", + "factors.mom_60", + "factors.volat_20", + "factors.turn_20", + "news.sentiment_index", + "news.heat_score", } self.required_fields = sorted(base_scope | department_scope) @@ -121,10 +127,19 @@ class BacktestEngine: trade_date_str, window=60, ) - close_values = [value for _date, value in closes] - mom20 = momentum(close_values, 20) - mom60 = momentum(close_values, 60) - volat20 = volatility(close_values, 20) + close_values = [value for _date, value in closes if value is not None] + + mom20 = scope_values.get("factors.mom_20") + if mom20 is None and len(close_values) >= 20: + mom20 = momentum(close_values, 20) + + mom60 = scope_values.get("factors.mom_60") + if mom60 is None and len(close_values) >= 60: + mom60 = momentum(close_values, 60) + + volat20 = scope_values.get("factors.volat_20") + if volat20 is None and len(close_values) >= 2: + volat20 = volatility(close_values, 20) turnover_series = self.data_broker.fetch_series( "daily_basic", @@ -133,8 +148,20 @@ class BacktestEngine: trade_date_str, window=20, ) - turnover_values = [value for _date, value in turnover_series] - turn20 = rolling_mean(turnover_values, 20) + turnover_values = [value for _date, value in turnover_series if value is not None] + + turn20 = scope_values.get("factors.turn_20") + if turn20 is None and turnover_values: + turn20 = rolling_mean(turnover_values, 20) + + if mom20 is None: + mom20 = 0.0 + if mom60 is None: + mom60 = 0.0 + if volat20 is None: + volat20 = 0.0 + if turn20 is None: + turn20 = 0.0 liquidity_score = normalize(turn20, factor=20.0) cost_penalty = normalize( @@ -142,12 +169,15 @@ class BacktestEngine: factor=50.0, ) + sentiment_index = scope_values.get("news.sentiment_index", 0.0) + heat_score = scope_values.get("news.heat_score", 0.0) + scope_values.setdefault("news.sentiment_index", sentiment_index) + scope_values.setdefault("news.heat_score", heat_score) + scope_values.setdefault("factors.mom_20", mom20) scope_values.setdefault("factors.mom_60", mom60) scope_values.setdefault("factors.volat_20", volat20) scope_values.setdefault("factors.turn_20", turn20) - scope_values.setdefault("news.sentiment_index", 0.0) - scope_values.setdefault("news.heat_score", 0.0) if scope_values.get("macro.industry_heat") is None: scope_values["macro.industry_heat"] = 0.5 if scope_values.get("macro.relative_strength") is None: @@ -189,8 +219,8 @@ class BacktestEngine: "turn_20": turn20, "liquidity_score": liquidity_score, "cost_penalty": cost_penalty, - "news_heat": scope_values.get("news.heat_score", 0.0), - "news_sentiment": scope_values.get("news.sentiment_index", 0.0), + "news_heat": heat_score, + "news_sentiment": sentiment_index, "industry_heat": scope_values.get("macro.industry_heat", 0.0), "industry_relative_mom": scope_values.get( "macro.relative_strength", @@ -818,6 +848,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: nav_rows: List[tuple] = [] trade_rows: List[tuple] = [] + risk_rows: List[tuple] = [] summary_payload: Dict[str, object] = {} turnover_sum = 0.0 @@ -893,6 +924,10 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: "confidence": trade.get("confidence"), "target_weight": trade.get("target_weight"), "value": trade.get("value"), + "fee": trade.get("fee"), + "slippage": trade.get("slippage"), + "risk_penalty": trade.get("risk_penalty"), + "liquidity_score": trade.get("liquidity_score"), } trade_rows.append( ( @@ -913,6 +948,18 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: for event in result.risk_events: reason = str(event.get("reason") or "unknown") breakdown[reason] = breakdown.get(reason, 0) + 1 + risk_rows.append( + ( + cfg.id, + str(event.get("trade_date", "")), + str(event.get("ts_code", "")), + reason, + str(event.get("action", "")), + float(event.get("target_weight", 0.0) or 0.0), + float(event.get("confidence", 0.0) or 0.0), + json.dumps(event, ensure_ascii=False), + ) + ) summary_payload["risk_breakdown"] = breakdown cfg_payload = { @@ -943,6 +990,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,)) conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,)) + conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,)) conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,)) if nav_rows: @@ -963,6 +1011,15 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: trade_rows, ) + if risk_rows: + conn.executemany( + """ + INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + risk_rows, + ) + summary_payload.setdefault("universe", cfg.universe) summary_payload.setdefault("method", cfg.method) conn.execute( diff --git a/app/backtest/optimizer.py b/app/backtest/optimizer.py new file mode 100644 index 0000000..8e630a9 --- /dev/null +++ b/app/backtest/optimizer.py @@ -0,0 +1,139 @@ +"""Optimization utilities for DecisionEnv-based parameter tuning.""" +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Sequence, Tuple + +from app.backtest.decision_env import DecisionEnv, EpisodeMetrics +from app.backtest.decision_env import ParameterSpec +from app.utils.logging import get_logger +from app.utils.tuning import log_tuning_result + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "decision_bandit"} + + +@dataclass +class BanditConfig: + """Configuration for epsilon-greedy bandit optimization.""" + + experiment_id: str + strategy: str = "epsilon_greedy" + episodes: int = 20 + epsilon: float = 0.2 + seed: int | None = None + + +@dataclass +class BanditEpisode: + action: Dict[str, float] + reward: float + metrics: EpisodeMetrics + observation: Dict[str, float] + + +@dataclass +class BanditSummary: + episodes: List[BanditEpisode] = field(default_factory=list) + + @property + def best_episode(self) -> BanditEpisode | None: + if not self.episodes: + return None + return max(self.episodes, key=lambda item: item.reward) + + @property + def average_reward(self) -> float: + if not self.episodes: + return 0.0 + return sum(item.reward for item in self.episodes) / len(self.episodes) + + +class EpsilonGreedyBandit: + """Simple epsilon-greedy tuner using DecisionEnv as the reward oracle.""" + + def __init__(self, env: DecisionEnv, config: BanditConfig) -> None: + self.env = env + self.config = config + self._random = random.Random(config.seed) + self._specs: List[ParameterSpec] = list(getattr(env, "_specs", [])) + if not self._specs: + raise ValueError("DecisionEnv does not expose parameter specs") + self._value_estimates: Dict[Tuple[float, ...], float] = {} + self._counts: Dict[Tuple[float, ...], int] = {} + self._history = BanditSummary() + + def run(self) -> BanditSummary: + for episode in range(1, self.config.episodes + 1): + action = self._select_action() + self.env.reset() + obs, reward, done, info = self.env.step(action) + metrics = self.env.last_metrics + if metrics is None: + raise RuntimeError("DecisionEnv did not populate last_metrics") + key = tuple(action) + old_estimate = self._value_estimates.get(key, 0.0) + count = self._counts.get(key, 0) + 1 + self._counts[key] = count + self._value_estimates[key] = old_estimate + (reward - old_estimate) / count + + action_payload = self._action_to_mapping(action) + metrics_payload = _metrics_to_dict(metrics) + try: + log_tuning_result( + experiment_id=self.config.experiment_id, + strategy=self.config.strategy, + action=action_payload, + reward=reward, + metrics=metrics_payload, + weights=info.get("weights"), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("failed to log tuning result", extra=LOG_EXTRA) + + episode_record = BanditEpisode( + action=action_payload, + reward=reward, + metrics=metrics, + observation=obs, + ) + self._history.episodes.append(episode_record) + LOGGER.info( + "Bandit episode=%s reward=%.4f action=%s", + episode, + reward, + action_payload, + extra=LOG_EXTRA, + ) + return self._history + + def _select_action(self) -> List[float]: + if self._value_estimates and self._random.random() > self.config.epsilon: + best = max(self._value_estimates.items(), key=lambda item: item[1])[0] + return list(best) + return [ + self._random.uniform(spec.minimum, spec.maximum) + for spec in self._specs + ] + + def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]: + return { + spec.name: float(value) + for spec, value in zip(self._specs, action, strict=True) + } + + +def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]: + payload: Dict[str, float | Dict[str, int]] = { + "total_return": metrics.total_return, + "max_drawdown": metrics.max_drawdown, + "volatility": metrics.volatility, + "sharpe_like": metrics.sharpe_like, + "turnover": metrics.turnover, + "trade_count": float(metrics.trade_count), + "risk_count": float(metrics.risk_count), + } + if metrics.risk_breakdown: + payload["risk_breakdown"] = dict(metrics.risk_breakdown) + return payload diff --git a/app/data/schema.py b/app/data/schema.py index 282ab2b..4fa435e 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -327,6 +327,18 @@ SCHEMA_STATEMENTS: Iterable[str] = ( ); """, """ + CREATE TABLE IF NOT EXISTS bt_risk_events ( + cfg_id TEXT, + trade_date TEXT, + ts_code TEXT, + reason TEXT, + action TEXT, + target_weight REAL, + confidence REAL, + metadata TEXT + ); + """, + """ CREATE TABLE IF NOT EXISTS bt_nav ( cfg_id TEXT, trade_date TEXT, @@ -472,6 +484,7 @@ REQUIRED_TABLES = ( "heat_daily", "bt_config", "bt_trades", + "bt_risk_events", "bt_nav", "bt_report", "run_log", diff --git a/app/utils/data_access.py b/app/utils/data_access.py index f52546f..a7ad04f 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -118,13 +118,12 @@ class DataBroker: if cached is not None: return deepcopy(cached) - grouped: Dict[str, List[str]] = {} - field_map: Dict[Tuple[str, str], List[str]] = {} + grouped: Dict[str, List[Tuple[str, str]]] = {} derived_cache: Dict[str, Any] = {} results: Dict[str, Any] = {} for field_name in field_list: - resolved = self.resolve_field(field_name) - if not resolved: + parsed = parse_field_path(field_name) + if not parsed: derived = self._resolve_derived_field( ts_code, trade_date, @@ -134,11 +133,8 @@ class DataBroker: if derived is not None: results[field_name] = derived continue - table, column = resolved - grouped.setdefault(table, []) - if column not in grouped[table]: - grouped[table].append(column) - field_map.setdefault((table, column), []).append(field_name) + table, column = parsed + grouped.setdefault(table, []).append((column, field_name)) if not grouped: if cache_key is not None and results: @@ -152,10 +148,9 @@ class DataBroker: try: with db_session(read_only=True) as conn: - for table, columns in grouped.items(): - joined_cols = ", ".join(columns) + for table, items in grouped.items(): query = ( - f"SELECT trade_date, {joined_cols} FROM {table} " + f"SELECT * FROM {table} " "WHERE ts_code = ? AND trade_date <= ? " "ORDER BY trade_date DESC LIMIT 1" ) @@ -165,22 +160,25 @@ class DataBroker: LOGGER.debug( "查询失败 table=%s fields=%s err=%s", table, - columns, + [column for column, _field in items], exc, extra=LOG_EXTRA, ) continue if not row: continue - for column in columns: - value = row[column] + available = row.keys() + for column, original in items: + resolved_column = self._resolve_column_in_row(table, column, available) + if resolved_column is None: + continue + value = row[resolved_column] if value is None: continue - for original in field_map.get((table, column), [f"{table}.{column}"]): - try: - results[original] = float(value) - except (TypeError, ValueError): - results[original] = value + try: + results[original] = float(value) + except (TypeError, ValueError): + results[original] = value except sqlite3.OperationalError as exc: LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) if cache_key is not None: @@ -698,6 +696,22 @@ class DataBroker: while len(cache) > limit: cache.popitem(last=False) + def _resolve_column_in_row( + self, + table: str, + column: str, + available: Sequence[str], + ) -> Optional[str]: + alias_map = self.FIELD_ALIASES.get(table, {}) + candidate = alias_map.get(column, column) + if candidate in available: + return candidate + lowered = candidate.lower() + for name in available: + if name.lower() == lowered: + return name + return None + def _resolve_column(self, table: str, column: str) -> Optional[str]: columns = self._get_table_columns(table) if columns is None: diff --git a/app/utils/tuning.py b/app/utils/tuning.py index c4bd4c3..fb269a0 100644 --- a/app/utils/tuning.py +++ b/app/utils/tuning.py @@ -2,7 +2,8 @@ from __future__ import annotations import json -from typing import Any, Dict, Optional + +from typing import Any, Dict, Mapping, Optional from .db import db_session from .logging import get_logger @@ -40,3 +41,96 @@ def log_tuning_result( ) except Exception: # noqa: BLE001 LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + + +def select_best_tuning_result( + experiment_id: str, + *, + metric: str = "reward", + descending: bool = True, + require_weights: bool = False, +) -> Optional[Dict[str, Any]]: + """Return the best tuning result for the given experiment. + + ``metric`` may refer to ``reward`` (default) or any key inside the + persisted metrics payload. When ``require_weights`` is True, rows lacking + weight definitions are ignored. + """ + + with db_session(read_only=True) as conn: + rows = conn.execute( + """ + SELECT id, action, weights, reward, metrics, created_at + FROM tuning_results + WHERE experiment_id = ? + """, + (experiment_id,), + ).fetchall() + + if not rows: + return None + + best_row: Optional[Mapping[str, Any]] = None + best_metrics: Dict[str, Any] = {} + best_action: Dict[str, float] = {} + best_weights: Dict[str, float] = {} + best_score: Optional[float] = None + + for row in rows: + action = _decode_json(row["action"]) + weights = _decode_json(row["weights"]) + metrics_payload = _decode_json(row["metrics"]) + reward_value = float(row["reward"] or 0.0) + + if require_weights and not weights: + continue + + if metric == "reward": + score = reward_value + else: + score_raw = metrics_payload.get(metric) + if score_raw is None: + continue + try: + score = float(score_raw) + except (TypeError, ValueError): + continue + + if best_score is None: + choose = True + else: + choose = score > best_score if descending else score < best_score + + if choose: + best_score = score + best_row = row + best_metrics = metrics_payload + best_action = action + best_weights = weights + + if best_row is None: + return None + + return { + "id": best_row["id"], + "reward": float(best_row["reward"] or 0.0), + "score": best_score, + "metric": metric, + "action": best_action, + "weights": best_weights, + "metrics": best_metrics, + "created_at": best_row["created_at"], + } + + +def _decode_json(payload: Any) -> Dict[str, Any]: + if not payload: + return {} + if isinstance(payload, Mapping): + return dict(payload) + if isinstance(payload, str): + try: + return json.loads(payload) + except json.JSONDecodeError: + return {} + return {} diff --git a/docs/TODO.md b/docs/TODO.md index 241e874..072d72a 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -13,6 +13,7 @@ ## 2. 数据与特征层 - 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。 +- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。 - 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。 - 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。 - 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。 diff --git a/scripts/apply_best_weights.py b/scripts/apply_best_weights.py new file mode 100644 index 0000000..3170efa --- /dev/null +++ b/scripts/apply_best_weights.py @@ -0,0 +1,83 @@ +"""Apply or display the best tuning result for an experiment.""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Iterable + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.utils.config import get_config, save_config +from app.utils.tuning import select_best_tuning_result +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Apply best tuning weights") + parser.add_argument("experiment_id", help="Experiment identifier") + parser.add_argument( + "--metric", + default="reward", + help="Metric name for ranking (default: reward)", + ) + parser.add_argument( + "--ascending", + action="store_true", + help="Sort metric ascending instead of descending", + ) + parser.add_argument( + "--require-weights", + action="store_true", + help="Ignore records without weight payload", + ) + parser.add_argument( + "--apply-config", + action="store_true", + help="Update agent_weights in config with best result weights (fallback to action)", + ) + return parser + + +def run_cli(argv: Iterable[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(list(argv) if argv is not None else None) + + best = select_best_tuning_result( + args.experiment_id, + metric=args.metric, + descending=not args.ascending, + require_weights=args.require_weights, + ) + if not best: + LOGGER.error("未找到实验结果 experiment_id=%s", args.experiment_id) + return 1 + + print(json.dumps(best, ensure_ascii=False, indent=2)) + + if args.apply_config: + weights = best.get("weights") or best.get("action") + if not weights: + LOGGER.error("最佳结果缺少权重信息,无法更新配置") + return 2 + cfg = get_config() + if not cfg.agent_weights: + LOGGER.warning("配置缺少 agent_weights,初始化默认值") + cfg.agent_weights.update_from_dict(weights) + save_config(cfg) + LOGGER.info("已写入新的 agent_weights 至配置") + + return 0 + + +def main() -> None: + raise SystemExit(run_cli()) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_bandit_optimization.py b/scripts/run_bandit_optimization.py new file mode 100644 index 0000000..55b728b --- /dev/null +++ b/scripts/run_bandit_optimization.py @@ -0,0 +1,124 @@ +"""Run epsilon-greedy bandit tuning on DecisionEnv.""" +from __future__ import annotations + +import argparse +import json +import sys +from datetime import datetime, date +from pathlib import Path +from typing import Iterable, List + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.agents.registry import default_agents +from app.backtest.decision_env import DecisionEnv, ParameterSpec +from app.backtest.engine import BtConfig +from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit +from app.utils.config import get_config + + +def _parse_date(value: str) -> date: + return datetime.strptime(value, "%Y%m%d").date() + + +def _parse_param(text: str) -> ParameterSpec: + parts = text.split(":") + if len(parts) not in {3, 4}: + raise argparse.ArgumentTypeError( + "parameter format must be name:target:min[:max]" + ) + name, target, minimum = parts[:3] + maximum = parts[3] if len(parts) == 4 else "1.0" + return ParameterSpec( + name=name, + target=target, + minimum=float(minimum), + maximum=float(maximum), + ) + + +def _resolve_baseline_weights() -> dict: + cfg = get_config() + if cfg.agent_weights: + return cfg.agent_weights.as_dict() + return {agent.name: 1.0 for agent in default_agents()} + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer") + parser.add_argument("experiment_id", help="Experiment identifier to log results") + parser.add_argument("name", help="Backtest config name") + parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD") + parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD") + parser.add_argument( + "--universe", + required=True, + help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ", + ) + parser.add_argument( + "--param", + action="append", + required=True, + help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)", + ) + parser.add_argument("--episodes", type=int, default=20) + parser.add_argument("--epsilon", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=None) + return parser + + +def run_cli(argv: Iterable[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(list(argv) if argv is not None else None) + + if args.end < args.start: + parser.error("end date must not precede start date") + + specs: List[ParameterSpec] = [_parse_param(item) for item in args.param] + universe = [token.strip() for token in args.universe.split(",") if token.strip()] + bt_cfg = BtConfig( + id=args.experiment_id, + name=args.name, + start_date=args.start, + end_date=args.end, + universe=universe, + params={}, + ) + + env = DecisionEnv( + bt_config=bt_cfg, + parameter_specs=specs, + baseline_weights=_resolve_baseline_weights(), + ) + optimizer = EpsilonGreedyBandit( + env, + BanditConfig( + experiment_id=args.experiment_id, + episodes=args.episodes, + epsilon=args.epsilon, + seed=args.seed, + ), + ) + summary = optimizer.run() + best = summary.best_episode + output = { + "episodes": len(summary.episodes), + "average_reward": summary.average_reward, + "best": { + "reward": best.reward if best else None, + "action": best.action if best else None, + "metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None, + }, + } + print(json.dumps(output, ensure_ascii=False, indent=2)) + return 0 + + +def main() -> None: + raise SystemExit(run_cli()) + + +if __name__ == "__main__": + main() diff --git a/tests/test_backtest_engine_factors.py b/tests/test_backtest_engine_factors.py new file mode 100644 index 0000000..2da5661 --- /dev/null +++ b/tests/test_backtest_engine_factors.py @@ -0,0 +1,57 @@ +"""Verify BacktestEngine consumes persisted factor fields.""" +from __future__ import annotations + +from datetime import date + +import pytest + +from app.backtest.engine import BacktestEngine, BtConfig + + +@pytest.fixture() +def engine(monkeypatch): + cfg = BtConfig( + id="test", + name="factor", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params={}, + ) + engine = BacktestEngine(cfg) + + def fake_fetch_latest(ts_code, trade_date, fields): # noqa: D401 + assert "factors.mom_20" in fields + return { + "daily.close": 10.0, + "daily.pct_chg": 0.02, + "daily_basic.turnover_rate": 5.0, + "daily_basic.volume_ratio": 15.0, + "factors.mom_20": 0.12, + "factors.mom_60": 0.25, + "factors.volat_20": 0.05, + "factors.turn_20": 3.0, + "news.sentiment_index": 0.3, + "news.heat_score": 0.4, + "macro.industry_heat": 0.6, + "macro.relative_strength": 0.7, + } + + monkeypatch.setattr(engine.data_broker, "fetch_latest", fake_fetch_latest) + monkeypatch.setattr(engine.data_broker, "fetch_series", lambda *args, **kwargs: []) + monkeypatch.setattr(engine.data_broker, "fetch_flags", lambda *args, **kwargs: False) + + return engine + + +def test_load_market_data_prefers_factors(engine): + data = engine.load_market_data(date(2025, 1, 10)) + record = data["000001.SZ"] + features = record["features"] + assert features["mom_20"] == pytest.approx(0.12) + assert features["mom_60"] == pytest.approx(0.25) + assert features["volat_20"] == pytest.approx(0.05) + assert features["turn_20"] == pytest.approx(3.0) + assert features["news_sentiment"] == pytest.approx(0.3) + assert features["news_heat"] == pytest.approx(0.4) + assert features["risk_penalty"] == pytest.approx(min(1.0, 0.05 * 5.0)) diff --git a/tests/test_backtest_engine_risk.py b/tests/test_backtest_engine_risk.py index e7df919..794f9e8 100644 --- a/tests/test_backtest_engine_risk.py +++ b/tests/test_backtest_engine_risk.py @@ -5,9 +5,20 @@ from datetime import date import pytest +import json + from app.agents.base import AgentAction, AgentContext from app.agents.game import Decision -from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState +from app.backtest.engine import ( + BacktestEngine, + BacktestResult, + BtConfig, + PortfolioState, + _persist_backtest_results, +) +from app.data.schema import initialize_database +from app.utils.config import DataPaths, get_config +from app.utils.db import db_session def _make_context(price: float, features: dict | None = None) -> AgentContext: @@ -43,6 +54,20 @@ def _engine_with_params(params: dict[str, float]) -> BacktestEngine: return BacktestEngine(cfg) +@pytest.fixture() +def isolated_db(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + tmp_root = tmp_path / "data" + tmp_root.mkdir(parents=True, exist_ok=True) + cfg.data_paths = DataPaths(root=tmp_root) + initialize_database() + try: + yield + finally: + cfg.data_paths = original_paths + + def test_buy_respects_risk_caps(): engine = _engine_with_params( { @@ -130,3 +155,56 @@ def test_sell_applies_slippage_and_fee(): assert not state.holdings assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"]) assert not result.risk_events + + +def test_persist_backtest_results_saves_risk_events(isolated_db): + cfg = BtConfig( + id="risk_cfg", + name="risk", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params={}, + ) + result = BacktestResult() + result.nav_series = [ + { + "trade_date": "2025-01-10", + "nav": 100.0, + "cash": 100.0, + "market_value": 0.0, + "realized_pnl": 0.0, + "unrealized_pnl": 0.0, + "turnover": 0.0, + } + ] + result.risk_events = [ + { + "trade_date": "2025-01-10", + "ts_code": "000001.SZ", + "reason": "limit_up", + "action": "buy_l", + "target_weight": 0.3, + "confidence": 0.8, + } + ] + + _persist_backtest_results(cfg, result) + + with db_session(read_only=True) as conn: + risk_row = conn.execute( + "SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?", + (cfg.id,), + ).fetchone() + assert risk_row is not None + assert risk_row["reason"] == "limit_up" + metadata = json.loads(risk_row["metadata"]) + assert metadata["action"] == "buy_l" + + summary_row = conn.execute( + "SELECT summary FROM bt_report WHERE cfg_id = ?", + (cfg.id,), + ).fetchone() + summary = json.loads(summary_row["summary"]) + assert summary["risk_events"] == 1 + assert summary["risk_breakdown"]["limit_up"] == 1 diff --git a/tests/test_bandit_optimizer.py b/tests/test_bandit_optimizer.py new file mode 100644 index 0000000..9ebae46 --- /dev/null +++ b/tests/test_bandit_optimizer.py @@ -0,0 +1,92 @@ +"""Tests for epsilon-greedy bandit optimizer.""" +from __future__ import annotations + +import pytest + +from app.backtest.decision_env import EpisodeMetrics, ParameterSpec +from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit +from app.utils import tuning + + +class DummyEnv: + def __init__(self) -> None: + self._specs = [ + ParameterSpec(name="w1", target="agent_weights.A_mom", minimum=0.0, maximum=1.0) + ] + self._last_metrics: EpisodeMetrics | None = None + self._episode = 0 + + @property + def action_dim(self) -> int: + return 1 + + @property + def last_metrics(self) -> EpisodeMetrics | None: + return self._last_metrics + + def reset(self) -> dict: + self._episode += 1 + return {"episode": float(self._episode)} + + def step(self, action): + value = float(action[0]) + reward = 1.0 - abs(value - 0.7) + metrics = EpisodeMetrics( + total_return=reward, + max_drawdown=0.1, + volatility=0.05, + nav_series=[], + trades=[], + turnover=100.0, + trade_count=0, + risk_count=1, + risk_breakdown={"test": 1}, + ) + self._last_metrics = metrics + obs = { + "total_return": reward, + "max_drawdown": 0.1, + "volatility": 0.05, + "sharpe_like": reward / 0.05, + "turnover": 100.0, + "trade_count": 0.0, + "risk_count": 1.0, + } + info = { + "nav_series": [], + "trades": [], + "weights": {"A_mom": value}, + "risk_breakdown": metrics.risk_breakdown, + "risk_events": [], + } + return obs, reward, True, info + + +@pytest.fixture(autouse=True) +def patch_logging(monkeypatch): + records = [] + + def fake_log_tuning_result(**kwargs): + records.append(kwargs) + + monkeypatch.setattr(tuning, "log_tuning_result", fake_log_tuning_result) + from app.backtest import optimizer as optimizer_module + + monkeypatch.setattr(optimizer_module, "log_tuning_result", fake_log_tuning_result) + return records + + +def test_bandit_optimizer_runs_and_logs(patch_logging): + env = DummyEnv() + optimizer = EpsilonGreedyBandit( + env, + BanditConfig(experiment_id="exp", episodes=5, epsilon=0.5, seed=42), + ) + summary = optimizer.run() + + assert len(summary.episodes) == 5 + assert summary.best_episode is not None + assert patch_logging and len(patch_logging) == 5 + payload = patch_logging[0]["metrics"] + assert isinstance(payload, dict) + assert "risk_breakdown" in payload diff --git a/tests/test_decision_env.py b/tests/test_decision_env.py new file mode 100644 index 0000000..c3e1d41 --- /dev/null +++ b/tests/test_decision_env.py @@ -0,0 +1,92 @@ +"""Tests for DecisionEnv risk-aware reward and info outputs.""" +from __future__ import annotations + +from datetime import date + +import pytest + +from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec +from app.backtest.engine import BacktestResult, BtConfig + + +class _StubEngine: + def __init__(self, cfg: BtConfig) -> None: # noqa: D401 + self.cfg = cfg + self.weights = {} + self.department_manager = None + + def run(self) -> BacktestResult: + result = BacktestResult() + result.nav_series = [ + { + "trade_date": "2025-01-10", + "nav": 102.0, + "cash": 50.0, + "market_value": 52.0, + "realized_pnl": 1.0, + "unrealized_pnl": 1.0, + "turnover": 20000.0, + } + ] + result.trades = [ + { + "trade_date": "2025-01-10", + "ts_code": "000001.SZ", + "action": "buy", + "quantity": 100.0, + "price": 100.0, + "value": 10000.0, + "fee": 5.0, + } + ] + result.risk_events = [ + { + "trade_date": "2025-01-10", + "ts_code": "000002.SZ", + "reason": "limit_up", + "action": "buy_l", + "confidence": 0.7, + "target_weight": 0.2, + } + ] + return result + + +def test_decision_env_returns_risk_metrics(monkeypatch): + cfg = BtConfig( + id="stub", + name="stub", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params={}, + ) + specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)] + env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5}) + + monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine) + + obs, reward, done, info = env.step([0.8]) + + assert done is True + assert "risk_count" in obs and obs["risk_count"] == 1.0 + assert obs["turnover"] == pytest.approx(20000.0) + assert info["risk_events"][0]["reason"] == "limit_up" + assert info["risk_breakdown"]["limit_up"] == 1 + assert reward < obs["total_return"] + + +def test_default_reward_penalizes_metrics(): + metrics = EpisodeMetrics( + total_return=0.1, + max_drawdown=0.2, + volatility=0.05, + nav_series=[], + trades=[], + turnover=1000.0, + trade_count=0, + risk_count=2, + risk_breakdown={"foo": 2}, + ) + reward = DecisionEnv._default_reward(metrics) + assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0)) diff --git a/tests/test_tuning_utils.py b/tests/test_tuning_utils.py new file mode 100644 index 0000000..0e5753d --- /dev/null +++ b/tests/test_tuning_utils.py @@ -0,0 +1,81 @@ +"""Tests for tuning result selection and CLI application.""" +from __future__ import annotations + +import json + +import pytest + +from app.data.schema import initialize_database +from app.utils.config import DataPaths, get_config +from app.utils.db import db_session +from app.utils.tuning import select_best_tuning_result + +import scripts.apply_best_weights as apply_best_weights + + +@pytest.fixture() +def isolated_env(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + tmp_root = tmp_path / "data" + tmp_root.mkdir(parents=True, exist_ok=True) + cfg.data_paths = DataPaths(root=tmp_root) + initialize_database() + try: + yield cfg + finally: + cfg.data_paths = original_paths + + +def _insert_result(experiment: str, reward: float, metrics: dict, weights: dict | None = None, action: dict | None = None) -> None: + with db_session() as conn: + conn.execute( + """ + INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + experiment, + "test", + json.dumps(action or {}, ensure_ascii=False), + json.dumps(weights or {}, ensure_ascii=False), + reward, + json.dumps(metrics, ensure_ascii=False), + ), + ) + + +def test_select_best_by_reward(isolated_env): + _insert_result("exp", 0.1, {"risk_count": 2}, {"A_mom": 0.3}) + _insert_result("exp", 0.25, {"risk_count": 4}, {"A_mom": 0.6}) + + best = select_best_tuning_result("exp") + assert best is not None + assert best["reward"] == pytest.approx(0.25) + assert best["weights"]["A_mom"] == pytest.approx(0.6) + + +def test_select_best_by_metric(isolated_env): + _insert_result("exp_metric", 0.2, {"risk_count": 5}, {"A_mom": 0.4}) + _insert_result("exp_metric", 0.1, {"risk_count": 2}, {"A_mom": 0.7}) + + best = select_best_tuning_result("exp_metric", metric="risk_count", descending=False) + assert best is not None + assert best["weights"]["A_mom"] == pytest.approx(0.7) + assert best["metrics"]["risk_count"] == 2 + + +def test_apply_best_weights_cli_updates_config(isolated_env, capsys): + cfg = isolated_env + _insert_result("exp_cli", 0.3, {"risk_count": 1}, {"A_mom": 0.65, "A_val": 0.2}) + exit_code = apply_best_weights.run_cli([ + "exp_cli", + "--apply-config", + ]) + assert exit_code == 0 + output = capsys.readouterr().out + payload = json.loads(output) + assert payload["metric"] == "reward" + updated = cfg.agent_weights.as_dict() + assert updated["A_mom"] == pytest.approx(0.65) + assert updated["A_val"] == pytest.approx(0.2)