This commit is contained in:
sam 2025-09-30 18:34:29 +08:00
parent 8befd80cb7
commit 07e5bb1b68
14 changed files with 995 additions and 34 deletions

View File

@ -36,6 +36,10 @@ class EpisodeMetrics:
volatility: float
nav_series: List[Dict[str, float]]
trades: List[Dict[str, object]]
turnover: float
trade_count: int
risk_count: int
risk_breakdown: Dict[str, int]
@property
def sharpe_like(self) -> float:
@ -109,11 +113,16 @@ class DecisionEnv:
"max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
}
info = {
"nav_series": metrics.nav_series,
"trades": metrics.trades,
"weights": weights,
"risk_breakdown": metrics.risk_breakdown,
"risk_events": getattr(result, "risk_events", []),
}
return observation, reward, True, info
@ -131,7 +140,21 @@ class DecisionEnv:
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
nav_series = result.nav_series or []
if not nav_series:
return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades)
risk_breakdown: Dict[str, int] = {}
for event in getattr(result, "risk_events", []) or []:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=0.0,
max_drawdown=0.0,
volatility=0.0,
nav_series=[],
trades=result.trades,
turnover=0.0,
trade_count=len(result.trades or []),
risk_count=len(getattr(result, "risk_events", []) or []),
risk_breakdown=risk_breakdown,
)
nav_values = [row.get("nav", 0.0) for row in nav_series]
if not nav_values or nav_values[0] == 0:
@ -158,17 +181,30 @@ class DecisionEnv:
else:
volatility = 0.0
turnover = sum(float(row.get("turnover", 0.0) or 0.0) for row in nav_series)
risk_events = getattr(result, "risk_events", []) or []
risk_breakdown: Dict[str, int] = {}
for event in risk_events:
reason = str(event.get("reason") or "unknown")
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
return EpisodeMetrics(
total_return=float(total_return),
max_drawdown=float(max_drawdown),
volatility=volatility,
nav_series=nav_series,
trades=result.trades,
turnover=float(turnover),
trade_count=len(result.trades or []),
risk_count=len(risk_events),
risk_breakdown=risk_breakdown,
)
@staticmethod
def _default_reward(metrics: EpisodeMetrics) -> float:
penalty = 0.5 * metrics.max_drawdown
risk_penalty = 0.05 * metrics.risk_count
turnover_penalty = 0.00001 * metrics.turnover
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
return metrics.total_return - penalty
@property

View File

