This commit is contained in:
sam 2025-09-30 18:34:29 +08:00
parent 8befd80cb7
commit 07e5bb1b68
14 changed files with 995 additions and 34 deletions

View File

@ -36,6 +36,10 @@ class EpisodeMetrics:
volatility: float volatility: float
nav_series: List[Dict[str, float]] nav_series: List[Dict[str, float]]
trades: List[Dict[str, object]] trades: List[Dict[str, object]]
turnover: float
trade_count: int
risk_count: int
risk_breakdown: Dict[str, int]
@property @property
def sharpe_like(self) -> float: def sharpe_like(self) -> float:
@ -109,11 +113,16 @@ class DecisionEnv:
"max_drawdown": metrics.max_drawdown, "max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility, "volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like, "sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
} }
info = { info = {
"nav_series": metrics.nav_series, "nav_series": metrics.nav_series,
"trades": metrics.trades, "trades": metrics.trades,
"weights": weights, "weights": weights,
"risk_breakdown": metrics.risk_breakdown,
"risk_events": getattr(result, "risk_events", []),
} }
return observation, reward, True, info return observation, reward, True, info
@ -131,7 +140,21 @@ class DecisionEnv:
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics: def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
nav_series = result.nav_series or [] nav_series = result.nav_series or []
if not nav_series: if not nav_series:
return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades) risk_breakdown: Dict[str, int] = {}
for event in getattr(result, "risk_events", []) or []:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=0.0,
max_drawdown=0.0,
volatility=0.0,
nav_series=[],
trades=result.trades,
turnover=0.0,
trade_count=len(result.trades or []),
risk_count=len(getattr(result, "risk_events", []) or []),
risk_breakdown=risk_breakdown,
)
nav_values = [row.get("nav", 0.0) for row in nav_series] nav_values = [row.get("nav", 0.0) for row in nav_series]
if not nav_values or nav_values[0] == 0: if not nav_values or nav_values[0] == 0:
@ -158,17 +181,30 @@ class DecisionEnv:
else: else:
volatility = 0.0 volatility = 0.0
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
risk_events = getattr(result, "risk_events", []) or []
risk_breakdown: Dict[str, int] = {}
for event in risk_events:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics( return EpisodeMetrics(
total_return=float(total_return), total_return=float(total_return),
max_drawdown=float(max_drawdown), max_drawdown=float(max_drawdown),
volatility=volatility, volatility=volatility,
nav_series=nav_series, nav_series=nav_series,
trades=result.trades, trades=result.trades,
turnover=float(turnover),
trade_count=len(result.trades or []),
risk_count=len(risk_events),
risk_breakdown=risk_breakdown,
) )
@staticmethod @staticmethod
def _default_reward(metrics: EpisodeMetrics) -> float: def _default_reward(metrics: EpisodeMetrics) -> float:
penalty = 0.5 * metrics.max_drawdown risk_penalty = 0.05 * metrics.risk_count
turnover_penalty = 0.00001 * metrics.turnover
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
return metrics.total_return - penalty return metrics.total_return - penalty
@property @property

View File

