This commit is contained in:
sam 2025-09-30 18:08:15 +08:00
parent 8f820e441e
commit 8befd80cb7
5 changed files with 536 additions and 0 deletions

View File

@ -0,0 +1,117 @@
"""命令行脚本:按日期区间执行 TuShare 拉数并同步计算因子。"""
from __future__ import annotations
import argparse
import sys
from datetime import datetime, date
from pathlib import Path
from typing import Iterable, Sequence
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.ingest.tushare import FetchJob, run_ingestion
from app.utils.config import get_config
from app.utils.logging import get_logger
from app.utils import alerts
LOGGER = get_logger(__name__)
def _parse_date(text: str) -> date:
try:
return datetime.strptime(text, "%Y%m%d").date()
except ValueError as exc: # noqa: BLE001
raise argparse.ArgumentTypeError(f"无法解析日期:{text}") from exc
def _parse_codes(raw: Sequence[str] | None) -> tuple[str, ...] | None:
if not raw:
return None
normalized = []
for item in raw:
token = item.strip().upper()
if token:
normalized.append(token)
return tuple(dict.fromkeys(normalized)) or None
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="按日期区间执行 TuShare 拉数并同步更新因子表",
)
parser.add_argument("start", type=_parse_date, help="起始交易日格式YYYYMMDD")
parser.add_argument("end", type=_parse_date, help="结束交易日格式YYYYMMDD")
parser.add_argument(
"--codes",
nargs="*",
default=None,
help="可选的股票代码列表(如 000001.SZ不传则处理全市场",
)
parser.add_argument(
"--include-limits",
action="store_true",
help="是否同步涨跌停/停牌等扩展数据(默认关闭,便于快速试跑)",
)
parser.add_argument(
"--name",
default="daily_ingestion",
help="任务名称,用于日志与告警标记",
)
parser.add_argument(
"--granularity",
default="daily",
choices=("daily", "weekly"),
help="任务粒度,目前仅 daily 会触发因子计算",
)
return parser
def run_cli(argv: Iterable[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(list(argv) if argv is not None else None)
if args.end < args.start:
parser.error("结束日期不能早于起始日期")
codes = _parse_codes(args.codes)
job = FetchJob(
name=str(args.name),
start=args.start,
end=args.end,
granularity=str(args.granularity),
ts_codes=codes,
)
LOGGER.info(
"准备执行拉数任务 name=%s start=%s end=%s codes=%s granularity=%s",
job.name,
job.start,
job.end,
job.ts_codes,
job.granularity,
)
try:
run_ingestion(job, include_limits=bool(args.include_limits))
except Exception: # noqa: BLE001
LOGGER.exception("拉数任务执行失败")
return 1
warnings = alerts.get_warnings()
if warnings:
LOGGER.warning("任务完成但存在告警:%s", warnings)
return 2
LOGGER.info("任务执行完成,无告警")
return 0
def main() -> None:
exit_code = run_cli()
raise SystemExit(exit_code)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,132 @@
"""Tests for BacktestEngine risk-aware execution."""
from __future__ import annotations
from datetime import date
import pytest
from app.agents.base import AgentAction, AgentContext
from app.agents.game import Decision
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
def _make_context(price: float, features: dict | None = None) -> AgentContext:
scope_values = {"daily.close": price}
return AgentContext(
ts_code="000001.SZ",
trade_date="2025-01-10",
features=features or {},
market_snapshot={},
raw={"scope_values": scope_values},
)
def _make_decision(action: AgentAction, target_weight: float = 0.0) -> Decision:
return Decision(
action=action,
confidence=0.8,
target_weight=target_weight,
feasible_actions=[],
utilities={},
)
def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
cfg = BtConfig(
id="test",
name="test",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params=params,
)
return BacktestEngine(cfg)
def test_buy_respects_risk_caps():
engine = _engine_with_params(
{
"max_position_weight": 0.2,
"fee_rate": 0.0,
"slippage_bps": 0.0,
"max_daily_turnover_ratio": 1.0,
}
)
state = PortfolioState(cash=100_000.0)
result = BacktestResult()
features = {
"liquidity_score": 0.7,
"risk_penalty": 0.25,
}
context = _make_context(100.0, features)
decision = _make_decision(AgentAction.BUY_L, target_weight=0.5)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
expected_qty = (100_000.0 * 0.2 * (1 - 0.25)) / 100.0
assert state.holdings["000001.SZ"] == pytest.approx(expected_qty)
assert state.cash == pytest.approx(100_000.0 - expected_qty * 100.0)
assert result.trades and result.trades[0]["status"] == "executed"
assert result.nav_series[0]["turnover"] == pytest.approx(expected_qty * 100.0)
def test_buy_blocked_by_limit_up_records_risk():
engine = _engine_with_params({})
state = PortfolioState(cash=50_000.0)
result = BacktestResult()
features = {"limit_up": True}
context = _make_context(100.0, features)
decision = _make_decision(AgentAction.BUY_M, target_weight=0.1)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert "000001.SZ" not in state.holdings
assert not result.trades
assert result.risk_events
assert result.risk_events[0]["reason"] == "limit_up"
def test_sell_applies_slippage_and_fee():
engine = _engine_with_params(
{
"max_position_weight": 0.3,
"fee_rate": 0.001,
"slippage_bps": 20.0,
"max_daily_turnover_ratio": 1.0,
}
)
state = PortfolioState(
cash=0.0,
holdings={"000001.SZ": 100.0},
cost_basis={"000001.SZ": 90.0},
opened_dates={"000001.SZ": "2024-12-01"},
)
result = BacktestResult()
context = _make_context(100.0, {})
decision = _make_decision(AgentAction.SELL)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
trade = result.trades[0]
assert pytest.approx(trade["price"], rel=1e-6) == 100.0 * (1 - 0.002)
assert pytest.approx(trade["fee"], rel=1e-6) == trade["value"] * 0.001
assert state.cash == pytest.approx(trade["value"] - trade["fee"])
assert state.realized_pnl == pytest.approx((trade["price"] - 90.0) * 100 - trade["fee"])
assert not state.holdings
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
assert not result.risk_events

View File

@ -0,0 +1,117 @@
"""验证 DataBroker 的缓存与回退行为。"""
from __future__ import annotations
import sqlite3
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Tuple
import pytest
from app.utils import data_access
from app.utils.data_access import DataBroker
class _FakeRow(dict):
def __getitem__(self, key: str) -> Any: # type: ignore[override]
return dict.get(self, key)
class _FakeCursor:
def __init__(self, rows: Iterable[Dict[str, Any]]):
self._rows = list(rows)
def fetchone(self) -> Dict[str, Any] | None:
return self._rows[0] if self._rows else None
def fetchall(self) -> List[Dict[str, Any]]:
return list(self._rows)
class _FakeConn:
def __init__(self, calls: List[str], failure_flags: Dict[str, bool]) -> None:
self._calls = calls
self._failure_flags = failure_flags
self._daily_series = [
_FakeRow({"trade_date": "20250112", "close": 12.3, "open": 12.1}),
_FakeRow({"trade_date": "20250111", "close": 11.9, "open": 11.6}),
]
self._turn_series = [
_FakeRow({"trade_date": "20250112", "turnover_rate": 2.5}),
_FakeRow({"trade_date": "20250111", "turnover_rate": 2.2}),
]
def execute(self, query: str, params: Tuple[Any, ...] | None = None):
self._calls.append(query)
params = params or ()
upper = query.upper()
if upper.startswith("PRAGMA TABLE_INFO"):
if "DAILY_BASIC" in upper:
rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "turnover_rate"})]
else:
rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "close"}), _FakeRow({"name": "open"})]
return _FakeCursor(rows)
if "FROM DAILY " in upper:
if self._failure_flags.get("daily"):
raise sqlite3.OperationalError("stub failure")
if "LIMIT 1" in upper:
return _FakeCursor([self._daily_series[0]])
limit = params[-1] if params else len(self._daily_series)
return _FakeCursor(self._daily_series[: int(limit)])
if "FROM DAILY_BASIC" in upper:
limit = params[-1] if params else len(self._turn_series)
return _FakeCursor(self._turn_series[: int(limit)])
raise AssertionError(f"Unexpected query: {query}")
def close(self) -> None: # pragma: no cover - compatibility
return None
@pytest.fixture()
def patched_db(monkeypatch):
calls: List[str] = []
failure_flags: Dict[str, bool] = defaultdict(bool)
@contextmanager
def _session(read_only: bool = False): # noqa: D401 - contextmanager stub
conn = _FakeConn(calls, failure_flags)
yield conn
monkeypatch.setattr(data_access, "db_session", _session)
yield calls, failure_flags
def test_fetch_latest_uses_cache(patched_db):
calls, failure = patched_db
broker = DataBroker()
result = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"])
assert result["daily.close"] == pytest.approx(12.3)
first_count = len(calls)
result_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.open", "daily.close"])
assert result_cached == result
assert len(calls) == first_count
failure["daily"] = True
still_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"])
assert still_cached["daily.close"] == pytest.approx(12.3)
assert len(calls) == first_count
def test_fetch_series_cache_and_disable(patched_db):
calls, _ = patched_db
broker = DataBroker(series_cache_size=4)
series = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
assert len(series) == 2
first_count = len(calls)
series_cached = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
assert series_cached == series
assert len(calls) == first_count
broker_no_cache = DataBroker(enable_cache=False)
calls_before = len(calls)
broker_no_cache.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
assert len(calls) > calls_before

