This commit is contained in:
sam 2025-09-30 17:23:18 +08:00
parent 5228ea1c41
commit 30007cc056
6 changed files with 458 additions and 13 deletions

View File

@ -63,6 +63,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
);
""",
"""
CREATE TABLE IF NOT EXISTS factors (
ts_code TEXT,
trade_date TEXT,
mom_20 REAL,
mom_60 REAL,
volat_20 REAL,
turn_20 REAL,
updated_at TEXT,
PRIMARY KEY (ts_code, trade_date)
);
""",
"""
CREATE TABLE IF NOT EXISTS adj_factor (
ts_code TEXT,
trade_date TEXT,
@ -442,6 +454,7 @@ REQUIRED_TABLES = (
"stock_basic",
"daily",
"daily_basic",
"factors",
"adj_factor",
"suspend",
"trade_calendar",

View File

@ -1,9 +1,21 @@
"""Feature engineering for signals and indicator computation."""
from __future__ import annotations
import re
from dataclasses import dataclass
from datetime import date
from typing import Iterable, List
from datetime import datetime, date, timezone
from typing import Dict, Iterable, List, Optional, Sequence
from app.core.indicators import momentum, rolling_mean, volatility
from app.data.schema import initialize_database
from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_compute"}
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
@dataclass
@ -16,7 +28,7 @@ class FactorSpec:
class FactorResult:
ts_code: str
trade_date: date
values: dict
values: Dict[str, float | None]
DEFAULT_FACTORS: List[FactorSpec] = [
@ -27,13 +39,257 @@ DEFAULT_FACTORS: List[FactorSpec] = [
]
def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]:
"""Calculate factor values for the requested date.
def compute_factors(
trade_date: date,
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
*,
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = False,
) -> List[FactorResult]:
"""Calculate and persist factor values for the requested date.
This function should join historical price data, apply rolling windows, and
persist results into an factors table. The implementation is left as future
work.
``ts_codes`` can be supplied to restrict computation to a subset of the
universe. When ``skip_existing`` is True, securities that already have an
entry for ``trade_date`` will be ignored.
"""
_ = trade_date, factors
raise NotImplementedError
specs = [spec for spec in factors if spec.window > 0]
if not specs:
return []
initialize_database()
trade_date_str = trade_date.strftime("%Y%m%d")
_ensure_factor_columns(specs)
allowed = {code.strip().upper() for code in ts_codes or () if code.strip()}
universe = _load_universe(trade_date_str, allowed if allowed else None)
if not universe:
LOGGER.info("无可用标的生成因子 trade_date=%s", trade_date_str, extra=LOG_EXTRA)
return []
if skip_existing:
existing = _existing_factor_codes(trade_date_str)
universe = [code for code in universe if code not in existing]
if not universe:
LOGGER.debug(
"目标交易日因子已存在 trade_date=%s universe_size=%s",
trade_date_str,
len(existing),
extra=LOG_EXTRA,
)
return []
broker = DataBroker()
results: List[FactorResult] = []
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
for ts_code in universe:
values = _compute_security_factors(broker, ts_code, trade_date_str, specs)
if not values:
continue
results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values))
rows_to_persist.append((ts_code, values))
if rows_to_persist:
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
return results
def compute_factor_range(
start: date,
end: date,
*,
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = True,
) -> List[FactorResult]:
"""Compute factors for all trading days within ``[start, end]`` inclusive."""
if end < start:
raise ValueError("end date must not precede start date")
initialize_database()
allowed = None
if ts_codes:
allowed = tuple(dict.fromkeys(code.strip().upper() for code in ts_codes if code.strip()))
if not allowed:
allowed = None
start_str = start.strftime("%Y%m%d")
end_str = end.strftime("%Y%m%d")
trade_dates = _list_trade_dates(start_str, end_str, allowed)
aggregated: List[FactorResult] = []
for trade_date_str in trade_dates:
trade_day = datetime.strptime(trade_date_str, "%Y%m%d").date()
aggregated.extend(
compute_factors(
trade_day,
factors,
ts_codes=allowed,
skip_existing=skip_existing,
)
)
return aggregated
def _load_universe(trade_date: str, allowed: Optional[set[str]] = None) -> List[str]:
query = "SELECT ts_code FROM daily WHERE trade_date = ? ORDER BY ts_code"
with db_session(read_only=True) as conn:
rows = conn.execute(query, (trade_date,)).fetchall()
codes = [row["ts_code"] for row in rows if row["ts_code"]]
if allowed:
allowed_upper = {code.upper() for code in allowed}
return [code for code in codes if code.upper() in allowed_upper]
return codes
def _existing_factor_codes(trade_date: str) -> set[str]:
with db_session(read_only=True) as conn:
rows = conn.execute(
"SELECT ts_code FROM factors WHERE trade_date = ?",
(trade_date,),
).fetchall()
return {row["ts_code"] for row in rows if row["ts_code"]}
def _list_trade_dates(
start_date: str,
end_date: str,
allowed: Optional[Sequence[str]],
) -> List[str]:
params: List[str] = [start_date, end_date]
if allowed:
placeholders = ", ".join("?" for _ in allowed)
query = (
"SELECT DISTINCT trade_date FROM daily "
"WHERE trade_date BETWEEN ? AND ? "
f"AND ts_code IN ({placeholders}) "
"ORDER BY trade_date"
)
params.extend(allowed)
else:
query = (
"SELECT DISTINCT trade_date FROM daily "
"WHERE trade_date BETWEEN ? AND ? "
"ORDER BY trade_date"
)
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
return [row["trade_date"] for row in rows if row["trade_date"]]
def _compute_security_factors(
broker: DataBroker,
ts_code: str,
trade_date: str,
specs: Sequence[FactorSpec],
) -> Dict[str, float | None]:
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
max_close_window = max(close_windows) if close_windows else 0
max_turn_window = max(turnover_windows) if turnover_windows else 0
close_series = _fetch_series_values(
broker,
"daily",
"close",
ts_code,
trade_date,
max_close_window,
)
turnover_series = _fetch_series_values(
broker,
"daily_basic",
"turnover_rate",
ts_code,
trade_date,
max_turn_window,
)
results: Dict[str, float | None] = {}
for spec in specs:
prefix = _factor_prefix(spec.name)
if prefix == "mom":
if len(close_series) >= spec.window:
results[spec.name] = momentum(close_series, spec.window)
else:
results[spec.name] = None
elif prefix == "volat":
if len(close_series) >= 2:
results[spec.name] = volatility(close_series, spec.window)
else:
results[spec.name] = None
elif prefix == "turn":
if len(turnover_series) >= spec.window:
results[spec.name] = rolling_mean(turnover_series, spec.window)
else:
results[spec.name] = None
else:
LOGGER.debug(
"忽略未识别的因子 name=%s ts_code=%s",
spec.name,
ts_code,
extra=LOG_EXTRA,
)
return results
def _persist_factor_rows(
trade_date: str,
rows: Sequence[tuple[str, Dict[str, float | None]]],
specs: Sequence[FactorSpec],
) -> None:
columns = sorted({spec.name for spec in specs})
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
insert_columns = ["ts_code", "trade_date", "updated_at", *columns]
placeholders = ", ".join(["?"] * len(insert_columns))
update_clause = ", ".join(
f"{column}=excluded.{column}" for column in ["updated_at", *columns]
)
sql = (
f"INSERT INTO factors ({', '.join(insert_columns)}) "
f"VALUES ({placeholders}) "
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
)
with db_session() as conn:
for ts_code, values in rows:
payload = [ts_code, trade_date, timestamp]
payload.extend(values.get(column) for column in columns)
conn.execute(sql, payload)
def _ensure_factor_columns(specs: Sequence[FactorSpec]) -> None:
pending = {spec.name for spec in specs if _IDENTIFIER_RE.match(spec.name)}
if not pending:
return
with db_session() as conn:
existing_rows = conn.execute("PRAGMA table_info(factors)").fetchall()
existing = {row["name"] for row in existing_rows}
for column in sorted(pending - existing):
conn.execute(f"ALTER TABLE factors ADD COLUMN {column} REAL")
def _fetch_series_values(
broker: DataBroker,
table: str,
column: str,
ts_code: str,
trade_date: str,
window: int,
) -> List[float]:
if window <= 0:
return []
series = broker.fetch_series(table, column, ts_code, trade_date, window)
values: List[float] = []
for _dt, raw in series:
try:
values.append(float(raw))
except (TypeError, ValueError):
continue
return values
def _factor_prefix(name: str) -> str:
return name.split("_", 1)[0] if name else ""

View File

@ -1,3 +1,7 @@
# 记住,我们在开发可实战的投资助理工具,其业务水平要处在投资的前列。不要单纯只实现些简单的功能
# 项目待办清单
> 用于跟踪现阶段尚未完成或需要后续完善的工作,便于规划优先级。

9
tests/conftest.py Normal file
View File

@ -0,0 +1,9 @@
"""Pytest configuration shared across test modules."""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))

162
tests/test_factors.py Normal file
View File

@ -0,0 +1,162 @@
"""Tests for factor computation pipeline."""
from __future__ import annotations
from datetime import date, timedelta
import pytest
from app.core.indicators import momentum, rolling_mean, volatility
from app.data.schema import initialize_database
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factor_range,
compute_factors,
)
from app.utils.config import DataPaths, get_config
from app.utils.data_access import DataBroker
from app.utils.db import db_session
@pytest.fixture()
def isolated_db(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
try:
yield
finally:
cfg.data_paths = original_paths
def _populate_sample_data(ts_code: str, as_of: date) -> None:
initialize_database()
with db_session() as conn:
for offset in range(60):
current_day = as_of - timedelta(days=offset)
trade_date = current_day.strftime("%Y%m%d")
close = 100 + (59 - offset)
turnover = 5 + 0.1 * (59 - offset)
conn.execute(
"""
INSERT OR REPLACE INTO daily
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
close,
close,
close,
close,
0.0,
1000.0,
1_000_000.0,
),
)
conn.execute(
"""
INSERT OR REPLACE INTO daily_basic
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio)
VALUES (?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
turnover,
turnover,
1.0,
),
)
def test_compute_factors_persists_and_updates(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 1, 30)
_populate_sample_data(ts_code, trade_day)
specs = [*DEFAULT_FACTORS, FactorSpec("mom_5", 5)]
results = compute_factors(trade_day, specs)
assert results
result_map = {result.ts_code: result for result in results}
assert ts_code in result_map
result: FactorResult = result_map[ts_code]
close_series = [100 + (59 - offset) for offset in range(60)]
turnover_series = [5 + 0.1 * (59 - offset) for offset in range(60)]
expected_mom20 = momentum(close_series, 20)
expected_mom60 = momentum(close_series, 60)
expected_mom5 = momentum(close_series, 5)
expected_volat20 = volatility(close_series, 20)
expected_turn20 = rolling_mean(turnover_series, 20)
assert result.values["mom_20"] == pytest.approx(expected_mom20)
assert result.values["mom_60"] == pytest.approx(expected_mom60)
assert result.values["mom_5"] == pytest.approx(expected_mom5)
assert result.values["volat_20"] == pytest.approx(expected_volat20)
assert result.values["turn_20"] == pytest.approx(expected_turn20)
trade_date_str = trade_day.strftime("%Y%m%d")
with db_session(read_only=True) as conn:
row = conn.execute(
"""
SELECT mom_20, mom_60, mom_5, volat_20, turn_20
FROM factors WHERE ts_code = ? AND trade_date = ?
""",
(ts_code, trade_date_str),
).fetchone()
assert row is not None
assert row["mom_20"] == pytest.approx(expected_mom20)
assert row["mom_60"] == pytest.approx(expected_mom60)
assert row["mom_5"] == pytest.approx(expected_mom5)
assert row["volat_20"] == pytest.approx(expected_volat20)
assert row["turn_20"] == pytest.approx(expected_turn20)
broker = DataBroker()
latest = broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_5", "factors.turn_20"])
assert latest["factors.mom_5"] == pytest.approx(expected_mom5)
assert latest["factors.turn_20"] == pytest.approx(expected_turn20)
# Calling compute_factors again should update existing rows without error.
second_results = compute_factors(trade_day, specs)
assert second_results
assert broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_20"])["factors.mom_20"] == pytest.approx(
expected_mom20
)
def test_compute_factors_skip_existing(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 2, 10)
_populate_sample_data(ts_code, trade_day)
compute_factors(trade_day)
skipped = compute_factors(trade_day, skip_existing=True)
assert skipped == []
def test_compute_factor_range_filters_universe(isolated_db):
code_a = "000001.SZ"
code_b = "000002.SZ"
end_day = date(2025, 3, 5)
start_day = end_day - timedelta(days=1)
_populate_sample_data(code_a, end_day)
_populate_sample_data(code_b, end_day)
results = compute_factor_range(start_day, end_day, ts_codes=[code_a])
assert results
assert {result.ts_code for result in results} == {code_a}
with db_session(read_only=True) as conn:
rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall()
assert {row["ts_code"] for row in rows} == {code_a}
repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a])
assert repeated == []

View File

@ -28,16 +28,17 @@ def isolated_db(tmp_path):
def test_fetch_rss_feed_parses_entries(monkeypatch):
published = datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
sample_feed = (
"""
<rss version="2.0">
f"""
<rss version=\"2.0\">
<channel>
<title>Example</title>
<item>
<title>新闻公司利好公告</title>
<link>https://example.com/a</link>
<description><![CDATA[内容包含 000001.SZ ]]></description>
<pubDate>Wed, 01 Jan 2025 08:30:00 GMT</pubDate>
<pubDate>{published}</pubDate>
<guid>a</guid>
</item>
</channel>