@ -98,6 +98,12 @@ class BacktestEngine:
"daily_basic.volume_ratio",
"stk_limit.up_limit",
"stk_limit.down_limit",
"factors.mom_20",
"factors.mom_60",
"factors.volat_20",
"factors.turn_20",
"news.sentiment_index",
"news.heat_score",
}
self.required_fields = sorted(base_scope | department_scope)
@ -121,10 +127,19 @@ class BacktestEngine:
trade_date_str,
window=60,
)
close_values = [value for _date, value in closes]
mom20 = momentum(close_values, 20)
mom60 = momentum(close_values, 60)
volat20 = volatility(close_values, 20)
close_values = [value for _date, value in closes if value is not None]
mom20 = scope_values.get("factors.mom_20")
if mom20 is None and len(close_values) >= 20:
mom20 = momentum(close_values, 20)
mom60 = scope_values.get("factors.mom_60")
if mom60 is None and len(close_values) >= 60:
mom60 = momentum(close_values, 60)
volat20 = scope_values.get("factors.volat_20")
if volat20 is None and len(close_values) >= 2:
volat20 = volatility(close_values, 20)
turnover_series = self.data_broker.fetch_series(
"daily_basic",
@ -133,8 +148,20 @@ class BacktestEngine:
trade_date_str,
window=20,
)
turnover_values = [value for _date, value in turnover_series]
turn20 = rolling_mean(turnover_values, 20)
turnover_values = [value for _date, value in turnover_series if value is not None]
turn20 = scope_values.get("factors.turn_20")
if turn20 is None and turnover_values:
turn20 = rolling_mean(turnover_values, 20)
if mom20 is None:
mom20 = 0.0
if mom60 is None:
mom60 = 0.0
if volat20 is None:
volat20 = 0.0
if turn20 is None:
turn20 = 0.0
liquidity_score = normalize(turn20, factor=20.0)
cost_penalty = normalize(
@ -142,12 +169,15 @@ class BacktestEngine:
factor=50.0,
)
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
heat_score = scope_values.get("news.heat_score", 0.0)
scope_values.setdefault("news.sentiment_index", sentiment_index)
scope_values.setdefault("news.heat_score", heat_score)
scope_values.setdefault("factors.mom_20", mom20)
scope_values.setdefault("factors.mom_60", mom60)
scope_values.setdefault("factors.volat_20", volat20)
scope_values.setdefault("factors.turn_20", turn20)
scope_values.setdefault("news.sentiment_index", 0.0)
scope_values.setdefault("news.heat_score", 0.0)
if scope_values.get("macro.industry_heat") is None:
scope_values["macro.industry_heat"] = 0.5
if scope_values.get("macro.relative_strength") is None:
@ -189,8 +219,8 @@ class BacktestEngine:
"turn_20": turn20,
"liquidity_score": liquidity_score,
"cost_penalty": cost_penalty,
"news_heat": scope_values.get("news.heat_score", 0.0),
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
"news_heat": heat_score,
"news_sentiment": sentiment_index,
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
"industry_relative_mom": scope_values.get(
"macro.relative_strength",
@ -818,6 +848,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
nav_rows: List[tuple] = []
trade_rows: List[tuple] = []
risk_rows: List[tuple] = []
summary_payload: Dict[str, object] = {}
turnover_sum = 0.0
@ -893,6 +924,10 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
"confidence": trade.get("confidence"),
"target_weight": trade.get("target_weight"),
"value": trade.get("value"),
"fee": trade.get("fee"),
"slippage": trade.get("slippage"),
"risk_penalty": trade.get("risk_penalty"),
"liquidity_score": trade.get("liquidity_score"),
}
trade_rows.append(
(
@ -913,6 +948,18 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
for event in result.risk_events:
reason = str(event.get("reason") or "unknown")
breakdown[reason] = breakdown.get(reason, 0) + 1
risk_rows.append(
(
cfg.id,
str(event.get("trade_date", "")),
str(event.get("ts_code", "")),
reason,
str(event.get("action", "")),
float(event.get("target_weight", 0.0) or 0.0),
float(event.get("confidence", 0.0) or 0.0),
json.dumps(event, ensure_ascii=False),
)
)
summary_payload["risk_breakdown"] = breakdown
cfg_payload = {
@ -943,6 +990,7 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
conn.execute("DELETE FROM bt_nav WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_trades WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_risk_events WHERE cfg_id = ?", (cfg.id,))
conn.execute("DELETE FROM bt_report WHERE cfg_id = ?", (cfg.id,))
if nav_rows:
@ -963,6 +1011,15 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
trade_rows,
)
if risk_rows:
conn.executemany(
"""
INSERT INTO bt_risk_events (cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
risk_rows,
)
summary_payload.setdefault("universe", cfg.universe)
summary_payload.setdefault("method", cfg.method)
conn.execute(

139
app/backtest/optimizer.py Normal file
View File

@ -0,0 +1,139 @@
"""Optimization utilities for DecisionEnv-based parameter tuning."""
from __future__ import annotations
import random
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Sequence, Tuple
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
from app.backtest.decision_env import ParameterSpec
from app.utils.logging import get_logger
from app.utils.tuning import log_tuning_result
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_bandit"}
@dataclass
class BanditConfig:
"""Configuration for epsilon-greedy bandit optimization."""
experiment_id: str
strategy: str = "epsilon_greedy"
episodes: int = 20
epsilon: float = 0.2
seed: int | None = None
@dataclass
class BanditEpisode:
action: Dict[str, float]
reward: float
metrics: EpisodeMetrics
observation: Dict[str, float]
@dataclass
class BanditSummary:
episodes: List[BanditEpisode] = field(default_factory=list)
@property
def best_episode(self) -> BanditEpisode | None:
if not self.episodes:
return None
return max(self.episodes, key=lambda item: item.reward)
@property
def average_reward(self) -> float:
if not self.episodes:
return 0.0
return sum(item.reward for item in self.episodes) / len(self.episodes)
class EpsilonGreedyBandit:
"""Simple epsilon-greedy tuner using DecisionEnv as the reward oracle."""
def __init__(self, env: DecisionEnv, config: BanditConfig) -> None:
self.env = env
self.config = config
self._random = random.Random(config.seed)
self._specs: List[ParameterSpec] = list(getattr(env, "_specs", []))
if not self._specs:
raise ValueError("DecisionEnv does not expose parameter specs")
self._value_estimates: Dict[Tuple[float, ...], float] = {}
self._counts: Dict[Tuple[float, ...], int] = {}
self._history = BanditSummary()
def run(self) -> BanditSummary:
for episode in range(1, self.config.episodes + 1):
action = self._select_action()
self.env.reset()
obs, reward, done, info = self.env.step(action)
metrics = self.env.last_metrics
if metrics is None:
raise RuntimeError("DecisionEnv did not populate last_metrics")
key = tuple(action)
old_estimate = self._value_estimates.get(key, 0.0)
count = self._counts.get(key, 0) + 1
self._counts[key] = count
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
action_payload = self._action_to_mapping(action)
metrics_payload = _metrics_to_dict(metrics)
try:
log_tuning_result(
experiment_id=self.config.experiment_id,
strategy=self.config.strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=info.get("weights"),
)
except Exception: # noqa: BLE001
LOGGER.exception("failed to log tuning result", extra=LOG_EXTRA)
episode_record = BanditEpisode(
action=action_payload,
reward=reward,
metrics=metrics,
observation=obs,
)
self._history.episodes.append(episode_record)
LOGGER.info(
"Bandit episode=%s reward=%.4f action=%s",
episode,
reward,
action_payload,
extra=LOG_EXTRA,
)
return self._history
def _select_action(self) -> List[float]:
if self._value_estimates and self._random.random() > self.config.epsilon:
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
return list(best)
return [
self._random.uniform(spec.minimum, spec.maximum)
for spec in self._specs
]
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]:
return {
spec.name: float(value)
for spec, value in zip(self._specs, action, strict=True)
}
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
payload: Dict[str, float | Dict[str, int]] = {
"total_return": metrics.total_return,
"max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like,
"turnover": metrics.turnover,
"trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count),
}
if metrics.risk_breakdown:
payload["risk_breakdown"] = dict(metrics.risk_breakdown)
return payload

View File

@ -327,6 +327,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_risk_events (
cfg_id TEXT,
trade_date TEXT,
ts_code TEXT,
reason TEXT,
action TEXT,
target_weight REAL,
confidence REAL,
metadata TEXT
);
""",
"""
CREATE TABLE IF NOT EXISTS bt_nav (
cfg_id TEXT,
trade_date TEXT,
@ -472,6 +484,7 @@ REQUIRED_TABLES = (
"heat_daily",
"bt_config",
"bt_trades",
"bt_risk_events",
"bt_nav",
"bt_report",
"run_log",

View File

@ -118,13 +118,12 @@ class DataBroker:
if cached is not None:
return deepcopy(cached)
grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
grouped: Dict[str, List[Tuple[str, str]]] = {}
derived_cache: Dict[str, Any] = {}
results: Dict[str, Any] = {}
for field_name in field_list:
resolved = self.resolve_field(field_name)
if not resolved:
parsed = parse_field_path(field_name)
if not parsed:
derived = self._resolve_derived_field(
ts_code,
trade_date,
@ -134,11 +133,8 @@ class DataBroker:
if derived is not None:
results[field_name] = derived
continue
table, column = resolved
grouped.setdefault(table, [])
if column not in grouped[table]:
grouped[table].append(column)
field_map.setdefault((table, column), []).append(field_name)
table, column = parsed
grouped.setdefault(table, []).append((column, field_name))
if not grouped:
if cache_key is not None and results:
@ -152,10 +148,9 @@ class DataBroker:
try:
with db_session(read_only=True) as conn:
for table, columns in grouped.items():
joined_cols = ", ".join(columns)
for table, items in grouped.items():
query = (
f"SELECT trade_date, {joined_cols} FROM {table} "
f"SELECT * FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
@ -165,22 +160,25 @@ class DataBroker:
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
[column for column, _field in items],
exc,
extra=LOG_EXTRA,
)
continue
if not row:
continue
for column in columns:
value = row[column]
available = row.keys()
for column, original in items:
resolved_column = self._resolve_column_in_row(table, column, available)
if resolved_column is None:
continue
value = row[resolved_column]
if value is None:
continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
except sqlite3.OperationalError as exc:
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
if cache_key is not None:
@ -698,6 +696,22 @@ class DataBroker:
while len(cache) > limit:
cache.popitem(last=False)
def _resolve_column_in_row(
self,
table: str,
column: str,
available: Sequence[str],
) -> Optional[str]:
alias_map = self.FIELD_ALIASES.get(table, {})
candidate = alias_map.get(column, column)
if candidate in available:
return candidate
lowered = candidate.lower()
for name in available:
if name.lower() == lowered:
return name
return None
def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table)
if columns is None:

View File

@ -2,7 +2,8 @@
from __future__ import annotations
import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional
from .db import db_session
from .logging import get_logger
@ -40,3 +41,96 @@ def log_tuning_result(
)
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
def select_best_tuning_result(
experiment_id: str,
*,
metric: str = "reward",
descending: bool = True,
require_weights: bool = False,
) -> Optional[Dict[str, Any]]:
"""Return the best tuning result for the given experiment.
``metric`` may refer to ``reward`` (default) or any key inside the
persisted metrics payload. When ``require_weights`` is True, rows lacking
weight definitions are ignored.
"""
with db_session(read_only=True) as conn:
rows = conn.execute(
"""
SELECT id, action, weights, reward, metrics, created_at
FROM tuning_results
WHERE experiment_id = ?
""",
(experiment_id,),
).fetchall()
if not rows:
return None
best_row: Optional[Mapping[str, Any]] = None
best_metrics: Dict[str, Any] = {}
best_action: Dict[str, float] = {}
best_weights: Dict[str, float] = {}
best_score: Optional[float] = None
for row in rows:
action = _decode_json(row["action"])
weights = _decode_json(row["weights"])
metrics_payload = _decode_json(row["metrics"])
reward_value = float(row["reward"] or 0.0)
if require_weights and not weights:
continue
if metric == "reward":
score = reward_value
else:
score_raw = metrics_payload.get(metric)
if score_raw is None:
continue
try:
score = float(score_raw)
except (TypeError, ValueError):
continue
if best_score is None:
choose = True
else:
choose = score > best_score if descending else score < best_score
if choose:
best_score = score
best_row = row
best_metrics = metrics_payload
best_action = action
best_weights = weights
if best_row is None:
return None
return {
"id": best_row["id"],
"reward": float(best_row["reward"] or 0.0),
"score": best_score,
"metric": metric,
"action": best_action,
"weights": best_weights,
"metrics": best_metrics,
"created_at": best_row["created_at"],
}
def _decode_json(payload: Any) -> Dict[str, Any]:
if not payload:
return {}
if isinstance(payload, Mapping):
return dict(payload)
if isinstance(payload, str):
try:
return json.loads(payload)
except json.JSONDecodeError:
return {}
return {}

View File

@ -13,6 +13,7 @@
## 2. 数据与特征层
- 实现 `app/features/factors.py` 中的 `compute_factors()`,补齐因子计算与持久化流程。
- DataBroker `fetch_latest` 查询改为读取整行字段,使用时按需取值,避免列缺失导致的异常,后续取数逻辑遵循该约定。
- 完成 `app/ingest/rss.py` 的 RSS 拉取与写库逻辑,打通新闻与情绪数据源。
- 强化 `DataBroker` 的取数校验、缓存与回退策略,确保行情/特征补数统一自动化,减少人工兜底。
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。

View File

@ -0,0 +1,83 @@
"""Apply or display the best tuning result for an experiment."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.utils.config import get_config, save_config
from app.utils.tuning import select_best_tuning_result
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Apply best tuning weights")
parser.add_argument("experiment_id", help="Experiment identifier")
parser.add_argument(
"--metric",
default="reward",
help="Metric name for ranking (default: reward)",
)
parser.add_argument(
"--ascending",
action="store_true",
help="Sort metric ascending instead of descending",
)
parser.add_argument(
"--require-weights",
action="store_true",
help="Ignore records without weight payload",
)
parser.add_argument(
"--apply-config",
action="store_true",
help="Update agent_weights in config with best result weights (fallback to action)",
)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
best = select_best_tuning_result(
args.experiment_id,
metric=args.metric,
descending=not args.ascending,
require_weights=args.require_weights,
)
if not best:
LOGGER.error("未找到实验结果 experiment_id=%s", args.experiment_id)
return 1
print(json.dumps(best, ensure_ascii=False, indent=2))
if args.apply_config:
weights = best.get("weights") or best.get("action")
if not weights:
LOGGER.error("最佳结果缺少权重信息,无法更新配置")
return 2
cfg = get_config()
if not cfg.agent_weights:
LOGGER.warning("配置缺少 agent_weights初始化默认值")
cfg.agent_weights.update_from_dict(weights)
save_config(cfg)
LOGGER.info("已写入新的 agent_weights 至配置")
return 0
def main() -> None:
raise SystemExit(run_cli())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,124 @@
"""Run epsilon-greedy bandit tuning on DecisionEnv."""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime, date
from pathlib import Path
from typing import Iterable, List
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.engine import BtConfig
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.utils.config import get_config
def _parse_date(value: str) -> date:
return datetime.strptime(value, "%Y%m%d").date()
def _parse_param(text: str) -> ParameterSpec:
parts = text.split(":")
if len(parts) not in {3, 4}:
raise argparse.ArgumentTypeError(
"parameter format must be name:target:min[:max]"
)
name, target, minimum = parts[:3]
maximum = parts[3] if len(parts) == 4 else "1.0"
return ParameterSpec(
name=name,
target=target,
minimum=float(minimum),
maximum=float(maximum),
)
def _resolve_baseline_weights() -> dict:
cfg = get_config()
if cfg.agent_weights:
return cfg.agent_weights.as_dict()
return {agent.name: 1.0 for agent in default_agents()}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
parser.add_argument("experiment_id", help="Experiment identifier to log results")
parser.add_argument("name", help="Backtest config name")
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
parser.add_argument(
"--universe",
required=True,
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
)
parser.add_argument(
"--param",
action="append",
required=True,
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
)
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--epsilon", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=None)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
if args.end < args.start:
parser.error("end date must not precede start date")
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
bt_cfg = BtConfig(
id=args.experiment_id,
name=args.name,
start_date=args.start,
end_date=args.end,
universe=universe,
params={},
)
env = DecisionEnv(
bt_config=bt_cfg,
parameter_specs=specs,
baseline_weights=_resolve_baseline_weights(),
)
optimizer = EpsilonGreedyBandit(
env,
BanditConfig(
experiment_id=args.experiment_id,
episodes=args.episodes,
epsilon=args.epsilon,
seed=args.seed,
),
)
summary = optimizer.run()
best = summary.best_episode
output = {
"episodes": len(summary.episodes),
"average_reward": summary.average_reward,
"best": {
"reward": best.reward if best else None,
"action": best.action if best else None,
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
},
}
print(json.dumps(output, ensure_ascii=False, indent=2))
return 0
def main() -> None:
raise SystemExit(run_cli())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,57 @@
"""Verify BacktestEngine consumes persisted factor fields."""
from __future__ import annotations
from datetime import date
import pytest
from app.backtest.engine import BacktestEngine, BtConfig
@pytest.fixture()
def engine(monkeypatch):
cfg = BtConfig(
id="test",
name="factor",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
engine = BacktestEngine(cfg)
def fake_fetch_latest(ts_code, trade_date, fields): # noqa: D401
assert "factors.mom_20" in fields
return {
"daily.close": 10.0,
"daily.pct_chg": 0.02,
"daily_basic.turnover_rate": 5.0,
"daily_basic.volume_ratio": 15.0,
"factors.mom_20": 0.12,
"factors.mom_60": 0.25,
"factors.volat_20": 0.05,
"factors.turn_20": 3.0,
"news.sentiment_index": 0.3,
"news.heat_score": 0.4,
"macro.industry_heat": 0.6,
"macro.relative_strength": 0.7,
}
monkeypatch.setattr(engine.data_broker, "fetch_latest", fake_fetch_latest)
monkeypatch.setattr(engine.data_broker, "fetch_series", lambda *args, **kwargs: [])
monkeypatch.setattr(engine.data_broker, "fetch_flags", lambda *args, **kwargs: False)
return engine
def test_load_market_data_prefers_factors(engine):
data = engine.load_market_data(date(2025, 1, 10))
record = data["000001.SZ"]
features = record["features"]
assert features["mom_20"] == pytest.approx(0.12)
assert features["mom_60"] == pytest.approx(0.25)
assert features["volat_20"] == pytest.approx(0.05)
assert features["turn_20"] == pytest.approx(3.0)
assert features["news_sentiment"] == pytest.approx(0.3)
assert features["news_heat"] == pytest.approx(0.4)
assert features["risk_penalty"] == pytest.approx(min(1.0, 0.05 * 5.0))

View File

@ -5,9 +5,20 @@ from datetime import date
import pytest
import json
from app.agents.base import AgentAction, AgentContext
from app.agents.game import Decision
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
from app.backtest.engine import (
BacktestEngine,
BacktestResult,
BtConfig,
PortfolioState,
_persist_backtest_results,
)
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
from app.utils.db import db_session
def _make_context(price: float, features: dict | None = None) -> AgentContext:
@ -43,6 +54,20 @@ def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
return BacktestEngine(cfg)
@pytest.fixture()
def isolated_db(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
initialize_database()
try:
yield
finally:
cfg.data_paths = original_paths
def test_buy_respects_risk_caps():
engine = _engine_with_params(
{
@ -130,3 +155,56 @@ def test_sell_applies_slippage_and_fee():
assert not state.holdings
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
assert not result.risk_events
def test_persist_backtest_results_saves_risk_events(isolated_db):
cfg = BtConfig(
id="risk_cfg",
name="risk",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
result = BacktestResult()
result.nav_series = [
{
"trade_date": "2025-01-10",
"nav": 100.0,
"cash": 100.0,
"market_value": 0.0,
"realized_pnl": 0.0,
"unrealized_pnl": 0.0,
"turnover": 0.0,
}
]
result.risk_events = [
{
"trade_date": "2025-01-10",
"ts_code": "000001.SZ",
"reason": "limit_up",
"action": "buy_l",
"target_weight": 0.3,
"confidence": 0.8,
}
]
_persist_backtest_results(cfg, result)
with db_session(read_only=True) as conn:
risk_row = conn.execute(
"SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?",
(cfg.id,),
).fetchone()
assert risk_row is not None
assert risk_row["reason"] == "limit_up"
metadata = json.loads(risk_row["metadata"])
assert metadata["action"] == "buy_l"
summary_row = conn.execute(
"SELECT summary FROM bt_report WHERE cfg_id = ?",
(cfg.id,),
).fetchone()
summary = json.loads(summary_row["summary"])
assert summary["risk_events"] == 1
assert summary["risk_breakdown"]["limit_up"] == 1

View File

@ -0,0 +1,92 @@
"""Tests for epsilon-greedy bandit optimizer."""
from __future__ import annotations
import pytest
from app.backtest.decision_env import EpisodeMetrics, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.utils import tuning
class DummyEnv:
def __init__(self) -> None:
self._specs = [
ParameterSpec(name="w1", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)
]
self._last_metrics: EpisodeMetrics | None = None
self._episode = 0
@property
def action_dim(self) -> int:
return 1
@property
def last_metrics(self) -> EpisodeMetrics | None:
return self._last_metrics
def reset(self) -> dict:
self._episode += 1
return {"episode": float(self._episode)}
def step(self, action):
value = float(action[0])
reward = 1.0 - abs(value - 0.7)
metrics = EpisodeMetrics(
total_return=reward,
max_drawdown=0.1,
volatility=0.05,
nav_series=[],
trades=[],
turnover=100.0,
trade_count=0,
risk_count=1,
risk_breakdown={"test": 1},
)
self._last_metrics = metrics
obs = {
"total_return": reward,
"max_drawdown": 0.1,
"volatility": 0.05,
"sharpe_like": reward / 0.05,
"turnover": 100.0,
"trade_count": 0.0,
"risk_count": 1.0,
}
info = {
"nav_series": [],
"trades": [],
"weights": {"A_mom": value},
"risk_breakdown": metrics.risk_breakdown,
"risk_events": [],
}
return obs, reward, True, info
@pytest.fixture(autouse=True)
def patch_logging(monkeypatch):
records = []
def fake_log_tuning_result(**kwargs):
records.append(kwargs)
monkeypatch.setattr(tuning, "log_tuning_result", fake_log_tuning_result)
from app.backtest import optimizer as optimizer_module
monkeypatch.setattr(optimizer_module, "log_tuning_result", fake_log_tuning_result)
return records
def test_bandit_optimizer_runs_and_logs(patch_logging):
env = DummyEnv()
optimizer = EpsilonGreedyBandit(
env,
BanditConfig(experiment_id="exp", episodes=5, epsilon=0.5, seed=42),
)
summary = optimizer.run()
assert len(summary.episodes) == 5
assert summary.best_episode is not None
assert patch_logging and len(patch_logging) == 5
payload = patch_logging[0]["metrics"]
assert isinstance(payload, dict)
assert "risk_breakdown" in payload

View File

@ -0,0 +1,92 @@
"""Tests for DecisionEnv risk-aware reward and info outputs."""
from __future__ import annotations
from datetime import date
import pytest
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
from app.backtest.engine import BacktestResult, BtConfig
class _StubEngine:
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
self.cfg = cfg
self.weights = {}
self.department_manager = None
def run(self) -> BacktestResult:
result = BacktestResult()
result.nav_series = [
{
"trade_date": "2025-01-10",
"nav": 102.0,
"cash": 50.0,
"market_value": 52.0,
"realized_pnl": 1.0,
"unrealized_pnl": 1.0,
"turnover": 20000.0,
}
]
result.trades = [
{
"trade_date": "2025-01-10",
"ts_code": "000001.SZ",
"action": "buy",
"quantity": 100.0,
"price": 100.0,
"value": 10000.0,
"fee": 5.0,
}
]
result.risk_events = [
{
"trade_date": "2025-01-10",
"ts_code": "000002.SZ",
"reason": "limit_up",
"action": "buy_l",
"confidence": 0.7,
"target_weight": 0.2,
}
]
return result
def test_decision_env_returns_risk_metrics(monkeypatch):
cfg = BtConfig(
id="stub",
name="stub",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)]
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
obs, reward, done, info = env.step([0.8])
assert done is True
assert "risk_count" in obs and obs["risk_count"] == 1.0
assert obs["turnover"] == pytest.approx(20000.0)
assert info["risk_events"][0]["reason"] == "limit_up"
assert info["risk_breakdown"]["limit_up"] == 1
assert reward < obs["total_return"]
def test_default_reward_penalizes_metrics():
metrics = EpisodeMetrics(
total_return=0.1,
max_drawdown=0.2,
volatility=0.05,
nav_series=[],
trades=[],
turnover=1000.0,
trade_count=0,
risk_count=2,
risk_breakdown={"foo": 2},
)
reward = DecisionEnv._default_reward(metrics)
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.00001 * 1000.0))

View File

@ -0,0 +1,81 @@
"""Tests for tuning result selection and CLI application."""
from __future__ import annotations
import json
import pytest
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
from app.utils.db import db_session
from app.utils.tuning import select_best_tuning_result
import scripts.apply_best_weights as apply_best_weights
@pytest.fixture()
def isolated_env(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
initialize_database()
try:
yield cfg
finally:
cfg.data_paths = original_paths
def _insert_result(experiment: str, reward: float, metrics: dict, weights: dict | None = None, action: dict | None = None) -> None:
with db_session() as conn:
conn.execute(
"""
INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
experiment,
"test",
json.dumps(action or {}, ensure_ascii=False),
json.dumps(weights or {}, ensure_ascii=False),
reward,
json.dumps(metrics, ensure_ascii=False),
),
)
def test_select_best_by_reward(isolated_env):
_insert_result("exp", 0.1, {"risk_count": 2}, {"A_mom": 0.3})
_insert_result("exp", 0.25, {"risk_count": 4}, {"A_mom": 0.6})
best = select_best_tuning_result("exp")
assert best is not None
assert best["reward"] == pytest.approx(0.25)
assert best["weights"]["A_mom"] == pytest.approx(0.6)
def test_select_best_by_metric(isolated_env):
_insert_result("exp_metric", 0.2, {"risk_count": 5}, {"A_mom": 0.4})
_insert_result("exp_metric", 0.1, {"risk_count": 2}, {"A_mom": 0.7})
best = select_best_tuning_result("exp_metric", metric="risk_count", descending=False)
assert best is not None
assert best["weights"]["A_mom"] == pytest.approx(0.7)
assert best["metrics"]["risk_count"] == 2
def test_apply_best_weights_cli_updates_config(isolated_env, capsys):
cfg = isolated_env
_insert_result("exp_cli", 0.3, {"risk_count": 1}, {"A_mom": 0.65, "A_val": 0.2})
exit_code = apply_best_weights.run_cli([
"exp_cli",
"--apply-config",
])
assert exit_code == 0
output = capsys.readouterr().out
payload = json.loads(output)
assert payload["metric"] == "reward"
updated = cfg.agent_weights.as_dict()
assert updated["A_mom"] == pytest.approx(0.65)
assert updated["A_val"] == pytest.approx(0.2)