llm-quant/scripts/run_bandit_optimization.py
2025-09-30 18:34:29 +08:00

125 lines
3.8 KiB
Python

"""Run epsilon-greedy bandit tuning on DecisionEnv."""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime, date
from pathlib import Path
from typing import Iterable, List
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.engine import BtConfig
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.utils.config import get_config
def _parse_date(value: str) -> date:
return datetime.strptime(value, "%Y%m%d").date()
def _parse_param(text: str) -> ParameterSpec:
parts = text.split(":")
if len(parts) not in {3, 4}:
raise argparse.ArgumentTypeError(
"parameter format must be name:target:min[:max]"
)
name, target, minimum = parts[:3]
maximum = parts[3] if len(parts) == 4 else "1.0"
return ParameterSpec(
name=name,
target=target,
minimum=float(minimum),
maximum=float(maximum),
)
def _resolve_baseline_weights() -> dict:
cfg = get_config()
if cfg.agent_weights:
return cfg.agent_weights.as_dict()
return {agent.name: 1.0 for agent in default_agents()}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="DecisionEnv bandit optimizer")
parser.add_argument("experiment_id", help="Experiment identifier to log results")
parser.add_argument("name", help="Backtest config name")
parser.add_argument("start", type=_parse_date, help="Start date YYYYMMDD")
parser.add_argument("end", type=_parse_date, help="End date YYYYMMDD")
parser.add_argument(
"--universe",
required=True,
help="Comma separated ts_codes, e.g. 000001.SZ,000002.SZ",
)
parser.add_argument(
"--param",
action="append",
required=True,
help="Parameter spec name:target:min[:max] (target like agent_weights.A_mom)",
)
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--epsilon", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=None)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
if args.end < args.start:
parser.error("end date must not precede start date")
specs: List[ParameterSpec] = [_parse_param(item) for item in args.param]
universe = [token.strip() for token in args.universe.split(",") if token.strip()]
bt_cfg = BtConfig(
id=args.experiment_id,
name=args.name,
start_date=args.start,
end_date=args.end,
universe=universe,
params={},
)
env = DecisionEnv(
bt_config=bt_cfg,
parameter_specs=specs,
baseline_weights=_resolve_baseline_weights(),
)
optimizer = EpsilonGreedyBandit(
env,
BanditConfig(
experiment_id=args.experiment_id,
episodes=args.episodes,
epsilon=args.epsilon,
seed=args.seed,
),
)
summary = optimizer.run()
best = summary.best_episode
output = {
"episodes": len(summary.episodes),
"average_reward": summary.average_reward,
"best": {
"reward": best.reward if best else None,
"action": best.action if best else None,
"metrics": (best.metrics and json.dumps(best.metrics.risk_breakdown)) if best else None,
},
}
print(json.dumps(output, ensure_ascii=False, indent=2))
return 0
def main() -> None:
raise SystemExit(run_cli())
if __name__ == "__main__":
main()