add blacklist risk control and alert dispatching
This commit is contained in:
parent
d2a056d7c0
commit
1ca2f2be19
@ -660,6 +660,7 @@ def _risk_review_message(reason: str) -> str:
|
||||
"risk_penalty_extreme": "风险评分极高,建议暂停加仓",
|
||||
"risk_penalty_high": "风险评分偏高,建议复核",
|
||||
"external_alert": "外部风险告警触发复核",
|
||||
"blacklist": "标的命中黑名单,禁止交易",
|
||||
}
|
||||
return mapping.get(reason, "触发风险复核,需人工确认")
|
||||
|
||||
|
||||
@ -49,6 +49,8 @@ class RiskAgent(Agent):
|
||||
def feasible(self, context: AgentContext, action: AgentAction) -> bool:
|
||||
if action is AgentAction.SELL:
|
||||
return True
|
||||
if context.features.get("is_blacklisted", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
|
||||
return False
|
||||
if context.features.get("is_suspended", False):
|
||||
return False
|
||||
if context.features.get("limit_up", False) and action not in (AgentAction.SELL, AgentAction.HOLD):
|
||||
@ -74,6 +76,15 @@ class RiskAgent(Agent):
|
||||
notes={"trigger": "is_suspended"},
|
||||
)
|
||||
|
||||
if bool(features.get("is_blacklisted")):
|
||||
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="blacklist",
|
||||
recommended_action=fallback,
|
||||
notes={"trigger": "is_blacklisted"},
|
||||
)
|
||||
|
||||
if bool(features.get("limit_up")) and decision_action in {
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
|
||||
@ -698,6 +698,7 @@ class BacktestEngine:
|
||||
action_override: Optional[AgentAction] = None,
|
||||
target_weight_override: Optional[float] = None,
|
||||
) -> None:
|
||||
reason_str = str(reason)
|
||||
payload = {
|
||||
"trade_date": trade_date_str,
|
||||
"ts_code": ts_code,
|
||||
@ -708,21 +709,43 @@ class BacktestEngine:
|
||||
else decision.target_weight
|
||||
),
|
||||
"confidence": decision.confidence,
|
||||
"reason": reason,
|
||||
"reason": reason_str,
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
risk_events.append(payload)
|
||||
risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None
|
||||
status = None
|
||||
risk_meta = None
|
||||
if isinstance(payload.get("risk_assessment"), dict):
|
||||
risk_meta = payload.get("risk_assessment")
|
||||
elif extra and isinstance(extra.get("risk_assessment"), dict):
|
||||
risk_meta = extra.get("risk_assessment")
|
||||
status: Optional[str] = None
|
||||
if isinstance(risk_meta, dict):
|
||||
status = risk_meta.get("status")
|
||||
status = str(risk_meta.get("status") or "")
|
||||
payload.setdefault("risk_status", status)
|
||||
if status == "blocked":
|
||||
try:
|
||||
message = f"{ts_code} 风险阻断: {reason_str}"
|
||||
alerts.add_warning(
|
||||
"backtest_risk",
|
||||
f"{ts_code} 风险阻断: {reason}",
|
||||
message,
|
||||
detail=json.dumps(payload, ensure_ascii=False),
|
||||
level="error",
|
||||
tags=["risk", reason_str, status],
|
||||
payload=payload,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
|
||||
elif status and status != "ok":
|
||||
try:
|
||||
message = f"{ts_code} 风险提示: {reason_str}"
|
||||
alerts.add_warning(
|
||||
"backtest_risk",
|
||||
message,
|
||||
detail=json.dumps(payload, ensure_ascii=False),
|
||||
level="warning",
|
||||
tags=["risk", reason_str, status],
|
||||
payload=payload,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
|
||||
|
||||
169
app/utils/alert_dispatcher.py
Normal file
169
app/utils/alert_dispatcher.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""Dispatch structured alerts to external channels."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence
|
||||
|
||||
import requests
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.utils.config import AlertChannelSettings
|
||||
else:
|
||||
AlertChannelSettings = Any # type: ignore[assignment]
|
||||
|
||||
_LEVEL_RANK: Dict[str, int] = {
|
||||
"debug": 10,
|
||||
"info": 20,
|
||||
"warning": 30,
|
||||
"error": 40,
|
||||
"critical": 50,
|
||||
}
|
||||
|
||||
|
||||
class _Channel:
|
||||
"""Runtime wrapper around a configured alert channel."""
|
||||
|
||||
__slots__ = (
|
||||
"settings",
|
||||
"_lock",
|
||||
"_last_signature",
|
||||
"_last_sent",
|
||||
)
|
||||
|
||||
def __init__(self, settings: AlertChannelSettings) -> None:
|
||||
self.settings = settings
|
||||
self._lock = threading.Lock()
|
||||
self._last_signature: Optional[str] = None
|
||||
self._last_sent: float = 0.0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return getattr(self.settings, "key", "channel")
|
||||
|
||||
def send(self, entry: Mapping[str, Any]) -> None:
|
||||
if not self._should_send(entry):
|
||||
return
|
||||
payload = self._build_payload(entry)
|
||||
self._deliver(payload)
|
||||
|
||||
def _should_send(self, entry: Mapping[str, Any]) -> bool:
|
||||
level = str(entry.get("level", "warning") or "warning").lower()
|
||||
level_rank = _LEVEL_RANK.get(level, _LEVEL_RANK["warning"])
|
||||
|
||||
threshold = str(getattr(self.settings, "level", "warning") or "warning").lower()
|
||||
threshold_rank = _LEVEL_RANK.get(threshold, _LEVEL_RANK["warning"])
|
||||
if level_rank < threshold_rank:
|
||||
return False
|
||||
|
||||
channel_tags: Sequence[str] = getattr(self.settings, "tags", []) or []
|
||||
if channel_tags:
|
||||
event_tags = entry.get("tags") or []
|
||||
if not isinstance(event_tags, Iterable):
|
||||
event_tags = []
|
||||
if not set(str(tag) for tag in event_tags).intersection(str(tag) for tag in channel_tags):
|
||||
return False
|
||||
|
||||
cooldown = float(getattr(self.settings, "cooldown_seconds", 0.0) or 0.0)
|
||||
signature = f"{entry.get('source')}|{entry.get('message')}|{level}"
|
||||
if cooldown > 0:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
if self._last_signature == signature and (now - self._last_sent) < cooldown:
|
||||
return False
|
||||
self._last_signature = signature
|
||||
self._last_sent = now
|
||||
return True
|
||||
|
||||
def _build_payload(self, entry: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"source": entry.get("source"),
|
||||
"message": entry.get("message"),
|
||||
"detail": entry.get("detail"),
|
||||
"timestamp": entry.get("timestamp"),
|
||||
"level": entry.get("level"),
|
||||
}
|
||||
tags = entry.get("tags")
|
||||
if isinstance(tags, Iterable) and not isinstance(tags, (str, bytes)):
|
||||
payload["tags"] = list(tags)
|
||||
if "payload" in entry and entry["payload"] is not None:
|
||||
payload["payload"] = entry["payload"]
|
||||
extra_params = getattr(self.settings, "extra_params", None)
|
||||
if isinstance(extra_params, Mapping):
|
||||
for key, value in extra_params.items():
|
||||
payload.setdefault(key, value)
|
||||
return payload
|
||||
|
||||
def _deliver(self, payload: Mapping[str, Any]) -> None:
|
||||
url = getattr(self.settings, "url", "")
|
||||
if not url:
|
||||
return
|
||||
method = str(getattr(self.settings, "method", "POST") or "POST").upper()
|
||||
timeout = float(getattr(self.settings, "timeout", 3.0) or 3.0)
|
||||
headers: MutableMapping[str, str] = {}
|
||||
raw_headers = getattr(self.settings, "headers", None)
|
||||
if isinstance(raw_headers, Mapping):
|
||||
headers = {str(k): str(v) for k, v in raw_headers.items()}
|
||||
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
body = json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
signing_secret = getattr(self.settings, "signing_secret", None)
|
||||
if signing_secret:
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
digest = hmac.new(
|
||||
str(signing_secret).encode("utf-8"),
|
||||
body.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
headers.setdefault("X-Signature", digest)
|
||||
|
||||
try:
|
||||
requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("发送告警失败: channel=%s", self.name)
|
||||
|
||||
|
||||
class AlertDispatcher:
|
||||
"""Singleton-style dispatcher coordinating channel delivery."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._channels: Dict[str, _Channel] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def configure(self, configs: Mapping[str, AlertChannelSettings]) -> None:
|
||||
active: Dict[str, _Channel] = {}
|
||||
for key, cfg in configs.items():
|
||||
if not cfg or not getattr(cfg, "enabled", True):
|
||||
continue
|
||||
if not getattr(cfg, "url", ""):
|
||||
continue
|
||||
channel = _Channel(cfg)
|
||||
active[key] = channel
|
||||
with self._lock:
|
||||
self._channels = active
|
||||
|
||||
def dispatch(self, entry: Mapping[str, Any]) -> None:
|
||||
if not self._channels:
|
||||
return
|
||||
for channel in list(self._channels.values()):
|
||||
channel.send(entry)
|
||||
|
||||
|
||||
_DISPATCHER = AlertDispatcher()
|
||||
|
||||
|
||||
def get_dispatcher() -> AlertDispatcher:
|
||||
return _DISPATCHER
|
||||
@ -1,46 +1,124 @@
|
||||
"""Runtime data warning registry for surfacing ingestion issues in UI."""
|
||||
"""Runtime data warning registry with external dispatch support."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from .alert_dispatcher import get_dispatcher
|
||||
|
||||
_ALERTS: List[Dict[str, str]] = []
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.utils.config import AlertChannelSettings
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_ALERTS: List[Dict[str, Any]] = []
|
||||
_SINKS: List[Callable[[Dict[str, Any]], None]] = []
|
||||
_LOCK = Lock()
|
||||
_MAX_ALERTS = 50
|
||||
|
||||
|
||||
def add_warning(source: str, message: str, detail: Optional[str] = None) -> None:
|
||||
"""Register or update a warning entry."""
|
||||
def configure_channels(channels: Mapping[str, "AlertChannelSettings"]) -> None:
|
||||
"""Configure external dispatch channels."""
|
||||
|
||||
source = source.strip() or "unknown"
|
||||
message = message.strip() or "发生未知异常"
|
||||
try:
|
||||
get_dispatcher().configure(channels)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("配置外部告警通道失败", exc_info=True)
|
||||
|
||||
|
||||
def register_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""Attach an additional sink to receive alert payloads."""
|
||||
|
||||
with _LOCK:
|
||||
if sink not in _SINKS:
|
||||
_SINKS.append(sink)
|
||||
|
||||
|
||||
def unregister_sink(sink: Callable[[Dict[str, Any]], None]) -> None:
|
||||
"""Detach a previously registered sink."""
|
||||
|
||||
with _LOCK:
|
||||
_SINKS[:] = [existing for existing in _SINKS if existing is not sink]
|
||||
|
||||
|
||||
def add_warning(
|
||||
source: str,
|
||||
message: str,
|
||||
detail: Optional[str] = None,
|
||||
*,
|
||||
level: str = "warning",
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
payload: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Register or update a warning entry and dispatch to sinks."""
|
||||
|
||||
source = (source or "").strip() or "unknown"
|
||||
message = (message or "").strip() or "发生未知异常"
|
||||
normalized_level = str(level or "warning").lower()
|
||||
timestamp = datetime.utcnow().isoformat(timespec="seconds") + "Z"
|
||||
normalized_tags: List[str] = []
|
||||
if tags:
|
||||
normalized_tags = [
|
||||
str(tag).strip()
|
||||
for tag in tags
|
||||
if isinstance(tag, str) and tag.strip()
|
||||
]
|
||||
|
||||
snapshot: Dict[str, Any] = {}
|
||||
sinks: List[Callable[[Dict[str, Any]], None]] = []
|
||||
|
||||
with _LOCK:
|
||||
for alert in _ALERTS:
|
||||
if alert["source"] == source and alert["message"] == message:
|
||||
alert["timestamp"] = timestamp
|
||||
alert["level"] = normalized_level
|
||||
if detail:
|
||||
alert["detail"] = detail
|
||||
return
|
||||
entry = {
|
||||
if normalized_tags:
|
||||
alert["tags"] = list(normalized_tags)
|
||||
if payload is not None:
|
||||
alert["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
|
||||
snapshot = dict(alert)
|
||||
sinks = list(_SINKS)
|
||||
break
|
||||
else:
|
||||
entry: Dict[str, Any] = {
|
||||
"source": source,
|
||||
"message": message,
|
||||
"timestamp": timestamp,
|
||||
"level": normalized_level,
|
||||
}
|
||||
if detail:
|
||||
entry["detail"] = detail
|
||||
if normalized_tags:
|
||||
entry["tags"] = list(normalized_tags)
|
||||
if payload is not None:
|
||||
entry["payload"] = dict(payload) if isinstance(payload, Mapping) else payload
|
||||
_ALERTS.append(entry)
|
||||
if len(_ALERTS) > 50:
|
||||
del _ALERTS[:-50]
|
||||
if len(_ALERTS) > _MAX_ALERTS:
|
||||
del _ALERTS[:-_MAX_ALERTS]
|
||||
snapshot = dict(entry)
|
||||
sinks = list(_SINKS)
|
||||
|
||||
for sink in sinks:
|
||||
try:
|
||||
sink(dict(snapshot))
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("执行告警 sink 失败:%s", getattr(sink, "__name__", sink), exc_info=True)
|
||||
|
||||
try:
|
||||
get_dispatcher().dispatch(snapshot)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("外部告警发送失败 source=%s", source, exc_info=True)
|
||||
|
||||
|
||||
def get_warnings() -> List[Dict[str, str]]:
|
||||
def get_warnings() -> List[Dict[str, Any]]:
|
||||
"""Return a copy of current warning entries."""
|
||||
|
||||
with _LOCK:
|
||||
return list(_ALERTS)
|
||||
return [dict(alert) for alert in _ALERTS]
|
||||
|
||||
|
||||
def clear_warnings(source: Optional[str] = None) -> None:
|
||||
@ -50,6 +128,5 @@ def clear_warnings(source: Optional[str] = None) -> None:
|
||||
if source is None:
|
||||
_ALERTS.clear()
|
||||
return
|
||||
source = source.strip()
|
||||
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source]
|
||||
|
||||
source_key = source.strip()
|
||||
_ALERTS[:] = [alert for alert in _ALERTS if alert["source"] != source_key]
|
||||
|
||||
@ -58,6 +58,43 @@ class PortfolioSettings:
|
||||
max_sector_exposure: float = 0.35 # 行业敞口上限 35%
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertChannelSettings:
|
||||
"""Configuration for external alert delivery channels."""
|
||||
|
||||
key: str
|
||||
kind: str = "webhook"
|
||||
url: str = ""
|
||||
enabled: bool = True
|
||||
level: str = "warning"
|
||||
tags: List[str] = field(default_factory=list)
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
timeout: float = 3.0
|
||||
method: str = "POST"
|
||||
template: str = ""
|
||||
signing_secret: Optional[str] = None
|
||||
cooldown_seconds: float = 0.0
|
||||
extra_params: Dict[str, object] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
payload: Dict[str, object] = {
|
||||
"kind": self.kind,
|
||||
"url": self.url,
|
||||
"enabled": self.enabled,
|
||||
"level": self.level,
|
||||
"tags": list(self.tags),
|
||||
"headers": dict(self.headers),
|
||||
"timeout": self.timeout,
|
||||
"method": self.method,
|
||||
"template": self.template,
|
||||
"cooldown_seconds": self.cooldown_seconds,
|
||||
"extra_params": dict(self.extra_params),
|
||||
}
|
||||
if self.signing_secret:
|
||||
payload["signing_secret"] = self.signing_secret
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentWeights:
|
||||
"""Default weighting for decision agents."""
|
||||
@ -528,6 +565,7 @@ class AppConfig:
|
||||
llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings)
|
||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
|
||||
alert_channels: Dict[str, AlertChannelSettings] = field(default_factory=dict)
|
||||
|
||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||
return self.llm
|
||||
@ -634,6 +672,52 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
)
|
||||
cfg.portfolio = updated_portfolio
|
||||
|
||||
alert_channels_payload = payload.get("alert_channels")
|
||||
if isinstance(alert_channels_payload, dict):
|
||||
channels: Dict[str, AlertChannelSettings] = {}
|
||||
for key, data in alert_channels_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
normalized_key = str(key)
|
||||
raw_tags = data.get("tags")
|
||||
tags: List[str] = []
|
||||
if isinstance(raw_tags, list):
|
||||
tags = [
|
||||
str(tag).strip()
|
||||
for tag in raw_tags
|
||||
if isinstance(tag, str) and tag.strip()
|
||||
]
|
||||
headers: Dict[str, str] = {}
|
||||
raw_headers = data.get("headers")
|
||||
if isinstance(raw_headers, Mapping):
|
||||
headers = {
|
||||
str(h_key): str(h_val)
|
||||
for h_key, h_val in raw_headers.items()
|
||||
if h_key is not None
|
||||
}
|
||||
extra_params: Dict[str, object] = {}
|
||||
raw_extra = data.get("extra_params")
|
||||
if isinstance(raw_extra, Mapping):
|
||||
extra_params = dict(raw_extra)
|
||||
channel = AlertChannelSettings(
|
||||
key=normalized_key,
|
||||
kind=str(data.get("kind") or "webhook"),
|
||||
url=str(data.get("url") or ""),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
level=str(data.get("level") or "warning"),
|
||||
tags=tags,
|
||||
headers=headers,
|
||||
timeout=float(data.get("timeout", 3.0) or 3.0),
|
||||
method=str(data.get("method") or "POST").upper(),
|
||||
template=str(data.get("template") or ""),
|
||||
signing_secret=str(data.get("signing_secret")) if data.get("signing_secret") else None,
|
||||
cooldown_seconds=float(data.get("cooldown_seconds", 0.0) or 0.0),
|
||||
extra_params=extra_params,
|
||||
)
|
||||
if channel.url:
|
||||
channels[channel.key] = channel
|
||||
cfg.alert_channels = channels
|
||||
|
||||
cost_payload = payload.get("llm_cost")
|
||||
if isinstance(cost_payload, dict):
|
||||
cfg.llm_cost.update_from_dict(cost_payload)
|
||||
@ -865,6 +949,10 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
"max_sector_exposure": cfg.portfolio.max_sector_exposure,
|
||||
},
|
||||
},
|
||||
"alert_channels": {
|
||||
name: channel.to_dict()
|
||||
for name, channel in cfg.alert_channels.items()
|
||||
},
|
||||
"llm": {
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": cfg.llm.majority_threshold,
|
||||
@ -918,6 +1006,13 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
LOGGER.info("配置已写入:%s", path)
|
||||
except OSError:
|
||||
LOGGER.exception("配置写入失败:%s", path)
|
||||
return
|
||||
|
||||
try:
|
||||
from app.utils import alerts as _alerts # 延迟导入以避免循环
|
||||
_alerts.configure_channels(cfg.alert_channels)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("更新告警通道失败", exc_info=True)
|
||||
|
||||
|
||||
def _load_env_defaults(cfg: AppConfig) -> None:
|
||||
@ -935,12 +1030,34 @@ def _load_env_defaults(cfg: AppConfig) -> None:
|
||||
if provider_cfg:
|
||||
provider_cfg.api_key = sanitized
|
||||
|
||||
webhook = os.getenv("LLM_QUANT_ALERT_WEBHOOK")
|
||||
if webhook:
|
||||
key = "env_webhook"
|
||||
channel = AlertChannelSettings(
|
||||
key=key,
|
||||
kind="webhook",
|
||||
url=webhook.strip(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
enabled=True,
|
||||
level=str(os.getenv("LLM_QUANT_ALERT_LEVEL", "warning") or "warning"),
|
||||
)
|
||||
tags_raw = os.getenv("LLM_QUANT_ALERT_TAGS")
|
||||
if tags_raw:
|
||||
channel.tags = [tag.strip() for tag in tags_raw.split(",") if tag.strip()]
|
||||
cfg.alert_channels[key] = channel
|
||||
|
||||
cfg.sync_runtime_llm()
|
||||
|
||||
|
||||
_load_from_file(CONFIG)
|
||||
_load_env_defaults(CONFIG)
|
||||
|
||||
try:
|
||||
from app.utils import alerts as _alerts_module # 延迟导入避免循环依赖
|
||||
_alerts_module.configure_channels(CONFIG.alert_channels)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("初始化告警通道失败", exc_info=True)
|
||||
|
||||
|
||||
def get_config() -> AppConfig:
|
||||
"""Return a mutable global configuration instance."""
|
||||
|
||||
@ -50,9 +50,9 @@
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 |
|
||||
| 风险事件持久化 | 🔄 | 已形成 `risk_round` → `risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 |
|
||||
| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 |
|
||||
| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 |
|
||||
| 风险事件持久化 | ✅ | `risk_round` 事件现已批量写入 `bt_risk_events`,附带风险状态/元数据并在回测面板支持 Drill-down。 |
|
||||
| 实时告警接入 | ✅ | 引入告警分发器支持配置化 webhook 通道,风险阻断/复核即时推送至外部渠道。 |
|
||||
| 风险场景测试 | ✅ | 新增停牌、仓位超限、黑名单等集成测试覆盖,验证风险闭环执行。 |
|
||||
|
||||
## 测试与质量保障
|
||||
|
||||
|
||||
70
tests/test_alert_dispatcher.py
Normal file
70
tests/test_alert_dispatcher.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Tests for alert dispatcher configuration and delivery."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.utils import alerts
|
||||
from app.utils.config import AlertChannelSettings
|
||||
|
||||
|
||||
def test_alert_dispatcher_posts_payload(monkeypatch):
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
def fake_request(*, method, url, data=None, headers=None, timeout=None):
|
||||
calls.append(
|
||||
{
|
||||
"method": method,
|
||||
"url": url,
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
"timeout": timeout,
|
||||
}
|
||||
)
|
||||
|
||||
class _Resp:
|
||||
status_code = 200
|
||||
|
||||
return _Resp()
|
||||
|
||||
monkeypatch.setattr("app.utils.alert_dispatcher.requests.request", fake_request)
|
||||
|
||||
alerts.clear_warnings()
|
||||
alerts.configure_channels(
|
||||
{
|
||||
"ops": AlertChannelSettings(
|
||||
key="ops",
|
||||
kind="webhook",
|
||||
url="https://example.com/webhook",
|
||||
enabled=True,
|
||||
level="info",
|
||||
headers={"X-Test": "1"},
|
||||
extra_params={"channel": "risk"},
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
alerts.add_warning(
|
||||
"risk_system",
|
||||
"阻断测试",
|
||||
detail="blocked",
|
||||
level="error",
|
||||
tags=["risk", "blocked"],
|
||||
payload={"reason": "blocked"},
|
||||
)
|
||||
|
||||
assert calls, "expected dispatcher to send webhook call"
|
||||
call = calls[0]
|
||||
assert call["method"] == "POST"
|
||||
assert call["url"] == "https://example.com/webhook"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
payload = json.loads(call["data"])
|
||||
assert payload["message"] == "阻断测试"
|
||||
assert payload["channel"] == "risk"
|
||||
assert payload["payload"]["reason"] == "blocked"
|
||||
|
||||
warnings = alerts.get_warnings()
|
||||
assert warnings
|
||||
assert warnings[0]["level"] == "error"
|
||||
assert "blocked" in warnings[0].get("tags", [])
|
||||
|
||||
alerts.configure_channels({})
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import json
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.agents.game import Decision, RiskAssessment
|
||||
from app.backtest.engine import (
|
||||
BacktestEngine,
|
||||
BacktestResult,
|
||||
@ -18,6 +18,7 @@ from app.backtest.engine import (
|
||||
)
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.config import DataPaths, get_config
|
||||
from app.utils import alerts
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@ -121,6 +122,79 @@ def test_buy_blocked_by_limit_up_records_risk():
|
||||
assert result.risk_events[0]["reason"] == "limit_up"
|
||||
|
||||
|
||||
def test_position_limit_triggers_risk_event_and_adjusts_execution():
|
||||
alerts.clear_warnings()
|
||||
engine = _engine_with_params(
|
||||
{
|
||||
"max_position_weight": 0.3,
|
||||
"fee_rate": 0.0,
|
||||
"slippage_bps": 0.0,
|
||||
"max_daily_turnover_ratio": 1.0,
|
||||
}
|
||||
)
|
||||
state = PortfolioState(cash=100_000.0)
|
||||
result = BacktestResult()
|
||||
context = _make_context(100.0, {"position_limit": True})
|
||||
decision = _make_decision(AgentAction.BUY_L, target_weight=0.4)
|
||||
decision.risk_assessment = RiskAssessment(
|
||||
status="pending_review",
|
||||
reason="position_limit",
|
||||
recommended_action=AgentAction.BUY_S,
|
||||
notes={"trigger": "position_limit"},
|
||||
)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
assert not result.trades, "position limit should block execution despite adjustment"
|
||||
assert result.risk_events
|
||||
event_with_status = next(
|
||||
(event for event in result.risk_events if event.get("risk_status")),
|
||||
None,
|
||||
)
|
||||
assert event_with_status is not None
|
||||
assert event_with_status["reason"] == "position_limit"
|
||||
assert event_with_status.get("risk_status") == "pending_review"
|
||||
warning_messages = [item["message"] for item in alerts.get_warnings()]
|
||||
assert any("风险提示" in msg for msg in warning_messages)
|
||||
alerts.clear_warnings()
|
||||
|
||||
|
||||
def test_blacklist_blocks_execution_and_warns():
|
||||
alerts.clear_warnings()
|
||||
engine = _engine_with_params({})
|
||||
state = PortfolioState(cash=50_000.0)
|
||||
result = BacktestResult()
|
||||
context = _make_context(100.0, {"is_blacklisted": True})
|
||||
decision = _make_decision(AgentAction.BUY_M, target_weight=0.2)
|
||||
decision.risk_assessment = RiskAssessment(
|
||||
status="blocked",
|
||||
reason="blacklist",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "is_blacklisted"},
|
||||
)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
assert not result.trades
|
||||
assert result.risk_events
|
||||
event = result.risk_events[0]
|
||||
assert event["reason"] == "blacklist"
|
||||
assert event.get("risk_status") == "blocked"
|
||||
warning_messages = [item["message"] for item in alerts.get_warnings()]
|
||||
assert any("风险阻断" in msg for msg in warning_messages)
|
||||
alerts.clear_warnings()
|
||||
|
||||
|
||||
def test_sell_applies_slippage_and_fee():
|
||||
engine = _engine_with_params(
|
||||
{
|
||||
|
||||
@ -117,3 +117,26 @@ def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch):
|
||||
assert not state.holdings
|
||||
assert not result.trades
|
||||
assert result.nav_series[0]["nav"] == pytest.approx(100_000.0)
|
||||
|
||||
|
||||
def test_decide_records_suspension_risk_round():
|
||||
agents = default_agents()
|
||||
context = _make_context({"is_suspended": True})
|
||||
|
||||
decision = decide(
|
||||
context,
|
||||
agents,
|
||||
weights={agent.name: 1.0 for agent in agents},
|
||||
department_manager=None,
|
||||
)
|
||||
|
||||
assert decision.requires_review is True
|
||||
assert decision.risk_assessment
|
||||
assert decision.risk_assessment.status == "blocked"
|
||||
assert decision.risk_assessment.reason == "suspended"
|
||||
|
||||
risk_rounds = [summary for summary in decision.rounds if summary.agenda == "risk_review"]
|
||||
assert risk_rounds
|
||||
notes = risk_rounds[0].notes
|
||||
assert notes.get("status") == "blocked"
|
||||
assert notes.get("reason") == "suspended"
|
||||
|
||||
@ -41,6 +41,23 @@ def test_risk_agent_pending_on_conflict() -> None:
|
||||
assert recommendation.reason == "conflict_threshold"
|
||||
|
||||
|
||||
def test_risk_agent_blocks_on_blacklist() -> None:
|
||||
agent = RiskAgent()
|
||||
context = _make_context(is_blacklisted=True)
|
||||
recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False)
|
||||
assert recommendation.status == "blocked"
|
||||
assert recommendation.reason == "blacklist"
|
||||
assert recommendation.recommended_action == AgentAction.HOLD
|
||||
|
||||
|
||||
def test_blacklist_constraints_buy_feasibility() -> None:
|
||||
agent = RiskAgent()
|
||||
context = _make_context(is_blacklisted=True)
|
||||
assert agent.feasible(context, AgentAction.SELL) is True
|
||||
assert agent.feasible(context, AgentAction.HOLD) is True
|
||||
assert agent.feasible(context, AgentAction.BUY_S) is False
|
||||
|
||||
|
||||
def test_evaluate_risk_external_alerts() -> None:
|
||||
agent = RiskAgent()
|
||||
context = AgentContext(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user