View File

@ -0,0 +1,108 @@
"""验证 TuShare 拉数流程与因子计算的集成行为。"""
from __future__ import annotations
import sys
import types
from datetime import date
import pytest
# 某些环境下 pandas 可能存在二进制依赖问题,这里提供最小桩避免导入失败
try: # pragma: no cover - 测试运行环境中若 pandas 可用则直接复用
import pandas as _pd # type: ignore
except Exception: # pragma: no cover - stub fallback
pandas_stub = types.ModuleType("pandas")
class _DummyFrame: # pylint: disable=too-few-public-methods
empty = True
def __init__(self, *args, **kwargs): # noqa: D401
"""轻量占位,避免测试期调用实际逻辑。"""
def to_dict(self, *_args, **_kwargs):
return {}
def reindex(self, *_args, **_kwargs):
return self
def where(self, *_args, **_kwargs):
return self
pandas_stub.DataFrame = _DummyFrame
pandas_stub.Series = _DummyFrame
pandas_stub.concat = lambda *args, **kwargs: _DummyFrame() # type: ignore[arg-type]
pandas_stub.Timestamp = lambda *args, **kwargs: None # type: ignore[assignment]
pandas_stub.to_datetime = lambda value, **kwargs: value # type: ignore[assignment]
pandas_stub.isna = lambda value: False # type: ignore[assignment]
pandas_stub.notna = lambda value: True # type: ignore[assignment]
sys.modules.setdefault("pandas", pandas_stub)
else: # pragma: no cover
sys.modules.setdefault("pandas", _pd)
from app.ingest.tushare import FetchJob, run_ingestion
from app.utils import alerts
@pytest.fixture(autouse=True)
def clear_alerts():
alerts.clear_warnings()
yield
alerts.clear_warnings()
def test_run_ingestion_triggers_factor_range(monkeypatch):
job = FetchJob(
name="daily_job",
start=date(2025, 1, 10),
end=date(2025, 1, 11),
ts_codes=("000001.SZ",),
)
coverage_called = {}
def fake_coverage(*args, **kwargs):
coverage_called["args"] = (args, kwargs)
monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", fake_coverage)
captured: dict = {}
def fake_compute(start, end, **kwargs):
captured["start"] = start
captured["end"] = end
captured["kwargs"] = kwargs
return []
monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute)
run_ingestion(job, include_limits=False)
assert "args" in coverage_called
assert captured["start"] == job.start
assert captured["end"] == job.end
assert captured["kwargs"] == {"ts_codes": job.ts_codes, "skip_existing": False}
def test_run_ingestion_skips_factors_for_non_daily(monkeypatch):
job = FetchJob(
name="weekly_job",
start=date(2025, 1, 10),
end=date(2025, 1, 17),
granularity="weekly",
ts_codes=None,
)
monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", lambda *_, **__: None)
invoked = {"count": 0}
def fake_compute(*args, **kwargs):
invoked["count"] += 1
return []
monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute)
run_ingestion(job)
assert invoked["count"] == 0

