add blacklist risk control and alert dispatching

This commit is contained in:
sam 2025-10-17 10:52:49 +08:00
parent d2a056d7c0
commit 1ca2f2be19
11 changed files with 614 additions and 32 deletions

View File

@ -660,6 +660,7 @@ def _risk_review_message(reason: str) -> str:
"risk_penalty_extreme": "风险评分极高,建议暂停加仓", "risk_penalty_extreme": "风险评分极高,建议暂停加仓",
"risk_penalty_high": "风险评分偏高,建议复核", "risk_penalty_high": "风险评分偏高,建议复核",
"external_alert": "外部风险告警触发复核", "external_alert": "外部风险告警触发复核",
"blacklist": "标的命中黑名单,禁止交易",
} }
return mapping.get(reason, "触发风险复核,需人工确认") return mapping.get(reason, "触发风险复核,需人工确认")

View File

@ -49,6 +49,8 @@ class RiskAgent(Agent):
def feasible(self, context: AgentContext, action: AgentAction) -> bool: def feasible(self, context: AgentContext, action: AgentAction) -> bool:
if action is AgentAction.SELL: if action is AgentAction.SELL:
return True return True
if context.features.get("is_blacklisted", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
return False
if context.features.get("is_suspended", False): if context.features.get("is_suspended", False):
return False return False
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD): if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
@ -74,6 +76,15 @@ class RiskAgent(Agent):
notes={"trigger": "is_suspended"}, notes={"trigger": "is_suspended"},
) )
if bool(features.get("is_blacklisted")):
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
return RiskRecommendation(
status="blocked",
reason="blacklist",
recommended_action=fallback,
notes={"trigger": "is_blacklisted"},
)
if bool(features.get("limit_up")) and decision_action in { if bool(features.get("limit_up")) and decision_action in {
AgentAction.BUY_S, AgentAction.BUY_S,
AgentAction.BUY_M, AgentAction.BUY_M,

View File

@ -698,6 +698,7 @@ class BacktestEngine:
action_override: Optional[AgentAction] = None, action_override: Optional[AgentAction] = None,
target_weight_override: Optional[float] = None, target_weight_override: Optional[float] = None,
) -> None: ) -> None:
reason_str = str(reason)
payload = { payload = {
"trade_date": trade_date_str, "trade_date": trade_date_str,
"ts_code": ts_code, "ts_code": ts_code,
@ -708,21 +709,43 @@ class BacktestEngine:
else decision.target_weight else decision.target_weight
), ),
"confidence": decision.confidence, "confidence": decision.confidence,
"reason": reason, "reason": reason_str,
} }
if extra: if extra:
payload.update(extra) payload.update(extra)
risk_events.append(payload) risk_events.append(payload)
risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None risk_meta = None
status = None if isinstance(payload.get("risk_assessment"), dict):
risk_meta = payload.get("risk_assessment")
elif extra and isinstance(extra.get("risk_assessment"), dict):
risk_meta = extra.get("risk_assessment")
status: Optional[str] = None
if isinstance(risk_meta, dict): if isinstance(risk_meta, dict):
status = risk_meta.get("status") status = str(risk_meta.get("status") or "")
payload.setdefault("risk_status", status)
if status == "blocked": if status == "blocked":
try: try:
message = f"{ts_code} 风险阻断: {reason_str}"
alerts.add_warning( alerts.add_warning(
"backtest_risk", "backtest_risk",
f"{ts_code} 风险阻断: {reason}", message,
detail=json.dumps(payload, ensure_ascii=False), detail=json.dumps(payload, ensure_ascii=False),
level="error",
tags=["risk", reason_str, status],
payload=payload,
)
except Exception: # noqa: BLE001
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
elif status and status != "ok":
try:
message = f"{ts_code} 风险提示: {reason_str}"
alerts.add_warning(
"backtest_risk",
message,
detail=json.dumps(payload, ensure_ascii=False),
level="warning",
tags=["risk", reason_str, status],
payload=payload,
) )
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA) LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)

View File

