From 30007cc056a26e949677789512dec985df13e1ac Mon Sep 17 00:00:00 2001 From: sam Date: Tue, 30 Sep 2025 17:23:18 +0800 Subject: [PATCH] update --- app/data/schema.py | 13 ++ app/features/factors.py | 276 +++++++++++++++++++++++++++++++++++++-- docs/TODO.md | 4 + tests/conftest.py | 9 ++ tests/test_factors.py | 162 +++++++++++++++++++++++ tests/test_rss_ingest.py | 7 +- 6 files changed, 458 insertions(+), 13 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_factors.py diff --git a/app/data/schema.py b/app/data/schema.py index 56ffcba..282ab2b 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -63,6 +63,18 @@ SCHEMA_STATEMENTS: Iterable[str] = ( ); """, """ + CREATE TABLE IF NOT EXISTS factors ( + ts_code TEXT, + trade_date TEXT, + mom_20 REAL, + mom_60 REAL, + volat_20 REAL, + turn_20 REAL, + updated_at TEXT, + PRIMARY KEY (ts_code, trade_date) + ); + """, + """ CREATE TABLE IF NOT EXISTS adj_factor ( ts_code TEXT, trade_date TEXT, @@ -442,6 +454,7 @@ REQUIRED_TABLES = ( "stock_basic", "daily", "daily_basic", + "factors", "adj_factor", "suspend", "trade_calendar", diff --git a/app/features/factors.py b/app/features/factors.py index 485884f..681cd14 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -1,9 +1,21 @@ """Feature engineering for signals and indicator computation.""" from __future__ import annotations +import re from dataclasses import dataclass -from datetime import date -from typing import Iterable, List +from datetime import datetime, date, timezone +from typing import Dict, Iterable, List, Optional, Sequence + +from app.core.indicators import momentum, rolling_mean, volatility +from app.data.schema import initialize_database +from app.utils.data_access import DataBroker +from app.utils.db import db_session +from app.utils.logging import get_logger + + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "factor_compute"} +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") @dataclass @@ -16,7 +28,7 @@ class FactorSpec: class FactorResult: ts_code: str trade_date: date - values: dict + values: Dict[str, float | None] DEFAULT_FACTORS: List[FactorSpec] = [ @@ -27,13 +39,257 @@ DEFAULT_FACTORS: List[FactorSpec] = [ ] -def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]: - """Calculate factor values for the requested date. +def compute_factors( + trade_date: date, + factors: Iterable[FactorSpec] = DEFAULT_FACTORS, + *, + ts_codes: Optional[Sequence[str]] = None, + skip_existing: bool = False, +) -> List[FactorResult]: + """Calculate and persist factor values for the requested date. - This function should join historical price data, apply rolling windows, and - persist results into an factors table. The implementation is left as future - work. + ``ts_codes`` can be supplied to restrict computation to a subset of the + universe. When ``skip_existing`` is True, securities that already have an + entry for ``trade_date`` will be ignored. """ - _ = trade_date, factors - raise NotImplementedError + specs = [spec for spec in factors if spec.window > 0] + if not specs: + return [] + + initialize_database() + trade_date_str = trade_date.strftime("%Y%m%d") + + _ensure_factor_columns(specs) + + allowed = {code.strip().upper() for code in ts_codes or () if code.strip()} + universe = _load_universe(trade_date_str, allowed if allowed else None) + if not universe: + LOGGER.info("无可用标的生成因子 trade_date=%s", trade_date_str, extra=LOG_EXTRA) + return [] + + if skip_existing: + existing = _existing_factor_codes(trade_date_str) + universe = [code for code in universe if code not in existing] + if not universe: + LOGGER.debug( + "目标交易日因子已存在 trade_date=%s universe_size=%s", + trade_date_str, + len(existing), + extra=LOG_EXTRA, + ) + return [] + + broker = DataBroker() + results: List[FactorResult] = [] + rows_to_persist: List[tuple[str, Dict[str, float | None]]] = [] + for ts_code in universe: + values = _compute_security_factors(broker, ts_code, trade_date_str, specs) + if not values: + continue + results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values)) + rows_to_persist.append((ts_code, values)) + + if rows_to_persist: + _persist_factor_rows(trade_date_str, rows_to_persist, specs) + return results + + +def compute_factor_range( + start: date, + end: date, + *, + factors: Iterable[FactorSpec] = DEFAULT_FACTORS, + ts_codes: Optional[Sequence[str]] = None, + skip_existing: bool = True, +) -> List[FactorResult]: + """Compute factors for all trading days within ``[start, end]`` inclusive.""" + + if end < start: + raise ValueError("end date must not precede start date") + + initialize_database() + allowed = None + if ts_codes: + allowed = tuple(dict.fromkeys(code.strip().upper() for code in ts_codes if code.strip())) + if not allowed: + allowed = None + + start_str = start.strftime("%Y%m%d") + end_str = end.strftime("%Y%m%d") + trade_dates = _list_trade_dates(start_str, end_str, allowed) + + aggregated: List[FactorResult] = [] + for trade_date_str in trade_dates: + trade_day = datetime.strptime(trade_date_str, "%Y%m%d").date() + aggregated.extend( + compute_factors( + trade_day, + factors, + ts_codes=allowed, + skip_existing=skip_existing, + ) + ) + return aggregated + + +def _load_universe(trade_date: str, allowed: Optional[set[str]] = None) -> List[str]: + query = "SELECT ts_code FROM daily WHERE trade_date = ? ORDER BY ts_code" + with db_session(read_only=True) as conn: + rows = conn.execute(query, (trade_date,)).fetchall() + codes = [row["ts_code"] for row in rows if row["ts_code"]] + if allowed: + allowed_upper = {code.upper() for code in allowed} + return [code for code in codes if code.upper() in allowed_upper] + return codes + + +def _existing_factor_codes(trade_date: str) -> set[str]: + with db_session(read_only=True) as conn: + rows = conn.execute( + "SELECT ts_code FROM factors WHERE trade_date = ?", + (trade_date,), + ).fetchall() + return {row["ts_code"] for row in rows if row["ts_code"]} + + +def _list_trade_dates( + start_date: str, + end_date: str, + allowed: Optional[Sequence[str]], +) -> List[str]: + params: List[str] = [start_date, end_date] + if allowed: + placeholders = ", ".join("?" for _ in allowed) + query = ( + "SELECT DISTINCT trade_date FROM daily " + "WHERE trade_date BETWEEN ? AND ? " + f"AND ts_code IN ({placeholders}) " + "ORDER BY trade_date" + ) + params.extend(allowed) + else: + query = ( + "SELECT DISTINCT trade_date FROM daily " + "WHERE trade_date BETWEEN ? AND ? " + "ORDER BY trade_date" + ) + with db_session(read_only=True) as conn: + rows = conn.execute(query, params).fetchall() + return [row["trade_date"] for row in rows if row["trade_date"]] + + +def _compute_security_factors( + broker: DataBroker, + ts_code: str, + trade_date: str, + specs: Sequence[FactorSpec], +) -> Dict[str, float | None]: + close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}] + turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"] + max_close_window = max(close_windows) if close_windows else 0 + max_turn_window = max(turnover_windows) if turnover_windows else 0 + + close_series = _fetch_series_values( + broker, + "daily", + "close", + ts_code, + trade_date, + max_close_window, + ) + turnover_series = _fetch_series_values( + broker, + "daily_basic", + "turnover_rate", + ts_code, + trade_date, + max_turn_window, + ) + + results: Dict[str, float | None] = {} + for spec in specs: + prefix = _factor_prefix(spec.name) + if prefix == "mom": + if len(close_series) >= spec.window: + results[spec.name] = momentum(close_series, spec.window) + else: + results[spec.name] = None + elif prefix == "volat": + if len(close_series) >= 2: + results[spec.name] = volatility(close_series, spec.window) + else: + results[spec.name] = None + elif prefix == "turn": + if len(turnover_series) >= spec.window: + results[spec.name] = rolling_mean(turnover_series, spec.window) + else: + results[spec.name] = None + else: + LOGGER.debug( + "忽略未识别的因子 name=%s ts_code=%s", + spec.name, + ts_code, + extra=LOG_EXTRA, + ) + return results + + +def _persist_factor_rows( + trade_date: str, + rows: Sequence[tuple[str, Dict[str, float | None]]], + specs: Sequence[FactorSpec], +) -> None: + columns = sorted({spec.name for spec in specs}) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + insert_columns = ["ts_code", "trade_date", "updated_at", *columns] + placeholders = ", ".join(["?"] * len(insert_columns)) + update_clause = ", ".join( + f"{column}=excluded.{column}" for column in ["updated_at", *columns] + ) + sql = ( + f"INSERT INTO factors ({', '.join(insert_columns)}) " + f"VALUES ({placeholders}) " + f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}" + ) + + with db_session() as conn: + for ts_code, values in rows: + payload = [ts_code, trade_date, timestamp] + payload.extend(values.get(column) for column in columns) + conn.execute(sql, payload) + + +def _ensure_factor_columns(specs: Sequence[FactorSpec]) -> None: + pending = {spec.name for spec in specs if _IDENTIFIER_RE.match(spec.name)} + if not pending: + return + with db_session() as conn: + existing_rows = conn.execute("PRAGMA table_info(factors)").fetchall() + existing = {row["name"] for row in existing_rows} + for column in sorted(pending - existing): + conn.execute(f"ALTER TABLE factors ADD COLUMN {column} REAL") + + +def _fetch_series_values( + broker: DataBroker, + table: str, + column: str, + ts_code: str, + trade_date: str, + window: int, +) -> List[float]: + if window <= 0: + return [] + series = broker.fetch_series(table, column, ts_code, trade_date, window) + values: List[float] = [] + for _dt, raw in series: + try: + values.append(float(raw)) + except (TypeError, ValueError): + continue + return values + + +def _factor_prefix(name: str) -> str: + return name.split("_", 1)[0] if name else "" diff --git a/docs/TODO.md b/docs/TODO.md index 6b81d02..241e874 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -1,3 +1,7 @@ +# 记住,我们在开发可实战的投资助理工具,其业务水平要处在投资的前列。不要单纯只实现些简单的功能 + + + # 项目待办清单 > 用于跟踪现阶段尚未完成或需要后续完善的工作,便于规划优先级。 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6d615d3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +"""Pytest configuration shared across test modules.""" +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) diff --git a/tests/test_factors.py b/tests/test_factors.py new file mode 100644 index 0000000..2c3897c --- /dev/null +++ b/tests/test_factors.py @@ -0,0 +1,162 @@ +"""Tests for factor computation pipeline.""" +from __future__ import annotations + +from datetime import date, timedelta + +import pytest + +from app.core.indicators import momentum, rolling_mean, volatility +from app.data.schema import initialize_database +from app.features.factors import ( + DEFAULT_FACTORS, + FactorResult, + FactorSpec, + compute_factor_range, + compute_factors, +) +from app.utils.config import DataPaths, get_config +from app.utils.data_access import DataBroker +from app.utils.db import db_session + + +@pytest.fixture() +def isolated_db(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + tmp_root = tmp_path / "data" + tmp_root.mkdir(parents=True, exist_ok=True) + cfg.data_paths = DataPaths(root=tmp_root) + try: + yield + finally: + cfg.data_paths = original_paths + + +def _populate_sample_data(ts_code: str, as_of: date) -> None: + initialize_database() + with db_session() as conn: + for offset in range(60): + current_day = as_of - timedelta(days=offset) + trade_date = current_day.strftime("%Y%m%d") + close = 100 + (59 - offset) + turnover = 5 + 0.1 * (59 - offset) + conn.execute( + """ + INSERT OR REPLACE INTO daily + (ts_code, trade_date, open, high, low, close, pct_chg, vol, amount) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + ts_code, + trade_date, + close, + close, + close, + close, + 0.0, + 1000.0, + 1_000_000.0, + ), + ) + conn.execute( + """ + INSERT OR REPLACE INTO daily_basic + (ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio) + VALUES (?, ?, ?, ?, ?) + """, + ( + ts_code, + trade_date, + turnover, + turnover, + 1.0, + ), + ) + + +def test_compute_factors_persists_and_updates(isolated_db): + ts_code = "000001.SZ" + trade_day = date(2025, 1, 30) + _populate_sample_data(ts_code, trade_day) + + specs = [*DEFAULT_FACTORS, FactorSpec("mom_5", 5)] + results = compute_factors(trade_day, specs) + + assert results + result_map = {result.ts_code: result for result in results} + assert ts_code in result_map + result: FactorResult = result_map[ts_code] + + close_series = [100 + (59 - offset) for offset in range(60)] + turnover_series = [5 + 0.1 * (59 - offset) for offset in range(60)] + + expected_mom20 = momentum(close_series, 20) + expected_mom60 = momentum(close_series, 60) + expected_mom5 = momentum(close_series, 5) + expected_volat20 = volatility(close_series, 20) + expected_turn20 = rolling_mean(turnover_series, 20) + + assert result.values["mom_20"] == pytest.approx(expected_mom20) + assert result.values["mom_60"] == pytest.approx(expected_mom60) + assert result.values["mom_5"] == pytest.approx(expected_mom5) + assert result.values["volat_20"] == pytest.approx(expected_volat20) + assert result.values["turn_20"] == pytest.approx(expected_turn20) + + trade_date_str = trade_day.strftime("%Y%m%d") + with db_session(read_only=True) as conn: + row = conn.execute( + """ + SELECT mom_20, mom_60, mom_5, volat_20, turn_20 + FROM factors WHERE ts_code = ? AND trade_date = ? + """, + (ts_code, trade_date_str), + ).fetchone() + assert row is not None + assert row["mom_20"] == pytest.approx(expected_mom20) + assert row["mom_60"] == pytest.approx(expected_mom60) + assert row["mom_5"] == pytest.approx(expected_mom5) + assert row["volat_20"] == pytest.approx(expected_volat20) + assert row["turn_20"] == pytest.approx(expected_turn20) + + broker = DataBroker() + latest = broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_5", "factors.turn_20"]) + assert latest["factors.mom_5"] == pytest.approx(expected_mom5) + assert latest["factors.turn_20"] == pytest.approx(expected_turn20) + + # Calling compute_factors again should update existing rows without error. + second_results = compute_factors(trade_day, specs) + assert second_results + assert broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_20"])["factors.mom_20"] == pytest.approx( + expected_mom20 + ) + + +def test_compute_factors_skip_existing(isolated_db): + ts_code = "000001.SZ" + trade_day = date(2025, 2, 10) + _populate_sample_data(ts_code, trade_day) + + compute_factors(trade_day) + skipped = compute_factors(trade_day, skip_existing=True) + assert skipped == [] + + +def test_compute_factor_range_filters_universe(isolated_db): + code_a = "000001.SZ" + code_b = "000002.SZ" + end_day = date(2025, 3, 5) + start_day = end_day - timedelta(days=1) + + _populate_sample_data(code_a, end_day) + _populate_sample_data(code_b, end_day) + + results = compute_factor_range(start_day, end_day, ts_codes=[code_a]) + assert results + assert {result.ts_code for result in results} == {code_a} + + with db_session(read_only=True) as conn: + rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall() + assert {row["ts_code"] for row in rows} == {code_a} + + repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a]) + assert repeated == [] diff --git a/tests/test_rss_ingest.py b/tests/test_rss_ingest.py index e7dab9e..686a197 100644 --- a/tests/test_rss_ingest.py +++ b/tests/test_rss_ingest.py @@ -28,16 +28,17 @@ def isolated_db(tmp_path): def test_fetch_rss_feed_parses_entries(monkeypatch): + published = datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT") sample_feed = ( - """ - + f""" + Example 新闻:公司利好公告 https://example.com/a - Wed, 01 Jan 2025 08:30:00 GMT + {published} a