View File

@ -0,0 +1,62 @@
"""针对 scripts/run_ingestion_job.py 的 CLI 行为测试。"""
from __future__ import annotations
import pytest
import scripts.run_ingestion_job as cli
@pytest.fixture(autouse=True)
def reset_alerts(monkeypatch):
monkeypatch.setattr(cli.alerts, "clear_warnings", lambda *args, **kwargs: None)
yield
def test_cli_invokes_run_ingestion_with_codes(monkeypatch):
captured: dict = {}
def fake_run(job, include_limits):
captured["job"] = job
captured["include_limits"] = include_limits
monkeypatch.setattr(cli, "run_ingestion", fake_run)
monkeypatch.setattr(cli.alerts, "get_warnings", lambda: [])
exit_code = cli.run_cli(
[
"20250110",
"20250112",
"--codes",
"000001.SZ",
"000002.SZ",
"--include-limits",
"--name",
"test_job",
]
)
assert exit_code == 0
job = captured["job"]
assert job.name == "test_job"
assert job.start.isoformat() == "2025-01-10"
assert job.end.isoformat() == "2025-01-12"
assert job.ts_codes == ("000001.SZ", "000002.SZ")
assert captured["include_limits"] is True
def test_cli_returns_warning_status(monkeypatch):
monkeypatch.setattr(cli, "run_ingestion", lambda *args, **kwargs: None)
monkeypatch.setattr(
cli.alerts,
"get_warnings",
lambda: [{"source": "Factors", "message": "mock warning"}],
)
exit_code = cli.run_cli(["20250101", "20250102"])
assert exit_code == 2
def test_cli_validates_date_order(monkeypatch):
with pytest.raises(SystemExit):
cli.run_cli(["20250105", "20250101"])