@ -0,0 +1,169 @@
"""Dispatch structured alerts to external channels."""
from __future__ import annotations
import json
import logging
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence
import requests
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: # pragma: no cover
from app.utils.config import AlertChannelSettings
else:
AlertChannelSettings = Any # type: ignore[assignment]
_LEVEL_RANK: Dict[str, int] = {
"debug": 10,
"info": 20,
"warning": 30,
"error": 40,
"critical": 50,
}
class _Channel:
"""Runtime wrapper around a configured alert channel."""
__slots__ = (
"settings",
"_lock",
"_last_signature",
"_last_sent",
)
def __init__(self, settings: AlertChannelSettings) -> None:
self.settings = settings
self._lock = threading.Lock()
self._last_signature: Optional[str] = None
self._last_sent: float = 0.0
@property
def name(self) -> str:
return getattr(self.settings, "key", "channel")
def send(self, entry: Mapping[str, Any]) -> None:
if not self._should_send(entry):
return
payload = self._build_payload(entry)
self._deliver(payload)
def _should_send(self, entry: Mapping[str, Any]) -> bool:
level = str(entry.get("level", "warning") or "warning").lower()
level_rank = _LEVEL_RANK.get(level, _LEVEL_RANK["warning"])
threshold = str(getattr(self.settings, "level", "warning") or "warning").lower()
threshold_rank = _LEVEL_RANK.get(threshold, _LEVEL_RANK["warning"])
if level_rank < threshold_rank:
return False
channel_tags: Sequence[str] = getattr(self.settings, "tags", []) or []
if channel_tags:
event_tags = entry.get("tags") or []
if not isinstance(event_tags, Iterable):
event_tags = []
if not set(str(tag) for tag in event_tags).intersection(str(tag) for tag in channel_tags):
return False
cooldown = float(getattr(self.settings, "cooldown_seconds", 0.0) or 0.0)
signature = f"{entry.get('source')}|{entry.get('message')}|{level}"
if cooldown > 0:
now = time.monotonic()
with self._lock:
if self._last_signature == signature and (now - self._last_sent) < cooldown:
return False
self._last_signature = signature
self._last_sent = now
return True
def _build_payload(self, entry: Mapping[str, Any]) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"source": entry.get("source"),
"message": entry.get("message"),
"detail": entry.get("detail"),
"timestamp": entry.get("timestamp"),
"level": entry.get("level"),
}
tags = entry.get("tags")
if isinstance(tags, Iterable) and not isinstance(tags, (str, bytes)):
payload["tags"] = list(tags)
if "payload" in entry and entry["payload"] is not None:
payload["payload"] = entry["payload"]
extra_params = getattr(self.settings, "extra_params", None)
if isinstance(extra_params, Mapping):
for key, value in extra_params.items():
payload.setdefault(key, value)
return payload
def _deliver(self, payload: Mapping[str, Any]) -> None:
url = getattr(self.settings, "url", "")
if not url:
return
method = str(getattr(self.settings, "method", "POST") or "POST").upper()
timeout = float(getattr(self.settings, "timeout", 3.0) or 3.0)
headers: MutableMapping[str, str] = {}
raw_headers = getattr(self.settings, "headers", None)
if isinstance(raw_headers, Mapping):
headers = {str(k): str(v) for k, v in raw_headers.items()}
headers.setdefault("Content-Type", "application/json")
body = json.dumps(payload, ensure_ascii=False)
signing_secret = getattr(self.settings, "signing_secret", None)
if signing_secret:
import hashlib
import hmac
digest = hmac.new(
str(signing_secret).encode("utf-8"),
body.encode("utf-8"),
hashlib.sha256,
).hexdigest()
headers.setdefault("X-Signature", digest)
try:
requests.request(
method=method,
url=url,
data=body,
headers=headers,
timeout=timeout,
)
except Exception: # noqa: BLE001
LOGGER.exception("发送告警失败: channel=%s", self.name)
class AlertDispatcher:
"""Singleton-style dispatcher coordinating channel delivery."""
def __init__(self) -> None:
self._channels: Dict[str, _Channel] = {}
self._lock = threading.Lock()
def configure(self, configs: Mapping[str, AlertChannelSettings]) -> None:
active: Dict[str, _Channel] = {}
for key, cfg in configs.items():
if not cfg or not getattr(cfg, "enabled", True):
continue
if not getattr(cfg, "url", ""):
continue
channel = _Channel(cfg)
active[key] = channel
with self._lock:
self._channels = active
def dispatch(self, entry: Mapping[str, Any]) -> None:
if not self._channels:
return
for channel in list(self._channels.values()):
channel.send(entry)
_DISPATCHER = AlertDispatcher()
def get_dispatcher() -> AlertDispatcher:
return _DISPATCHER

View File

