diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 140d51d..8f37013 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -4,12 +4,14 @@ from __future__ import annotations import json from dataclasses import dataclass, field from datetime import date -from typing import Dict, Iterable, List, Mapping +from statistics import mean, pstdev +from typing import Any, Dict, Iterable, List, Mapping from app.agents.base import AgentContext from app.agents.departments import DepartmentManager from app.agents.game import Decision, decide from app.agents.registry import default_agents +from app.utils.data_access import DataBroker from app.utils.config import get_config from app.utils.db import db_session from app.utils.logging import get_logger @@ -19,6 +21,45 @@ LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "backtest"} +def _compute_momentum(values: List[float], window: int) -> float: + if window <= 0 or len(values) < window: + return 0.0 + latest = values[0] + past = values[window - 1] + if past is None or past == 0: + return 0.0 + try: + return (latest / past) - 1.0 + except ZeroDivisionError: + return 0.0 + + +def _compute_volatility(values: List[float], window: int) -> float: + if len(values) < 2 or window <= 1: + return 0.0 + limit = min(window, len(values) - 1) + returns: List[float] = [] + for idx in range(limit): + current = values[idx] + previous = values[idx + 1] + if previous is None or previous == 0: + continue + returns.append((current / previous) - 1.0) + if len(returns) < 2: + return 0.0 + return float(pstdev(returns)) + + +def _normalize(value: Any, factor: float) -> float: + try: + numeric = float(value) + except (TypeError, ValueError): + return 0.0 + if factor <= 0: + return max(0.0, min(1.0, numeric)) + return max(0.0, min(1.0, numeric / factor)) + + @dataclass class BtConfig: id: str @@ -57,18 +98,137 @@ class BacktestEngine: self.department_manager = ( DepartmentManager(app_cfg) if app_cfg.departments else None ) + self.data_broker = DataBroker() + department_scope: set[str] = set() + for settings in app_cfg.departments.values(): + department_scope.update(settings.data_scope) + base_scope = { + "daily.close", + "daily.open", + "daily.high", + "daily.low", + "daily.pct_chg", + "daily.vol", + "daily.amount", + "daily_basic.turnover_rate", + "daily_basic.turnover_rate_f", + "daily_basic.volume_ratio", + "stk_limit.up_limit", + "stk_limit.down_limit", + } + self.required_fields = sorted(base_scope | department_scope) - def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]: - """Load per-stock feature vectors. Replace with real data access.""" + def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]: + """Load per-stock feature vectors and context slices for the trade date.""" - _ = trade_date - return {} + trade_date_str = trade_date.strftime("%Y%m%d") + feature_map: Dict[str, Dict[str, Any]] = {} + universe = self.cfg.universe or [] + for ts_code in universe: + scope_values = self.data_broker.fetch_latest( + ts_code, + trade_date_str, + self.required_fields, + ) + + closes = self.data_broker.fetch_series( + "daily", + "close", + ts_code, + trade_date_str, + window=60, + ) + close_values = [value for _date, value in closes] + mom20 = _compute_momentum(close_values, 20) + mom60 = _compute_momentum(close_values, 60) + volat20 = _compute_volatility(close_values, 20) + + turnover_series = self.data_broker.fetch_series( + "daily_basic", + "turnover_rate", + ts_code, + trade_date_str, + window=20, + ) + turnover_values = [value for _date, value in turnover_series] + turn20 = mean(turnover_values) if turnover_values else 0.0 + + liquidity_score = _normalize(turn20, factor=20.0) + cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0) + + latest_close = scope_values.get("daily.close", 0.0) + latest_pct = scope_values.get("daily.pct_chg", 0.0) + latest_turnover = scope_values.get("daily_basic.turnover_rate", 0.0) + + up_limit = scope_values.get("stk_limit.up_limit") + limit_up = False + if up_limit and latest_close: + limit_up = latest_close >= up_limit * 0.999 + + down_limit = scope_values.get("stk_limit.down_limit") + limit_down = False + if down_limit and latest_close: + limit_down = latest_close <= down_limit * 1.001 + + features = { + "mom_20": mom20, + "mom_60": mom60, + "volat_20": volat20, + "turn_20": turn20, + "liquidity_score": liquidity_score, + "cost_penalty": cost_penalty, + "news_heat": scope_values.get("news.heat_score", 0.0), + "news_sentiment": scope_values.get("news.sentiment_index", 0.0), + "industry_heat": scope_values.get("macro.industry_heat", 0.0), + "industry_relative_mom": scope_values.get( + "macro.relative_strength", + scope_values.get("index.performance_peers", 0.0), + ), + "risk_penalty": min(1.0, volat20 * 5.0), + "is_suspended": False, + "limit_up": limit_up, + "limit_down": limit_down, + "position_limit": False, + } + + market_snapshot = { + "close": latest_close, + "pct_chg": latest_pct, + "turnover_rate": latest_turnover, + "volume": scope_values.get("daily.vol", 0.0), + "amount": scope_values.get("daily.amount", 0.0), + "up_limit": up_limit, + "down_limit": down_limit, + } + + raw_payload = { + "scope_values": scope_values, + "close_series": closes, + "turnover_series": turnover_series, + } + + feature_map[ts_code] = { + "features": features, + "market_snapshot": market_snapshot, + "raw": raw_payload, + } + + return feature_map def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]: feature_map = self.load_market_data(trade_date) decisions: List[Decision] = [] - for ts_code, features in feature_map.items(): - context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features) + for ts_code, payload in feature_map.items(): + features = payload.get("features", {}) + market_snapshot = payload.get("market_snapshot", {}) + raw = payload.get("raw", {}) + context = AgentContext( + ts_code=ts_code, + trade_date=trade_date.isoformat(), + features=features, + market_snapshot=market_snapshot, + raw=raw, + ) decision = decide( context, self.agents, diff --git a/app/utils/data_access.py b/app/utils/data_access.py new file mode 100644 index 0000000..d5b9f73 --- /dev/null +++ b/app/utils/data_access.py @@ -0,0 +1,159 @@ +"""Utility helpers to retrieve structured data slices for agents and departments.""" +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Dict, Iterable, List, Sequence, Tuple + +from .db import db_session +from .logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "data_broker"} + +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _is_safe_identifier(name: str) -> bool: + return bool(_IDENTIFIER_RE.match(name)) + + +def _safe_split(path: str) -> Tuple[str, str] | None: + if "." not in path: + return None + table, column = path.split(".", 1) + table = table.strip() + column = column.strip() + if not table or not column: + return None + if not (_is_safe_identifier(table) and _is_safe_identifier(column)): + LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA) + return None + return table, column + + +@dataclass +class DataBroker: + """Lightweight data access helper for agent/LLM consumption.""" + + def fetch_latest( + self, + ts_code: str, + trade_date: str, + fields: Iterable[str], + ) -> Dict[str, float]: + """Fetch the latest value (<= trade_date) for each requested field.""" + + grouped: Dict[str, List[str]] = {} + for item in fields: + if not item: + continue + normalized = _safe_split(str(item)) + if not normalized: + continue + table, column = normalized + grouped.setdefault(table, []) + if column not in grouped[table]: + grouped[table].append(column) + + if not grouped: + return {} + + results: Dict[str, float] = {} + with db_session(read_only=True) as conn: + for table, columns in grouped.items(): + joined_cols = ", ".join(columns) + query = ( + f"SELECT trade_date, {joined_cols} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT 1" + ) + try: + row = conn.execute(query, (ts_code, trade_date)).fetchone() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "查询失败 table=%s fields=%s err=%s", + table, + columns, + exc, + extra=LOG_EXTRA, + ) + continue + if not row: + continue + for column in columns: + value = row[column] + if value is None: + continue + key = f"{table}.{column}" + results[key] = float(value) + return results + + def fetch_series( + self, + table: str, + column: str, + ts_code: str, + end_date: str, + window: int, + ) -> List[Tuple[str, float]]: + """Return descending time series tuples within the specified window.""" + + if window <= 0: + return [] + if not (_is_safe_identifier(table) and _is_safe_identifier(column)): + return [] + query = ( + f"SELECT trade_date, {column} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT ?" + ) + with db_session(read_only=True) as conn: + try: + rows = conn.execute(query, (ts_code, end_date, window)).fetchall() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "时间序列查询失败 table=%s column=%s err=%s", + table, + column, + exc, + extra=LOG_EXTRA, + ) + return [] + series: List[Tuple[str, float]] = [] + for row in rows: + value = row[column] + if value is None: + continue + series.append((row["trade_date"], float(value))) + return series + + def fetch_flags( + self, + table: str, + ts_code: str, + trade_date: str, + where_clause: str, + params: Sequence[object], + ) -> bool: + """Generic helper to test if a record exists (used for limit/suspend lookups).""" + + if not _is_safe_identifier(table): + return False + query = ( + f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1" + ) + bind_params = (ts_code, *params) + with db_session(read_only=True) as conn: + try: + row = conn.execute(query, bind_params).fetchone() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "flag 查询失败 table=%s where=%s err=%s", + table, + where_clause, + exc, + extra=LOG_EXTRA, + ) + return False + return row is not None