add blacklist risk control and alert dispatching

This commit is contained in:
sam 2025-10-17 10:52:49 +08:00
parent d2a056d7c0
commit 1ca2f2be19
11 changed files with 614 additions and 32 deletions

View File

@ -660,6 +660,7 @@ def _risk_review_message(reason: str) -> str:
"risk_penalty_extreme": "风险评分极高,建议暂停加仓",
"risk_penalty_high": "风险评分偏高,建议复核",
"external_alert": "外部风险告警触发复核",
"blacklist": "标的命中黑名单,禁止交易",
}
return mapping.get(reason, "触发风险复核,需人工确认")

View File

@ -49,6 +49,8 @@ class RiskAgent(Agent):
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
if action is AgentAction.SELL:
return True
if context.features.get("is_blacklisted", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
return False
if context.features.get("is_suspended", False):
return False
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
@ -74,6 +76,15 @@ class RiskAgent(Agent):
notes={"trigger": "is_suspended"},
)
if bool(features.get("is_blacklisted")):
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
return RiskRecommendation(
status="blocked",
reason="blacklist",
recommended_action=fallback,
notes={"trigger": "is_blacklisted"},
)
if bool(features.get("limit_up")) and decision_action in {
AgentAction.BUY_S,
AgentAction.BUY_M,

View File

@ -698,6 +698,7 @@ class BacktestEngine:
action_override: Optional[AgentAction] = None,
target_weight_override: Optional[float] = None,
) -> None:
reason_str = str(reason)
payload = {
"trade_date": trade_date_str,
"ts_code": ts_code,
@ -708,21 +709,43 @@ class BacktestEngine:
else decision.target_weight
),
"confidence": decision.confidence,
"reason": reason,
"reason": reason_str,
}
if extra:
payload.update(extra)
risk_events.append(payload)
risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None
status = None
risk_meta = None
if isinstance(payload.get("risk_assessment"), dict):
risk_meta = payload.get("risk_assessment")
elif extra and isinstance(extra.get("risk_assessment"), dict):
risk_meta = extra.get("risk_assessment")
status: Optional[str] = None
if isinstance(risk_meta, dict):
status = risk_meta.get("status")
status = str(risk_meta.get("status") or "")
payload.setdefault("risk_status", status)
if status == "blocked":
try:
message = f"{ts_code} 风险阻断: {reason_str}"
alerts.add_warning(
"backtest_risk",
f"{ts_code} 风险阻断: {reason}",
message,
detail=json.dumps(payload, ensure_ascii=False),
level="error",
tags=["risk", reason_str, status],
payload=payload,
)
except Exception: # noqa: BLE001
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
elif status and status != "ok":
try:
message = f"{ts_code} 风险提示: {reason_str}"
alerts.add_warning(
"backtest_risk",
message,
detail=json.dumps(payload, ensure_ascii=False),
level="warning",
tags=["risk", reason_str, status],
payload=payload,
)
except Exception: # noqa: BLE001
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)

View File

@ -0,0 +1,169 @@
"""Dispatch structured alerts to external channels."""
from __future__ import annotations
import json
import logging
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence
import requests
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: # pragma: no cover
from app.utils.config import AlertChannelSettings
else:
AlertChannelSettings = Any # type: ignore[assignment]
_LEVEL_RANK: Dict[str, int] = {
"debug": 10,
"info": 20,
"warning": 30,
"error": 40,
"critical": 50,
}
class _Channel:
"""Runtime wrapper around a configured alert channel."""
__slots__ = (
"settings",
"_lock",
"_last_signature",
"_last_sent",
)
def __init__(self, settings: AlertChannelSettings) -> None:
self.settings = settings
self._lock = threading.Lock()
self._last_signature: Optional[str] = None
self._last_sent: float = 0.0
@property
def name(self) -> str:
return getattr(self.settings, "key", "channel")
def send(self, entry: Mapping[str, Any]) -> None:
if not self._should_send(entry):
return
payload = self._build_payload(entry)
self._deliver(payload)
def _should_send(self, entry: Mapping[str, Any]) -> bool:
level = str(entry.get("level", "warning") or "warning").lower()
level_rank = _LEVEL_RANK.get(level, _LEVEL_RANK["warning"])
threshold = str(getattr(self.settings, "level", "warning") or "warning").lower()
threshold_rank = _LEVEL_RANK.get(threshold, _LEVEL_RANK["warning"])
if level_rank < threshold_rank:
return False
channel_tags: Sequence[str] = getattr(self.settings, "tags", []) or []
if channel_tags:
event_tags = entry.get("tags") or []
if not isinstance(event_tags, Iterable):
event_tags = []
if not set(str(tag) for tag in event_tags).intersection(str(tag) for tag in channel_tags):
return False
cooldown = float(getattr(self.settings, "cooldown_seconds", 0.0) or 0.0)
signature = f"{entry.get('source')}|{entry.get('message')}|{level}"
if cooldown > 0:
now = time.monotonic()
with self._lock:
if self._last_signature == signature and (now - self._last_sent) < cooldown:
return False
self._last_signature = signature
self._last_sent = now
return True
def _build_payload(self, entry: Mapping[str, Any]) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"source": entry.get("source"),
"message": entry.get("message"),
"detail": entry.get("detail"),
"timestamp": entry.get("timestamp"),
"level": entry.get("level"),
}
tags = entry.get("tags")
if isinstance(tags, Iterable) and not isinstance(tags, (str, bytes)):
payload["tags"] = list(tags)
if "payload" in entry and entry["payload"] is not None:
payload["payload"] = entry["payload"]
extra_params = getattr(self.settings, "extra_params", None)
if isinstance(extra_params, Mapping):
for key, value in extra_params.items():
payload.setdefault(key, value)
return payload
def _deliver(self, payload: Mapping[str, Any]) -> None:
url = getattr(self.settings, "url", "")
if not url:
return
method = str(getattr(self.settings, "method", "POST") or "POST").upper()
timeout = float(getattr(self.settings, "timeout", 3.0) or 3.0)
headers: MutableMapping[str, str] = {}
raw_headers = getattr(self.settings, "headers", None)
if isinstance(raw_headers, Mapping):
headers = {str(k): str(v) for k, v in raw_headers.items()}
headers.setdefault("Content-Type", "application/json")
body = json.dumps(payload, ensure_ascii=False)
signing_secret = getattr(self.settings, "signing_secret", None)
if signing_secret:
import hashlib
import hmac
digest = hmac.new(
str(signing_secret).encode("utf-8"),
body.encode("utf-8"),
hashlib.sha256,
).hexdigest()
headers.setdefault("X-Signature", digest)
try:
requests.request(
method=method,
url=url,
data=body,
headers=headers,
timeout=timeout,
)
except Exception: # noqa: BLE001
LOGGER.exception("发送告警失败: channel=%s", self.name)
class AlertDispatcher:
"""Singleton-style dispatcher coordinating channel delivery."""
def __init__(self) -> None:
self._channels: Dict[str, _Channel] = {}
self._lock = threading.Lock()
def configure(self, configs: Mapping[str, AlertChannelSettings]) -> None:
active: Dict[str, _Channel] = {}
for key, cfg in configs.items():
if not cfg or not getattr(cfg, "enabled", True):
continue
if not getattr(cfg, "url", ""):
continue
channel = _Channel(cfg)
active[key] = channel
with self._lock:
self._channels = active
def dispatch(self, entry: Mapping[str, Any]) -> None:
if not self._channels:
return
for channel in list(self._channels.values()):
channel.send(entry)
_DISPATCHER = AlertDispatcher()
def get_dispatcher() -> AlertDispatcher:
return _DISPATCHER

View File