@ -1,46 +1,124 @@
"""Runtime data warning registry for surfacing ingestion issues in UI.""" """Runtime data warning registry with external dispatch support."""
from __future__ import annotations from __future__ import annotations
import logging
from datetime import datetime from datetime import datetime
from threading import Lock from threading import Lock
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Sequence
from .alert_dispatcher import get_dispatcher
_ALERTS: List[Dict[str, str]] = [] if TYPE_CHECKING: # pragma: no cover
from app.utils.config import AlertChannelSettings
LOGGER = logging.getLogger(__name__)
_ALERTS: List[Dict[str, Any]] = []
_SINKS: List[Callable[[Dict[str, Any]], None]] = []
_LOCK = Lock() _LOCK = Lock()
_MAX_ALERTS = 50
def add_warning(source: str, message: str, detail: Optional[str] = None) -> None: def configure_channels(channels: Mapping[str, "AlertChannelSettings"]) -> None:
"""Register or update a warning entry.""" """Configure external dispatch channels."""
source = source.strip() or "unknown" try:
message = message.strip() or "发生未知异常" get_dispatcher().configure(channels)
except Exception: # noqa: BLE001
LOGGER.debug("配置外部告警通道失败", exc_info=True)
def register_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
"""Attach an additional sink to receive alert payloads."""
with _LOCK:
if sink not in _SINKS:
_SINKS.append(sink)
def unregister_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
"""Detach a previously registered sink."""
with _LOCK:
_SINKS[:] = [existing for existing in _SINKS if existing is not sink]
def add_warning(
source: str,
message: str,
detail: Optional[str] = None,
*,
level: str = "warning",
tags: Optional[Sequence[str]] = None,
payload: Optional[Mapping[str, Any]] = None,
) -> None:
"""Register or update a warning entry and dispatch to sinks."""
source = (source or "").strip() or "unknown"
message = (message or "").strip() or "发生未知异常"
normalized_level = str(level or "warning").lower()
timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z" timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z"
normalized_tags: List[str] = []
if tags:
normalized_tags = [
str(tag).strip()
for tag in tags
if isinstance(tag, str) and tag.strip()
]
snapshot: Dict[str, Any] = {}
sinks: List[Callable[[Dict[str, Any]], None]] = []
with _LOCK: with _LOCK:
for alert in _ALERTS: for alert in _ALERTS:
if alert["source"] == source and alert["message"] == message: if alert["source"] == source and alert["message"] == message:
alert["timestamp"] = timestamp alert["timestamp"] = timestamp
alert["level"] = normalized_level
if detail: if detail:
alert["detail"] = detail alert["detail"] = detail
return if normalized_tags:
entry = { alert["tags"] = list(normalized_tags)
if payload is not None:
alert["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
snapshot = dict(alert)
sinks = list(_SINKS)
break
else:
entry: Dict[str, Any] = {
"source": source, "source": source,
"message": message, "message": message,
"timestamp": timestamp, "timestamp": timestamp,
"level": normalized_level,
} }
if detail: if detail:
entry["detail"] = detail entry["detail"] = detail
if normalized_tags:
entry["tags"] = list(normalized_tags)
if payload is not None:
entry["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
_ALERTS.append(entry) _ALERTS.append(entry)
if len(_ALERTS) > 50: if len(_ALERTS) > _MAX_ALERTS:
del _ALERTS[:-50] del _ALERTS[:-_MAX_ALERTS]
snapshot = dict(entry)
sinks = list(_SINKS)
for sink in sinks:
try:
sink(dict(snapshot))
except Exception: # noqa: BLE001
LOGGER.debug("执行告警 sink 失败:%s", getattr(sink, "__name__", sink), exc_info=True)
try:
get_dispatcher().dispatch(snapshot)
except Exception: # noqa: BLE001
LOGGER.debug("外部告警发送失败 source=%s", source, exc_info=True)
def get_warnings() -> List[Dict[str, str]]: def get_warnings() -> List[Dict[str, Any]]:
"""Return a copy of current warning entries.""" """Return a copy of current warning entries."""
with _LOCK: with _LOCK:
return list(_ALERTS) return [dict(alert) for alert in _ALERTS]
def clear_warnings(source: Optional[str] = None) -> None: def clear_warnings(source: Optional[str] = None) -> None:
@ -50,6 +128,5 @@ def clear_warnings(source: Optional[str] = None) -> None:
if source is None: if source is None:
_ALERTS.clear() _ALERTS.clear()
return return
source = source.strip() source_key = source.strip()
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source] _ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source_key]

View File

@ -58,6 +58,43 @@ class PortfolioSettings:
max_sector_exposure: float = 0.35 # 行业敞口上限 35% max_sector_exposure: float = 0.35 # 行业敞口上限 35%
@dataclass
class AlertChannelSettings:
"""Configuration for external alert delivery channels."""
key: str
kind: str = "webhook"
url: str = ""
enabled: bool = True
level: str = "warning"
tags: List[str] = field(default_factory=list)
headers: Dict[str, str] = field(default_factory=dict)
timeout: float = 3.0
method: str = "POST"
template: str = ""
signing_secret: Optional[str] = None
cooldown_seconds: float = 0.0
extra_params: Dict[str, object] = field(default_factory=dict)
def to_dict(self) -> Dict[str, object]:
payload: Dict[str, object] = {
"kind": self.kind,
"url": self.url,
"enabled": self.enabled,
"level": self.level,
"tags": list(self.tags),
"headers": dict(self.headers),
"timeout": self.timeout,
"method": self.method,
"template": self.template,
"cooldown_seconds": self.cooldown_seconds,
"extra_params": dict(self.extra_params),
}
if self.signing_secret:
payload["signing_secret"] = self.signing_secret
return payload
@dataclass @dataclass
class AgentWeights: class AgentWeights:
"""Default weighting for decision agents.""" """Default weighting for decision agents."""
@ -528,6 +565,7 @@ class AppConfig:
llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings) llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings) portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
alert_channels: Dict[str, AlertChannelSettings] = field(default_factory=dict)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
return self.llm return self.llm
@ -634,6 +672,52 @@ def _load_from_file(cfg: AppConfig) -> None:
) )
cfg.portfolio = updated_portfolio cfg.portfolio = updated_portfolio
alert_channels_payload = payload.get("alert_channels")
if isinstance(alert_channels_payload, dict):
channels: Dict[str, AlertChannelSettings] = {}
for key, data in alert_channels_payload.items():
if not isinstance(data, dict):
continue
normalized_key = str(key)
raw_tags = data.get("tags")
tags: List[str] = []
if isinstance(raw_tags, list):
tags = [
str(tag).strip()
for tag in raw_tags
if isinstance(tag, str) and tag.strip()
]
headers: Dict[str, str] = {}
raw_headers = data.get("headers")
if isinstance(raw_headers, Mapping):
headers = {
str(h_key): str(h_val)
for h_key, h_val in raw_headers.items()
if h_key is not None
}
extra_params: Dict[str, object] = {}
raw_extra = data.get("extra_params")
if isinstance(raw_extra, Mapping):
extra_params = dict(raw_extra)
channel = AlertChannelSettings(
key=normalized_key,
kind=str(data.get("kind") or "webhook"),
url=str(data.get("url") or ""),
enabled=bool(data.get("enabled", True)),
level=str(data.get("level") or "warning"),
tags=tags,
headers=headers,
timeout=float(data.get("timeout", 3.0) or 3.0),
method=str(data.get("method") or "POST").upper(),
template=str(data.get("template") or ""),
signing_secret=str(data.get("signing_secret")) if data.get("signing_secret") else None,
cooldown_seconds=float(data.get("cooldown_seconds", 0.0) or 0.0),
extra_params=extra_params,
)
if channel.url:
channels[channel.key] = channel
cfg.alert_channels = channels
cost_payload = payload.get("llm_cost") cost_payload = payload.get("llm_cost")
if isinstance(cost_payload, dict): if isinstance(cost_payload, dict):
cfg.llm_cost.update_from_dict(cost_payload) cfg.llm_cost.update_from_dict(cost_payload)
@ -865,6 +949,10 @@ def save_config(cfg: AppConfig | None = None) -> None:
"max_sector_exposure": cfg.portfolio.max_sector_exposure, "max_sector_exposure": cfg.portfolio.max_sector_exposure,
}, },
}, },
"alert_channels": {
name: channel.to_dict()
for name, channel in cfg.alert_channels.items()
},
"llm": { "llm": {
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": cfg.llm.majority_threshold, "majority_threshold": cfg.llm.majority_threshold,
@ -918,6 +1006,13 @@ def save_config(cfg: AppConfig | None = None) -> None:
LOGGER.info("配置已写入:%s", path) LOGGER.info("配置已写入:%s", path)
except OSError: except OSError:
LOGGER.exception("配置写入失败:%s", path) LOGGER.exception("配置写入失败:%s", path)
return
try:
from app.utils import alerts as _alerts # 延迟导入以避免循环
_alerts.configure_channels(cfg.alert_channels)
except Exception: # noqa: BLE001
LOGGER.debug("更新告警通道失败", exc_info=True)
def _load_env_defaults(cfg: AppConfig) -> None: def _load_env_defaults(cfg: AppConfig) -> None:
@ -935,12 +1030,34 @@ def _load_env_defaults(cfg: AppConfig) -> None:
if provider_cfg: if provider_cfg:
provider_cfg.api_key = sanitized provider_cfg.api_key = sanitized
webhook = os.getenv("LLM_QUANT_ALERT_WEBHOOK")
if webhook:
key = "env_webhook"
channel = AlertChannelSettings(
key=key,
kind="webhook",
url=webhook.strip(),
headers={"Content-Type": "application/json"},
enabled=True,
level=str(os.getenv("LLM_QUANT_ALERT_LEVEL", "warning") or "warning"),
)
tags_raw = os.getenv("LLM_QUANT_ALERT_TAGS")
if tags_raw:
channel.tags = [tag.strip() for tag in tags_raw.split(",") if tag.strip()]
cfg.alert_channels[key] = channel
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
_load_from_file(CONFIG) _load_from_file(CONFIG)
_load_env_defaults(CONFIG) _load_env_defaults(CONFIG)
try:
from app.utils import alerts as _alerts_module # 延迟导入避免循环依赖
_alerts_module.configure_channels(CONFIG.alert_channels)
except Exception: # noqa: BLE001
LOGGER.debug("初始化告警通道失败", exc_info=True)
def get_config() -> AppConfig: def get_config() -> AppConfig:
"""Return a mutable global configuration instance.""" """Return a mutable global configuration instance."""

View File

@ -50,9 +50,9 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 | | 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 |
| 风险事件持久化 | 🔄 | 已形成 `risk_round``risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 | | 风险事件持久化 | ✅ | `risk_round` 事件现已批量写入 `bt_risk_events`,附带风险状态/元数据并在回测面板支持 Drill-down。 |
| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 | | 实时告警接入 | ✅ | 引入告警分发器支持配置化 webhook 通道,风险阻断/复核即时推送至外部渠道。 |
| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 | | 风险场景测试 | ✅ | 新增停牌、仓位超限、黑名单等集成测试覆盖,验证风险闭环执行。 |
## 测试与质量保障 ## 测试与质量保障

View File

@ -0,0 +1,70 @@
"""Tests for alert dispatcher configuration and delivery."""
from __future__ import annotations
import json
from app.utils import alerts
from app.utils.config import AlertChannelSettings
def test_alert_dispatcher_posts_payload(monkeypatch):
calls: list[dict[str, object]] = []
def fake_request(*, method, url, data=None, headers=None, timeout=None):
calls.append(
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"timeout": timeout,
}
)
class _Resp:
status_code = 200
return _Resp()
monkeypatch.setattr("app.utils.alert_dispatcher.requests.request", fake_request)
alerts.clear_warnings()
alerts.configure_channels(
{
"ops": AlertChannelSettings(
key="ops",
kind="webhook",
url="https://example.com/webhook",
enabled=True,
level="info",
headers={"X-Test": "1"},
extra_params={"channel": "risk"},
)
}
)
alerts.add_warning(
"risk_system",
"阻断测试",
detail="blocked",
level="error",
tags=["risk", "blocked"],
payload={"reason": "blocked"},
)
assert calls, "expected dispatcher to send webhook call"
call = calls[0]
assert call["method"] == "POST"
assert call["url"] == "https://example.com/webhook"
assert call["headers"]["X-Test"] == "1"
payload = json.loads(call["data"])
assert payload["message"] == "阻断测试"
assert payload["channel"] == "risk"
assert payload["payload"]["reason"] == "blocked"
warnings = alerts.get_warnings()
assert warnings
assert warnings[0]["level"] == "error"
assert "blocked" in warnings[0].get("tags", [])
alerts.configure_channels({})

View File

@ -8,7 +8,7 @@ import pytest
import json import json
from app.agents.base import AgentAction, AgentContext from app.agents.base import AgentAction, AgentContext
from app.agents.game import Decision from app.agents.game import Decision, RiskAssessment
from app.backtest.engine import ( from app.backtest.engine import (
BacktestEngine, BacktestEngine,
BacktestResult, BacktestResult,
@ -18,6 +18,7 @@ from app.backtest.engine import (
) )
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config from app.utils.config import DataPaths, get_config
from app.utils import alerts
from app.utils.db import db_session from app.utils.db import db_session
@ -121,6 +122,79 @@ def test_buy_blocked_by_limit_up_records_risk():
assert result.risk_events[0]["reason"] == "limit_up" assert result.risk_events[0]["reason"] == "limit_up"
def test_position_limit_triggers_risk_event_and_adjusts_execution():
alerts.clear_warnings()
engine = _engine_with_params(
{
"max_position_weight": 0.3,
"fee_rate": 0.0,
"slippage_bps": 0.0,
"max_daily_turnover_ratio": 1.0,
}
)
state = PortfolioState(cash=100_000.0)
result = BacktestResult()
context = _make_context(100.0, {"position_limit": True})
decision = _make_decision(AgentAction.BUY_L, target_weight=0.4)
decision.risk_assessment = RiskAssessment(
status="pending_review",
reason="position_limit",
recommended_action=AgentAction.BUY_S,
notes={"trigger": "position_limit"},
)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert not result.trades, "position limit should block execution despite adjustment"
assert result.risk_events
event_with_status = next(
(event for event in result.risk_events if event.get("risk_status")),
None,
)
assert event_with_status is not None
assert event_with_status["reason"] == "position_limit"
assert event_with_status.get("risk_status") == "pending_review"
warning_messages = [item["message"] for item in alerts.get_warnings()]
assert any("风险提示" in msg for msg in warning_messages)
alerts.clear_warnings()
def test_blacklist_blocks_execution_and_warns():
alerts.clear_warnings()
engine = _engine_with_params({})
state = PortfolioState(cash=50_000.0)
result = BacktestResult()
context = _make_context(100.0, {"is_blacklisted": True})
decision = _make_decision(AgentAction.BUY_M, target_weight=0.2)
decision.risk_assessment = RiskAssessment(
status="blocked",
reason="blacklist",
recommended_action=AgentAction.HOLD,
notes={"trigger": "is_blacklisted"},
)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert not result.trades
assert result.risk_events
event = result.risk_events[0]
assert event["reason"] == "blacklist"
assert event.get("risk_status") == "blocked"
warning_messages = [item["message"] for item in alerts.get_warnings()]
assert any("风险阻断" in msg for msg in warning_messages)
alerts.clear_warnings()
def test_sell_applies_slippage_and_fee(): def test_sell_applies_slippage_and_fee():
engine = _engine_with_params( engine = _engine_with_params(
{ {

View File

@ -117,3 +117,26 @@ def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch):
assert not state.holdings assert not state.holdings
assert not result.trades assert not result.trades
assert result.nav_series[0]["nav"] == pytest.approx(100_000.0) assert result.nav_series[0]["nav"] == pytest.approx(100_000.0)
def test_decide_records_suspension_risk_round():
agents = default_agents()
context = _make_context({"is_suspended": True})
decision = decide(
context,
agents,
weights={agent.name: 1.0 for agent in agents},
department_manager=None,
)
assert decision.requires_review is True
assert decision.risk_assessment
assert decision.risk_assessment.status == "blocked"
assert decision.risk_assessment.reason == "suspended"
risk_rounds = [summary for summary in decision.rounds if summary.agenda == "risk_review"]
assert risk_rounds
notes = risk_rounds[0].notes
assert notes.get("status") == "blocked"
assert notes.get("reason") == "suspended"

View File

@ -41,6 +41,23 @@ def test_risk_agent_pending_on_conflict() -> None:
assert recommendation.reason == "conflict_threshold" assert recommendation.reason == "conflict_threshold"
def test_risk_agent_blocks_on_blacklist() -> None:
agent = RiskAgent()
context = _make_context(is_blacklisted=True)
recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False)
assert recommendation.status == "blocked"
assert recommendation.reason == "blacklist"
assert recommendation.recommended_action == AgentAction.HOLD
def test_blacklist_constraints_buy_feasibility() -> None:
agent = RiskAgent()
context = _make_context(is_blacklisted=True)
assert agent.feasible(context, AgentAction.SELL) is True
assert agent.feasible(context, AgentAction.HOLD) is True
assert agent.feasible(context, AgentAction.BUY_S) is False
def test_evaluate_risk_external_alerts() -> None: def test_evaluate_risk_external_alerts() -> None:
agent = RiskAgent() agent = RiskAgent()
context = AgentContext( context = AgentContext(