From b3f2f5b4fcde61498f2b137920c338a22eaeabcc Mon Sep 17 00:00:00 2001 From: sam Date: Mon, 29 Sep 2025 16:01:37 +0800 Subject: [PATCH] update --- app/agents/departments.py | 10 +- app/backtest/engine.py | 73 +++--- app/core/__init__.py | 1 + app/core/indicators.py | 86 +++++++ app/utils/config.py | 26 ++- app/utils/data_access.py | 457 ++++++++++++++++++++++++++++++++------ 6 files changed, 530 insertions(+), 123 deletions(-) create mode 100644 app/core/__init__.py create mode 100644 app/core/indicators.py diff --git a/app/agents/departments.py b/app/agents/departments.py index f4fd32e..df3206f 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -74,7 +74,15 @@ class DepartmentDecision: class DepartmentAgent: """Wraps LLM ensemble logic for a single analytical department.""" - ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"] + ALLOWED_TABLES: ClassVar[List[str]] = [ + "daily", + "daily_basic", + "stk_limit", + "suspend", + "heat_daily", + "news", + "index_daily", + ] def __init__( self, diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 10698e3..41cf856 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -4,7 +4,6 @@ from __future__ import annotations import json from dataclasses import dataclass, field from datetime import date -from statistics import mean, pstdev from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional from app.agents.base import AgentAction, AgentContext @@ -16,50 +15,13 @@ from app.utils.data_access import DataBroker from app.utils.config import get_config from app.utils.db import db_session from app.utils.logging import get_logger +from app.core.indicators import momentum, normalize, rolling_mean, volatility LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "backtest"} -def _compute_momentum(values: List[float], window: int) -> float: - if window <= 0 or len(values) < window: - return 0.0 - latest = values[0] - past = values[window - 1] - if past is None or past == 0: - return 0.0 - try: - return (latest / past) - 1.0 - except ZeroDivisionError: - return 0.0 - - -def _compute_volatility(values: List[float], window: int) -> float: - if len(values) < 2 or window <= 1: - return 0.0 - limit = min(window, len(values) - 1) - returns: List[float] = [] - for idx in range(limit): - current = values[idx] - previous = values[idx + 1] - if previous is None or previous == 0: - continue - returns.append((current / previous) - 1.0) - if len(returns) < 2: - return 0.0 - return float(pstdev(returns)) - - -def _normalize(value: Any, factor: float) -> float: - try: - numeric = float(value) - except (TypeError, ValueError): - return 0.0 - if factor <= 0: - return max(0.0, min(1.0, numeric)) - return max(0.0, min(1.0, numeric / factor)) - @dataclass class BtConfig: @@ -143,9 +105,9 @@ class BacktestEngine: window=60, ) close_values = [value for _date, value in closes] - mom20 = _compute_momentum(close_values, 20) - mom60 = _compute_momentum(close_values, 60) - volat20 = _compute_volatility(close_values, 20) + mom20 = momentum(close_values, 20) + mom60 = momentum(close_values, 60) + volat20 = volatility(close_values, 20) turnover_series = self.data_broker.fetch_series( "daily_basic", @@ -155,10 +117,31 @@ class BacktestEngine: window=20, ) turnover_values = [value for _date, value in turnover_series] - turn20 = mean(turnover_values) if turnover_values else 0.0 + turn20 = rolling_mean(turnover_values, 20) - liquidity_score = _normalize(turn20, factor=20.0) - cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0) + liquidity_score = normalize(turn20, factor=20.0) + cost_penalty = normalize( + scope_values.get("daily_basic.volume_ratio", 0.0), + factor=50.0, + ) + + scope_values.setdefault("factors.mom_20", mom20) + scope_values.setdefault("factors.mom_60", mom60) + scope_values.setdefault("factors.volat_20", volat20) + scope_values.setdefault("factors.turn_20", turn20) + scope_values.setdefault("news.sentiment_index", 0.0) + scope_values.setdefault("news.heat_score", 0.0) + if scope_values.get("macro.industry_heat") is None: + scope_values["macro.industry_heat"] = 0.5 + if scope_values.get("macro.relative_strength") is None: + peer_strength = scope_values.get("index.performance_peers") + if peer_strength is None: + peer_strength = 0.5 + scope_values["macro.relative_strength"] = peer_strength + scope_values.setdefault( + "index.performance_peers", + scope_values.get("macro.relative_strength", 0.5), + ) latest_close = scope_values.get("daily.close", 0.0) latest_pct = scope_values.get("daily.pct_chg", 0.0) diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..927aa1d --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1 @@ +"""Core utilities shared across application layers.""" diff --git a/app/core/indicators.py b/app/core/indicators.py new file mode 100644 index 0000000..5f3f342 --- /dev/null +++ b/app/core/indicators.py @@ -0,0 +1,86 @@ +"""Reusable quantitative indicator helpers.""" +from __future__ import annotations + +from statistics import pstdev +from typing import Iterable, Sequence + + +def _to_float_list(values: Iterable[object]) -> list[float]: + cleaned: list[float] = [] + for value in values: + try: + cleaned.append(float(value)) + except (TypeError, ValueError): + continue + return cleaned + + +def momentum(series: Sequence[object], window: int) -> float: + """Return simple momentum ratio over ``window`` periods. + + ``series`` is expected to be ordered from most recent to oldest. ``0.0`` is + returned when insufficient history or the denominator is invalid. + """ + + if window <= 0: + return 0.0 + numeric = _to_float_list(series) + if len(numeric) < window: + return 0.0 + latest = numeric[0] + past = numeric[window - 1] + if past == 0.0: + return 0.0 + try: + return (latest / past) - 1.0 + except ZeroDivisionError: + return 0.0 + + +def volatility(series: Sequence[object], window: int) -> float: + """Compute population standard deviation of simple returns.""" + + if window <= 1: + return 0.0 + numeric = _to_float_list(series) + if len(numeric) < 2: + return 0.0 + limit = min(window, len(numeric) - 1) + returns: list[float] = [] + for idx in range(limit): + current = numeric[idx] + previous = numeric[idx + 1] + if previous == 0.0: + continue + returns.append((current / previous) - 1.0) + if len(returns) < 2: + return 0.0 + return float(pstdev(returns)) + + +def rolling_mean(series: Sequence[object], window: int) -> float: + """Return the arithmetic mean over the latest ``window`` observations.""" + + if window <= 0: + return 0.0 + numeric = _to_float_list(series) + if not numeric: + return 0.0 + subset = numeric[: min(window, len(numeric))] + if not subset: + return 0.0 + return float(sum(subset) / len(subset)) + + +def normalize(value: object, *, factor: float | None = None, clamp: tuple[float, float] = (0.0, 1.0)) -> float: + """Clamp ``value`` into the ``clamp`` interval after optional scaling.""" + + if clamp[0] > clamp[1]: + raise ValueError("clamp minimum cannot exceed maximum") + try: + numeric = float(value) + except (TypeError, ValueError): + return clamp[0] + if factor and factor > 0: + numeric = numeric / factor + return max(clamp[0], min(clamp[1], numeric)) diff --git a/app/utils/config.py b/app/utils/config.py index 8204eff..5834dab 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -243,6 +243,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "daily_basic.turnover_rate", "factors.mom_20", "factors.mom_60", + "factors.volat_20", ], "prompt": "你主导动量风格研究,关注价格与成交量的加速变化,需在保持纪律的前提下判定短期多空倾向。", }, @@ -253,8 +254,9 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "data_scope": [ "daily_basic.pe", "daily_basic.pb", - "daily_basic.roe", - "fundamental.growth", + "daily_basic.ps", + "daily_basic.dv_ratio", + "factors.turn_20", ], "prompt": "你负责价值与质量评估,应结合估值分位、盈利持续性及安全边际给出配置建议。", }, @@ -265,7 +267,6 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "data_scope": [ "news.sentiment_index", "news.heat_score", - "events.latest_headlines", ], "prompt": "你专注新闻和事件驱动,应评估正负面舆情对标的短线波动的可能影响。", }, @@ -275,8 +276,11 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "description": "衡量成交活跃度与交易成本,控制进出场的实现可能性。", "data_scope": [ "daily_basic.volume_ratio", + "daily_basic.turnover_rate", "daily_basic.turnover_rate_f", - "market.spread_estimate", + "factors.turn_20", + "stk_limit.up_limit", + "stk_limit.down_limit", ], "prompt": "你负责评估该标的的流动性与滑点风险,需要提出可执行的仓位调整建议。", }, @@ -286,8 +290,8 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "description": "追踪宏观与行业景气度,为行业配置和风险偏好提供参考。", "data_scope": [ "macro.industry_heat", - "macro.liquidity_cycle", "index.performance_peers", + "macro.relative_strength", ], "prompt": "你负责宏观与行业研判,应结合宏观周期、行业景气与相对强弱给出方向性意见。", }, @@ -296,9 +300,10 @@ def _default_departments() -> Dict[str, DepartmentSettings]: "title": "风险控制部门", "description": "监控极端风险、合规与交易限制,必要时行使否决。", "data_scope": [ - "market.limit_flags", - "portfolio.position", - "risk.alerts", + "daily.pct_chg", + "suspend.suspend_type", + "stk_limit.up_limit", + "stk_limit.down_limit", ], "prompt": "你负责风险控制,应识别停牌、涨跌停、持仓约束等因素,必要时提出减仓或观望建议。", }, @@ -608,9 +613,8 @@ def save_config(cfg: AppConfig | None = None) -> None: try: path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(path.suffix + ".tmp") if path.suffix else path.with_name(path.name + ".tmp") - tmp_path.write_text(serialized, encoding="utf-8") - tmp_path.replace(path) + with path.open("w", encoding="utf-8") as fh: + fh.write(serialized) LOGGER.info("配置已写入:%s", path) except OSError: LOGGER.exception("配置写入失败:%s", path) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index 8155f7a..337efb4 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -2,11 +2,14 @@ from __future__ import annotations import re +import sqlite3 from dataclasses import dataclass -from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple +from datetime import datetime, timedelta +from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple from .db import db_session from .logging import get_logger +from app.core.indicators import momentum, normalize, rolling_mean, volatility LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "data_broker"} @@ -38,6 +41,27 @@ def parse_field_path(path: str) -> Tuple[str, str] | None: return _safe_split(path) +def _parse_trade_date(value: object) -> Optional[datetime]: + if value is None: + return None + text = str(value).strip() + if not text: + return None + text = text.replace("-", "") + try: + return datetime.strptime(text[:8], "%Y%m%d") + except ValueError: + return None + + +def _start_of_day(dt: datetime) -> str: + return dt.strftime("%Y-%m-%d 00:00:00") + + +def _end_of_day(dt: datetime) -> str: + return dt.strftime("%Y-%m-%d 23:59:59") + + @dataclass class DataBroker: """Lightweight data access helper for agent/LLM consumption.""" @@ -65,60 +89,77 @@ class DataBroker: }, } MAX_WINDOW: ClassVar[int] = 120 + BENCHMARK_INDEX: ClassVar[str] = "000300.SH" def fetch_latest( self, ts_code: str, trade_date: str, fields: Iterable[str], - ) -> Dict[str, float]: + ) -> Dict[str, Any]: """Fetch the latest value (<= trade_date) for each requested field.""" grouped: Dict[str, List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {} + derived_cache: Dict[str, Any] = {} + results: Dict[str, Any] = {} for item in fields: if not item: continue - resolved = self.resolve_field(str(item)) + field_name = str(item) + resolved = self.resolve_field(field_name) if not resolved: + derived = self._resolve_derived_field( + ts_code, + trade_date, + field_name, + derived_cache, + ) + if derived is not None: + results[field_name] = derived continue table, column = resolved grouped.setdefault(table, []) if column not in grouped[table]: grouped[table].append(column) - field_map.setdefault((table, column), []).append(str(item)) + field_map.setdefault((table, column), []).append(field_name) if not grouped: - return {} + return results - results: Dict[str, float] = {} - with db_session(read_only=True) as conn: - for table, columns in grouped.items(): - joined_cols = ", ".join(columns) - query = ( - f"SELECT trade_date, {joined_cols} FROM {table} " - "WHERE ts_code = ? AND trade_date <= ? " - "ORDER BY trade_date DESC LIMIT 1" - ) - try: - row = conn.execute(query, (ts_code, trade_date)).fetchone() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "查询失败 table=%s fields=%s err=%s", - table, - columns, - exc, - extra=LOG_EXTRA, + try: + with db_session(read_only=True) as conn: + for table, columns in grouped.items(): + joined_cols = ", ".join(columns) + query = ( + f"SELECT trade_date, {joined_cols} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT 1" ) - continue - if not row: - continue - for column in columns: - value = row[column] - if value is None: + try: + row = conn.execute(query, (ts_code, trade_date)).fetchone() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "查询失败 table=%s fields=%s err=%s", + table, + columns, + exc, + extra=LOG_EXTRA, + ) continue - for original in field_map.get((table, column), [f"{table}.{column}"]): - results[original] = float(value) + if not row: + continue + for column in columns: + value = row[column] + if value is None: + continue + for original in field_map.get((table, column), [f"{table}.{column}"]): + try: + results[original] = float(value) + except (TypeError, ValueError): + results[original] = value + except sqlite3.OperationalError as exc: + LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) return results def fetch_series( @@ -149,18 +190,28 @@ class DataBroker: "WHERE ts_code = ? AND trade_date <= ? " "ORDER BY trade_date DESC LIMIT ?" ) - with db_session(read_only=True) as conn: - try: - rows = conn.execute(query, (ts_code, end_date, window)).fetchall() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "时间序列查询失败 table=%s column=%s err=%s", - table, - column, - exc, - extra=LOG_EXTRA, - ) - return [] + try: + with db_session(read_only=True) as conn: + try: + rows = conn.execute(query, (ts_code, end_date, window)).fetchall() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "时间序列查询失败 table=%s column=%s err=%s", + table, + column, + exc, + extra=LOG_EXTRA, + ) + return [] + except sqlite3.OperationalError as exc: + LOGGER.debug( + "时间序列连接失败 table=%s column=%s err=%s", + table, + column, + exc, + extra=LOG_EXTRA, + ) + return [] series: List[Tuple[str, float]] = [] for row in rows: value = row[resolved] @@ -185,18 +236,27 @@ class DataBroker: f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1" ) bind_params = (ts_code, *params) - with db_session(read_only=True) as conn: - try: - row = conn.execute(query, bind_params).fetchone() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "flag 查询失败 table=%s where=%s err=%s", - table, - where_clause, - exc, - extra=LOG_EXTRA, - ) - return False + try: + with db_session(read_only=True) as conn: + try: + row = conn.execute(query, bind_params).fetchone() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "flag 查询失败 table=%s where=%s err=%s", + table, + where_clause, + exc, + extra=LOG_EXTRA, + ) + return False + except sqlite3.OperationalError as exc: + LOGGER.debug( + "flag 查询连接失败 table=%s err=%s", + table, + exc, + extra=LOG_EXTRA, + ) + return False return row is not None def fetch_table_rows( @@ -231,23 +291,288 @@ class DataBroker: params = (ts_code, window) results: List[Dict[str, object]] = [] - with db_session(read_only=True) as conn: - try: - rows = conn.execute(query, params).fetchall() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "表查询失败 table=%s err=%s", - table, - exc, - extra=LOG_EXTRA, - ) - return [] + try: + with db_session(read_only=True) as conn: + try: + rows = conn.execute(query, params).fetchall() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "表查询失败 table=%s err=%s", + table, + exc, + extra=LOG_EXTRA, + ) + return [] + except sqlite3.OperationalError as exc: + LOGGER.debug( + "表连接失败 table=%s err=%s", + table, + exc, + extra=LOG_EXTRA, + ) + return [] for row in rows: record = {col: row[col] for col in columns} results.append(record) return results + def _resolve_derived_field( + self, + ts_code: str, + trade_date: str, + field: str, + cache: Dict[str, Any], + ) -> Optional[Any]: + if field in cache: + return cache[field] + + value: Optional[Any] = None + if field == "factors.mom_20": + value = self._derived_price_momentum(ts_code, trade_date, 20) + elif field == "factors.mom_60": + value = self._derived_price_momentum(ts_code, trade_date, 60) + elif field == "factors.volat_20": + value = self._derived_price_volatility(ts_code, trade_date, 20) + elif field == "factors.turn_20": + value = self._derived_turnover_mean(ts_code, trade_date, 20) + elif field == "news.sentiment_index": + rows = cache.get("__news_rows__") + if rows is None: + rows = self._fetch_recent_news(ts_code, trade_date) + cache["__news_rows__"] = rows + value = self._news_sentiment_from_rows(rows) + elif field == "news.heat_score": + rows = cache.get("__news_rows__") + if rows is None: + rows = self._fetch_recent_news(ts_code, trade_date) + cache["__news_rows__"] = rows + value = self._news_heat_from_rows(rows) + elif field == "macro.industry_heat": + value = self._derived_industry_heat(ts_code, trade_date) + elif field in {"macro.relative_strength", "index.performance_peers"}: + value = self._derived_relative_strength(ts_code, trade_date, cache) + + cache[field] = value + return value + + def _derived_price_momentum( + self, + ts_code: str, + trade_date: str, + window: int, + ) -> Optional[float]: + series = self.fetch_series("daily", "close", ts_code, trade_date, window) + values = [value for _dt, value in series] + if not values: + return None + return momentum(values, window) + + def _derived_price_volatility( + self, + ts_code: str, + trade_date: str, + window: int, + ) -> Optional[float]: + series = self.fetch_series("daily", "close", ts_code, trade_date, window) + values = [value for _dt, value in series] + if len(values) < 2: + return None + return volatility(values, window) + + def _derived_turnover_mean( + self, + ts_code: str, + trade_date: str, + window: int, + ) -> Optional[float]: + series = self.fetch_series( + "daily_basic", + "turnover_rate", + ts_code, + trade_date, + window, + ) + values = [value for _dt, value in series] + if not values: + return None + return rolling_mean(values, window) + + def _fetch_recent_news( + self, + ts_code: str, + trade_date: str, + days: int = 3, + limit: int = 120, + ) -> List[Dict[str, Any]]: + baseline = _parse_trade_date(trade_date) + if baseline is None: + return [] + start = _start_of_day(baseline - timedelta(days=days)) + end = _end_of_day(baseline) + query = ( + "SELECT sentiment, heat FROM news " + "WHERE ts_code = ? AND pub_time BETWEEN ? AND ? " + "ORDER BY pub_time DESC LIMIT ?" + ) + try: + with db_session(read_only=True) as conn: + rows = conn.execute(query, (ts_code, start, end, limit)).fetchall() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "新闻查询连接失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return [] + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "新闻查询失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return [] + return [dict(row) for row in rows] + + @staticmethod + def _news_sentiment_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]: + sentiments: List[float] = [] + for row in rows: + value = row.get("sentiment") + if value is None: + continue + try: + sentiments.append(float(value)) + except (TypeError, ValueError): + continue + if not sentiments: + return None + avg = sum(sentiments) / len(sentiments) + return max(-1.0, min(1.0, avg)) + + @staticmethod + def _news_heat_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]: + if not rows: + return None + total_heat = 0.0 + for row in rows: + value = row.get("heat") + if value is None: + continue + try: + total_heat += max(float(value), 0.0) + except (TypeError, ValueError): + continue + if total_heat > 0: + return normalize(total_heat, factor=100.0) + return normalize(len(rows), factor=20.0) + + def _derived_industry_heat(self, ts_code: str, trade_date: str) -> Optional[float]: + industry = self._lookup_industry(ts_code) + if not industry: + return None + query = ( + "SELECT heat FROM heat_daily " + "WHERE scope = ? AND key = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT 1" + ) + try: + with db_session(read_only=True) as conn: + row = conn.execute(query, ("industry", industry, trade_date)).fetchone() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "行业热度查询失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return None + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "行业热度读取异常 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return None + if not row: + return None + heat_value = row["heat"] + if heat_value is None: + return None + return normalize(heat_value, factor=100.0) + + def _lookup_industry(self, ts_code: str) -> Optional[str]: + cache = getattr(self, "_industry_cache", None) + if cache is None: + cache = {} + self._industry_cache = cache + if ts_code in cache: + return cache[ts_code] + query = "SELECT industry FROM stock_basic WHERE ts_code = ?" + try: + with db_session(read_only=True) as conn: + row = conn.execute(query, (ts_code,)).fetchone() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "行业查询连接失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + cache[ts_code] = None + return None + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "行业查询失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + cache[ts_code] = None + return None + industry = None + if row: + industry = row["industry"] + cache[ts_code] = industry + return industry + + def _derived_relative_strength( + self, + ts_code: str, + trade_date: str, + cache: Dict[str, Any], + ) -> Optional[float]: + window = 20 + series = self.fetch_series("daily", "close", ts_code, trade_date, max(window, 30)) + values = [value for _dt, value in series] + if not values: + return None + stock_momentum = momentum(values, window) + bench_key = f"__benchmark_mom_{window}" + benchmark = cache.get(bench_key) + if benchmark is None: + benchmark = self._index_momentum(trade_date, window) + cache[bench_key] = benchmark + diff = stock_momentum if benchmark is None else stock_momentum - benchmark + diff = max(-0.2, min(0.2, diff)) + return (diff + 0.2) / 0.4 + + def _index_momentum(self, trade_date: str, window: int) -> Optional[float]: + series = self.fetch_series( + "index_daily", + "close", + self.BENCHMARK_INDEX, + trade_date, + window, + ) + values = [value for _dt, value in series] + if not values: + return None + return momentum(values, window) + def resolve_field(self, field: str) -> Optional[Tuple[str, str]]: normalized = _safe_split(field) if not normalized: