This commit is contained in:
sam 2025-09-28 16:38:56 +08:00
parent 15bd9d06a7
commit 8b3c6fe690
2 changed files with 326 additions and 7 deletions

View File

@ -4,12 +4,14 @@ from __future__ import annotations
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date from datetime import date
from typing import Dict, Iterable, List, Mapping from statistics import mean, pstdev
from typing import Any, Dict, Iterable, List, Mapping
from app.agents.base import AgentContext from app.agents.base import AgentContext
from app.agents.departments import DepartmentManager from app.agents.departments import DepartmentManager
from app.agents.game import Decision, decide from app.agents.game import Decision, decide
from app.agents.registry import default_agents from app.agents.registry import default_agents
from app.utils.data_access import DataBroker
from app.utils.config import get_config from app.utils.config import get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
@ -19,6 +21,45 @@ LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "backtest"} LOG_EXTRA = {"stage": "backtest"}
def _compute_momentum(values: List[float], window: int) -> float:
if window <= 0 or len(values) < window:
return 0.0
latest = values[0]
past = values[window - 1]
if past is None or past == 0:
return 0.0
try:
return (latest / past) - 1.0
except ZeroDivisionError:
return 0.0
def _compute_volatility(values: List[float], window: int) -> float:
if len(values) < 2 or window <= 1:
return 0.0
limit = min(window, len(values) - 1)
returns: List[float] = []
for idx in range(limit):
current = values[idx]
previous = values[idx + 1]
if previous is None or previous == 0:
continue
returns.append((current / previous) - 1.0)
if len(returns) < 2:
return 0.0
return float(pstdev(returns))
def _normalize(value: Any, factor: float) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
if factor <= 0:
return max(0.0, min(1.0, numeric))
return max(0.0, min(1.0, numeric / factor))
@dataclass @dataclass
class BtConfig: class BtConfig:
id: str id: str
@ -57,18 +98,137 @@ class BacktestEngine:
self.department_manager = ( self.department_manager = (
DepartmentManager(app_cfg) if app_cfg.departments else None DepartmentManager(app_cfg) if app_cfg.departments else None
) )
self.data_broker = DataBroker()
department_scope: set[str] = set()
for settings in app_cfg.departments.values():
department_scope.update(settings.data_scope)
base_scope = {
"daily.close",
"daily.open",
"daily.high",
"daily.low",
"daily.pct_chg",
"daily.vol",
"daily.amount",
"daily_basic.turnover_rate",
"daily_basic.turnover_rate_f",
"daily_basic.volume_ratio",
"stk_limit.up_limit",
"stk_limit.down_limit",
}
self.required_fields = sorted(base_scope | department_scope)
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]: def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]:
"""Load per-stock feature vectors. Replace with real data access.""" """Load per-stock feature vectors and context slices for the trade date."""
_ = trade_date trade_date_str = trade_date.strftime("%Y%m%d")
return {} feature_map: Dict[str, Dict[str, Any]] = {}
universe = self.cfg.universe or []
for ts_code in universe:
scope_values = self.data_broker.fetch_latest(
ts_code,
trade_date_str,
self.required_fields,
)
closes = self.data_broker.fetch_series(
"daily",
"close",
ts_code,
trade_date_str,
window=60,
)
close_values = [value for _date, value in closes]
mom20 = _compute_momentum(close_values, 20)
mom60 = _compute_momentum(close_values, 60)
volat20 = _compute_volatility(close_values, 20)
turnover_series = self.data_broker.fetch_series(
"daily_basic",
"turnover_rate",
ts_code,
trade_date_str,
window=20,
)
turnover_values = [value for _date, value in turnover_series]
turn20 = mean(turnover_values) if turnover_values else 0.0
liquidity_score = _normalize(turn20, factor=20.0)
cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0)
latest_close = scope_values.get("daily.close", 0.0)
latest_pct = scope_values.get("daily.pct_chg", 0.0)
latest_turnover = scope_values.get("daily_basic.turnover_rate", 0.0)
up_limit = scope_values.get("stk_limit.up_limit")
limit_up = False
if up_limit and latest_close:
limit_up = latest_close >= up_limit * 0.999
down_limit = scope_values.get("stk_limit.down_limit")
limit_down = False
if down_limit and latest_close:
limit_down = latest_close <= down_limit * 1.001
features = {
"mom_20": mom20,
"mom_60": mom60,
"volat_20": volat20,
"turn_20": turn20,
"liquidity_score": liquidity_score,
"cost_penalty": cost_penalty,
"news_heat": scope_values.get("news.heat_score", 0.0),
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
"industry_relative_mom": scope_values.get(
"macro.relative_strength",
scope_values.get("index.performance_peers", 0.0),
),
"risk_penalty": min(1.0, volat20 * 5.0),
"is_suspended": False,
"limit_up": limit_up,
"limit_down": limit_down,
"position_limit": False,
}
market_snapshot = {
"close": latest_close,
"pct_chg": latest_pct,
"turnover_rate": latest_turnover,
"volume": scope_values.get("daily.vol", 0.0),
"amount": scope_values.get("daily.amount", 0.0),
"up_limit": up_limit,
"down_limit": down_limit,
}
raw_payload = {
"scope_values": scope_values,
"close_series": closes,
"turnover_series": turnover_series,
}
feature_map[ts_code] = {
"features": features,
"market_snapshot": market_snapshot,
"raw": raw_payload,
}
return feature_map
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]: def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
feature_map = self.load_market_data(trade_date) feature_map = self.load_market_data(trade_date)
decisions: List[Decision] = [] decisions: List[Decision] = []
for ts_code, features in feature_map.items(): for ts_code, payload in feature_map.items():
context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features) features = payload.get("features", {})
market_snapshot = payload.get("market_snapshot", {})
raw = payload.get("raw", {})
context = AgentContext(
ts_code=ts_code,
trade_date=trade_date.isoformat(),
features=features,
market_snapshot=market_snapshot,
raw=raw,
)
decision = decide( decision = decide(
context, context,
self.agents, self.agents,

159
app/utils/data_access.py Normal file
View File

@ -0,0 +1,159 @@
"""Utility helpers to retrieve structured data slices for agents and departments."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple
from .db import db_session
from .logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_broker"}
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _is_safe_identifier(name: str) -> bool:
return bool(_IDENTIFIER_RE.match(name))
def _safe_split(path: str) -> Tuple[str, str] | None:
if "." not in path:
return None
table, column = path.split(".", 1)
table = table.strip()
column = column.strip()
if not table or not column:
return None
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA)
return None
return table, column
@dataclass
class DataBroker:
"""Lightweight data access helper for agent/LLM consumption."""
def fetch_latest(
self,
ts_code: str,
trade_date: str,
fields: Iterable[str],
) -> Dict[str, float]:
"""Fetch the latest value (<= trade_date) for each requested field."""
grouped: Dict[str, List[str]] = {}
for item in fields:
if not item:
continue
normalized = _safe_split(str(item))
if not normalized:
continue
table, column = normalized
grouped.setdefault(table, [])
if column not in grouped[table]:
grouped[table].append(column)
if not grouped:
return {}
results: Dict[str, float] = {}
with db_session(read_only=True) as conn:
for table, columns in grouped.items():
joined_cols = ", ".join(columns)
query = (
f"SELECT trade_date, {joined_cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
try:
row = conn.execute(query, (ts_code, trade_date)).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
)
continue
if not row:
continue
for column in columns:
value = row[column]
if value is None:
continue
key = f"{table}.{column}"
results[key] = float(value)
return results
def fetch_series(
self,
table: str,
column: str,
ts_code: str,
end_date: str,
window: int,
) -> List[Tuple[str, float]]:
"""Return descending time series tuples within the specified window."""
if window <= 0:
return []
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
return []
query = (
f"SELECT trade_date, {column} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
with db_session(read_only=True) as conn:
try:
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列查询失败 table=%s column=%s err=%s",
table,
column,
exc,
extra=LOG_EXTRA,
)
return []
series: List[Tuple[str, float]] = []
for row in rows:
value = row[column]
if value is None:
continue
series.append((row["trade_date"], float(value)))
return series
def fetch_flags(
self,
table: str,
ts_code: str,
trade_date: str,
where_clause: str,
params: Sequence[object],
) -> bool:
"""Generic helper to test if a record exists (used for limit/suspend lookups)."""
if not _is_safe_identifier(table):
return False
query = (
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
)
bind_params = (ts_code, *params)
with db_session(read_only=True) as conn:
try:
row = conn.execute(query, bind_params).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"flag 查询失败 table=%s where=%s err=%s",
table,
where_clause,
exc,
extra=LOG_EXTRA,
)
return False
return row is not None