@ -1,46 +1,124 @@
"""Runtime data warning registry for surfacing ingestion issues in UI."""
"""Runtime data warning registry with external dispatch support."""
from __future__ import annotations
import logging
from datetime import datetime
from threading import Lock
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Sequence
from .alert_dispatcher import get_dispatcher
_ALERTS: List[Dict[str, str]] = []
if TYPE_CHECKING: # pragma: no cover
from app.utils.config import AlertChannelSettings
LOGGER = logging.getLogger(__name__)
_ALERTS: List[Dict[str, Any]] = []
_SINKS: List[Callable[[Dict[str, Any]], None]] = []
_LOCK = Lock()
_MAX_ALERTS = 50
def add_warning(source: str, message: str, detail: Optional[str] = None) -> None:
"""Register or update a warning entry."""
def configure_channels(channels: Mapping[str, "AlertChannelSettings"]) -> None:
"""Configure external dispatch channels."""
source = source.strip() or "unknown"
message = message.strip() or "发生未知异常"
try:
get_dispatcher().configure(channels)
except Exception: # noqa: BLE001
LOGGER.debug("配置外部告警通道失败", exc_info=True)
def register_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
"""Attach an additional sink to receive alert payloads."""
with _LOCK:
if sink not in _SINKS:
_SINKS.append(sink)
def unregister_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
"""Detach a previously registered sink."""
with _LOCK:
_SINKS[:] = [existing for existing in _SINKS if existing is not sink]
def add_warning(
source: str,
message: str,
detail: Optional[str] = None,
*,
level: str = "warning",
tags: Optional[Sequence[str]] = None,
payload: Optional[Mapping[str, Any]] = None,
) -> None:
"""Register or update a warning entry and dispatch to sinks."""
source = (source or "").strip() or "unknown"
message = (message or "").strip() or "发生未知异常"
normalized_level = str(level or "warning").lower()
timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z"
normalized_tags: List[str] = []
if tags:
normalized_tags = [
str(tag).strip()
for tag in tags
if isinstance(tag, str) and tag.strip()
]
snapshot: Dict[str, Any] = {}
sinks: List[Callable[[Dict[str, Any]], None]] = []
with _LOCK:
for alert in _ALERTS:
if alert["source"] == source and alert["message"] == message:
alert["timestamp"] = timestamp
alert["level"] = normalized_level
if detail:
alert["detail"] = detail
return
entry = {
if normalized_tags:
alert["tags"] = list(normalized_tags)
if payload is not None:
alert["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
snapshot = dict(alert)
sinks = list(_SINKS)
break
else:
entry: Dict[str, Any] = {
"source": source,
"message": message,
"timestamp": timestamp,
"level": normalized_level,
}
if detail:
entry["detail"] = detail
if normalized_tags:
entry["tags"] = list(normalized_tags)
if payload is not None:
entry["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
_ALERTS.append(entry)
if len(_ALERTS) > 50:
del _ALERTS[:-50]
if len(_ALERTS) > _MAX_ALERTS:
del _ALERTS[:-_MAX_ALERTS]
snapshot = dict(entry)
sinks = list(_SINKS)
for sink in sinks:
try:
sink(dict(snapshot))
except Exception: # noqa: BLE001
LOGGER.debug("执行告警 sink 失败:%s", getattr(sink, "__name__", sink), exc_info=True)
try:
get_dispatcher().dispatch(snapshot)
except Exception: # noqa: BLE001
LOGGER.debug("外部告警发送失败 source=%s", source, exc_info=True)
def get_warnings() -> List[Dict[str, str]]:
def get_warnings() -> List[Dict[str, Any]]:
"""Return a copy of current warning entries."""
with _LOCK:
return list(_ALERTS)
return [dict(alert) for alert in _ALERTS]
def clear_warnings(source: Optional[str] = None) -> None:
@ -50,6 +128,5 @@ def clear_warnings(source: Optional[str] = None) -> None:
if source is None:
_ALERTS.clear()
return
source = source.strip()
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source]
source_key = source.strip()
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source_key]

View File

@ -58,6 +58,43 @@ class PortfolioSettings:
max_sector_exposure: float = 0.35 # 行业敞口上限 35%
@dataclass
class AlertChannelSettings:
"""Configuration for external alert delivery channels."""
key: str
kind: str = "webhook"
url: str = ""
enabled: bool = True
level: str = "warning"
tags: List[str] = field(default_factory=list)
headers: Dict[str, str] = field(default_factory=dict)
timeout: float = 3.0
method: str = "POST"
template: str = ""
signing_secret: Optional[str] = None
cooldown_seconds: float = 0.0
extra_params: Dict[str, object] = field(default_factory=dict)
def to_dict(self) -> Dict[str, object]:
payload: Dict[str, object] = {
"kind": self.kind,
"url": self.url,
"enabled": self.enabled,
"level": self.level,
"tags": list(self.tags),
"headers": dict(self.headers),
"timeout": self.timeout,
"method": self.method,
"template": self.template,
"cooldown_seconds": self.cooldown_seconds,
"extra_params": dict(self.extra_params),
}
if self.signing_secret:
payload["signing_secret"] = self.signing_secret
return payload
@dataclass
class AgentWeights:
"""Default weighting for decision agents."""
@ -528,6 +565,7 @@ class AppConfig:
llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
alert_channels: Dict[str, AlertChannelSettings] = field(default_factory=dict)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
return self.llm
@ -634,6 +672,52 @@ def _load_from_file(cfg: AppConfig) -> None:
)
cfg.portfolio = updated_portfolio
alert_channels_payload = payload.get("alert_channels")
if isinstance(alert_channels_payload, dict):
channels: Dict[str, AlertChannelSettings] = {}
for key, data in alert_channels_payload.items():
if not isinstance(data, dict):
continue
normalized_key = str(key)
raw_tags = data.get("tags")
tags: List[str] = []
if isinstance(raw_tags, list):
tags = [
str(tag).strip()
for tag in raw_tags
if isinstance(tag, str) and tag.strip()
]
headers: Dict[str, str] = {}
raw_headers = data.get("headers")
if isinstance(raw_headers, Mapping):
headers = {
str(h_key): str(h_val)
for h_key, h_val in raw_headers.items()
if h_key is not None
}
extra_params: Dict[str, object] = {}
raw_extra = data.get("extra_params")
if isinstance(raw_extra, Mapping):
extra_params = dict(raw_extra)
channel = AlertChannelSettings(
key=normalized_key,
kind=str(data.get("kind") or "webhook"),
url=str(data.get("url") or ""),
enabled=bool(data.get("enabled", True)),
level=str(data.get("level") or "warning"),
tags=tags,
headers=headers,
timeout=float(data.get("timeout", 3.0) or 3.0),
method=str(data.get("method") or "POST").upper(),
template=str(data.get("template") or ""),
signing_secret=str(data.get("signing_secret")) if data.get("signing_secret") else None,
cooldown_seconds=float(data.get("cooldown_seconds", 0.0) or 0.0),
extra_params=extra_params,
)
if channel.url:
channels[channel.key] = channel
cfg.alert_channels = channels
cost_payload = payload.get("llm_cost")
if isinstance(cost_payload, dict):
cfg.llm_cost.update_from_dict(cost_payload)
@ -865,6 +949,10 @@ def save_config(cfg: AppConfig | None = None) -> None:
"max_sector_exposure": cfg.portfolio.max_sector_exposure,
},
},
"alert_channels": {
name: channel.to_dict()
for name, channel in cfg.alert_channels.items()
},
"llm": {
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": cfg.llm.majority_threshold,
@ -918,6 +1006,13 @@ def save_config(cfg: AppConfig | None = None) -> None:
LOGGER.info("配置已写入:%s", path)
except OSError:
LOGGER.exception("配置写入失败:%s", path)
return
try:
from app.utils import alerts as _alerts # 延迟导入以避免循环
_alerts.configure_channels(cfg.alert_channels)
except Exception: # noqa: BLE001
LOGGER.debug("更新告警通道失败", exc_info=True)
def _load_env_defaults(cfg: AppConfig) -> None:
@ -935,12 +1030,34 @@ def _load_env_defaults(cfg: AppConfig) -> None:
if provider_cfg:
provider_cfg.api_key = sanitized
webhook = os.getenv("LLM_QUANT_ALERT_WEBHOOK")
if webhook:
key = "env_webhook"
channel = AlertChannelSettings(
key=key,
kind="webhook",
url=webhook.strip(),
headers={"Content-Type": "application/json"},
enabled=True,
level=str(os.getenv("LLM_QUANT_ALERT_LEVEL", "warning") or "warning"),
)
tags_raw = os.getenv("LLM_QUANT_ALERT_TAGS")
if tags_raw:
channel.tags = [tag.strip() for tag in tags_raw.split(",") if tag.strip()]
cfg.alert_channels[key] = channel
cfg.sync_runtime_llm()
_load_from_file(CONFIG)
_load_env_defaults(CONFIG)
try:
from app.utils import alerts as _alerts_module # 延迟导入避免循环依赖
_alerts_module.configure_channels(CONFIG.alert_channels)
except Exception: # noqa: BLE001
LOGGER.debug("初始化告警通道失败", exc_info=True)
def get_config() -> AppConfig:
"""Return a mutable global configuration instance."""

View File

@ -50,9 +50,9 @@
| 工作项 | 状态 | 说明 |
| --- | --- | --- |
| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 |
| 风险事件持久化 | 🔄 | 已形成 `risk_round``risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 |
| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 |
| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 |
| 风险事件持久化 | ✅ | `risk_round` 事件现已批量写入 `bt_risk_events`,附带风险状态/元数据并在回测面板支持 Drill-down。 |
| 实时告警接入 | ✅ | 引入告警分发器支持配置化 webhook 通道,风险阻断/复核即时推送至外部渠道。 |
| 风险场景测试 | ✅ | 新增停牌、仓位超限、黑名单等集成测试覆盖,验证风险闭环执行。 |
## 测试与质量保障

View File

@ -0,0 +1,70 @@
"""Tests for alert dispatcher configuration and delivery."""
from __future__ import annotations
import json
from app.utils import alerts
from app.utils.config import AlertChannelSettings
def test_alert_dispatcher_posts_payload(monkeypatch):
calls: list[dict[str, object]] = []
def fake_request(*, method, url, data=None, headers=None, timeout=None):
calls.append(
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"timeout": timeout,
}
)
class _Resp:
status_code = 200
return _Resp()
monkeypatch.setattr("app.utils.alert_dispatcher.requests.request", fake_request)
alerts.clear_warnings()
alerts.configure_channels(
{
"ops": AlertChannelSettings(
key="ops",
kind="webhook",
url="https://example.com/webhook",
enabled=True,
level="info",
headers={"X-Test": "1"},
extra_params={"channel": "risk"},
)
}
)
alerts.add_warning(
"risk_system",
"阻断测试",
detail="blocked",
level="error",
tags=["risk", "blocked"],
payload={"reason": "blocked"},
)
assert calls, "expected dispatcher to send webhook call"
call = calls[0]
assert call["method"] == "POST"
assert call["url"] == "https://example.com/webhook"
assert call["headers"]["X-Test"] == "1"
payload = json.loads(call["data"])
assert payload["message"] == "阻断测试"
assert payload["channel"] == "risk"
assert payload["payload"]["reason"] == "blocked"
warnings = alerts.get_warnings()
assert warnings
assert warnings[0]["level"] == "error"
assert "blocked" in warnings[0].get("tags", [])
alerts.configure_channels({})

View File

@ -8,7 +8,7 @@ import pytest
import json
from app.agents.base import AgentAction, AgentContext
from app.agents.game import Decision
from app.agents.game import Decision, RiskAssessment
from app.backtest.engine import (
BacktestEngine,
BacktestResult,
@ -18,6 +18,7 @@ from app.backtest.engine import (
)
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
from app.utils import alerts
from app.utils.db import db_session
@ -121,6 +122,79 @@ def test_buy_blocked_by_limit_up_records_risk():
assert result.risk_events[0]["reason"] == "limit_up"
def test_position_limit_triggers_risk_event_and_adjusts_execution():
alerts.clear_warnings()
engine = _engine_with_params(
{
"max_position_weight": 0.3,
"fee_rate": 0.0,
"slippage_bps": 0.0,
"max_daily_turnover_ratio": 1.0,
}
)
state = PortfolioState(cash=100_000.0)
result = BacktestResult()
context = _make_context(100.0, {"position_limit": True})
decision = _make_decision(AgentAction.BUY_L, target_weight=0.4)
decision.risk_assessment = RiskAssessment(
status="pending_review",
reason="position_limit",
recommended_action=AgentAction.BUY_S,
notes={"trigger": "position_limit"},
)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert not result.trades, "position limit should block execution despite adjustment"
assert result.risk_events
event_with_status = next(
(event for event in result.risk_events if event.get("risk_status")),
None,
)
assert event_with_status is not None
assert event_with_status["reason"] == "position_limit"
assert event_with_status.get("risk_status") == "pending_review"
warning_messages = [item["message"] for item in alerts.get_warnings()]
assert any("风险提示" in msg for msg in warning_messages)
alerts.clear_warnings()
def test_blacklist_blocks_execution_and_warns():
alerts.clear_warnings()
engine = _engine_with_params({})
state = PortfolioState(cash=50_000.0)
result = BacktestResult()
context = _make_context(100.0, {"is_blacklisted": True})
decision = _make_decision(AgentAction.BUY_M, target_weight=0.2)
decision.risk_assessment = RiskAssessment(
status="blocked",
reason="blacklist",
recommended_action=AgentAction.HOLD,
notes={"trigger": "is_blacklisted"},
)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert not result.trades
assert result.risk_events
event = result.risk_events[0]
assert event["reason"] == "blacklist"
assert event.get("risk_status") == "blocked"
warning_messages = [item["message"] for item in alerts.get_warnings()]
assert any("风险阻断" in msg for msg in warning_messages)
alerts.clear_warnings()
def test_sell_applies_slippage_and_fee():
engine = _engine_with_params(
{

View File

@ -117,3 +117,26 @@ def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch):
assert not state.holdings
assert not result.trades
assert result.nav_series[0]["nav"] == pytest.approx(100_000.0)
def test_decide_records_suspension_risk_round():
agents = default_agents()
context = _make_context({"is_suspended": True})
decision = decide(
context,
agents,
weights={agent.name: 1.0 for agent in agents},
department_manager=None,
)
assert decision.requires_review is True
assert decision.risk_assessment
assert decision.risk_assessment.status == "blocked"
assert decision.risk_assessment.reason == "suspended"
risk_rounds = [summary for summary in decision.rounds if summary.agenda == "risk_review"]
assert risk_rounds
notes = risk_rounds[0].notes
assert notes.get("status") == "blocked"
assert notes.get("reason") == "suspended"

View File

@ -41,6 +41,23 @@ def test_risk_agent_pending_on_conflict() -> None:
assert recommendation.reason == "conflict_threshold"
def test_risk_agent_blocks_on_blacklist() -> None:
agent = RiskAgent()
context = _make_context(is_blacklisted=True)
recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False)
assert recommendation.status == "blocked"
assert recommendation.reason == "blacklist"
assert recommendation.recommended_action == AgentAction.HOLD
def test_blacklist_constraints_buy_feasibility() -> None:
agent = RiskAgent()
context = _make_context(is_blacklisted=True)
assert agent.feasible(context, AgentAction.SELL) is True
assert agent.feasible(context, AgentAction.HOLD) is True
assert agent.feasible(context, AgentAction.BUY_S) is False
def test_evaluate_risk_external_alerts() -> None:
agent = RiskAgent()
context = AgentContext(