update
This commit is contained in:
parent
5228ea1c41
commit
30007cc056
@ -63,6 +63,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
|||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS factors (
|
||||||
|
ts_code TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
mom_20 REAL,
|
||||||
|
mom_60 REAL,
|
||||||
|
volat_20 REAL,
|
||||||
|
turn_20 REAL,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (ts_code, trade_date)
|
||||||
|
);
|
||||||
|
""",
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS adj_factor (
|
CREATE TABLE IF NOT EXISTS adj_factor (
|
||||||
ts_code TEXT,
|
ts_code TEXT,
|
||||||
trade_date TEXT,
|
trade_date TEXT,
|
||||||
@ -442,6 +454,7 @@ REQUIRED_TABLES = (
|
|||||||
"stock_basic",
|
"stock_basic",
|
||||||
"daily",
|
"daily",
|
||||||
"daily_basic",
|
"daily_basic",
|
||||||
|
"factors",
|
||||||
"adj_factor",
|
"adj_factor",
|
||||||
"suspend",
|
"suspend",
|
||||||
"trade_calendar",
|
"trade_calendar",
|
||||||
|
|||||||
@ -1,9 +1,21 @@
|
|||||||
"""Feature engineering for signals and indicator computation."""
|
"""Feature engineering for signals and indicator computation."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date
|
from datetime import datetime, date, timezone
|
||||||
from typing import Iterable, List
|
from typing import Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
|
from app.core.indicators import momentum, rolling_mean, volatility
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
|
from app.utils.db import db_session
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "factor_compute"}
|
||||||
|
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -16,7 +28,7 @@ class FactorSpec:
|
|||||||
class FactorResult:
|
class FactorResult:
|
||||||
ts_code: str
|
ts_code: str
|
||||||
trade_date: date
|
trade_date: date
|
||||||
values: dict
|
values: Dict[str, float | None]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FACTORS: List[FactorSpec] = [
|
DEFAULT_FACTORS: List[FactorSpec] = [
|
||||||
@ -27,13 +39,257 @@ DEFAULT_FACTORS: List[FactorSpec] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def compute_factors(trade_date: date, factors: Iterable[FactorSpec] = DEFAULT_FACTORS) -> List[FactorResult]:
|
def compute_factors(
|
||||||
"""Calculate factor values for the requested date.
|
trade_date: date,
|
||||||
|
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
|
||||||
|
*,
|
||||||
|
ts_codes: Optional[Sequence[str]] = None,
|
||||||
|
skip_existing: bool = False,
|
||||||
|
) -> List[FactorResult]:
|
||||||
|
"""Calculate and persist factor values for the requested date.
|
||||||
|
|
||||||
This function should join historical price data, apply rolling windows, and
|
``ts_codes`` can be supplied to restrict computation to a subset of the
|
||||||
persist results into an factors table. The implementation is left as future
|
universe. When ``skip_existing`` is True, securities that already have an
|
||||||
work.
|
entry for ``trade_date`` will be ignored.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_ = trade_date, factors
|
specs = [spec for spec in factors if spec.window > 0]
|
||||||
raise NotImplementedError
|
if not specs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
initialize_database()
|
||||||
|
trade_date_str = trade_date.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
_ensure_factor_columns(specs)
|
||||||
|
|
||||||
|
allowed = {code.strip().upper() for code in ts_codes or () if code.strip()}
|
||||||
|
universe = _load_universe(trade_date_str, allowed if allowed else None)
|
||||||
|
if not universe:
|
||||||
|
LOGGER.info("无可用标的生成因子 trade_date=%s", trade_date_str, extra=LOG_EXTRA)
|
||||||
|
return []
|
||||||
|
|
||||||
|
if skip_existing:
|
||||||
|
existing = _existing_factor_codes(trade_date_str)
|
||||||
|
universe = [code for code in universe if code not in existing]
|
||||||
|
if not universe:
|
||||||
|
LOGGER.debug(
|
||||||
|
"目标交易日因子已存在 trade_date=%s universe_size=%s",
|
||||||
|
trade_date_str,
|
||||||
|
len(existing),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
broker = DataBroker()
|
||||||
|
results: List[FactorResult] = []
|
||||||
|
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
||||||
|
for ts_code in universe:
|
||||||
|
values = _compute_security_factors(broker, ts_code, trade_date_str, specs)
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values))
|
||||||
|
rows_to_persist.append((ts_code, values))
|
||||||
|
|
||||||
|
if rows_to_persist:
|
||||||
|
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def compute_factor_range(
|
||||||
|
start: date,
|
||||||
|
end: date,
|
||||||
|
*,
|
||||||
|
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
|
||||||
|
ts_codes: Optional[Sequence[str]] = None,
|
||||||
|
skip_existing: bool = True,
|
||||||
|
) -> List[FactorResult]:
|
||||||
|
"""Compute factors for all trading days within ``[start, end]`` inclusive."""
|
||||||
|
|
||||||
|
if end < start:
|
||||||
|
raise ValueError("end date must not precede start date")
|
||||||
|
|
||||||
|
initialize_database()
|
||||||
|
allowed = None
|
||||||
|
if ts_codes:
|
||||||
|
allowed = tuple(dict.fromkeys(code.strip().upper() for code in ts_codes if code.strip()))
|
||||||
|
if not allowed:
|
||||||
|
allowed = None
|
||||||
|
|
||||||
|
start_str = start.strftime("%Y%m%d")
|
||||||
|
end_str = end.strftime("%Y%m%d")
|
||||||
|
trade_dates = _list_trade_dates(start_str, end_str, allowed)
|
||||||
|
|
||||||
|
aggregated: List[FactorResult] = []
|
||||||
|
for trade_date_str in trade_dates:
|
||||||
|
trade_day = datetime.strptime(trade_date_str, "%Y%m%d").date()
|
||||||
|
aggregated.extend(
|
||||||
|
compute_factors(
|
||||||
|
trade_day,
|
||||||
|
factors,
|
||||||
|
ts_codes=allowed,
|
||||||
|
skip_existing=skip_existing,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return aggregated
|
||||||
|
|
||||||
|
|
||||||
|
def _load_universe(trade_date: str, allowed: Optional[set[str]] = None) -> List[str]:
|
||||||
|
query = "SELECT ts_code FROM daily WHERE trade_date = ? ORDER BY ts_code"
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, (trade_date,)).fetchall()
|
||||||
|
codes = [row["ts_code"] for row in rows if row["ts_code"]]
|
||||||
|
if allowed:
|
||||||
|
allowed_upper = {code.upper() for code in allowed}
|
||||||
|
return [code for code in codes if code.upper() in allowed_upper]
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
def _existing_factor_codes(trade_date: str) -> set[str]:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT ts_code FROM factors WHERE trade_date = ?",
|
||||||
|
(trade_date,),
|
||||||
|
).fetchall()
|
||||||
|
return {row["ts_code"] for row in rows if row["ts_code"]}
|
||||||
|
|
||||||
|
|
||||||
|
def _list_trade_dates(
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
allowed: Optional[Sequence[str]],
|
||||||
|
) -> List[str]:
|
||||||
|
params: List[str] = [start_date, end_date]
|
||||||
|
if allowed:
|
||||||
|
placeholders = ", ".join("?" for _ in allowed)
|
||||||
|
query = (
|
||||||
|
"SELECT DISTINCT trade_date FROM daily "
|
||||||
|
"WHERE trade_date BETWEEN ? AND ? "
|
||||||
|
f"AND ts_code IN ({placeholders}) "
|
||||||
|
"ORDER BY trade_date"
|
||||||
|
)
|
||||||
|
params.extend(allowed)
|
||||||
|
else:
|
||||||
|
query = (
|
||||||
|
"SELECT DISTINCT trade_date FROM daily "
|
||||||
|
"WHERE trade_date BETWEEN ? AND ? "
|
||||||
|
"ORDER BY trade_date"
|
||||||
|
)
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
return [row["trade_date"] for row in rows if row["trade_date"]]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_security_factors(
|
||||||
|
broker: DataBroker,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
specs: Sequence[FactorSpec],
|
||||||
|
) -> Dict[str, float | None]:
|
||||||
|
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
|
||||||
|
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
|
||||||
|
max_close_window = max(close_windows) if close_windows else 0
|
||||||
|
max_turn_window = max(turnover_windows) if turnover_windows else 0
|
||||||
|
|
||||||
|
close_series = _fetch_series_values(
|
||||||
|
broker,
|
||||||
|
"daily",
|
||||||
|
"close",
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
max_close_window,
|
||||||
|
)
|
||||||
|
turnover_series = _fetch_series_values(
|
||||||
|
broker,
|
||||||
|
"daily_basic",
|
||||||
|
"turnover_rate",
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
max_turn_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: Dict[str, float | None] = {}
|
||||||
|
for spec in specs:
|
||||||
|
prefix = _factor_prefix(spec.name)
|
||||||
|
if prefix == "mom":
|
||||||
|
if len(close_series) >= spec.window:
|
||||||
|
results[spec.name] = momentum(close_series, spec.window)
|
||||||
|
else:
|
||||||
|
results[spec.name] = None
|
||||||
|
elif prefix == "volat":
|
||||||
|
if len(close_series) >= 2:
|
||||||
|
results[spec.name] = volatility(close_series, spec.window)
|
||||||
|
else:
|
||||||
|
results[spec.name] = None
|
||||||
|
elif prefix == "turn":
|
||||||
|
if len(turnover_series) >= spec.window:
|
||||||
|
results[spec.name] = rolling_mean(turnover_series, spec.window)
|
||||||
|
else:
|
||||||
|
results[spec.name] = None
|
||||||
|
else:
|
||||||
|
LOGGER.debug(
|
||||||
|
"忽略未识别的因子 name=%s ts_code=%s",
|
||||||
|
spec.name,
|
||||||
|
ts_code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _persist_factor_rows(
|
||||||
|
trade_date: str,
|
||||||
|
rows: Sequence[tuple[str, Dict[str, float | None]]],
|
||||||
|
specs: Sequence[FactorSpec],
|
||||||
|
) -> None:
|
||||||
|
columns = sorted({spec.name for spec in specs})
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
insert_columns = ["ts_code", "trade_date", "updated_at", *columns]
|
||||||
|
placeholders = ", ".join(["?"] * len(insert_columns))
|
||||||
|
update_clause = ", ".join(
|
||||||
|
f"{column}=excluded.{column}" for column in ["updated_at", *columns]
|
||||||
|
)
|
||||||
|
sql = (
|
||||||
|
f"INSERT INTO factors ({', '.join(insert_columns)}) "
|
||||||
|
f"VALUES ({placeholders}) "
|
||||||
|
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
for ts_code, values in rows:
|
||||||
|
payload = [ts_code, trade_date, timestamp]
|
||||||
|
payload.extend(values.get(column) for column in columns)
|
||||||
|
conn.execute(sql, payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_factor_columns(specs: Sequence[FactorSpec]) -> None:
|
||||||
|
pending = {spec.name for spec in specs if _IDENTIFIER_RE.match(spec.name)}
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
with db_session() as conn:
|
||||||
|
existing_rows = conn.execute("PRAGMA table_info(factors)").fetchall()
|
||||||
|
existing = {row["name"] for row in existing_rows}
|
||||||
|
for column in sorted(pending - existing):
|
||||||
|
conn.execute(f"ALTER TABLE factors ADD COLUMN {column} REAL")
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_series_values(
|
||||||
|
broker: DataBroker,
|
||||||
|
table: str,
|
||||||
|
column: str,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
window: int,
|
||||||
|
) -> List[float]:
|
||||||
|
if window <= 0:
|
||||||
|
return []
|
||||||
|
series = broker.fetch_series(table, column, ts_code, trade_date, window)
|
||||||
|
values: List[float] = []
|
||||||
|
for _dt, raw in series:
|
||||||
|
try:
|
||||||
|
values.append(float(raw))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _factor_prefix(name: str) -> str:
|
||||||
|
return name.split("_", 1)[0] if name else ""
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
# 记住,我们在开发可实战的投资助理工具,其业务水平要处在投资的前列。不要单纯只实现些简单的功能
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 项目待办清单
|
# 项目待办清单
|
||||||
|
|
||||||
> 用于跟踪现阶段尚未完成或需要后续完善的工作,便于规划优先级。
|
> 用于跟踪现阶段尚未完成或需要后续完善的工作,便于规划优先级。
|
||||||
|
|||||||
9
tests/conftest.py
Normal file
9
tests/conftest.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""Pytest configuration shared across test modules."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(ROOT))
|
||||||
162
tests/test_factors.py
Normal file
162
tests/test_factors.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
"""Tests for factor computation pipeline."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.indicators import momentum, rolling_mean, volatility
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.features.factors import (
|
||||||
|
DEFAULT_FACTORS,
|
||||||
|
FactorResult,
|
||||||
|
FactorSpec,
|
||||||
|
compute_factor_range,
|
||||||
|
compute_factors,
|
||||||
|
)
|
||||||
|
from app.utils.config import DataPaths, get_config
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def isolated_db(tmp_path):
|
||||||
|
cfg = get_config()
|
||||||
|
original_paths = cfg.data_paths
|
||||||
|
tmp_root = tmp_path / "data"
|
||||||
|
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
cfg.data_paths = DataPaths(root=tmp_root)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
cfg.data_paths = original_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_sample_data(ts_code: str, as_of: date) -> None:
|
||||||
|
initialize_database()
|
||||||
|
with db_session() as conn:
|
||||||
|
for offset in range(60):
|
||||||
|
current_day = as_of - timedelta(days=offset)
|
||||||
|
trade_date = current_day.strftime("%Y%m%d")
|
||||||
|
close = 100 + (59 - offset)
|
||||||
|
turnover = 5 + 0.1 * (59 - offset)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO daily
|
||||||
|
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
close,
|
||||||
|
close,
|
||||||
|
close,
|
||||||
|
close,
|
||||||
|
0.0,
|
||||||
|
1000.0,
|
||||||
|
1_000_000.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO daily_basic
|
||||||
|
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
turnover,
|
||||||
|
turnover,
|
||||||
|
1.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_factors_persists_and_updates(isolated_db):
|
||||||
|
ts_code = "000001.SZ"
|
||||||
|
trade_day = date(2025, 1, 30)
|
||||||
|
_populate_sample_data(ts_code, trade_day)
|
||||||
|
|
||||||
|
specs = [*DEFAULT_FACTORS, FactorSpec("mom_5", 5)]
|
||||||
|
results = compute_factors(trade_day, specs)
|
||||||
|
|
||||||
|
assert results
|
||||||
|
result_map = {result.ts_code: result for result in results}
|
||||||
|
assert ts_code in result_map
|
||||||
|
result: FactorResult = result_map[ts_code]
|
||||||
|
|
||||||
|
close_series = [100 + (59 - offset) for offset in range(60)]
|
||||||
|
turnover_series = [5 + 0.1 * (59 - offset) for offset in range(60)]
|
||||||
|
|
||||||
|
expected_mom20 = momentum(close_series, 20)
|
||||||
|
expected_mom60 = momentum(close_series, 60)
|
||||||
|
expected_mom5 = momentum(close_series, 5)
|
||||||
|
expected_volat20 = volatility(close_series, 20)
|
||||||
|
expected_turn20 = rolling_mean(turnover_series, 20)
|
||||||
|
|
||||||
|
assert result.values["mom_20"] == pytest.approx(expected_mom20)
|
||||||
|
assert result.values["mom_60"] == pytest.approx(expected_mom60)
|
||||||
|
assert result.values["mom_5"] == pytest.approx(expected_mom5)
|
||||||
|
assert result.values["volat_20"] == pytest.approx(expected_volat20)
|
||||||
|
assert result.values["turn_20"] == pytest.approx(expected_turn20)
|
||||||
|
|
||||||
|
trade_date_str = trade_day.strftime("%Y%m%d")
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT mom_20, mom_60, mom_5, volat_20, turn_20
|
||||||
|
FROM factors WHERE ts_code = ? AND trade_date = ?
|
||||||
|
""",
|
||||||
|
(ts_code, trade_date_str),
|
||||||
|
).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["mom_20"] == pytest.approx(expected_mom20)
|
||||||
|
assert row["mom_60"] == pytest.approx(expected_mom60)
|
||||||
|
assert row["mom_5"] == pytest.approx(expected_mom5)
|
||||||
|
assert row["volat_20"] == pytest.approx(expected_volat20)
|
||||||
|
assert row["turn_20"] == pytest.approx(expected_turn20)
|
||||||
|
|
||||||
|
broker = DataBroker()
|
||||||
|
latest = broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_5", "factors.turn_20"])
|
||||||
|
assert latest["factors.mom_5"] == pytest.approx(expected_mom5)
|
||||||
|
assert latest["factors.turn_20"] == pytest.approx(expected_turn20)
|
||||||
|
|
||||||
|
# Calling compute_factors again should update existing rows without error.
|
||||||
|
second_results = compute_factors(trade_day, specs)
|
||||||
|
assert second_results
|
||||||
|
assert broker.fetch_latest(ts_code, trade_date_str, ["factors.mom_20"])["factors.mom_20"] == pytest.approx(
|
||||||
|
expected_mom20
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_factors_skip_existing(isolated_db):
|
||||||
|
ts_code = "000001.SZ"
|
||||||
|
trade_day = date(2025, 2, 10)
|
||||||
|
_populate_sample_data(ts_code, trade_day)
|
||||||
|
|
||||||
|
compute_factors(trade_day)
|
||||||
|
skipped = compute_factors(trade_day, skip_existing=True)
|
||||||
|
assert skipped == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_factor_range_filters_universe(isolated_db):
|
||||||
|
code_a = "000001.SZ"
|
||||||
|
code_b = "000002.SZ"
|
||||||
|
end_day = date(2025, 3, 5)
|
||||||
|
start_day = end_day - timedelta(days=1)
|
||||||
|
|
||||||
|
_populate_sample_data(code_a, end_day)
|
||||||
|
_populate_sample_data(code_b, end_day)
|
||||||
|
|
||||||
|
results = compute_factor_range(start_day, end_day, ts_codes=[code_a])
|
||||||
|
assert results
|
||||||
|
assert {result.ts_code for result in results} == {code_a}
|
||||||
|
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall()
|
||||||
|
assert {row["ts_code"] for row in rows} == {code_a}
|
||||||
|
|
||||||
|
repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a])
|
||||||
|
assert repeated == []
|
||||||
@ -28,16 +28,17 @@ def isolated_db(tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
def test_fetch_rss_feed_parses_entries(monkeypatch):
|
def test_fetch_rss_feed_parses_entries(monkeypatch):
|
||||||
|
published = datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||||||
sample_feed = (
|
sample_feed = (
|
||||||
"""
|
f"""
|
||||||
<rss version="2.0">
|
<rss version=\"2.0\">
|
||||||
<channel>
|
<channel>
|
||||||
<title>Example</title>
|
<title>Example</title>
|
||||||
<item>
|
<item>
|
||||||
<title>新闻:公司利好公告</title>
|
<title>新闻:公司利好公告</title>
|
||||||
<link>https://example.com/a</link>
|
<link>https://example.com/a</link>
|
||||||
<description><![CDATA[内容包含 000001.SZ ]]></description>
|
<description><![CDATA[内容包含 000001.SZ ]]></description>
|
||||||
<pubDate>Wed, 01 Jan 2025 08:30:00 GMT</pubDate>
|
<pubDate>{published}</pubDate>
|
||||||
<guid>a</guid>
|
<guid>a</guid>
|
||||||
</item>
|
</item>
|
||||||
</channel>
|
</channel>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user