125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
"""Run epsilon-greedy bandit tuning on DecisionEnv."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from datetime import datetime, date
|
|
from pathlib import Path
|
|
from typing import Iterable, List
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from app.agents.registry import default_agents
|
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
|
from app.backtest.engine import BtConfig
|
|
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
|
from app.utils.config import get_config
|
|
|
|
|
|
def _parse_date(value: str) -> date:
|
|
return datetime.strptime(value, "%Y%m%d").date()
|
|
|
|
|
|
def _parse_param(text: str) -> ParameterSpec:
|
|
parts = text.split(":")
|
|
if len(parts) not in {3, 4}:
|
|
raise argparse.ArgumentTypeError(
|
|
"parameter format must be name:target:min[:max]"
|
|
)
|
|
name, target, minimum = parts[:3]
|
|
maximum = parts[3] if len(parts) == 4 else "1.0"
|
|
return ParameterSpec(
|
|
name=name,
|
|
target=target,
|
|
minimum=float(minimum),
|
|
maximum=float(maximum),
|
|
)
|
|
|
|
|
|
def _resolve_baseline_weights() -> dict:
|
|
cfg = get_config()
|
|
if cfg.agent_weights:
|
|
return cfg.agent_weights.as_dict()
|
|
return {agent.name: 1.0 for agent in default_agents()}
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
|
|
parser.add_argument("experiment_id", help="Experiment identifier to log results")
|
|
parser.add_argument("name", help="Backtest config name")
|
|
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
|
|
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
|
|
parser.add_argument(
|
|
"--universe",
|
|
required=True,
|
|
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
|
|
)
|
|
parser.add_argument(
|
|
"--param",
|
|
action="append",
|
|
required=True,
|
|
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
|
|
)
|
|
parser.add_argument("--episodes", type=int, default=20)
|
|
parser.add_argument("--epsilon", type=float, default=0.2)
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
return parser
|
|
|
|
|
|
def run_cli(argv: Iterable[str] | None = None) -> int:
|
|
parser = build_parser()
|
|
args = parser.parse_args(list(argv) if argv is not None else None)
|
|
|
|
if args.end < args.start:
|
|
parser.error("end date must not precede start date")
|
|
|
|
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
|
|
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
|
|
bt_cfg = BtConfig(
|
|
id=args.experiment_id,
|
|
name=args.name,
|
|
start_date=args.start,
|
|
end_date=args.end,
|
|
universe=universe,
|
|
params={},
|
|
)
|
|
|
|
env = DecisionEnv(
|
|
bt_config=bt_cfg,
|
|
parameter_specs=specs,
|
|
baseline_weights=_resolve_baseline_weights(),
|
|
)
|
|
optimizer = EpsilonGreedyBandit(
|
|
env,
|
|
BanditConfig(
|
|
experiment_id=args.experiment_id,
|
|
episodes=args.episodes,
|
|
epsilon=args.epsilon,
|
|
seed=args.seed,
|
|
),
|
|
)
|
|
summary = optimizer.run()
|
|
best = summary.best_episode
|
|
output = {
|
|
"episodes": len(summary.episodes),
|
|
"average_reward": summary.average_reward,
|
|
"best": {
|
|
"reward": best.reward if best else None,
|
|
"action": best.action if best else None,
|
|
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
|
|
},
|
|
}
|
|
print(json.dumps(output, ensure_ascii=False, indent=2))
|
|
return 0
|
|
|
|
|
|
def main() -> None:
|
|
raise SystemExit(run_cli())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|