diff --git a/app/agents/game.py b/app/agents/game.py index 597de6b..e2ac6fb 100644 --- a/app/agents/game.py +++ b/app/agents/game.py @@ -660,6 +660,7 @@ def _risk_review_message(reason: str) -> str: "risk_penalty_extreme": "风险评分极高,建议暂停加仓", "risk_penalty_high": "风险评分偏高,建议复核", "external_alert": "外部风险告警触发复核", + "blacklist": "标的命中黑名单,禁止交易", } return mapping.get(reason, "触发风险复核,需人工确认") diff --git a/app/agents/risk.py b/app/agents/risk.py index e78f71d..8ac639d 100644 --- a/app/agents/risk.py +++ b/app/agents/risk.py @@ -49,6 +49,8 @@ class RiskAgent(Agent): def feasible(self, context: AgentContext, action: AgentAction) -> bool: if action is AgentAction.SELL: return True + if context.features.get("is_blacklisted", False) and action not in (AgentAction.SELL, AgentAction.HOLD): + return False if context.features.get("is_suspended", False): return False if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD): @@ -74,6 +76,15 @@ class RiskAgent(Agent): notes={"trigger": "is_suspended"}, ) + if bool(features.get("is_blacklisted")): + fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD + return RiskRecommendation( + status="blocked", + reason="blacklist", + recommended_action=fallback, + notes={"trigger": "is_blacklisted"}, + ) + if bool(features.get("limit_up")) and decision_action in { AgentAction.BUY_S, AgentAction.BUY_M, diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 079cdc5..7ded11f 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -698,6 +698,7 @@ class BacktestEngine: action_override: Optional[AgentAction] = None, target_weight_override: Optional[float] = None, ) -> None: + reason_str = str(reason) payload = { "trade_date": trade_date_str, "ts_code": ts_code, @@ -708,21 +709,43 @@ class BacktestEngine: else decision.target_weight ), "confidence": decision.confidence, - "reason": reason, + "reason": reason_str, } if extra: payload.update(extra) risk_events.append(payload) - risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None - status = None + risk_meta = None + if isinstance(payload.get("risk_assessment"), dict): + risk_meta = payload.get("risk_assessment") + elif extra and isinstance(extra.get("risk_assessment"), dict): + risk_meta = extra.get("risk_assessment") + status: Optional[str] = None if isinstance(risk_meta, dict): - status = risk_meta.get("status") + status = str(risk_meta.get("status") or "") + payload.setdefault("risk_status", status) if status == "blocked": try: + message = f"{ts_code} 风险阻断: {reason_str}" alerts.add_warning( "backtest_risk", - f"{ts_code} 风险阻断: {reason}", + message, detail=json.dumps(payload, ensure_ascii=False), + level="error", + tags=["risk", reason_str, status], + payload=payload, + ) + except Exception: # noqa: BLE001 + LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA) + elif status and status != "ok": + try: + message = f"{ts_code} 风险提示: {reason_str}" + alerts.add_warning( + "backtest_risk", + message, + detail=json.dumps(payload, ensure_ascii=False), + level="warning", + tags=["risk", reason_str, status], + payload=payload, ) except Exception: # noqa: BLE001 LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA) diff --git a/app/utils/alert_dispatcher.py b/app/utils/alert_dispatcher.py new file mode 100644 index 0000000..9d95745 --- /dev/null +++ b/app/utils/alert_dispatcher.py @@ -0,0 +1,169 @@ +"""Dispatch structured alerts to external channels.""" +from __future__ import annotations + +import json +import logging +import threading +import time +from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence + +import requests + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: # pragma: no cover + from app.utils.config import AlertChannelSettings +else: + AlertChannelSettings = Any # type: ignore[assignment] + +_LEVEL_RANK: Dict[str, int] = { + "debug": 10, + "info": 20, + "warning": 30, + "error": 40, + "critical": 50, +} + + +class _Channel: + """Runtime wrapper around a configured alert channel.""" + + __slots__ = ( + "settings", + "_lock", + "_last_signature", + "_last_sent", + ) + + def __init__(self, settings: AlertChannelSettings) -> None: + self.settings = settings + self._lock = threading.Lock() + self._last_signature: Optional[str] = None + self._last_sent: float = 0.0 + + @property + def name(self) -> str: + return getattr(self.settings, "key", "channel") + + def send(self, entry: Mapping[str, Any]) -> None: + if not self._should_send(entry): + return + payload = self._build_payload(entry) + self._deliver(payload) + + def _should_send(self, entry: Mapping[str, Any]) -> bool: + level = str(entry.get("level", "warning") or "warning").lower() + level_rank = _LEVEL_RANK.get(level, _LEVEL_RANK["warning"]) + + threshold = str(getattr(self.settings, "level", "warning") or "warning").lower() + threshold_rank = _LEVEL_RANK.get(threshold, _LEVEL_RANK["warning"]) + if level_rank < threshold_rank: + return False + + channel_tags: Sequence[str] = getattr(self.settings, "tags", []) or [] + if channel_tags: + event_tags = entry.get("tags") or [] + if not isinstance(event_tags, Iterable): + event_tags = [] + if not set(str(tag) for tag in event_tags).intersection(str(tag) for tag in channel_tags): + return False + + cooldown = float(getattr(self.settings, "cooldown_seconds", 0.0) or 0.0) + signature = f"{entry.get('source')}|{entry.get('message')}|{level}" + if cooldown > 0: + now = time.monotonic() + with self._lock: + if self._last_signature == signature and (now - self._last_sent) < cooldown: + return False + self._last_signature = signature + self._last_sent = now + return True + + def _build_payload(self, entry: Mapping[str, Any]) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "source": entry.get("source"), + "message": entry.get("message"), + "detail": entry.get("detail"), + "timestamp": entry.get("timestamp"), + "level": entry.get("level"), + } + tags = entry.get("tags") + if isinstance(tags, Iterable) and not isinstance(tags, (str, bytes)): + payload["tags"] = list(tags) + if "payload" in entry and entry["payload"] is not None: + payload["payload"] = entry["payload"] + extra_params = getattr(self.settings, "extra_params", None) + if isinstance(extra_params, Mapping): + for key, value in extra_params.items(): + payload.setdefault(key, value) + return payload + + def _deliver(self, payload: Mapping[str, Any]) -> None: + url = getattr(self.settings, "url", "") + if not url: + return + method = str(getattr(self.settings, "method", "POST") or "POST").upper() + timeout = float(getattr(self.settings, "timeout", 3.0) or 3.0) + headers: MutableMapping[str, str] = {} + raw_headers = getattr(self.settings, "headers", None) + if isinstance(raw_headers, Mapping): + headers = {str(k): str(v) for k, v in raw_headers.items()} + + headers.setdefault("Content-Type", "application/json") + body = json.dumps(payload, ensure_ascii=False) + + signing_secret = getattr(self.settings, "signing_secret", None) + if signing_secret: + import hashlib + import hmac + + digest = hmac.new( + str(signing_secret).encode("utf-8"), + body.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + headers.setdefault("X-Signature", digest) + + try: + requests.request( + method=method, + url=url, + data=body, + headers=headers, + timeout=timeout, + ) + except Exception: # noqa: BLE001 + LOGGER.exception("发送告警失败: channel=%s", self.name) + + +class AlertDispatcher: + """Singleton-style dispatcher coordinating channel delivery.""" + + def __init__(self) -> None: + self._channels: Dict[str, _Channel] = {} + self._lock = threading.Lock() + + def configure(self, configs: Mapping[str, AlertChannelSettings]) -> None: + active: Dict[str, _Channel] = {} + for key, cfg in configs.items(): + if not cfg or not getattr(cfg, "enabled", True): + continue + if not getattr(cfg, "url", ""): + continue + channel = _Channel(cfg) + active[key] = channel + with self._lock: + self._channels = active + + def dispatch(self, entry: Mapping[str, Any]) -> None: + if not self._channels: + return + for channel in list(self._channels.values()): + channel.send(entry) + + +_DISPATCHER = AlertDispatcher() + + +def get_dispatcher() -> AlertDispatcher: + return _DISPATCHER diff --git a/app/utils/alerts.py b/app/utils/alerts.py index a403f7b..11080c5 100644 --- a/app/utils/alerts.py +++ b/app/utils/alerts.py @@ -1,46 +1,124 @@ -"""Runtime data warning registry for surfacing ingestion issues in UI.""" +"""Runtime data warning registry with external dispatch support.""" from __future__ import annotations +import logging from datetime import datetime from threading import Lock -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Sequence +from .alert_dispatcher import get_dispatcher -_ALERTS: List[Dict[str, str]] = [] +if TYPE_CHECKING: # pragma: no cover + from app.utils.config import AlertChannelSettings + +LOGGER = logging.getLogger(__name__) + +_ALERTS: List[Dict[str, Any]] = [] +_SINKS: List[Callable[[Dict[str, Any]], None]] = [] _LOCK = Lock() +_MAX_ALERTS = 50 -def add_warning(source: str, message: str, detail: Optional[str] = None) -> None: - """Register or update a warning entry.""" +def configure_channels(channels: Mapping[str, "AlertChannelSettings"]) -> None: + """Configure external dispatch channels.""" - source = source.strip() or "unknown" - message = message.strip() or "发生未知异常" + try: + get_dispatcher().configure(channels) + except Exception: # noqa: BLE001 + LOGGER.debug("配置外部告警通道失败", exc_info=True) + + +def register_sink(sink: Callable[[Dict[str, Any]], None]) -> None: + """Attach an additional sink to receive alert payloads.""" + + with _LOCK: + if sink not in _SINKS: + _SINKS.append(sink) + + +def unregister_sink(sink: Callable[[Dict[str, Any]], None]) -> None: + """Detach a previously registered sink.""" + + with _LOCK: + _SINKS[:] = [existing for existing in _SINKS if existing is not sink] + + +def add_warning( + source: str, + message: str, + detail: Optional[str] = None, + *, + level: str = "warning", + tags: Optional[Sequence[str]] = None, + payload: Optional[Mapping[str, Any]] = None, +) -> None: + """Register or update a warning entry and dispatch to sinks.""" + + source = (source or "").strip() or "unknown" + message = (message or "").strip() or "发生未知异常" + normalized_level = str(level or "warning").lower() timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z" + normalized_tags: List[str] = [] + if tags: + normalized_tags = [ + str(tag).strip() + for tag in tags + if isinstance(tag, str) and tag.strip() + ] + + snapshot: Dict[str, Any] = {} + sinks: List[Callable[[Dict[str, Any]], None]] = [] with _LOCK: for alert in _ALERTS: if alert["source"] == source and alert["message"] == message: alert["timestamp"] = timestamp + alert["level"] = normalized_level if detail: alert["detail"] = detail - return - entry = { - "source": source, - "message": message, - "timestamp": timestamp, - } - if detail: - entry["detail"] = detail - _ALERTS.append(entry) - if len(_ALERTS) > 50: - del _ALERTS[:-50] + if normalized_tags: + alert["tags"] = list(normalized_tags) + if payload is not None: + alert["payload"] = dict(payload) if isinstance(payload, Mapping) else payload + snapshot = dict(alert) + sinks = list(_SINKS) + break + else: + entry: Dict[str, Any] = { + "source": source, + "message": message, + "timestamp": timestamp, + "level": normalized_level, + } + if detail: + entry["detail"] = detail + if normalized_tags: + entry["tags"] = list(normalized_tags) + if payload is not None: + entry["payload"] = dict(payload) if isinstance(payload, Mapping) else payload + _ALERTS.append(entry) + if len(_ALERTS) > _MAX_ALERTS: + del _ALERTS[:-_MAX_ALERTS] + snapshot = dict(entry) + sinks = list(_SINKS) + + for sink in sinks: + try: + sink(dict(snapshot)) + except Exception: # noqa: BLE001 + LOGGER.debug("执行告警 sink 失败:%s", getattr(sink, "__name__", sink), exc_info=True) + + try: + get_dispatcher().dispatch(snapshot) + except Exception: # noqa: BLE001 + LOGGER.debug("外部告警发送失败 source=%s", source, exc_info=True) -def get_warnings() -> List[Dict[str, str]]: +def get_warnings() -> List[Dict[str, Any]]: """Return a copy of current warning entries.""" with _LOCK: - return list(_ALERTS) + return [dict(alert) for alert in _ALERTS] def clear_warnings(source: Optional[str] = None) -> None: @@ -50,6 +128,5 @@ def clear_warnings(source: Optional[str] = None) -> None: if source is None: _ALERTS.clear() return - source = source.strip() - _ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source] - + source_key = source.strip() + _ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source_key] diff --git a/app/utils/config.py b/app/utils/config.py index 6d55fc5..7180c5c 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -58,6 +58,43 @@ class PortfolioSettings: max_sector_exposure: float = 0.35 # 行业敞口上限 35% +@dataclass +class AlertChannelSettings: + """Configuration for external alert delivery channels.""" + + key: str + kind: str = "webhook" + url: str = "" + enabled: bool = True + level: str = "warning" + tags: List[str] = field(default_factory=list) + headers: Dict[str, str] = field(default_factory=dict) + timeout: float = 3.0 + method: str = "POST" + template: str = "" + signing_secret: Optional[str] = None + cooldown_seconds: float = 0.0 + extra_params: Dict[str, object] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, object]: + payload: Dict[str, object] = { + "kind": self.kind, + "url": self.url, + "enabled": self.enabled, + "level": self.level, + "tags": list(self.tags), + "headers": dict(self.headers), + "timeout": self.timeout, + "method": self.method, + "template": self.template, + "cooldown_seconds": self.cooldown_seconds, + "extra_params": dict(self.extra_params), + } + if self.signing_secret: + payload["signing_secret"] = self.signing_secret + return payload + + @dataclass class AgentWeights: """Default weighting for decision agents.""" @@ -528,6 +565,7 @@ class AppConfig: llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) portfolio: PortfolioSettings = field(default_factory=PortfolioSettings) + alert_channels: Dict[str, AlertChannelSettings] = field(default_factory=dict) def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: return self.llm @@ -634,6 +672,52 @@ def _load_from_file(cfg: AppConfig) -> None: ) cfg.portfolio = updated_portfolio + alert_channels_payload = payload.get("alert_channels") + if isinstance(alert_channels_payload, dict): + channels: Dict[str, AlertChannelSettings] = {} + for key, data in alert_channels_payload.items(): + if not isinstance(data, dict): + continue + normalized_key = str(key) + raw_tags = data.get("tags") + tags: List[str] = [] + if isinstance(raw_tags, list): + tags = [ + str(tag).strip() + for tag in raw_tags + if isinstance(tag, str) and tag.strip() + ] + headers: Dict[str, str] = {} + raw_headers = data.get("headers") + if isinstance(raw_headers, Mapping): + headers = { + str(h_key): str(h_val) + for h_key, h_val in raw_headers.items() + if h_key is not None + } + extra_params: Dict[str, object] = {} + raw_extra = data.get("extra_params") + if isinstance(raw_extra, Mapping): + extra_params = dict(raw_extra) + channel = AlertChannelSettings( + key=normalized_key, + kind=str(data.get("kind") or "webhook"), + url=str(data.get("url") or ""), + enabled=bool(data.get("enabled", True)), + level=str(data.get("level") or "warning"), + tags=tags, + headers=headers, + timeout=float(data.get("timeout", 3.0) or 3.0), + method=str(data.get("method") or "POST").upper(), + template=str(data.get("template") or ""), + signing_secret=str(data.get("signing_secret")) if data.get("signing_secret") else None, + cooldown_seconds=float(data.get("cooldown_seconds", 0.0) or 0.0), + extra_params=extra_params, + ) + if channel.url: + channels[channel.key] = channel + cfg.alert_channels = channels + cost_payload = payload.get("llm_cost") if isinstance(cost_payload, dict): cfg.llm_cost.update_from_dict(cost_payload) @@ -865,6 +949,10 @@ def save_config(cfg: AppConfig | None = None) -> None: "max_sector_exposure": cfg.portfolio.max_sector_exposure, }, }, + "alert_channels": { + name: channel.to_dict() + for name, channel in cfg.alert_channels.items() + }, "llm": { "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, @@ -918,6 +1006,13 @@ def save_config(cfg: AppConfig | None = None) -> None: LOGGER.info("配置已写入:%s", path) except OSError: LOGGER.exception("配置写入失败:%s", path) + return + + try: + from app.utils import alerts as _alerts # 延迟导入以避免循环 + _alerts.configure_channels(cfg.alert_channels) + except Exception: # noqa: BLE001 + LOGGER.debug("更新告警通道失败", exc_info=True) def _load_env_defaults(cfg: AppConfig) -> None: @@ -935,12 +1030,34 @@ def _load_env_defaults(cfg: AppConfig) -> None: if provider_cfg: provider_cfg.api_key = sanitized + webhook = os.getenv("LLM_QUANT_ALERT_WEBHOOK") + if webhook: + key = "env_webhook" + channel = AlertChannelSettings( + key=key, + kind="webhook", + url=webhook.strip(), + headers={"Content-Type": "application/json"}, + enabled=True, + level=str(os.getenv("LLM_QUANT_ALERT_LEVEL", "warning") or "warning"), + ) + tags_raw = os.getenv("LLM_QUANT_ALERT_TAGS") + if tags_raw: + channel.tags = [tag.strip() for tag in tags_raw.split(",") if tag.strip()] + cfg.alert_channels[key] = channel + cfg.sync_runtime_llm() _load_from_file(CONFIG) _load_env_defaults(CONFIG) +try: + from app.utils import alerts as _alerts_module # 延迟导入避免循环依赖 + _alerts_module.configure_channels(CONFIG.alert_channels) +except Exception: # noqa: BLE001 + LOGGER.debug("初始化告警通道失败", exc_info=True) + def get_config() -> AppConfig: """Return a mutable global configuration instance.""" diff --git a/docs/TODO.md b/docs/TODO.md index f5620c3..bd189d3 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -50,9 +50,9 @@ | 工作项 | 状态 | 说明 | | --- | --- | --- | | 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 | -| 风险事件持久化 | 🔄 | 已形成 `risk_round` → `risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 | -| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 | -| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 | +| 风险事件持久化 | ✅ | `risk_round` 事件现已批量写入 `bt_risk_events`,附带风险状态/元数据并在回测面板支持 Drill-down。 | +| 实时告警接入 | ✅ | 引入告警分发器支持配置化 webhook 通道,风险阻断/复核即时推送至外部渠道。 | +| 风险场景测试 | ✅ | 新增停牌、仓位超限、黑名单等集成测试覆盖,验证风险闭环执行。 | ## 测试与质量保障 diff --git a/tests/test_alert_dispatcher.py b/tests/test_alert_dispatcher.py new file mode 100644 index 0000000..6e90ecb --- /dev/null +++ b/tests/test_alert_dispatcher.py @@ -0,0 +1,70 @@ +"""Tests for alert dispatcher configuration and delivery.""" +from __future__ import annotations + +import json + +from app.utils import alerts +from app.utils.config import AlertChannelSettings + + +def test_alert_dispatcher_posts_payload(monkeypatch): + calls: list[dict[str, object]] = [] + + def fake_request(*, method, url, data=None, headers=None, timeout=None): + calls.append( + { + "method": method, + "url": url, + "data": data, + "headers": headers, + "timeout": timeout, + } + ) + + class _Resp: + status_code = 200 + + return _Resp() + + monkeypatch.setattr("app.utils.alert_dispatcher.requests.request", fake_request) + + alerts.clear_warnings() + alerts.configure_channels( + { + "ops": AlertChannelSettings( + key="ops", + kind="webhook", + url="https://example.com/webhook", + enabled=True, + level="info", + headers={"X-Test": "1"}, + extra_params={"channel": "risk"}, + ) + } + ) + + alerts.add_warning( + "risk_system", + "阻断测试", + detail="blocked", + level="error", + tags=["risk", "blocked"], + payload={"reason": "blocked"}, + ) + + assert calls, "expected dispatcher to send webhook call" + call = calls[0] + assert call["method"] == "POST" + assert call["url"] == "https://example.com/webhook" + assert call["headers"]["X-Test"] == "1" + payload = json.loads(call["data"]) + assert payload["message"] == "阻断测试" + assert payload["channel"] == "risk" + assert payload["payload"]["reason"] == "blocked" + + warnings = alerts.get_warnings() + assert warnings + assert warnings[0]["level"] == "error" + assert "blocked" in warnings[0].get("tags", []) + + alerts.configure_channels({}) diff --git a/tests/test_backtest_engine_risk.py b/tests/test_backtest_engine_risk.py index 794f9e8..bdd50d9 100644 --- a/tests/test_backtest_engine_risk.py +++ b/tests/test_backtest_engine_risk.py @@ -8,7 +8,7 @@ import pytest import json from app.agents.base import AgentAction, AgentContext -from app.agents.game import Decision +from app.agents.game import Decision, RiskAssessment from app.backtest.engine import ( BacktestEngine, BacktestResult, @@ -18,6 +18,7 @@ from app.backtest.engine import ( ) from app.data.schema import initialize_database from app.utils.config import DataPaths, get_config +from app.utils import alerts from app.utils.db import db_session @@ -121,6 +122,79 @@ def test_buy_blocked_by_limit_up_records_risk(): assert result.risk_events[0]["reason"] == "limit_up" +def test_position_limit_triggers_risk_event_and_adjusts_execution(): + alerts.clear_warnings() + engine = _engine_with_params( + { + "max_position_weight": 0.3, + "fee_rate": 0.0, + "slippage_bps": 0.0, + "max_daily_turnover_ratio": 1.0, + } + ) + state = PortfolioState(cash=100_000.0) + result = BacktestResult() + context = _make_context(100.0, {"position_limit": True}) + decision = _make_decision(AgentAction.BUY_L, target_weight=0.4) + decision.risk_assessment = RiskAssessment( + status="pending_review", + reason="position_limit", + recommended_action=AgentAction.BUY_S, + notes={"trigger": "position_limit"}, + ) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + assert not result.trades, "position limit should block execution despite adjustment" + assert result.risk_events + event_with_status = next( + (event for event in result.risk_events if event.get("risk_status")), + None, + ) + assert event_with_status is not None + assert event_with_status["reason"] == "position_limit" + assert event_with_status.get("risk_status") == "pending_review" + warning_messages = [item["message"] for item in alerts.get_warnings()] + assert any("风险提示" in msg for msg in warning_messages) + alerts.clear_warnings() + + +def test_blacklist_blocks_execution_and_warns(): + alerts.clear_warnings() + engine = _engine_with_params({}) + state = PortfolioState(cash=50_000.0) + result = BacktestResult() + context = _make_context(100.0, {"is_blacklisted": True}) + decision = _make_decision(AgentAction.BUY_M, target_weight=0.2) + decision.risk_assessment = RiskAssessment( + status="blocked", + reason="blacklist", + recommended_action=AgentAction.HOLD, + notes={"trigger": "is_blacklisted"}, + ) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + assert not result.trades + assert result.risk_events + event = result.risk_events[0] + assert event["reason"] == "blacklist" + assert event.get("risk_status") == "blocked" + warning_messages = [item["message"] for item in alerts.get_warnings()] + assert any("风险阻断" in msg for msg in warning_messages) + alerts.clear_warnings() + + def test_sell_applies_slippage_and_fee(): engine = _engine_with_params( { diff --git a/tests/test_decision_risk_integration.py b/tests/test_decision_risk_integration.py index ed12a5f..85dfac7 100644 --- a/tests/test_decision_risk_integration.py +++ b/tests/test_decision_risk_integration.py @@ -117,3 +117,26 @@ def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch): assert not state.holdings assert not result.trades assert result.nav_series[0]["nav"] == pytest.approx(100_000.0) + + +def test_decide_records_suspension_risk_round(): + agents = default_agents() + context = _make_context({"is_suspended": True}) + + decision = decide( + context, + agents, + weights={agent.name: 1.0 for agent in agents}, + department_manager=None, + ) + + assert decision.requires_review is True + assert decision.risk_assessment + assert decision.risk_assessment.status == "blocked" + assert decision.risk_assessment.reason == "suspended" + + risk_rounds = [summary for summary in decision.rounds if summary.agenda == "risk_review"] + assert risk_rounds + notes = risk_rounds[0].notes + assert notes.get("status") == "blocked" + assert notes.get("reason") == "suspended" diff --git a/tests/test_risk_agent.py b/tests/test_risk_agent.py index 019aa78..d19f7bc 100644 --- a/tests/test_risk_agent.py +++ b/tests/test_risk_agent.py @@ -41,6 +41,23 @@ def test_risk_agent_pending_on_conflict() -> None: assert recommendation.reason == "conflict_threshold" +def test_risk_agent_blocks_on_blacklist() -> None: + agent = RiskAgent() + context = _make_context(is_blacklisted=True) + recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False) + assert recommendation.status == "blocked" + assert recommendation.reason == "blacklist" + assert recommendation.recommended_action == AgentAction.HOLD + + +def test_blacklist_constraints_buy_feasibility() -> None: + agent = RiskAgent() + context = _make_context(is_blacklisted=True) + assert agent.feasible(context, AgentAction.SELL) is True + assert agent.feasible(context, AgentAction.HOLD) is True + assert agent.feasible(context, AgentAction.BUY_S) is False + + def test_evaluate_risk_external_alerts() -> None: agent = RiskAgent() context = AgentContext(