From 8befd80cb7ebd06af27c37ef596b696e96e50db2 Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 30 Sep 2025 18:08:15 +0800 Subject: [PATCH] update --- scripts/run_ingestion_job.py | 117 +++++++++++++++++++++++++ tests/test_backtest_engine_risk.py | 132 +++++++++++++++++++++++++++++ tests/test_data_broker_cache.py | 117 +++++++++++++++++++++++++ tests/test_ingest_tushare.py | 108 +++++++++++++++++++++++ tests/test_run_ingestion_script.py | 62 ++++++++++++++ 5 files changed, 536 insertions(+) create mode 100644 scripts/run_ingestion_job.py create mode 100644 tests/test_backtest_engine_risk.py create mode 100644 tests/test_data_broker_cache.py create mode 100644 tests/test_ingest_tushare.py create mode 100644 tests/test_run_ingestion_script.py diff --git a/scripts/run_ingestion_job.py b/scripts/run_ingestion_job.py new file mode 100644 index 0000000..f061953 --- /dev/null +++ b/scripts/run_ingestion_job.py @@ -0,0 +1,117 @@ +"""命令行脚本:按日期区间执行 TuShare 拉数并同步计算因子。""" +from __future__ import annotations + +import argparse +import sys +from datetime import datetime, date +from pathlib import Path +from typing import Iterable, Sequence + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.ingest.tushare import FetchJob, run_ingestion +from app.utils.config import get_config +from app.utils.logging import get_logger +from app.utils import alerts + +LOGGER = get_logger(__name__) + + +def _parse_date(text: str) -> date: + try: + return datetime.strptime(text, "%Y%m%d").date() + except ValueError as exc: # noqa: BLE001 + raise argparse.ArgumentTypeError(f"无法解析日期:{text}") from exc + + +def _parse_codes(raw: Sequence[str] | None) -> tuple[str, ...] | None: + if not raw: + return None + normalized = [] + for item in raw: + token = item.strip().upper() + if token: + normalized.append(token) + return tuple(dict.fromkeys(normalized)) or None + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="按日期区间执行 TuShare 拉数并同步更新因子表", + ) + parser.add_argument("start", type=_parse_date, help="起始交易日(格式:YYYYMMDD)") + parser.add_argument("end", type=_parse_date, help="结束交易日(格式:YYYYMMDD)") + parser.add_argument( + "--codes", + nargs="*", + default=None, + help="可选的股票代码列表(如 000001.SZ),不传则处理全市场", + ) + parser.add_argument( + "--include-limits", + action="store_true", + help="是否同步涨跌停/停牌等扩展数据(默认关闭,便于快速试跑)", + ) + parser.add_argument( + "--name", + default="daily_ingestion", + help="任务名称,用于日志与告警标记", + ) + parser.add_argument( + "--granularity", + default="daily", + choices=("daily", "weekly"), + help="任务粒度,目前仅 daily 会触发因子计算", + ) + return parser + + +def run_cli(argv: Iterable[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(list(argv) if argv is not None else None) + + if args.end < args.start: + parser.error("结束日期不能早于起始日期") + + codes = _parse_codes(args.codes) + job = FetchJob( + name=str(args.name), + start=args.start, + end=args.end, + granularity=str(args.granularity), + ts_codes=codes, + ) + + LOGGER.info( + "准备执行拉数任务 name=%s start=%s end=%s codes=%s granularity=%s", + job.name, + job.start, + job.end, + job.ts_codes, + job.granularity, + ) + + try: + run_ingestion(job, include_limits=bool(args.include_limits)) + except Exception: # noqa: BLE001 + LOGGER.exception("拉数任务执行失败") + return 1 + + warnings = alerts.get_warnings() + if warnings: + LOGGER.warning("任务完成但存在告警:%s", warnings) + return 2 + + LOGGER.info("任务执行完成,无告警") + return 0 + + +def main() -> None: + exit_code = run_cli() + raise SystemExit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/test_backtest_engine_risk.py b/tests/test_backtest_engine_risk.py new file mode 100644 index 0000000..e7df919 --- /dev/null +++ b/tests/test_backtest_engine_risk.py @@ -0,0 +1,132 @@ +"""Tests for BacktestEngine risk-aware execution.""" +from __future__ import annotations + +from datetime import date + +import pytest + +from app.agents.base import AgentAction, AgentContext +from app.agents.game import Decision +from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState + + +def _make_context(price: float, features: dict | None = None) -> AgentContext: + scope_values = {"daily.close": price} + return AgentContext( + ts_code="000001.SZ", + trade_date="2025-01-10", + features=features or {}, + market_snapshot={}, + raw={"scope_values": scope_values}, + ) + + +def _make_decision(action: AgentAction, target_weight: float = 0.0) -> Decision: + return Decision( + action=action, + confidence=0.8, + target_weight=target_weight, + feasible_actions=[], + utilities={}, + ) + + +def _engine_with_params(params: dict[str, float]) -> BacktestEngine: + cfg = BtConfig( + id="test", + name="test", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params=params, + ) + return BacktestEngine(cfg) + + +def test_buy_respects_risk_caps(): + engine = _engine_with_params( + { + "max_position_weight": 0.2, + "fee_rate": 0.0, + "slippage_bps": 0.0, + "max_daily_turnover_ratio": 1.0, + } + ) + state = PortfolioState(cash=100_000.0) + result = BacktestResult() + features = { + "liquidity_score": 0.7, + "risk_penalty": 0.25, + } + context = _make_context(100.0, features) + decision = _make_decision(AgentAction.BUY_L, target_weight=0.5) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + expected_qty = (100_000.0 * 0.2 * (1 - 0.25)) / 100.0 + assert state.holdings["000001.SZ"] == pytest.approx(expected_qty) + assert state.cash == pytest.approx(100_000.0 - expected_qty * 100.0) + assert result.trades and result.trades[0]["status"] == "executed" + assert result.nav_series[0]["turnover"] == pytest.approx(expected_qty * 100.0) + + +def test_buy_blocked_by_limit_up_records_risk(): + engine = _engine_with_params({}) + state = PortfolioState(cash=50_000.0) + result = BacktestResult() + features = {"limit_up": True} + context = _make_context(100.0, features) + decision = _make_decision(AgentAction.BUY_M, target_weight=0.1) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + assert "000001.SZ" not in state.holdings + assert not result.trades + assert result.risk_events + assert result.risk_events[0]["reason"] == "limit_up" + + +def test_sell_applies_slippage_and_fee(): + engine = _engine_with_params( + { + "max_position_weight": 0.3, + "fee_rate": 0.001, + "slippage_bps": 20.0, + "max_daily_turnover_ratio": 1.0, + } + ) + state = PortfolioState( + cash=0.0, + holdings={"000001.SZ": 100.0}, + cost_basis={"000001.SZ": 90.0}, + opened_dates={"000001.SZ": "2024-12-01"}, + ) + result = BacktestResult() + context = _make_context(100.0, {}) + decision = _make_decision(AgentAction.SELL) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + trade = result.trades[0] + assert pytest.approx(trade["price"], rel=1e-6) == 100.0 * (1 - 0.002) + assert pytest.approx(trade["fee"], rel=1e-6) == trade["value"] * 0.001 + assert state.cash == pytest.approx(trade["value"] - trade["fee"]) + assert state.realized_pnl == pytest.approx((trade["price"] - 90.0) * 100 - trade["fee"]) + assert not state.holdings + assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"]) + assert not result.risk_events diff --git a/tests/test_data_broker_cache.py b/tests/test_data_broker_cache.py new file mode 100644 index 0000000..e31fb3c --- /dev/null +++ b/tests/test_data_broker_cache.py @@ -0,0 +1,117 @@ +"""验证 DataBroker 的缓存与回退行为。""" +from __future__ import annotations + +import sqlite3 +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Dict, Iterable, List, Tuple + +import pytest + +from app.utils import data_access +from app.utils.data_access import DataBroker + + +class _FakeRow(dict): + def __getitem__(self, key: str) -> Any: # type: ignore[override] + return dict.get(self, key) + + +class _FakeCursor: + def __init__(self, rows: Iterable[Dict[str, Any]]): + self._rows = list(rows) + + def fetchone(self) -> Dict[str, Any] | None: + return self._rows[0] if self._rows else None + + def fetchall(self) -> List[Dict[str, Any]]: + return list(self._rows) + + +class _FakeConn: + def __init__(self, calls: List[str], failure_flags: Dict[str, bool]) -> None: + self._calls = calls + self._failure_flags = failure_flags + self._daily_series = [ + _FakeRow({"trade_date": "20250112", "close": 12.3, "open": 12.1}), + _FakeRow({"trade_date": "20250111", "close": 11.9, "open": 11.6}), + ] + self._turn_series = [ + _FakeRow({"trade_date": "20250112", "turnover_rate": 2.5}), + _FakeRow({"trade_date": "20250111", "turnover_rate": 2.2}), + ] + + def execute(self, query: str, params: Tuple[Any, ...] | None = None): + self._calls.append(query) + params = params or () + upper = query.upper() + if upper.startswith("PRAGMA TABLE_INFO"): + if "DAILY_BASIC" in upper: + rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "turnover_rate"})] + else: + rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "close"}), _FakeRow({"name": "open"})] + return _FakeCursor(rows) + if "FROM DAILY " in upper: + if self._failure_flags.get("daily"): + raise sqlite3.OperationalError("stub failure") + if "LIMIT 1" in upper: + return _FakeCursor([self._daily_series[0]]) + limit = params[-1] if params else len(self._daily_series) + return _FakeCursor(self._daily_series[: int(limit)]) + if "FROM DAILY_BASIC" in upper: + limit = params[-1] if params else len(self._turn_series) + return _FakeCursor(self._turn_series[: int(limit)]) + raise AssertionError(f"Unexpected query: {query}") + + def close(self) -> None: # pragma: no cover - compatibility + return None + + +@pytest.fixture() +def patched_db(monkeypatch): + calls: List[str] = [] + failure_flags: Dict[str, bool] = defaultdict(bool) + + @contextmanager + def _session(read_only: bool = False): # noqa: D401 - contextmanager stub + conn = _FakeConn(calls, failure_flags) + yield conn + + monkeypatch.setattr(data_access, "db_session", _session) + yield calls, failure_flags + + +def test_fetch_latest_uses_cache(patched_db): + calls, failure = patched_db + broker = DataBroker() + + result = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"]) + assert result["daily.close"] == pytest.approx(12.3) + first_count = len(calls) + + result_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.open", "daily.close"]) + assert result_cached == result + assert len(calls) == first_count + + failure["daily"] = True + still_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"]) + assert still_cached["daily.close"] == pytest.approx(12.3) + assert len(calls) == first_count + + +def test_fetch_series_cache_and_disable(patched_db): + calls, _ = patched_db + broker = DataBroker(series_cache_size=4) + + series = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2) + assert len(series) == 2 + first_count = len(calls) + + series_cached = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2) + assert series_cached == series + assert len(calls) == first_count + + broker_no_cache = DataBroker(enable_cache=False) + calls_before = len(calls) + broker_no_cache.fetch_series("daily", "close", "000001.SZ", "20250112", 2) + assert len(calls) > calls_before diff --git a/tests/test_ingest_tushare.py b/tests/test_ingest_tushare.py new file mode 100644 index 0000000..6aebbd0 --- /dev/null +++ b/tests/test_ingest_tushare.py @@ -0,0 +1,108 @@ +"""验证 TuShare 拉数流程与因子计算的集成行为。""" +from __future__ import annotations + +import sys +import types +from datetime import date + +import pytest + +# 某些环境下 pandas 可能存在二进制依赖问题,这里提供最小桩避免导入失败 +try: # pragma: no cover - 测试运行环境中若 pandas 可用则直接复用 + import pandas as _pd # type: ignore +except Exception: # pragma: no cover - stub fallback + pandas_stub = types.ModuleType("pandas") + + class _DummyFrame: # pylint: disable=too-few-public-methods + empty = True + + def __init__(self, *args, **kwargs): # noqa: D401 + """轻量占位,避免测试期调用实际逻辑。""" + + def to_dict(self, *_args, **_kwargs): + return {} + + def reindex(self, *_args, **_kwargs): + return self + + def where(self, *_args, **_kwargs): + return self + + pandas_stub.DataFrame = _DummyFrame + pandas_stub.Series = _DummyFrame + pandas_stub.concat = lambda *args, **kwargs: _DummyFrame() # type: ignore[arg-type] + pandas_stub.Timestamp = lambda *args, **kwargs: None # type: ignore[assignment] + pandas_stub.to_datetime = lambda value, **kwargs: value # type: ignore[assignment] + pandas_stub.isna = lambda value: False # type: ignore[assignment] + pandas_stub.notna = lambda value: True # type: ignore[assignment] + + sys.modules.setdefault("pandas", pandas_stub) +else: # pragma: no cover + sys.modules.setdefault("pandas", _pd) + +from app.ingest.tushare import FetchJob, run_ingestion +from app.utils import alerts + + +@pytest.fixture(autouse=True) +def clear_alerts(): + alerts.clear_warnings() + yield + alerts.clear_warnings() + + +def test_run_ingestion_triggers_factor_range(monkeypatch): + job = FetchJob( + name="daily_job", + start=date(2025, 1, 10), + end=date(2025, 1, 11), + ts_codes=("000001.SZ",), + ) + + coverage_called = {} + + def fake_coverage(*args, **kwargs): + coverage_called["args"] = (args, kwargs) + + monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", fake_coverage) + + captured: dict = {} + + def fake_compute(start, end, **kwargs): + captured["start"] = start + captured["end"] = end + captured["kwargs"] = kwargs + return [] + + monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute) + + run_ingestion(job, include_limits=False) + + assert "args" in coverage_called + assert captured["start"] == job.start + assert captured["end"] == job.end + assert captured["kwargs"] == {"ts_codes": job.ts_codes, "skip_existing": False} + + +def test_run_ingestion_skips_factors_for_non_daily(monkeypatch): + job = FetchJob( + name="weekly_job", + start=date(2025, 1, 10), + end=date(2025, 1, 17), + granularity="weekly", + ts_codes=None, + ) + + monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", lambda *_, **__: None) + + invoked = {"count": 0} + + def fake_compute(*args, **kwargs): + invoked["count"] += 1 + return [] + + monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute) + + run_ingestion(job) + + assert invoked["count"] == 0 diff --git a/tests/test_run_ingestion_script.py b/tests/test_run_ingestion_script.py new file mode 100644 index 0000000..73ff0d7 --- /dev/null +++ b/tests/test_run_ingestion_script.py @@ -0,0 +1,62 @@ +"""针对 scripts/run_ingestion_job.py 的 CLI 行为测试。""" +from __future__ import annotations + +import pytest + +import scripts.run_ingestion_job as cli + + +@pytest.fixture(autouse=True) +def reset_alerts(monkeypatch): + monkeypatch.setattr(cli.alerts, "clear_warnings", lambda *args, **kwargs: None) + yield + + +def test_cli_invokes_run_ingestion_with_codes(monkeypatch): + captured: dict = {} + + def fake_run(job, include_limits): + captured["job"] = job + captured["include_limits"] = include_limits + + monkeypatch.setattr(cli, "run_ingestion", fake_run) + monkeypatch.setattr(cli.alerts, "get_warnings", lambda: []) + + exit_code = cli.run_cli( + [ + "20250110", + "20250112", + "--codes", + "000001.SZ", + "000002.SZ", + "--include-limits", + "--name", + "test_job", + ] + ) + + assert exit_code == 0 + job = captured["job"] + assert job.name == "test_job" + assert job.start.isoformat() == "2025-01-10" + assert job.end.isoformat() == "2025-01-12" + assert job.ts_codes == ("000001.SZ", "000002.SZ") + assert captured["include_limits"] is True + + +def test_cli_returns_warning_status(monkeypatch): + monkeypatch.setattr(cli, "run_ingestion", lambda *args, **kwargs: None) + monkeypatch.setattr( + cli.alerts, + "get_warnings", + lambda: [{"source": "Factors", "message": "mock warning"}], + ) + + exit_code = cli.run_cli(["20250101", "20250102"]) + + assert exit_code == 2 + + +def test_cli_validates_date_order(monkeypatch): + with pytest.raises(SystemExit): + cli.run_cli(["20250105", "20250101"])