update
This commit is contained in:
parent
15bd9d06a7
commit
8b3c6fe690
@ -4,12 +4,14 @@ from __future__ import annotations
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from typing import Dict, Iterable, List, Mapping
|
||||
from statistics import mean, pstdev
|
||||
from typing import Any, Dict, Iterable, List, Mapping
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.departments import DepartmentManager
|
||||
from app.agents.game import Decision, decide
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
@ -19,6 +21,45 @@ LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "backtest"}
|
||||
|
||||
|
||||
def _compute_momentum(values: List[float], window: int) -> float:
|
||||
if window <= 0 or len(values) < window:
|
||||
return 0.0
|
||||
latest = values[0]
|
||||
past = values[window - 1]
|
||||
if past is None or past == 0:
|
||||
return 0.0
|
||||
try:
|
||||
return (latest / past) - 1.0
|
||||
except ZeroDivisionError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _compute_volatility(values: List[float], window: int) -> float:
|
||||
if len(values) < 2 or window <= 1:
|
||||
return 0.0
|
||||
limit = min(window, len(values) - 1)
|
||||
returns: List[float] = []
|
||||
for idx in range(limit):
|
||||
current = values[idx]
|
||||
previous = values[idx + 1]
|
||||
if previous is None or previous == 0:
|
||||
continue
|
||||
returns.append((current / previous) - 1.0)
|
||||
if len(returns) < 2:
|
||||
return 0.0
|
||||
return float(pstdev(returns))
|
||||
|
||||
|
||||
def _normalize(value: Any, factor: float) -> float:
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
if factor <= 0:
|
||||
return max(0.0, min(1.0, numeric))
|
||||
return max(0.0, min(1.0, numeric / factor))
|
||||
|
||||
|
||||
@dataclass
|
||||
class BtConfig:
|
||||
id: str
|
||||
@ -57,18 +98,137 @@ class BacktestEngine:
|
||||
self.department_manager = (
|
||||
DepartmentManager(app_cfg) if app_cfg.departments else None
|
||||
)
|
||||
self.data_broker = DataBroker()
|
||||
department_scope: set[str] = set()
|
||||
for settings in app_cfg.departments.values():
|
||||
department_scope.update(settings.data_scope)
|
||||
base_scope = {
|
||||
"daily.close",
|
||||
"daily.open",
|
||||
"daily.high",
|
||||
"daily.low",
|
||||
"daily.pct_chg",
|
||||
"daily.vol",
|
||||
"daily.amount",
|
||||
"daily_basic.turnover_rate",
|
||||
"daily_basic.turnover_rate_f",
|
||||
"daily_basic.volume_ratio",
|
||||
"stk_limit.up_limit",
|
||||
"stk_limit.down_limit",
|
||||
}
|
||||
self.required_fields = sorted(base_scope | department_scope)
|
||||
|
||||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
||||
"""Load per-stock feature vectors. Replace with real data access."""
|
||||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]:
|
||||
"""Load per-stock feature vectors and context slices for the trade date."""
|
||||
|
||||
_ = trade_date
|
||||
return {}
|
||||
trade_date_str = trade_date.strftime("%Y%m%d")
|
||||
feature_map: Dict[str, Dict[str, Any]] = {}
|
||||
universe = self.cfg.universe or []
|
||||
for ts_code in universe:
|
||||
scope_values = self.data_broker.fetch_latest(
|
||||
ts_code,
|
||||
trade_date_str,
|
||||
self.required_fields,
|
||||
)
|
||||
|
||||
closes = self.data_broker.fetch_series(
|
||||
"daily",
|
||||
"close",
|
||||
ts_code,
|
||||
trade_date_str,
|
||||
window=60,
|
||||
)
|
||||
close_values = [value for _date, value in closes]
|
||||
mom20 = _compute_momentum(close_values, 20)
|
||||
mom60 = _compute_momentum(close_values, 60)
|
||||
volat20 = _compute_volatility(close_values, 20)
|
||||
|
||||
turnover_series = self.data_broker.fetch_series(
|
||||
"daily_basic",
|
||||
"turnover_rate",
|
||||
ts_code,
|
||||
trade_date_str,
|
||||
window=20,
|
||||
)
|
||||
turnover_values = [value for _date, value in turnover_series]
|
||||
turn20 = mean(turnover_values) if turnover_values else 0.0
|
||||
|
||||
liquidity_score = _normalize(turn20, factor=20.0)
|
||||
cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0)
|
||||
|
||||
latest_close = scope_values.get("daily.close", 0.0)
|
||||
latest_pct = scope_values.get("daily.pct_chg", 0.0)
|
||||
latest_turnover = scope_values.get("daily_basic.turnover_rate", 0.0)
|
||||
|
||||
up_limit = scope_values.get("stk_limit.up_limit")
|
||||
limit_up = False
|
||||
if up_limit and latest_close:
|
||||
limit_up = latest_close >= up_limit * 0.999
|
||||
|
||||
down_limit = scope_values.get("stk_limit.down_limit")
|
||||
limit_down = False
|
||||
if down_limit and latest_close:
|
||||
limit_down = latest_close <= down_limit * 1.001
|
||||
|
||||
features = {
|
||||
"mom_20": mom20,
|
||||
"mom_60": mom60,
|
||||
"volat_20": volat20,
|
||||
"turn_20": turn20,
|
||||
"liquidity_score": liquidity_score,
|
||||
"cost_penalty": cost_penalty,
|
||||
"news_heat": scope_values.get("news.heat_score", 0.0),
|
||||
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
|
||||
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
|
||||
"industry_relative_mom": scope_values.get(
|
||||
"macro.relative_strength",
|
||||
scope_values.get("index.performance_peers", 0.0),
|
||||
),
|
||||
"risk_penalty": min(1.0, volat20 * 5.0),
|
||||
"is_suspended": False,
|
||||
"limit_up": limit_up,
|
||||
"limit_down": limit_down,
|
||||
"position_limit": False,
|
||||
}
|
||||
|
||||
market_snapshot = {
|
||||
"close": latest_close,
|
||||
"pct_chg": latest_pct,
|
||||
"turnover_rate": latest_turnover,
|
||||
"volume": scope_values.get("daily.vol", 0.0),
|
||||
"amount": scope_values.get("daily.amount", 0.0),
|
||||
"up_limit": up_limit,
|
||||
"down_limit": down_limit,
|
||||
}
|
||||
|
||||
raw_payload = {
|
||||
"scope_values": scope_values,
|
||||
"close_series": closes,
|
||||
"turnover_series": turnover_series,
|
||||
}
|
||||
|
||||
feature_map[ts_code] = {
|
||||
"features": features,
|
||||
"market_snapshot": market_snapshot,
|
||||
"raw": raw_payload,
|
||||
}
|
||||
|
||||
return feature_map
|
||||
|
||||
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
|
||||
feature_map = self.load_market_data(trade_date)
|
||||
decisions: List[Decision] = []
|
||||
for ts_code, features in feature_map.items():
|
||||
context = AgentContext(ts_code=ts_code, trade_date=trade_date.isoformat(), features=features)
|
||||
for ts_code, payload in feature_map.items():
|
||||
features = payload.get("features", {})
|
||||
market_snapshot = payload.get("market_snapshot", {})
|
||||
raw = payload.get("raw", {})
|
||||
context = AgentContext(
|
||||
ts_code=ts_code,
|
||||
trade_date=trade_date.isoformat(),
|
||||
features=features,
|
||||
market_snapshot=market_snapshot,
|
||||
raw=raw,
|
||||
)
|
||||
decision = decide(
|
||||
context,
|
||||
self.agents,
|
||||
|
||||
159
app/utils/data_access.py
Normal file
159
app/utils/data_access.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""Utility helpers to retrieve structured data slices for agents and departments."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Sequence, Tuple
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "data_broker"}
|
||||
|
||||
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _is_safe_identifier(name: str) -> bool:
|
||||
return bool(_IDENTIFIER_RE.match(name))
|
||||
|
||||
|
||||
def _safe_split(path: str) -> Tuple[str, str] | None:
|
||||
if "." not in path:
|
||||
return None
|
||||
table, column = path.split(".", 1)
|
||||
table = table.strip()
|
||||
column = column.strip()
|
||||
if not table or not column:
|
||||
return None
|
||||
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
||||
LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA)
|
||||
return None
|
||||
return table, column
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataBroker:
|
||||
"""Lightweight data access helper for agent/LLM consumption."""
|
||||
|
||||
def fetch_latest(
|
||||
self,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
fields: Iterable[str],
|
||||
) -> Dict[str, float]:
|
||||
"""Fetch the latest value (<= trade_date) for each requested field."""
|
||||
|
||||
grouped: Dict[str, List[str]] = {}
|
||||
for item in fields:
|
||||
if not item:
|
||||
continue
|
||||
normalized = _safe_split(str(item))
|
||||
if not normalized:
|
||||
continue
|
||||
table, column = normalized
|
||||
grouped.setdefault(table, [])
|
||||
if column not in grouped[table]:
|
||||
grouped[table].append(column)
|
||||
|
||||
if not grouped:
|
||||
return {}
|
||||
|
||||
results: Dict[str, float] = {}
|
||||
with db_session(read_only=True) as conn:
|
||||
for table, columns in grouped.items():
|
||||
joined_cols = ", ".join(columns)
|
||||
query = (
|
||||
f"SELECT trade_date, {joined_cols} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT 1"
|
||||
)
|
||||
try:
|
||||
row = conn.execute(query, (ts_code, trade_date)).fetchone()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"查询失败 table=%s fields=%s err=%s",
|
||||
table,
|
||||
columns,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if not row:
|
||||
continue
|
||||
for column in columns:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
continue
|
||||
key = f"{table}.{column}"
|
||||
results[key] = float(value)
|
||||
return results
|
||||
|
||||
def fetch_series(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
ts_code: str,
|
||||
end_date: str,
|
||||
window: int,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""Return descending time series tuples within the specified window."""
|
||||
|
||||
if window <= 0:
|
||||
return []
|
||||
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
||||
return []
|
||||
query = (
|
||||
f"SELECT trade_date, {column} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"时间序列查询失败 table=%s column=%s err=%s",
|
||||
table,
|
||||
column,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
series: List[Tuple[str, float]] = []
|
||||
for row in rows:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
continue
|
||||
series.append((row["trade_date"], float(value)))
|
||||
return series
|
||||
|
||||
def fetch_flags(
|
||||
self,
|
||||
table: str,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
where_clause: str,
|
||||
params: Sequence[object],
|
||||
) -> bool:
|
||||
"""Generic helper to test if a record exists (used for limit/suspend lookups)."""
|
||||
|
||||
if not _is_safe_identifier(table):
|
||||
return False
|
||||
query = (
|
||||
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
||||
)
|
||||
bind_params = (ts_code, *params)
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
row = conn.execute(query, bind_params).fetchone()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"flag 查询失败 table=%s where=%s err=%s",
|
||||
table,
|
||||
where_clause,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return False
|
||||
return row is not None
|
||||
Loading…
Reference in New Issue
Block a user