@ -98,6 +98,12 @@ class BacktestEngine:
"daily_basic.volume_ratio", "daily_basic.volume_ratio",
"stk_limit.up_limit", "stk_limit.up_limit",
"stk_limit.down_limit", "stk_limit.down_limit",
"factors.mom_20",
"factors.mom_60",
"factors.volat_20",
"factors.turn_20",
"news.sentiment_index",
"news.heat_score",
} }
self.required_fields = sorted(base_scope | department_scope) self.required_fields = sorted(base_scope | department_scope)
@ -121,9 +127,18 @@ class BacktestEngine:
trade_date_str, trade_date_str,
window=60, window=60,
) )
close_values = [value for _date, value in closes] close_values = [value for _date, value in closes if value is not None]
mom20 = scope_values.get("factors.mom_20")
if mom20 is None and len(close_values) >= 20:
mom20 = momentum(close_values, 20) mom20 = momentum(close_values, 20)
mom60 = scope_values.get("factors.mom_60")
if mom60 is None and len(close_values) >= 60:
mom60 = momentum(close_values, 60) mom60 = momentum(close_values, 60)
volat20 = scope_values.get("factors.volat_20")
if volat20 is None and len(close_values) >= 2:
volat20 = volatility(close_values, 20) volat20 = volatility(close_values, 20)
turnover_series = self.data_broker.fetch_series( turnover_series = self.data_broker.fetch_series(
@ -133,21 +148,36 @@ class BacktestEngine:
trade_date_str, trade_date_str,
window=20, window=20,
) )
turnover_values = [value for _date, value in turnover_series] turnover_values = [value for _date, value in turnover_series if value is not None]
turn20 = scope_values.get("factors.turn_20")
if turn20 is None and turnover_values:
turn20 = rolling_mean(turnover_values, 20) turn20 = rolling_mean(turnover_values, 20)
if mom20 is None:
mom20 = 0.0
if mom60 is None:
mom60 = 0.0
if volat20 is None:
volat20 = 0.0
if turn20 is None:
turn20 = 0.0
liquidity_score = normalize(turn20, factor=20.0) liquidity_score = normalize(turn20, factor=20.0)
cost_penalty = normalize( cost_penalty = normalize(
scope_values.get("daily_basic.volume_ratio", 0.0), scope_values.get("daily_basic.volume_ratio", 0.0),
factor=50.0, factor=50.0,
) )
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
heat_score = scope_values.get("news.heat_score", 0.0)
scope_values.setdefault("news.sentiment_index", sentiment_index)
scope_values.setdefault("news.heat_score", heat_score)
scope_values.setdefault("factors.mom_20", mom20) scope_values.setdefault("factors.mom_20", mom20)
scope_values.setdefault("factors.mom_60", mom60) scope_values.setdefault("factors.mom_60", mom60)
scope_values.setdefault("factors.volat_20", volat20) scope_values.setdefault("factors.volat_20", volat20)
scope_values.setdefault("factors.turn_20", turn20) scope_values.setdefault("factors.turn_20", turn20)
scope_values.setdefault("news.sentiment_index", 0.0)
scope_values.setdefault("news.heat_score", 0.0)
if scope_values.get("macro.industry_heat") is None: if scope_values.get("macro.industry_heat") is None:
scope_values["macro.industry_heat"] = 0.5 scope_values["macro.industry_heat"] = 0.5
if scope_values.get("macro.relative_strength") is None: if scope_values.get("macro.relative_strength") is None:
@ -189,8 +219,8 @@ class BacktestEngine:
"turn_20": turn20, "turn_20": turn20,
"liquidity_score": liquidity_score, "liquidity_score": liquidity_score,
"cost_penalty": cost_penalty, "cost_penalty": cost_penalty,
"news_heat": scope_values.get("news.heat_score", 0.0), "news_heat": heat_score,
"news_sentiment": scope_values.get("news.sentiment_index", 0.0), "news_sentiment": sentiment_index,
"industry_heat": scope_values.get("macro.industry_heat", 0.0), "industry_heat": scope_values.get("macro.industry_heat", 0.0),
"industry_relative_mom": scope_values.get( "industry_relative_mom": scope_values.get(
"macro.relative_strength", "macro.relative_strength",
@ -818,6 +848,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
nav_rows: List[tuple] = [] nav_rows: List[tuple] = []
trade_rows: List[tuple] = [] trade_rows: List[tuple] = []
risk_rows: List[tuple] = []
summary_payload: Dict[str, object] = {} summary_payload: Dict[str, object] = {}
turnover_sum = 0.0 turnover_sum = 0.0
@ -893,6 +924,10 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
"confidence": trade.get("confidence"), "confidence": trade.get("confidence"),
"target_weight": trade.get("target_weight"), "target_weight": trade.get("target_weight"),
"value": trade.get("value"), "value": trade.get("value"),
"fee": trade.get("fee"),
"slippage": trade.get("slippage"),
"risk_penalty": trade.get("risk_penalty"),
"liquidity_score": trade.get("liquidity_score"),
} }
trade_rows.append( trade_rows.append(
( (
@ -913,6 +948,18 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
for event in result.risk_events: for event in result.risk_events:
reason = str(event.get("reason") or "unknown") reason = str(event.get("reason") or "unknown")
breakdown[reason] = breakdown.get(reason, 0) + 1 breakdown[reason] = breakdown.get(reason, 0) + 1
risk_rows.append(
(
cfg.id,
str(event.get("trade_date", "")),
str(event.get("ts_code", "")),
reason,
str(event.get("action", "")),
float(event.get("target_weight", 0.0) or 0.0),
float(event.get("confidence", 0.0) or 0.0),
json.dumps(event, ensure_ascii=False),
)
)
summary_payload["risk_breakdown"] = breakdown summary_payload["risk_breakdown"] = breakdown
cfg_payload = { cfg_payload = {
@ -943,6 +990,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,)) conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,)) conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,)) conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
if nav_rows: if nav_rows:
@ -963,6 +1011,15 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
trade_rows, trade_rows,
) )
if risk_rows:
conn.executemany(
"""
INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
risk_rows,
)
summary_payload.setdefault("universe", cfg.universe) summary_payload.setdefault("universe", cfg.universe)
summary_payload.setdefault("method", cfg.method) summary_payload.setdefault("method", cfg.method)
conn.execute( conn.execute(

139
app/backtest/optimizer.py Normal file
View File

@ -0,0 +1,139 @@
"""Optimization utilities for DecisionEnv-based parameter tuning."""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Sequence, Tuple
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
from app.backtest.decision_env import ParameterSpec
from app.utils.logging import get_logger
from app.utils.tuning import log_tuning_result
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_bandit"}
@dataclass
class BanditConfig:
"""Configuration for epsilon-greedy bandit optimization."""
experiment_id: str
strategy: str = "epsilon_greedy"
episodes: int = 20
epsilon: float = 0.2
seed: int | None = None
@dataclass
class BanditEpisode:
action: Dict[str, float]
reward: float
metrics: EpisodeMetrics
observation: Dict[str, float]
@dataclass
class BanditSummary:
episodes: List[BanditEpisode] = field(default_factory=list)
@property
def best_episode(self) -> BanditEpisode | None:
if not self.episodes:
return None
return max(self.episodes, key=lambda item: item.reward)
@property
def average_reward(self) -> float:
if not self.episodes:
return 0.0
return sum(item.reward for item in self.episodes) / len(self.episodes)
class EpsilonGreedyBandit:
"""Simple epsilon-greedy tuner using DecisionEnv as the reward oracle."""
def __init__(self, env: DecisionEnv, config: BanditConfig) -> None:
self.env = env
self.config = config
self._random = random.Random(config.seed)
self._specs: List[ParameterSpec] = list(getattr(env, "_specs", []))
if not self._specs:
raise ValueError("DecisionEnv does not expose parameter specs")
self._value_estimates: Dict[Tuple[float, ...], float] = {}
self._counts: Dict[Tuple[float, ...], int] = {}
self._history = BanditSummary()
def run(self) -> BanditSummary:
for episode in range(1, self.config.episodes + 1):
action = self._select_action()
self.env.reset()
obs, reward, done, info = self.env.step(action)
metrics = self.env.last_metrics
if metrics is None:
raise RuntimeError("DecisionEnv did not populate last_metrics")
key = tuple(action)
old_estimate = self._value_estimates.get(key, 0.0)
count = self._counts.get(key, 0) + 1
self._counts[key] = count
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
action_payload = self._action_to_mapping(action)
metrics_payload = _metrics_to_dict(metrics)
try:
log_tuning_result(
experiment_id=self.config.experiment_id,
strategy=self.config.strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=info.get("weights"),
)
except Exception: # noqa: BLE001
LOGGER.exception("failed to log tuning result", extra=LOG_EXTRA)
episode_record = BanditEpisode(
action=action_payload,
reward=reward,
metrics=metrics,
observation=obs,
)
self._history.episodes.append(episode_record)
LOGGER.info(
"Bandit episode=%s reward=%.4f action=%s",
episode,
reward,
action_payload,
extra=LOG_EXTRA,
)
return self._history
def _select_action(self) -> List[float]:
if self._value_estimates and self._random.random() > self.config.epsilon:
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
return list(best)
return [
self._random.uniform(spec.minimum, spec.maximum)
for spec in self._specs
]
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]:
return {
spec.name: float(value)
for spec, value in zip(self._specs, action, strict=True)
}
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
payload: Dict[str, float | Dict[str, int]] = {
"total_return": metrics.total_return,
"max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
}
if metrics.risk_breakdown:
payload["risk_breakdown"] = dict(metrics.risk_breakdown)
return payload

View File

@ -327,6 +327,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
); );
""", """,
""" """
CREATE TABLE IF NOT EXISTS bt_risk_events (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
reason TEXT,
action TEXT,
target_weight REAL,
confidence REAL,
metadata TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_nav ( CREATE TABLE IF NOT EXISTS bt_nav (
cfg_id TEXT, cfg_id TEXT,
trade_date TEXT, trade_date TEXT,
@ -472,6 +484,7 @@ REQUIRED_TABLES = (
"heat_daily", "heat_daily",
"bt_config", "bt_config",
"bt_trades", "bt_trades",
"bt_risk_events",
"bt_nav", "bt_nav",
"bt_report", "bt_report",
"run_log", "run_log",

View File

@ -118,13 +118,12 @@ class DataBroker:
if cached is not None: if cached is not None:
return deepcopy(cached) return deepcopy(cached)
grouped: Dict[str, List[str]] = {} grouped: Dict[str, List[Tuple[str, str]]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
derived_cache: Dict[str, Any] = {} derived_cache: Dict[str, Any] = {}
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
for field_name in field_list: for field_name in field_list:
resolved = self.resolve_field(field_name) parsed = parse_field_path(field_name)
if not resolved: if not parsed:
derived = self._resolve_derived_field( derived = self._resolve_derived_field(
ts_code, ts_code,
trade_date, trade_date,
@ -134,11 +133,8 @@ class DataBroker:
if derived is not None: if derived is not None:
results[field_name] = derived results[field_name] = derived
continue continue
table, column = resolved table, column = parsed
grouped.setdefault(table, []) grouped.setdefault(table, []).append((column, field_name))
if column not in grouped[table]:
grouped[table].append(column)
field_map.setdefault((table, column), []).append(field_name)
if not grouped: if not grouped:
if cache_key is not None and results: if cache_key is not None and results:
@ -152,10 +148,9 @@ class DataBroker:
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
for table, columns in grouped.items(): for table, items in grouped.items():
joined_cols = ", ".join(columns)
query = ( query = (
f"SELECT trade_date, {joined_cols} FROM {table} " f"SELECT * FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? " "WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1" "ORDER BY trade_date DESC LIMIT 1"
) )
@ -165,18 +160,21 @@ class DataBroker:
LOGGER.debug( LOGGER.debug(
"查询失败 table=%s fields=%s err=%s", "查询失败 table=%s fields=%s err=%s",
table, table,
columns, [column for column, _field in items],
exc, exc,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
continue continue
if not row: if not row:
continue continue
for column in columns: available = row.keys()
value = row[column] for column, original in items:
resolved_column = self._resolve_column_in_row(table, column, available)
if resolved_column is None:
continue
value = row[resolved_column]
if value is None: if value is None:
continue continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try: try:
results[original] = float(value) results[original] = float(value)
except (TypeError, ValueError): except (TypeError, ValueError):
@ -698,6 +696,22 @@ class DataBroker:
while len(cache) > limit: while len(cache) > limit:
cache.popitem(last=False) cache.popitem(last=False)
def _resolve_column_in_row(
self,
table: str,
column: str,
available: Sequence[str],
) -> Optional[str]:
alias_map = self.FIELD_ALIASES.get(table, {})
candidate = alias_map.get(column, column)
if candidate in available:
return candidate
lowered = candidate.lower()
for name in available:
if name.lower() == lowered:
return name
return None
def _resolve_column(self, table: str, column: str) -> Optional[str]: def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table) columns = self._get_table_columns(table)
if columns is None: if columns is None:

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional
from .db import db_session from .db import db_session
from .logging import get_logger from .logging import get_logger
@ -40,3 +41,96 @@ def log_tuning_result(
) )
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
def select_best_tuning_result(
experiment_id: str,
*,
metric: str = "reward",
descending: bool = True,
require_weights: bool = False,
) -> Optional[Dict[str, Any]]:
"""Return the best tuning result for the given experiment.
``metric`` may refer to ``reward`` (default) or any key inside the
persisted metrics payload. When ``require_weights`` is True, rows lacking
weight definitions are ignored.
"""
with db_session(read_only=True) as conn:
rows = conn.execute(
"""
SELECT id, action, weights, reward, metrics, created_at
FROM tuning_results
WHERE experiment_id = ?
""",
(experiment_id,),
).fetchall()
if not rows:
return None
best_row: Optional[Mapping[str, Any]] = None
best_metrics: Dict[str, Any] = {}
best_action: Dict[str, float] = {}
best_weights: Dict[str, float] = {}
best_score: Optional[float] = None
for row in rows:
action = _decode_json(row["action"])
weights = _decode_json(row["weights"])
metrics_payload = _decode_json(row["metrics"])
reward_value = float(row["reward"] or 0.0)
if require_weights and not weights:
continue
if metric == "reward":
score = reward_value
else:
score_raw = metrics_payload.get(metric)
if score_raw is None:
continue
try:
score = float(score_raw)
except (TypeError, ValueError):
continue
if best_score is None:
choose = True
else:
choose = score > best_score if descending else score < best_score
if choose:
best_score = score
best_row = row
best_metrics = metrics_payload
best_action = action
best_weights = weights
if best_row is None:
return None
return {
"id": best_row["id"],
"reward": float(best_row["reward"] or 0.0),
"score": best_score,
"metric": metric,
"action": best_action,
"weights": best_weights,
"metrics": best_metrics,
"created_at": best_row["created_at"],
}
def _decode_json(payload: Any) -> Dict[str, Any]:
if not payload:
return {}
if isinstance(payload, Mapping):
return dict(payload)
if isinstance(payload, str):
try:
return json.loads(payload)
except json.JSONDecodeError:
return {}
return {}

View File

@ -13,6 +13,7 @@
## 2. 数据与特征层 ## 2. 数据与特征层
- 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。 - 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。
- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。
- 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。 - 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。
- 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。 - 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。 - 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。

View File

@ -0,0 +1,83 @@
"""Apply or display the best tuning result for an experiment."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.utils.config import get_config, save_config
from app.utils.tuning import select_best_tuning_result
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Apply best tuning weights")
parser.add_argument("experiment_id", help="Experiment identifier")
parser.add_argument(
"--metric",
default="reward",
help="Metric name for ranking (default: reward)",
)
parser.add_argument(
"--ascending",
action="store_true",
help="Sort metric ascending instead of descending",
)
parser.add_argument(
"--require-weights",
action="store_true",
help="Ignore records without weight payload",
)
parser.add_argument(
"--apply-config",
action="store_true",
help="Update agent_weights in config with best result weights (fallback to action)",
)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
best = select_best_tuning_result(
args.experiment_id,
metric=args.metric,
descending=not args.ascending,
require_weights=args.require_weights,
)
if not best:
LOGGER.error("未找到实验结果 experiment_id=%s", args.experiment_id)
return 1
print(json.dumps(best, ensure_ascii=False, indent=2))
if args.apply_config:
weights = best.get("weights") or best.get("action")
if not weights:
LOGGER.error("最佳结果缺少权重信息,无法更新配置")
return 2
cfg = get_config()
if not cfg.agent_weights:
LOGGER.warning("配置缺少 agent_weights初始化默认值")
cfg.agent_weights.update_from_dict(weights)
save_config(cfg)
LOGGER.info("已写入新的 agent_weights 至配置")
return 0
def main() -> None:
raise SystemExit(run_cli())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,124 @@
"""Run epsilon-greedy bandit tuning on DecisionEnv."""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime, date
from pathlib import Path
from typing import Iterable, List
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.engine import BtConfig
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.utils.config import get_config
def _parse_date(value: str) -> date:
return datetime.strptime(value, "%Y%m%d").date()
def _parse_param(text: str) -> ParameterSpec:
parts = text.split(":")
if len(parts) not in {3, 4}:
raise argparse.ArgumentTypeError(
"parameter format must be name:target:min[:max]"
)
name, target, minimum = parts[:3]
maximum = parts[3] if len(parts) == 4 else "1.0"
return ParameterSpec(
name=name,
target=target,
minimum=float(minimum),
maximum=float(maximum),
)
def _resolve_baseline_weights() -> dict:
cfg = get_config()
if cfg.agent_weights:
return cfg.agent_weights.as_dict()
return {agent.name: 1.0 for agent in default_agents()}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
parser.add_argument("experiment_id", help="Experiment identifier to log results")
parser.add_argument("name", help="Backtest config name")
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
parser.add_argument(
"--universe",
required=True,
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
)
parser.add_argument(
"--param",
action="append",
required=True,
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
)
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--epsilon", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=None)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
if args.end < args.start:
parser.error("end date must not precede start date")
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
bt_cfg = BtConfig(
id=args.experiment_id,
name=args.name,
start_date=args.start,
end_date=args.end,
universe=universe,
params={},
)
env = DecisionEnv(
bt_config=bt_cfg,
parameter_specs=specs,
baseline_weights=_resolve_baseline_weights(),
)
optimizer = EpsilonGreedyBandit(
env,
BanditConfig(
experiment_id=args.experiment_id,
episodes=args.episodes,
epsilon=args.epsilon,
seed=args.seed,
),
)
summary = optimizer.run()
best = summary.best_episode
output = {
"episodes": len(summary.episodes),
"average_reward": summary.average_reward,
"best": {
"reward": best.reward if best else None,
"action": best.action if best else None,
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
},
}
print(json.dumps(output, ensure_ascii=False, indent=2))
return 0
def main() -> None:
raise SystemExit(run_cli())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,57 @@
"""Verify BacktestEngine consumes persisted factor fields."""
from __future__ import annotations
from datetime import date
import pytest
from app.backtest.engine import BacktestEngine, BtConfig
@pytest.fixture()
def engine(monkeypatch):
cfg = BtConfig(
id="test",
name="factor",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
engine = BacktestEngine(cfg)
def fake_fetch_latest(ts_code, trade_date, fields): # noqa: D401
assert "factors.mom_20" in fields
return {
"daily.close": 10.0,
"daily.pct_chg": 0.02,
"daily_basic.turnover_rate": 5.0,
"daily_basic.volume_ratio": 15.0,
"factors.mom_20": 0.12,
"factors.mom_60": 0.25,
"factors.volat_20": 0.05,
"factors.turn_20": 3.0,
"news.sentiment_index": 0.3,
"news.heat_score": 0.4,
"macro.industry_heat": 0.6,
"macro.relative_strength": 0.7,
}
monkeypatch.setattr(engine.data_broker, "fetch_latest", fake_fetch_latest)
monkeypatch.setattr(engine.data_broker, "fetch_series", lambda *args, **kwargs: [])
monkeypatch.setattr(engine.data_broker, "fetch_flags", lambda *args, **kwargs: False)
return engine
def test_load_market_data_prefers_factors(engine):
data = engine.load_market_data(date(2025, 1, 10))
record = data["000001.SZ"]
features = record["features"]
assert features["mom_20"] == pytest.approx(0.12)
assert features["mom_60"] == pytest.approx(0.25)
assert features["volat_20"] == pytest.approx(0.05)
assert features["turn_20"] == pytest.approx(3.0)
assert features["news_sentiment"] == pytest.approx(0.3)
assert features["news_heat"] == pytest.approx(0.4)
assert features["risk_penalty"] == pytest.approx(min(1.0, 0.05 * 5.0))

View File

@ -5,9 +5,20 @@ from datetime import date
import pytest import pytest
import json
from app.agents.base import AgentAction, AgentContext from app.agents.base import AgentAction, AgentContext
from app.agents.game import Decision from app.agents.game import Decision
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState from app.backtest.engine import (
BacktestEngine,
BacktestResult,
BtConfig,
PortfolioState,
_persist_backtest_results,
)
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
from app.utils.db import db_session
def _make_context(price: float, features: dict | None = None) -> AgentContext: def _make_context(price: float, features: dict | None = None) -> AgentContext:
@ -43,6 +54,20 @@ def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
return BacktestEngine(cfg) return BacktestEngine(cfg)
@pytest.fixture()
def isolated_db(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
initialize_database()
try:
yield
finally:
cfg.data_paths = original_paths
def test_buy_respects_risk_caps(): def test_buy_respects_risk_caps():
engine = _engine_with_params( engine = _engine_with_params(
{ {
@ -130,3 +155,56 @@ def test_sell_applies_slippage_and_fee():
assert not state.holdings assert not state.holdings
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"]) assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
assert not result.risk_events assert not result.risk_events
def test_persist_backtest_results_saves_risk_events(isolated_db):
cfg = BtConfig(
id="risk_cfg",
name="risk",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
result = BacktestResult()
result.nav_series = [
{
"trade_date": "2025-01-10",
"nav": 100.0,
"cash": 100.0,
"market_value": 0.0,
"realized_pnl": 0.0,
"unrealized_pnl": 0.0,
"turnover": 0.0,
}
]
result.risk_events = [
{
"trade_date": "2025-01-10",
"ts_code": "000001.SZ",
"reason": "limit_up",
"action": "buy_l",
"target_weight": 0.3,
"confidence": 0.8,
}
]
_persist_backtest_results(cfg, result)
with db_session(read_only=True) as conn:
risk_row = conn.execute(
"SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?",
(cfg.id,),
).fetchone()
assert risk_row is not None
assert risk_row["reason"] == "limit_up"
metadata = json.loads(risk_row["metadata"])
assert metadata["action"] == "buy_l"
summary_row = conn.execute(
"SELECT summary FROM bt_report WHERE cfg_id = ?",
(cfg.id,),
).fetchone()
summary = json.loads(summary_row["summary"])
assert summary["risk_events"] == 1
assert summary["risk_breakdown"]["limit_up"] == 1

View File

@ -0,0 +1,92 @@
"""Tests for epsilon-greedy bandit optimizer."""
from __future__ import annotations
import pytest
from app.backtest.decision_env import EpisodeMetrics, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.utils import tuning
class DummyEnv:
def __init__(self) -> None:
self._specs = [
ParameterSpec(name="w1", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)
]
self._last_metrics: EpisodeMetrics | None = None
self._episode = 0
@property
def action_dim(self) -> int:
return 1
@property
def last_metrics(self) -> EpisodeMetrics | None:
return self._last_metrics
def reset(self) -> dict:
self._episode += 1
return {"episode": float(self._episode)}
def step(self, action):
value = float(action[0])
reward = 1.0 - abs(value - 0.7)
metrics = EpisodeMetrics(
total_return=reward,
max_drawdown=0.1,
volatility=0.05,
nav_series=[],
trades=[],
turnover=100.0,
trade_count=0,
risk_count=1,
risk_breakdown={"test": 1},
)
self._last_metrics = metrics
obs = {
"total_return": reward,
"max_drawdown": 0.1,
"volatility": 0.05,
"sharpe_like": reward / 0.05,
"turnover": 100.0,
"trade_count": 0.0,
"risk_count": 1.0,
}
info = {
"nav_series": [],
"trades": [],
"weights": {"A_mom": value},
"risk_breakdown": metrics.risk_breakdown,
"risk_events": [],
}
return obs, reward, True, info
@pytest.fixture(autouse=True)
def patch_logging(monkeypatch):
records = []
def fake_log_tuning_result(**kwargs):
records.append(kwargs)
monkeypatch.setattr(tuning, "log_tuning_result", fake_log_tuning_result)
from app.backtest import optimizer as optimizer_module
monkeypatch.setattr(optimizer_module, "log_tuning_result", fake_log_tuning_result)
return records
def test_bandit_optimizer_runs_and_logs(patch_logging):
env = DummyEnv()
optimizer = EpsilonGreedyBandit(
env,
BanditConfig(experiment_id="exp", episodes=5, epsilon=0.5, seed=42),
)
summary = optimizer.run()
assert len(summary.episodes) == 5
assert summary.best_episode is not None
assert patch_logging and len(patch_logging) == 5
payload = patch_logging[0]["metrics"]
assert isinstance(payload, dict)
assert "risk_breakdown" in payload

View File

@ -0,0 +1,92 @@
"""Tests for DecisionEnv risk-aware reward and info outputs."""
from __future__ import annotations
from datetime import date
import pytest
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
from app.backtest.engine import BacktestResult, BtConfig
class _StubEngine:
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
self.cfg = cfg
self.weights = {}
self.department_manager = None
def run(self) -> BacktestResult:
result = BacktestResult()
result.nav_series = [
{
"trade_date": "2025-01-10",
"nav": 102.0,
"cash": 50.0,
"market_value": 52.0,
"realized_pnl": 1.0,
"unrealized_pnl": 1.0,
"turnover": 20000.0,
}
]
result.trades = [
{
"trade_date": "2025-01-10",
"ts_code": "000001.SZ",
"action": "buy",
"quantity": 100.0,
"price": 100.0,
"value": 10000.0,
"fee": 5.0,
}
]
result.risk_events = [
{
"trade_date": "2025-01-10",
"ts_code": "000002.SZ",
"reason": "limit_up",
"action": "buy_l",
"confidence": 0.7,
"target_weight": 0.2,
}
]
return result
def test_decision_env_returns_risk_metrics(monkeypatch):
cfg = BtConfig(
id="stub",
name="stub",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)]
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
obs, reward, done, info = env.step([0.8])
assert done is True
assert "risk_count" in obs and obs["risk_count"] == 1.0
assert obs["turnover"] == pytest.approx(20000.0)
assert info["risk_events"][0]["reason"] == "limit_up"
assert info["risk_breakdown"]["limit_up"] == 1
assert reward < obs["total_return"]
def test_default_reward_penalizes_metrics():
metrics = EpisodeMetrics(
total_return=0.1,
max_drawdown=0.2,
volatility=0.05,
nav_series=[],
trades=[],
turnover=1000.0,
trade_count=0,
risk_count=2,
risk_breakdown={"foo": 2},
)
reward = DecisionEnv._default_reward(metrics)
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))

View File

@ -0,0 +1,81 @@
"""Tests for tuning result selection and CLI application."""
from __future__ import annotations
import json
import pytest
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
from app.utils.db import db_session
from app.utils.tuning import select_best_tuning_result
import scripts.apply_best_weights as apply_best_weights
@pytest.fixture()
def isolated_env(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
initialize_database()
try:
yield cfg
finally:
cfg.data_paths = original_paths
def _insert_result(experiment: str, reward: float, metrics: dict, weights: dict | None = None, action: dict | None = None) -> None:
with db_session() as conn:
conn.execute(
"""
INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
experiment,
"test",
json.dumps(action or {}, ensure_ascii=False),
json.dumps(weights or {}, ensure_ascii=False),
reward,
json.dumps(metrics, ensure_ascii=False),
),
)
def test_select_best_by_reward(isolated_env):
_insert_result("exp", 0.1, {"risk_count": 2}, {"A_mom": 0.3})
_insert_result("exp", 0.25, {"risk_count": 4}, {"A_mom": 0.6})
best = select_best_tuning_result("exp")
assert best is not None
assert best["reward"] == pytest.approx(0.25)
assert best["weights"]["A_mom"] == pytest.approx(0.6)
def test_select_best_by_metric(isolated_env):
_insert_result("exp_metric", 0.2, {"risk_count": 5}, {"A_mom": 0.4})
_insert_result("exp_metric", 0.1, {"risk_count": 2}, {"A_mom": 0.7})
best = select_best_tuning_result("exp_metric", metric="risk_count", descending=False)
assert best is not None
assert best["weights"]["A_mom"] == pytest.approx(0.7)
assert best["metrics"]["risk_count"] == 2
def test_apply_best_weights_cli_updates_config(isolated_env, capsys):
cfg = isolated_env
_insert_result("exp_cli", 0.3, {"risk_count": 1}, {"A_mom": 0.65, "A_val": 0.2})
exit_code = apply_best_weights.run_cli([
"exp_cli",
"--apply-config",
])
assert exit_code == 0
output = capsys.readouterr().out
payload = json.loads(output)
assert payload["metric"] == "reward"
updated = cfg.agent_weights.as_dict()
assert updated["A_mom"] == pytest.approx(0.65)
assert updated["A_val"] == pytest.approx(0.2)