update
This commit is contained in:
parent
8f820e441e
commit
8befd80cb7
117
scripts/run_ingestion_job.py
Normal file
117
scripts/run_ingestion_job.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""命令行脚本:按日期区间执行 TuShare 拉数并同步计算因子。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.utils.config import get_config
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils import alerts
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_date(text: str) -> date:
|
||||
try:
|
||||
return datetime.strptime(text, "%Y%m%d").date()
|
||||
except ValueError as exc: # noqa: BLE001
|
||||
raise argparse.ArgumentTypeError(f"无法解析日期:{text}") from exc
|
||||
|
||||
|
||||
def _parse_codes(raw: Sequence[str] | None) -> tuple[str, ...] | None:
|
||||
if not raw:
|
||||
return None
|
||||
normalized = []
|
||||
for item in raw:
|
||||
token = item.strip().upper()
|
||||
if token:
|
||||
normalized.append(token)
|
||||
return tuple(dict.fromkeys(normalized)) or None
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="按日期区间执行 TuShare 拉数并同步更新因子表",
|
||||
)
|
||||
parser.add_argument("start", type=_parse_date, help="起始交易日(格式:YYYYMMDD)")
|
||||
parser.add_argument("end", type=_parse_date, help="结束交易日(格式:YYYYMMDD)")
|
||||
parser.add_argument(
|
||||
"--codes",
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="可选的股票代码列表(如 000001.SZ),不传则处理全市场",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-limits",
|
||||
action="store_true",
|
||||
help="是否同步涨跌停/停牌等扩展数据(默认关闭,便于快速试跑)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="daily_ingestion",
|
||||
help="任务名称,用于日志与告警标记",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--granularity",
|
||||
default="daily",
|
||||
choices=("daily", "weekly"),
|
||||
help="任务粒度,目前仅 daily 会触发因子计算",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_cli(argv: Iterable[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(list(argv) if argv is not None else None)
|
||||
|
||||
if args.end < args.start:
|
||||
parser.error("结束日期不能早于起始日期")
|
||||
|
||||
codes = _parse_codes(args.codes)
|
||||
job = FetchJob(
|
||||
name=str(args.name),
|
||||
start=args.start,
|
||||
end=args.end,
|
||||
granularity=str(args.granularity),
|
||||
ts_codes=codes,
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
"准备执行拉数任务 name=%s start=%s end=%s codes=%s granularity=%s",
|
||||
job.name,
|
||||
job.start,
|
||||
job.end,
|
||||
job.ts_codes,
|
||||
job.granularity,
|
||||
)
|
||||
|
||||
try:
|
||||
run_ingestion(job, include_limits=bool(args.include_limits))
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("拉数任务执行失败")
|
||||
return 1
|
||||
|
||||
warnings = alerts.get_warnings()
|
||||
if warnings:
|
||||
LOGGER.warning("任务完成但存在告警:%s", warnings)
|
||||
return 2
|
||||
|
||||
LOGGER.info("任务执行完成,无告警")
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> None:
|
||||
exit_code = run_cli()
|
||||
raise SystemExit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
tests/test_backtest_engine_risk.py
Normal file
132
tests/test_backtest_engine_risk.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Tests for BacktestEngine risk-aware execution."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
|
||||
|
||||
|
||||
def _make_context(price: float, features: dict | None = None) -> AgentContext:
|
||||
scope_values = {"daily.close": price}
|
||||
return AgentContext(
|
||||
ts_code="000001.SZ",
|
||||
trade_date="2025-01-10",
|
||||
features=features or {},
|
||||
market_snapshot={},
|
||||
raw={"scope_values": scope_values},
|
||||
)
|
||||
|
||||
|
||||
def _make_decision(action: AgentAction, target_weight: float = 0.0) -> Decision:
|
||||
return Decision(
|
||||
action=action,
|
||||
confidence=0.8,
|
||||
target_weight=target_weight,
|
||||
feasible_actions=[],
|
||||
utilities={},
|
||||
)
|
||||
|
||||
|
||||
def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
|
||||
cfg = BtConfig(
|
||||
id="test",
|
||||
name="test",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params=params,
|
||||
)
|
||||
return BacktestEngine(cfg)
|
||||
|
||||
|
||||
def test_buy_respects_risk_caps():
|
||||
engine = _engine_with_params(
|
||||
{
|
||||
"max_position_weight": 0.2,
|
||||
"fee_rate": 0.0,
|
||||
"slippage_bps": 0.0,
|
||||
"max_daily_turnover_ratio": 1.0,
|
||||
}
|
||||
)
|
||||
state = PortfolioState(cash=100_000.0)
|
||||
result = BacktestResult()
|
||||
features = {
|
||||
"liquidity_score": 0.7,
|
||||
"risk_penalty": 0.25,
|
||||
}
|
||||
context = _make_context(100.0, features)
|
||||
decision = _make_decision(AgentAction.BUY_L, target_weight=0.5)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
expected_qty = (100_000.0 * 0.2 * (1 - 0.25)) / 100.0
|
||||
assert state.holdings["000001.SZ"] == pytest.approx(expected_qty)
|
||||
assert state.cash == pytest.approx(100_000.0 - expected_qty * 100.0)
|
||||
assert result.trades and result.trades[0]["status"] == "executed"
|
||||
assert result.nav_series[0]["turnover"] == pytest.approx(expected_qty * 100.0)
|
||||
|
||||
|
||||
def test_buy_blocked_by_limit_up_records_risk():
|
||||
engine = _engine_with_params({})
|
||||
state = PortfolioState(cash=50_000.0)
|
||||
result = BacktestResult()
|
||||
features = {"limit_up": True}
|
||||
context = _make_context(100.0, features)
|
||||
decision = _make_decision(AgentAction.BUY_M, target_weight=0.1)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
assert "000001.SZ" not in state.holdings
|
||||
assert not result.trades
|
||||
assert result.risk_events
|
||||
assert result.risk_events[0]["reason"] == "limit_up"
|
||||
|
||||
|
||||
def test_sell_applies_slippage_and_fee():
|
||||
engine = _engine_with_params(
|
||||
{
|
||||
"max_position_weight": 0.3,
|
||||
"fee_rate": 0.001,
|
||||
"slippage_bps": 20.0,
|
||||
"max_daily_turnover_ratio": 1.0,
|
||||
}
|
||||
)
|
||||
state = PortfolioState(
|
||||
cash=0.0,
|
||||
holdings={"000001.SZ": 100.0},
|
||||
cost_basis={"000001.SZ": 90.0},
|
||||
opened_dates={"000001.SZ": "2024-12-01"},
|
||||
)
|
||||
result = BacktestResult()
|
||||
context = _make_context(100.0, {})
|
||||
decision = _make_decision(AgentAction.SELL)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
trade = result.trades[0]
|
||||
assert pytest.approx(trade["price"], rel=1e-6) == 100.0 * (1 - 0.002)
|
||||
assert pytest.approx(trade["fee"], rel=1e-6) == trade["value"] * 0.001
|
||||
assert state.cash == pytest.approx(trade["value"] - trade["fee"])
|
||||
assert state.realized_pnl == pytest.approx((trade["price"] - 90.0) * 100 - trade["fee"])
|
||||
assert not state.holdings
|
||||
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
|
||||
assert not result.risk_events
|
||||
117
tests/test_data_broker_cache.py
Normal file
117
tests/test_data_broker_cache.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""验证 DataBroker 的缓存与回退行为。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils import data_access
|
||||
from app.utils.data_access import DataBroker
|
||||
|
||||
|
||||
class _FakeRow(dict):
|
||||
def __getitem__(self, key: str) -> Any: # type: ignore[override]
|
||||
return dict.get(self, key)
|
||||
|
||||
|
||||
class _FakeCursor:
|
||||
def __init__(self, rows: Iterable[Dict[str, Any]]):
|
||||
self._rows = list(rows)
|
||||
|
||||
def fetchone(self) -> Dict[str, Any] | None:
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
def fetchall(self) -> List[Dict[str, Any]]:
|
||||
return list(self._rows)
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
def __init__(self, calls: List[str], failure_flags: Dict[str, bool]) -> None:
|
||||
self._calls = calls
|
||||
self._failure_flags = failure_flags
|
||||
self._daily_series = [
|
||||
_FakeRow({"trade_date": "20250112", "close": 12.3, "open": 12.1}),
|
||||
_FakeRow({"trade_date": "20250111", "close": 11.9, "open": 11.6}),
|
||||
]
|
||||
self._turn_series = [
|
||||
_FakeRow({"trade_date": "20250112", "turnover_rate": 2.5}),
|
||||
_FakeRow({"trade_date": "20250111", "turnover_rate": 2.2}),
|
||||
]
|
||||
|
||||
def execute(self, query: str, params: Tuple[Any, ...] | None = None):
|
||||
self._calls.append(query)
|
||||
params = params or ()
|
||||
upper = query.upper()
|
||||
if upper.startswith("PRAGMA TABLE_INFO"):
|
||||
if "DAILY_BASIC" in upper:
|
||||
rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "turnover_rate"})]
|
||||
else:
|
||||
rows = [_FakeRow({"name": "trade_date"}), _FakeRow({"name": "close"}), _FakeRow({"name": "open"})]
|
||||
return _FakeCursor(rows)
|
||||
if "FROM DAILY " in upper:
|
||||
if self._failure_flags.get("daily"):
|
||||
raise sqlite3.OperationalError("stub failure")
|
||||
if "LIMIT 1" in upper:
|
||||
return _FakeCursor([self._daily_series[0]])
|
||||
limit = params[-1] if params else len(self._daily_series)
|
||||
return _FakeCursor(self._daily_series[: int(limit)])
|
||||
if "FROM DAILY_BASIC" in upper:
|
||||
limit = params[-1] if params else len(self._turn_series)
|
||||
return _FakeCursor(self._turn_series[: int(limit)])
|
||||
raise AssertionError(f"Unexpected query: {query}")
|
||||
|
||||
def close(self) -> None: # pragma: no cover - compatibility
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_db(monkeypatch):
|
||||
calls: List[str] = []
|
||||
failure_flags: Dict[str, bool] = defaultdict(bool)
|
||||
|
||||
@contextmanager
|
||||
def _session(read_only: bool = False): # noqa: D401 - contextmanager stub
|
||||
conn = _FakeConn(calls, failure_flags)
|
||||
yield conn
|
||||
|
||||
monkeypatch.setattr(data_access, "db_session", _session)
|
||||
yield calls, failure_flags
|
||||
|
||||
|
||||
def test_fetch_latest_uses_cache(patched_db):
|
||||
calls, failure = patched_db
|
||||
broker = DataBroker()
|
||||
|
||||
result = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"])
|
||||
assert result["daily.close"] == pytest.approx(12.3)
|
||||
first_count = len(calls)
|
||||
|
||||
result_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.open", "daily.close"])
|
||||
assert result_cached == result
|
||||
assert len(calls) == first_count
|
||||
|
||||
failure["daily"] = True
|
||||
still_cached = broker.fetch_latest("000001.SZ", "20250112", ["daily.close", "daily.open"])
|
||||
assert still_cached["daily.close"] == pytest.approx(12.3)
|
||||
assert len(calls) == first_count
|
||||
|
||||
|
||||
def test_fetch_series_cache_and_disable(patched_db):
|
||||
calls, _ = patched_db
|
||||
broker = DataBroker(series_cache_size=4)
|
||||
|
||||
series = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
|
||||
assert len(series) == 2
|
||||
first_count = len(calls)
|
||||
|
||||
series_cached = broker.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
|
||||
assert series_cached == series
|
||||
assert len(calls) == first_count
|
||||
|
||||
broker_no_cache = DataBroker(enable_cache=False)
|
||||
calls_before = len(calls)
|
||||
broker_no_cache.fetch_series("daily", "close", "000001.SZ", "20250112", 2)
|
||||
assert len(calls) > calls_before
|
||||
108
tests/test_ingest_tushare.py
Normal file
108
tests/test_ingest_tushare.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""验证 TuShare 拉数流程与因子计算的集成行为。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
# 某些环境下 pandas 可能存在二进制依赖问题,这里提供最小桩避免导入失败
|
||||
try: # pragma: no cover - 测试运行环境中若 pandas 可用则直接复用
|
||||
import pandas as _pd # type: ignore
|
||||
except Exception: # pragma: no cover - stub fallback
|
||||
pandas_stub = types.ModuleType("pandas")
|
||||
|
||||
class _DummyFrame: # pylint: disable=too-few-public-methods
|
||||
empty = True
|
||||
|
||||
def __init__(self, *args, **kwargs): # noqa: D401
|
||||
"""轻量占位,避免测试期调用实际逻辑。"""
|
||||
|
||||
def to_dict(self, *_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
def reindex(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
pandas_stub.DataFrame = _DummyFrame
|
||||
pandas_stub.Series = _DummyFrame
|
||||
pandas_stub.concat = lambda *args, **kwargs: _DummyFrame() # type: ignore[arg-type]
|
||||
pandas_stub.Timestamp = lambda *args, **kwargs: None # type: ignore[assignment]
|
||||
pandas_stub.to_datetime = lambda value, **kwargs: value # type: ignore[assignment]
|
||||
pandas_stub.isna = lambda value: False # type: ignore[assignment]
|
||||
pandas_stub.notna = lambda value: True # type: ignore[assignment]
|
||||
|
||||
sys.modules.setdefault("pandas", pandas_stub)
|
||||
else: # pragma: no cover
|
||||
sys.modules.setdefault("pandas", _pd)
|
||||
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.utils import alerts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_alerts():
|
||||
alerts.clear_warnings()
|
||||
yield
|
||||
alerts.clear_warnings()
|
||||
|
||||
|
||||
def test_run_ingestion_triggers_factor_range(monkeypatch):
|
||||
job = FetchJob(
|
||||
name="daily_job",
|
||||
start=date(2025, 1, 10),
|
||||
end=date(2025, 1, 11),
|
||||
ts_codes=("000001.SZ",),
|
||||
)
|
||||
|
||||
coverage_called = {}
|
||||
|
||||
def fake_coverage(*args, **kwargs):
|
||||
coverage_called["args"] = (args, kwargs)
|
||||
|
||||
monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", fake_coverage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def fake_compute(start, end, **kwargs):
|
||||
captured["start"] = start
|
||||
captured["end"] = end
|
||||
captured["kwargs"] = kwargs
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute)
|
||||
|
||||
run_ingestion(job, include_limits=False)
|
||||
|
||||
assert "args" in coverage_called
|
||||
assert captured["start"] == job.start
|
||||
assert captured["end"] == job.end
|
||||
assert captured["kwargs"] == {"ts_codes": job.ts_codes, "skip_existing": False}
|
||||
|
||||
|
||||
def test_run_ingestion_skips_factors_for_non_daily(monkeypatch):
|
||||
job = FetchJob(
|
||||
name="weekly_job",
|
||||
start=date(2025, 1, 10),
|
||||
end=date(2025, 1, 17),
|
||||
granularity="weekly",
|
||||
ts_codes=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.ingest.tushare.ensure_data_coverage", lambda *_, **__: None)
|
||||
|
||||
invoked = {"count": 0}
|
||||
|
||||
def fake_compute(*args, **kwargs):
|
||||
invoked["count"] += 1
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("app.ingest.tushare.compute_factor_range", fake_compute)
|
||||
|
||||
run_ingestion(job)
|
||||
|
||||
assert invoked["count"] == 0
|
||||
62
tests/test_run_ingestion_script.py
Normal file
62
tests/test_run_ingestion_script.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""针对 scripts/run_ingestion_job.py 的 CLI 行为测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import scripts.run_ingestion_job as cli
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_alerts(monkeypatch):
|
||||
monkeypatch.setattr(cli.alerts, "clear_warnings", lambda *args, **kwargs: None)
|
||||
yield
|
||||
|
||||
|
||||
def test_cli_invokes_run_ingestion_with_codes(monkeypatch):
|
||||
captured: dict = {}
|
||||
|
||||
def fake_run(job, include_limits):
|
||||
captured["job"] = job
|
||||
captured["include_limits"] = include_limits
|
||||
|
||||
monkeypatch.setattr(cli, "run_ingestion", fake_run)
|
||||
monkeypatch.setattr(cli.alerts, "get_warnings", lambda: [])
|
||||
|
||||
exit_code = cli.run_cli(
|
||||
[
|
||||
"20250110",
|
||||
"20250112",
|
||||
"--codes",
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
"--include-limits",
|
||||
"--name",
|
||||
"test_job",
|
||||
]
|
||||
)
|
||||
|
||||
assert exit_code == 0
|
||||
job = captured["job"]
|
||||
assert job.name == "test_job"
|
||||
assert job.start.isoformat() == "2025-01-10"
|
||||
assert job.end.isoformat() == "2025-01-12"
|
||||
assert job.ts_codes == ("000001.SZ", "000002.SZ")
|
||||
assert captured["include_limits"] is True
|
||||
|
||||
|
||||
def test_cli_returns_warning_status(monkeypatch):
|
||||
monkeypatch.setattr(cli, "run_ingestion", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
cli.alerts,
|
||||
"get_warnings",
|
||||
lambda: [{"source": "Factors", "message": "mock warning"}],
|
||||
)
|
||||
|
||||
exit_code = cli.run_cli(["20250101", "20250102"])
|
||||
|
||||
assert exit_code == 2
|
||||
|
||||
|
||||
def test_cli_validates_date_order(monkeypatch):
|
||||
with pytest.raises(SystemExit):
|
||||
cli.run_cli(["20250105", "20250101"])
|
||||
Loading…
Reference in New Issue
Block a user