llm-quant/scripts/run_factor_pipeline.py

244 lines
7.5 KiB
Python

"""Command-line helper for running the factor computation pipeline."""
from __future__ import annotations
import argparse
from datetime import date, datetime
from typing import Iterable, List, Optional, Sequence
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factor_range,
compute_factors,
compute_factors_incremental,
lookup_factor_spec,
)
from app.features.factor_audit import audit_factors
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
def main() -> None:
args = _build_parser().parse_args()
persist = not args.no_persist
factor_specs = _resolve_factor_specs(args.factors)
ts_codes = _normalize_codes(args.ts_codes)
batch_size = args.batch_size or 100
if args.mode == "single":
if not args.trade_date:
raise SystemExit("--trade-date is required in single mode")
trade_day = _parse_date(args.trade_date)
results = compute_factors(
trade_day,
factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
batch_size=batch_size,
persist=persist,
)
_print_summary_single(trade_day, results, persist)
audit_dates = [trade_day] if args.audit else []
elif args.mode == "range":
if not args.start or not args.end:
raise SystemExit("--start and --end are required in range mode")
start = _parse_date(args.start)
end = _parse_date(args.end)
results = compute_factor_range(
start,
end,
factors=factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
persist=persist,
)
_print_summary_range(start, end, results, persist)
audit_dates = sorted({result.trade_date for result in results}) if args.audit else []
else:
summary = compute_factors_incremental(
factors=factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
max_trading_days=args.max_days,
persist=persist,
)
_print_summary_incremental(summary, persist)
audit_dates = summary.get("trade_dates", []) if args.audit else []
if args.audit and audit_dates:
for audit_date in audit_dates:
summary = audit_factors(
audit_date,
factors=factor_specs,
tolerance=args.audit_tolerance,
max_issues=args.max_audit_issues,
)
_print_audit_summary(summary)
elif args.audit:
LOGGER.info("无可审计的日期,跳过因子审计步骤")
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run factor computation pipeline.")
parser.add_argument(
"--mode",
choices=("single", "range", "incremental"),
default="single",
help="Pipeline mode (default: single).",
)
parser.add_argument("--trade-date", help="Trade date (YYYYMMDD) for single mode.")
parser.add_argument("--start", help="Start date (YYYYMMDD) for range mode.")
parser.add_argument("--end", help="End date (YYYYMMDD) for range mode.")
parser.add_argument(
"--max-days",
type=int,
default=5,
help="Limit of trading days for incremental mode.",
)
parser.add_argument(
"--ts-code",
dest="ts_codes",
action="append",
help="Limit computation to specific ts_code. Can be provided multiple times.",
)
parser.add_argument(
"--factor",
dest="factors",
action="append",
help="Factor name to include. Defaults to the built-in set.",
)
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip securities that already have persisted values for the target date(s).",
)
parser.add_argument(
"--no-persist",
action="store_true",
help="Dry-run mode; compute factors without writing to the database.",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Override default batch size when computing factors.",
)
parser.add_argument(
"--audit",
action="store_true",
help="Run formula audit after computation completes.",
)
parser.add_argument(
"--audit-tolerance",
type=float,
default=1e-6,
help="Allowed absolute difference when auditing factors.",
)
parser.add_argument(
"--max-audit-issues",
type=int,
default=50,
help="Maximum number of detailed audit issues to print.",
)
return parser
def _resolve_factor_specs(names: Optional[Sequence[str]]) -> List[FactorSpec]:
if not names:
return list(DEFAULT_FACTORS)
resolved: List[FactorSpec] = []
seen: set[str] = set()
for name in names:
spec = lookup_factor_spec(name)
if spec is None:
LOGGER.warning("未知因子,忽略: %s", name)
continue
if spec.name in seen:
continue
resolved.append(spec)
seen.add(spec.name)
return resolved or list(DEFAULT_FACTORS)
def _normalize_codes(codes: Optional[Iterable[str]]) -> List[str] | None:
if not codes:
return None
normalized = []
for code in codes:
text = (code or "").strip().upper()
if text:
normalized.append(text)
return normalized or None
def _parse_date(value: str) -> date:
value = value.strip()
for fmt in ("%Y%m%d", "%Y-%m-%d"):
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
raise SystemExit(f"Invalid date: {value}")
def _print_summary_single(trade_day: date, results: Sequence[FactorResult], persist: bool) -> None:
LOGGER.info(
"单日因子计算完成 trade_date=%s rows=%s persist=%s",
trade_day.isoformat(),
len(results),
bool(persist),
)
def _print_summary_range(start: date, end: date, results: Sequence[FactorResult], persist: bool) -> None:
trade_dates = sorted({result.trade_date for result in results})
LOGGER.info(
"区间因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
start.isoformat(),
end.isoformat(),
len(trade_dates),
len(results),
bool(persist),
)
def _print_summary_incremental(summary: dict, persist: bool) -> None:
trade_dates = summary.get("trade_dates") or []
start = trade_dates[0].isoformat() if trade_dates else None
end = trade_dates[-1].isoformat() if trade_dates else None
LOGGER.info(
"增量因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
start,
end,
len(trade_dates),
summary.get("count", 0),
bool(persist),
)
def _print_audit_summary(summary) -> None:
LOGGER.info(
"因子审计 trade_date=%s mismatched=%s evaluated=%s missing_persisted=%s missing_recomputed=%s issues=%s",
summary.trade_date.isoformat(),
summary.mismatched,
summary.evaluated,
summary.missing_persisted,
summary.missing_recomputed,
len(summary.issues),
)
for issue in summary.issues:
LOGGER.warning(
"审计异常 ts_code=%s factor=%s stored=%s recomputed=%s diff=%s",
issue.ts_code,
issue.factor,
issue.stored,
issue.recomputed,
issue.difference,
)
if __name__ == "__main__":
main()