add risk assessment and protocol tracking to game engine

This commit is contained in:
sam 2025-10-07 20:15:25 +08:00
parent 721d59d1cf
commit 3c15d443d3
11 changed files with 1336 additions and 147 deletions

View File

@ -8,6 +8,15 @@ from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map
from .risk import RiskAgent, RiskRecommendation
from .protocols import (
DialogueMessage,
DialogueRole,
GameStructure,
MessageType,
ProtocolHost,
RoundSummary,
)
ACTIONS: Tuple[AgentAction, ...] = (
@ -23,6 +32,30 @@ def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
@dataclass
class BeliefUpdate:
belief: Dict[str, object]
rationale: Optional[str] = None
@dataclass
class RiskAssessment:
status: str
reason: str
recommended_action: Optional[AgentAction] = None
notes: Dict[str, object] = field(default_factory=dict)
def to_dict(self) -> Dict[str, object]:
payload: Dict[str, object] = {
"status": self.status,
"reason": self.reason,
"notes": dict(self.notes),
}
if self.recommended_action is not None:
payload["recommended_action"] = self.recommended_action.value
return payload
@dataclass
class Decision:
action: AgentAction
@ -33,6 +66,9 @@ class Decision:
department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict)
department_votes: Dict[str, float] = field(default_factory=dict)
requires_review: bool = False
rounds: List[RoundSummary] = field(default_factory=list)
risk_assessment: Optional[RiskAssessment] = None
belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict)
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
@ -132,6 +168,16 @@ def decide(
raw_weights = dict(weights)
department_decisions: Dict[str, DepartmentDecision] = {}
department_votes: Dict[str, float] = {}
host = ProtocolHost()
host_trace = host.bootstrap_trace(
session_id=f"{context.ts_code}:{context.trade_date}",
ts_code=context.ts_code,
trade_date=context.trade_date,
)
department_round: Optional[RoundSummary] = None
risk_round: Optional[RoundSummary] = None
execution_round: Optional[RoundSummary] = None
belief_updates: Dict[str, BeliefUpdate] = {}
if department_manager:
dept_context = department_context
@ -144,6 +190,12 @@ def decide(
raw=dict(getattr(context, "raw", {}) or {}),
)
department_decisions = department_manager.evaluate(dept_context)
if department_decisions:
department_round = host.start_round(
host_trace,
agenda="department_consensus",
structure=GameStructure.REPEATED,
)
for code, decision in department_decisions.items():
agent_key = f"dept_{code}"
dept_agent = department_manager.agents.get(code)
@ -155,6 +207,17 @@ def decide(
bucket = _department_vote_bucket(decision.action)
if bucket:
department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence
if department_round:
message = _department_message(code, decision)
host.handle_message(department_round, message)
belief_updates[code] = BeliefUpdate(
belief={
"action": decision.action.value,
"confidence": decision.confidence,
"signals": decision.signals,
},
rationale=decision.summary,
)
filtered_utilities = {action: utilities[action] for action in feas_actions}
hold_scores = utilities.get(AgentAction.HOLD, {})
@ -168,7 +231,111 @@ def decide(
action, confidence = vote(filtered_utilities, norm_weights)
weight = target_weight_for_action(action)
requires_review = _department_conflict_flag(department_votes)
conflict_flag = _department_conflict_flag(department_votes)
risk_agent = _find_risk_agent(agent_list)
risk_assessment = _evaluate_risk(
context,
action,
department_votes,
conflict_flag,
risk_agent,
)
requires_review = risk_assessment.status != "ok"
if department_round:
department_round.notes.setdefault("department_votes", dict(department_votes))
department_round.outcome = action.value
host.finalize_round(department_round)
if requires_review:
risk_round = host.ensure_round(
host_trace,
agenda="risk_review",
structure=GameStructure.CUSTOM,
)
review_message = DialogueMessage(
sender="risk_guard",
role=DialogueRole.RISK,
message_type=MessageType.COUNTER,
content=_risk_review_message(risk_assessment.reason),
confidence=1.0,
references=list(department_votes.keys()),
annotations={
"department_votes": dict(department_votes),
"risk_reason": risk_assessment.reason,
"recommended_action": (
risk_assessment.recommended_action.value
if risk_assessment.recommended_action
else None
),
"notes": dict(risk_assessment.notes),
},
)
host.handle_message(risk_round, review_message)
risk_round.notes.setdefault("status", risk_assessment.status)
risk_round.notes.setdefault("reason", risk_assessment.reason)
if risk_assessment.recommended_action:
risk_round.notes.setdefault(
"recommended_action",
risk_assessment.recommended_action.value,
)
risk_round.outcome = "REVIEW"
host.finalize_round(risk_round)
belief_updates["risk_guard"] = BeliefUpdate(
belief={
"status": risk_assessment.status,
"reason": risk_assessment.reason,
"recommended_action": (
risk_assessment.recommended_action.value
if risk_assessment.recommended_action
else None
),
},
)
execution_round = host.ensure_round(
host_trace,
agenda="execution_summary",
structure=GameStructure.REPEATED,
)
exec_action = action
exec_weight = weight
exec_status = "normal"
if requires_review and risk_assessment.recommended_action:
exec_action = risk_assessment.recommended_action
exec_status = "risk_adjusted"
exec_weight = target_weight_for_action(exec_action)
execution_message = DialogueMessage(
sender="execution_engine",
role=DialogueRole.EXECUTION,
message_type=MessageType.DIRECTIVE,
content=f"执行操作 {exec_action.value}",
confidence=1.0,
annotations={
"target_weight": exec_weight,
"requires_review": requires_review,
"execution_status": exec_status,
},
)
host.handle_message(execution_round, execution_message)
execution_round.outcome = exec_action.value
execution_round.notes.setdefault("execution_status", exec_status)
if exec_action is not action:
execution_round.notes.setdefault("original_action", action.value)
belief_updates["execution"] = BeliefUpdate(
belief={
"execution_status": exec_status,
"action": exec_action.value,
"target_weight": exec_weight,
},
)
host.finalize_round(execution_round)
host.close(host_trace)
rounds = host_trace.rounds if host_trace.rounds else _build_round_summaries(
department_decisions,
action,
department_votes,
)
return Decision(
action=action,
confidence=confidence,
@ -178,6 +345,9 @@ def decide(
department_decisions=department_decisions,
department_votes=department_votes,
requires_review=requires_review,
rounds=rounds,
risk_assessment=risk_assessment,
belief_updates=belief_updates,
)
@ -231,3 +401,145 @@ def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1:
return True
return False
def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage:
content = decision.summary or decision.raw_response or decision.action.value
references = decision.signals or []
annotations: Dict[str, object] = {
"risks": decision.risks,
"supplements": decision.supplements,
}
if decision.dialogue:
annotations["dialogue"] = decision.dialogue
if decision.telemetry:
annotations["telemetry"] = decision.telemetry
return DialogueMessage(
sender=code,
role=DialogueRole.PREDICTION,
message_type=MessageType.DECISION,
content=content,
confidence=decision.confidence,
references=references,
annotations=annotations,
)
def _evaluate_risk(
context: AgentContext,
action: AgentAction,
department_votes: Mapping[str, float],
conflict_flag: bool,
risk_agent: Optional[RiskAgent],
) -> RiskAssessment:
external_alerts = []
if getattr(context, "raw", None):
alerts = context.raw.get("risk_alerts", [])
if alerts:
external_alerts = list(alerts)
if risk_agent:
recommendation = risk_agent.assess(context, action, conflict_flag)
notes = dict(recommendation.notes)
notes.setdefault("department_votes", dict(department_votes))
if external_alerts:
notes.setdefault("external_alerts", external_alerts)
if recommendation.status == "ok":
recommendation = RiskRecommendation(
status="pending_review",
reason="external_alert",
recommended_action=recommendation.recommended_action or AgentAction.HOLD,
notes=notes,
)
else:
recommendation.notes = notes
return RiskAssessment(
status=recommendation.status,
reason=recommendation.reason,
recommended_action=recommendation.recommended_action,
notes=recommendation.notes,
)
notes: Dict[str, object] = {
"conflict": conflict_flag,
"department_votes": dict(department_votes),
}
if external_alerts:
notes["external_alerts"] = external_alerts
return RiskAssessment(
status="pending_review",
reason="external_alert",
recommended_action=AgentAction.HOLD,
notes=notes,
)
if conflict_flag:
return RiskAssessment(
status="pending_review",
reason="conflict_threshold",
notes=notes,
)
return RiskAssessment(status="ok", reason="clear", notes=notes)
def _find_risk_agent(agents: Iterable[Agent]) -> Optional[RiskAgent]:
for agent in agents:
if isinstance(agent, RiskAgent):
return agent
return None
def _risk_review_message(reason: str) -> str:
mapping = {
"conflict_threshold": "部门意见分歧,触发风险复核",
"suspended": "标的停牌,需冻结执行",
"limit_up": "标的涨停,执行需调整",
"position_limit": "仓位限制已触发,需调整目标",
"risk_penalty_extreme": "风险评分极高,建议暂停加仓",
"risk_penalty_high": "风险评分偏高,建议复核",
"external_alert": "外部风险告警触发复核",
}
return mapping.get(reason, "触发风险复核,需人工确认")
def _build_round_summaries(
department_decisions: Mapping[str, DepartmentDecision],
final_action: AgentAction,
department_votes: Mapping[str, float],
) -> List[RoundSummary]:
if not department_decisions:
return []
messages: List[DialogueMessage] = []
for code, decision in department_decisions.items():
content = decision.summary or decision.raw_response or decision.action.value
references = decision.signals or []
annotations: Dict[str, object] = {
"risks": decision.risks,
"supplements": decision.supplements,
}
if decision.dialogue:
annotations["dialogue"] = decision.dialogue
if decision.telemetry:
annotations["telemetry"] = decision.telemetry
message = DialogueMessage(
sender=code,
role=DialogueRole.PREDICTION,
message_type=MessageType.DECISION,
content=content,
confidence=decision.confidence,
references=references,
annotations=annotations,
)
messages.append(message)
notes: Dict[str, object] = {
"department_votes": dict(department_votes),
}
summary = RoundSummary(
index=0,
agenda="department_consensus",
structure=GameStructure.REPEATED,
resolved=True,
outcome=final_action.value,
messages=messages,
notes=notes,
)
return [summary]

270
app/agents/protocols.py Normal file
View File

@ -0,0 +1,270 @@
"""Protocols and data structures for multi-round multi-agent games."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
class GameStructure(str, Enum):
"""Supported multi-agent game topologies."""
REPEATED = "repeated"
SIGNALING = "signaling"
BAYESIAN = "bayesian"
CUSTOM = "custom"
class DialogueRole(str, Enum):
"""Roles participating in the negotiation agenda."""
HOST = "host"
PREDICTION = "prediction"
RISK = "risk"
EXECUTION = "execution"
OBSERVER = "observer"
class MessageType(str, Enum):
"""High-level classification of dialogue intents."""
HYPOTHESIS = "hypothesis"
EVIDENCE = "evidence"
COUNTER = "counter"
DECISION = "decision"
DIRECTIVE = "directive"
META = "meta"
@dataclass
class DialogueMessage:
"""Single utterance in the multi-round dialogue."""
sender: str
role: DialogueRole
message_type: MessageType
content: str
confidence: float = 0.0
references: List[str] = field(default_factory=list)
timestamp: Optional[str] = None
annotations: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"sender": self.sender,
"role": self.role.value,
"message_type": self.message_type.value,
"content": self.content,
"confidence": self.confidence,
"references": list(self.references),
"timestamp": self.timestamp,
"annotations": dict(self.annotations),
}
@dataclass
class BeliefSnapshot:
"""Belief state emitted by an agent before or after revision."""
agent: str
role: DialogueRole
belief: Dict[str, Any]
confidence: float
rationale: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"agent": self.agent,
"role": self.role.value,
"belief": dict(self.belief),
"confidence": self.confidence,
"rationale": self.rationale,
}
@dataclass
class BeliefRevision:
"""Tracks belief updates triggered during a round."""
before: BeliefSnapshot
after: BeliefSnapshot
justification: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"before": self.before.to_dict(),
"after": self.after.to_dict(),
"justification": self.justification,
}
@dataclass
class RoundSummary:
"""Aggregated view of a single negotiation round."""
index: int
agenda: str
structure: GameStructure
resolved: bool
outcome: Optional[str] = None
messages: List[DialogueMessage] = field(default_factory=list)
revisions: List[BeliefRevision] = field(default_factory=list)
constraints_triggered: List[str] = field(default_factory=list)
notes: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"index": self.index,
"agenda": self.agenda,
"structure": self.structure.value,
"resolved": self.resolved,
"outcome": self.outcome,
"messages": [message.to_dict() for message in self.messages],
"revisions": [revision.to_dict() for revision in self.revisions],
"constraints_triggered": list(self.constraints_triggered),
"notes": dict(self.notes),
}
@dataclass
class DialogueTrace:
"""Ordered collection of round summaries for auditing."""
session_id: str
ts_code: str
trade_date: str
rounds: List[RoundSummary] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"session_id": self.session_id,
"ts_code": self.ts_code,
"trade_date": self.trade_date,
"rounds": [summary.to_dict() for summary in self.rounds],
"metadata": dict(self.metadata),
}
class GameProtocol(Protocol):
"""Extension point for hosting multi-round agent negotiations."""
def bootstrap(self, trace: DialogueTrace) -> None:
"""Prepare the protocol-specific state before rounds begin."""
def start_round(self, trace: DialogueTrace, agenda: str, structure: GameStructure) -> RoundSummary:
"""Open a new round with the given agenda and structure descriptor."""
def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None:
"""Process a single dialogue message emitted by an agent."""
def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None:
"""Register a belief revision triggered by debate or new evidence."""
def finalize_round(self, summary: RoundSummary) -> None:
"""Mark the round as resolved and perform protocol-specific bookkeeping."""
def close(self, trace: DialogueTrace) -> None:
"""Finish the negotiation session and emit protocol artifacts."""
class ProtocolHost(GameProtocol):
"""Base implementation for agenda-driven negotiation protocols."""
def __init__(self) -> None:
self._current_round: Optional[RoundSummary] = None
self._trace: Optional[DialogueTrace] = None
self._round_index: int = 0
def bootstrap(self, trace: DialogueTrace) -> None:
trace.metadata.setdefault("host_started", True)
self._trace = trace
self._round_index = len(trace.rounds)
def bootstrap_trace(
self,
*,
session_id: str,
ts_code: str,
trade_date: str,
metadata: Optional[Dict[str, Any]] = None,
) -> DialogueTrace:
trace = DialogueTrace(
session_id=session_id,
ts_code=ts_code,
trade_date=trade_date,
metadata=dict(metadata or {}),
)
self.bootstrap(trace)
return trace
def start_round(
self,
trace: DialogueTrace,
agenda: str,
structure: GameStructure,
) -> RoundSummary:
index = self._round_index
summary = RoundSummary(
index=index,
agenda=agenda,
structure=structure,
resolved=False,
)
trace.rounds.append(summary)
self._current_round = summary
self._round_index += 1
return summary
def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None:
summary.messages.append(message)
def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None:
summary.revisions.append(revision)
def finalize_round(self, summary: RoundSummary) -> None:
summary.resolved = True
summary.notes.setdefault("message_count", len(summary.messages))
self._current_round = None
def close(self, trace: DialogueTrace) -> None:
trace.metadata.setdefault("host_finished", True)
self._trace = trace
def current_round(self) -> Optional[RoundSummary]:
return self._current_round
@property
def trace(self) -> Optional[DialogueTrace]:
return self._trace
def ensure_round(
self,
trace: DialogueTrace,
agenda: str,
structure: GameStructure,
) -> RoundSummary:
if self._current_round and not self._current_round.resolved:
return self._current_round
return self.start_round(trace, agenda, structure)
def round_to_dict(summary: RoundSummary) -> Dict[str, Any]:
"""Serialize a round summary for persistence layers."""
return summary.to_dict()
__all__ = [
"GameStructure",
"DialogueRole",
"MessageType",
"DialogueMessage",
"BeliefSnapshot",
"BeliefRevision",
"RoundSummary",
"DialogueTrace",
"GameProtocol",
"ProtocolHost",
"round_to_dict",
]

View File

@ -4,6 +4,35 @@ from __future__ import annotations
from .base import Agent, AgentAction, AgentContext
class RiskRecommendation:
"""Represents structured recommendation from the risk agent."""
__slots__ = ("status", "reason", "recommended_action", "notes")
def __init__(
self,
*,
status: str,
reason: str,
recommended_action: AgentAction | None = None,
notes: dict | None = None,
) -> None:
self.status = status
self.reason = reason
self.recommended_action = recommended_action
self.notes = notes or {}
def to_dict(self) -> dict:
payload = {
"status": self.status,
"reason": self.reason,
"notes": dict(self.notes),
}
if self.recommended_action is not None:
payload["recommended_action"] = self.recommended_action.value
return payload
class RiskAgent(Agent):
def __init__(self) -> None:
super().__init__(name="A_risk")
@ -27,3 +56,74 @@ class RiskAgent(Agent):
if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L):
return False
return True
def assess(
self,
context: AgentContext,
decision_action: AgentAction,
conflict_flag: bool,
) -> RiskRecommendation:
features = dict(context.features or {})
risk_penalty = float(features.get("risk_penalty") or 0.0)
if bool(features.get("is_suspended")):
return RiskRecommendation(
status="blocked",
reason="suspended",
recommended_action=AgentAction.HOLD,
notes={"trigger": "is_suspended"},
)
if bool(features.get("limit_up")) and decision_action in {
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
status="blocked",
reason="limit_up",
recommended_action=AgentAction.HOLD,
notes={"trigger": "limit_up"},
)
if bool(features.get("position_limit")) and decision_action in {
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
status="pending_review",
reason="position_limit",
recommended_action=AgentAction.BUY_S,
notes={"trigger": "position_limit"},
)
if risk_penalty >= 0.9 and decision_action in {
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
status="blocked",
reason="risk_penalty_extreme",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
)
if risk_penalty >= 0.7 and decision_action in {
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
status="pending_review",
reason="risk_penalty_high",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
)
if conflict_flag:
return RiskRecommendation(
status="pending_review",
reason="conflict_threshold",
)
return RiskRecommendation(status="ok", reason="clear")

View File

@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from app.agents.base import AgentAction, AgentContext
from app.agents.departments import DepartmentManager
from app.agents.game import Decision, decide
from app.agents.game import Decision, decide, target_weight_for_action
from app.agents.protocols import round_to_dict
from app.llm.metrics import record_decision as metrics_record_decision
from app.agents.registry import default_agents
from app.data.schema import initialize_database
@ -16,6 +17,7 @@ from app.utils.data_access import DataBroker
from app.utils.config import get_config
from app.utils.db import db_session
from app.utils.logging import get_logger
from app.utils import alerts
from app.core.indicators import momentum, normalize, rolling_mean, volatility
@ -436,6 +438,7 @@ class BacktestEngine:
)
)
round_payload = [round_to_dict(summary) for summary in decision.rounds]
global_payload = {
"_confidence": decision.confidence,
"_target_weight": decision.target_weight,
@ -459,6 +462,12 @@ class BacktestEngine:
for code, dept in decision.department_decisions.items()
if dept.telemetry
},
"_rounds": round_payload,
"_risk_assessment": (
decision.risk_assessment.to_dict()
if decision.risk_assessment
else None
),
}
rows.append(
(
@ -543,18 +552,42 @@ class BacktestEngine:
executed_trades: List[Dict[str, Any]] = []
risk_events: List[Dict[str, Any]] = []
def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None:
def _record_risk(
ts_code: str,
reason: str,
decision: Decision,
extra: Optional[Dict[str, Any]] = None,
action_override: Optional[AgentAction] = None,
target_weight_override: Optional[float] = None,
) -> None:
payload = {
"trade_date": trade_date_str,
"ts_code": ts_code,
"action": decision.action.value,
"target_weight": decision.target_weight,
"action": (action_override or decision.action).value,
"target_weight": (
target_weight_override
if target_weight_override is not None
else decision.target_weight
),
"confidence": decision.confidence,
"reason": reason,
}
if extra:
payload.update(extra)
risk_events.append(payload)
risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None
status = None
if isinstance(risk_meta, dict):
status = risk_meta.get("status")
if status == "blocked":
try:
alerts.add_warning(
"backtest_risk",
f"{ts_code} 风险阻断: {reason}",
detail=json.dumps(payload, ensure_ascii=False),
)
except Exception: # noqa: BLE001
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
for ts_code, decision in decisions_map.items():
price = price_map.get(ts_code)
@ -569,35 +602,54 @@ class BacktestEngine:
limit_down = bool(features.get("limit_down"))
position_limit = bool(features.get("position_limit"))
risk = decision.risk_assessment
effective_action = decision.action
effective_weight = decision.target_weight
if risk:
risk_payload = risk.to_dict()
risk_payload.setdefault("applied_action", effective_action.value)
if risk.recommended_action:
effective_action = risk.recommended_action
risk_payload["applied_action"] = effective_action.value
effective_weight = target_weight_for_action(effective_action)
if risk.status != "ok":
_record_risk(
ts_code,
risk.reason,
decision,
extra={"risk_assessment": risk_payload},
action_override=effective_action,
target_weight_override=effective_weight,
)
if risk.status == "blocked":
continue
if is_suspended:
_record_risk(ts_code, "suspended", decision)
continue
if decision.action in self._buy_actions:
if effective_action in self._buy_actions:
if limit_up:
_record_risk(ts_code, "limit_up", decision)
_record_risk(ts_code, "limit_up", decision, action_override=effective_action)
continue
if position_limit:
_record_risk(ts_code, "position_limit", decision)
_record_risk(ts_code, "position_limit", decision, action_override=effective_action)
continue
if risk_penalty >= 0.95:
_record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty})
continue
if decision.action in self._sell_actions and limit_down:
_record_risk(ts_code, "limit_down", decision)
if effective_action in self._sell_actions and limit_down:
_record_risk(ts_code, "limit_down", decision, action_override=effective_action)
continue
effective_weight = max(decision.target_weight, 0.0)
if decision.action in self._buy_actions:
effective_weight_value = max(effective_weight, 0.0)
if effective_action in self._buy_actions:
capped_weight = min(effective_weight, self.risk_params["max_position_weight"])
effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty)
elif decision.action in self._sell_actions:
effective_weight = 0.0
effective_weight_value = capped_weight * max(0.0, 1.0 - risk_penalty)
elif effective_action in self._sell_actions:
effective_weight_value = 0.0
desired_qty = current_qty
if decision.action in self._sell_actions:
if effective_action in self._sell_actions:
desired_qty = 0.0
elif decision.action in self._buy_actions or effective_weight >= 0.0:
desired_value = max(effective_weight, 0.0) * portfolio_value_before
elif effective_action in self._buy_actions or effective_weight_value > 0.0:
desired_value = max(effective_weight_value, 0.0) * portfolio_value_before
desired_qty = desired_value / price if price > 0 else current_qty
delta = desired_qty - current_qty
@ -654,7 +706,8 @@ class BacktestEngine:
"slippage": trade_price - price,
"confidence": decision.confidence,
"target_weight": decision.target_weight,
"effective_weight": effective_weight,
"effective_weight": effective_weight_value,
"effective_action": effective_action.value,
"risk_penalty": risk_penalty,
"liquidity_score": liquidity_score,
"status": "executed",
@ -694,7 +747,8 @@ class BacktestEngine:
"slippage": price - trade_price,
"confidence": decision.confidence,
"target_weight": decision.target_weight,
"effective_weight": effective_weight,
"effective_weight": effective_weight_value,
"effective_action": effective_action.value,
"risk_penalty": risk_penalty,
"liquidity_score": liquidity_score,
"realized_pnl": realized,

View File

@ -201,6 +201,7 @@ def render_backtest_review() -> None:
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
risk_df = pd.DataFrame()
if selected_ids:
try:
with db_session(read_only=True) as conn:
@ -214,11 +215,20 @@ def render_backtest_review() -> None:
conn,
params=tuple(selected_ids),
)
risk_df = pd.read_sql_query(
"SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata "
"FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
conn,
params=tuple(selected_ids),
)
except Exception: # noqa: BLE001
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
st.error("读取回测结果失败")
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
risk_df = pd.DataFrame()
start_filter: Optional[date] = None
end_filter: Optional[date] = None
if not nav_df.empty:
try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
@ -274,6 +284,9 @@ def render_backtest_review() -> None:
"交易数": summary.get("trade_count"),
"平均换手": summary.get("avg_turnover"),
"风险事件": summary.get("risk_events"),
"风险分布": json.dumps(summary.get("risk_breakdown"), ensure_ascii=False)
if summary.get("risk_breakdown")
else None,
}
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
if metrics_rows:
@ -291,6 +304,70 @@ def render_backtest_review() -> None:
pass
except Exception: # noqa: BLE001
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
if not risk_df.empty:
try:
risk_df["trade_date"] = pd.to_datetime(risk_df["trade_date"], errors="coerce")
risk_df = risk_df.dropna(subset=["trade_date"])
if start_filter is None or end_filter is None:
start_filter = pd.to_datetime(risk_df["trade_date"].min()).date()
end_filter = pd.to_datetime(risk_df["trade_date"].max()).date()
risk_df = risk_df[
(risk_df["trade_date"].dt.date >= start_filter)
& (risk_df["trade_date"].dt.date <= end_filter)
]
parsed_cols: List[Dict[str, object]] = []
for _, row in risk_df.iterrows():
try:
metadata = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else (row["metadata"] or {})
except json.JSONDecodeError:
metadata = {}
assessment = metadata.get("risk_assessment") or {}
parsed_cols.append(
{
"cfg_id": row["cfg_id"],
"trade_date": row["trade_date"].date().isoformat(),
"ts_code": row["ts_code"],
"reason": row["reason"],
"action": row["action"],
"target_weight": row["target_weight"],
"confidence": row["confidence"],
"risk_status": assessment.get("status"),
"recommended_action": assessment.get("recommended_action"),
"execution_status": metadata.get("execution_status"),
"metadata": metadata,
}
)
risk_detail_df = pd.DataFrame(parsed_cols)
with st.expander("风险事件明细", expanded=False):
st.dataframe(risk_detail_df.drop(columns=["metadata"], errors="ignore"), hide_index=True, width='stretch')
try:
st.download_button(
"下载风险事件(CSV)",
data=risk_detail_df.to_csv(index=False),
file_name="bt_risk_events.csv",
mime="text/csv",
key="dl_risk_events",
)
except Exception:
pass
agg = risk_detail_df.groupby(["cfg_id", "reason", "risk_status"], dropna=False).size().reset_index(name="count")
st.dataframe(agg, hide_index=True, width='stretch')
try:
if not agg.empty:
agg_fig = px.bar(
agg,
x="reason",
y="count",
color="risk_status",
facet_col="cfg_id",
title="风险事件分布",
)
agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20))
st.plotly_chart(agg_fig, use_container_width=True)
except Exception: # noqa: BLE001
LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA)
except Exception: # noqa: BLE001
LOGGER.debug("渲染风险事件失败", extra=LOG_EXTRA)
else:
st.info("请选择至少一个配置进行对比。")

View File

@ -14,6 +14,7 @@ from .config import get_config
from .db import db_session
from .logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility
from app.utils.db_query import BrokerQueryEngine
# 延迟导入,避免循环依赖
collect_data_coverage = None
@ -60,6 +61,39 @@ def _safe_split(path: str) -> Tuple[str, str] | None:
return table, column
@dataclass
class _RefreshCoordinator:
"""Orchestrates background refresh requests for the broker."""
broker: "DataBroker"
def ensure_for_latest(self, trade_date: str, fields: Iterable[str]) -> None:
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
return
normalized = parsed_date.strftime("%Y%m%d")
tables = self._collect_tables(fields)
if tables and self.broker.check_data_availability(normalized, tables):
self.broker._trigger_background_refresh(normalized)
def ensure_for_series(self, end_date: str, table: str) -> None:
parsed_date = _parse_trade_date(end_date)
if not parsed_date:
return
normalized = parsed_date.strftime("%Y%m%d")
if self.broker.check_data_availability(normalized, {table}):
self.broker._trigger_background_refresh(normalized)
def _collect_tables(self, fields: Iterable[str]) -> Set[str]:
tables: Set[str] = set()
for field_name in fields:
resolved = self.broker.resolve_field(field_name)
if resolved:
table, _ = resolved
tables.add(table)
return tables
def parse_field_path(path: str) -> Tuple[str, str] | None:
"""Validate and split a `table.column` field expression."""
@ -87,6 +121,17 @@ def _end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 23:59:59")
def _coerce_date(value: object) -> Optional[date]:
if value is None:
return None
if isinstance(value, date):
return value
parsed = _parse_trade_date(value)
if parsed:
return parsed.date()
return None
@dataclass
class DataBroker:
"""Lightweight data access helper with automated data fetching capabilities."""
@ -130,6 +175,8 @@ class DataBroker:
_refresh_in_progress: Dict[str, bool] = field(init=False, repr=False)
_refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False)
_coverage_cache: Dict[str, Dict] = field(init=False, repr=False)
_refresh: _RefreshCoordinator = field(init=False, repr=False)
_query_engine: BrokerQueryEngine = field(init=False, repr=False)
def __post_init__(self) -> None:
self._latest_cache = OrderedDict()
@ -139,6 +186,8 @@ class DataBroker:
self._refresh_in_progress = {}
self._refresh_callbacks = {}
self._coverage_cache = {}
self._refresh = _RefreshCoordinator(self)
self._query_engine = BrokerQueryEngine(db_session)
if initialize_database is not None:
initialize_database() # 确保数据库已初始化
else:
@ -169,22 +218,7 @@ class DataBroker:
# 检查是否需要自动补数
if auto_refresh:
# 解析交易日以确定是否需要补数
parsed_date = _parse_trade_date(trade_date)
if parsed_date:
# 检查最近交易日的数据是否存在
recent_trade_date = parsed_date.strftime('%Y%m%d')
# 对涉及的表进行数据可用性检查
tables = set()
for field_name in field_list:
resolved = self.resolve_field(field_name)
if resolved:
table, _ = resolved
tables.add(table)
if tables and self.check_data_availability(recent_trade_date, tables):
# 数据不足,触发后台补数
self._trigger_background_refresh(recent_trade_date)
self._refresh.ensure_for_latest(trade_date, field_list)
grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
@ -208,59 +242,41 @@ class DataBroker:
grouped[table].append(column)
field_map.setdefault((table, column), []).append(field_name)
if not grouped:
if cache_key is not None and results:
self._cache_store(
self._latest_cache,
cache_key,
deepcopy(results),
self.latest_cache_size,
)
return results
try:
with db_session(read_only=True) as conn:
for table, columns in grouped.items():
joined_cols = ", ".join(columns)
query = (
f"SELECT trade_date, {joined_cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
try:
row = conn.execute(query, (ts_code, trade_date)).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
)
continue
if not row:
continue
for column in columns:
value = row[column]
if value is None:
continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
except sqlite3.OperationalError as exc:
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
if cache_key is not None:
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
if grouped:
for table, columns in grouped.items():
try:
row = self._query_engine.fetch_latest(table, ts_code, trade_date, columns)
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"使用缓存结果 ts_code=%s trade_date=%s",
ts_code,
trade_date,
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
)
return deepcopy(cached)
continue
if not row:
continue
for column in columns:
value = row[column]
if value is None:
continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
if cache_key is not None and not results:
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
LOGGER.debug(
"使用缓存结果 ts_code=%s trade_date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA,
)
return deepcopy(cached)
if cache_key is not None and results:
self._cache_store(
self._latest_cache,
@ -306,9 +322,7 @@ class DataBroker:
# 检查是否需要自动补数
if auto_refresh:
parsed_date = _parse_trade_date(end_date)
if parsed_date and self.check_data_availability(end_date, {table}):
self._trigger_background_refresh(end_date)
self._refresh.ensure_for_series(end_date, table)
cache_key: Optional[Tuple[Any, ...]] = None
if self.enable_cache:
@ -323,21 +337,10 @@ class DataBroker:
"ORDER BY trade_date DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
try:
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列查询失败 table=%s column=%s err=%s",
table,
column,
exc,
extra=LOG_EXTRA,
)
return []
except sqlite3.OperationalError as exc:
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列连接失败 table=%s column=%s err=%s",
"时间序列查询失败 table=%s column=%s err=%s",
table,
column,
exc,
@ -358,9 +361,13 @@ class DataBroker:
series: List[Tuple[str, float]] = []
for row in rows:
value = row[resolved]
if value is None:
trade_dt = row["trade_date"]
if value is None or trade_dt is None:
continue
try:
series.append((trade_dt, float(value)))
except (TypeError, ValueError):
continue
series.append((row["trade_date"], float(value)))
if cache_key is not None and series:
self._cache_store(
self._series_cache,
@ -370,6 +377,32 @@ class DataBroker:
)
return series
def register_refresh_callback(
self,
start: date | str,
end: date | str,
callback: Callable[[], None],
) -> None:
"""Register a hook invoked after background refresh completes for the window."""
if callback is None:
return
start_date = _coerce_date(start)
end_date = _coerce_date(end)
if not start_date or not end_date:
LOGGER.debug(
"忽略无效补数回调窗口 start=%s end=%s",
start,
end,
extra=LOG_EXTRA,
)
return
key = f"{start_date}_{end_date}"
with self._refresh_lock:
bucket = self._refresh_callbacks.setdefault(key, [])
if callback not in bucket:
bucket.append(callback)
def fetch_flags(
self,
table: str,
@ -436,48 +469,19 @@ class DataBroker:
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
return []
column_list = ", ".join(columns)
has_trade_date = "trade_date" in columns
if has_trade_date:
query = (
f"SELECT {column_list} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
params: Tuple[object, ...] = (ts_code, trade_date, window)
else:
query = (
f"SELECT {column_list} FROM {table} "
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
)
params = (ts_code, window)
results: List[Dict[str, object]] = []
try:
with db_session(read_only=True) as conn:
try:
rows = conn.execute(query, params).fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"表查询失败 table=%s err=%s",
table,
exc,
extra=LOG_EXTRA,
)
return []
except sqlite3.OperationalError as exc:
LOGGER.debug(
"表连接失败 table=%s err=%s",
rows = self._query_engine.fetch_table(
table,
exc,
extra=LOG_EXTRA,
columns,
ts_code,
trade_date if "trade_date" in columns else None,
window,
)
except Exception as exc: # noqa: BLE001
LOGGER.debug("表查询失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
return []
for row in rows:
record = {col: row[col] for col in columns}
results.append(record)
return results
return [{col: row[col] for col in columns} for row in rows]
def _resolve_derived_field(
self,
@ -924,13 +928,20 @@ class DataBroker:
with self._refresh_lock:
callbacks = self._refresh_callbacks.pop(refresh_key, [])
self._refresh_in_progress[refresh_key] = False
if callbacks:
LOGGER.info(
"执行补数回调 count=%s key=%s",
len(callbacks),
refresh_key,
extra=LOG_EXTRA,
)
for callback in callbacks:
try:
callback()
except Exception as exc:
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
except Exception as exc:
LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA)
with self._refresh_lock:

73
app/utils/db_query.py Normal file
View File

@ -0,0 +1,73 @@
"""Shared read-only query helpers for database access."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, List, Mapping, Optional, Sequence
@dataclass
class BrokerQueryEngine:
"""Lightweight wrapper around standard query patterns."""
session_factory: Callable[..., object]
def fetch_latest(
self,
table: str,
ts_code: str,
trade_date: str,
columns: Sequence[str],
) -> Optional[Mapping[str, object]]:
if not columns:
return None
joined_cols = ", ".join(columns)
query = (
f"SELECT trade_date, {joined_cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
with self.session_factory(read_only=True) as conn:
return conn.execute(query, (ts_code, trade_date)).fetchone()
def fetch_series(
self,
table: str,
column: str,
ts_code: str,
end_date: str,
limit: int,
) -> List[Mapping[str, object]]:
query = (
f"SELECT trade_date, {column} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
with self.session_factory(read_only=True) as conn:
rows = conn.execute(query, (ts_code, end_date, limit)).fetchall()
return list(rows)
def fetch_table(
self,
table: str,
columns: Iterable[str],
ts_code: str,
trade_date: Optional[str],
limit: int,
) -> List[Mapping[str, object]]:
cols = ", ".join(columns)
if trade_date is None:
query = (
f"SELECT {cols} FROM {table} "
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
)
params: Sequence[object] = (ts_code, limit)
else:
query = (
f"SELECT {cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
params = (ts_code, trade_date, limit)
with self.session_factory(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
return list(rows)

41
docs/RISK_AGENT_PLAN.md Normal file
View File

@ -0,0 +1,41 @@
# 风险代理集成规划
## 目标
- 将 `risk_guard` 回合由占位消息升级为可执行决策:低置信度的部门共识需自动进入复核流程,根据风险因子、仓位约束、外部告警调整最终指令。
- 兼容历史回测与实时监控,确保补数、验证和告警闭环共享风险上下文。
## 数据与信号需求
- **输入特征**`risk_penalty`、`position_limit`、`is_suspended` 等现有布尔/得分信号新增日内波动、VaR、行业集中度、外部事件标签停牌预警、合规黑名单
- **实时事件**:补数触发的异常、执行失败回报、仓位越界日志;需要通过 `DataBroker.register_refresh_callback()` 和交易执行层的回调管理。
- **上下文结构**:在 `AgentContext.raw` 中携带 `risk_flags`、`compliance_notes`,供主持器和前端展示。
## 决策流程改造
1. **风险评估阶段**
- 在部门回合后触发 `risk_round`,由 `RiskAgent` 复写 `Decision``requires_review``target_weight`,必要时生成“回滚/减仓”建议。
- 引入 `RiskAssessment` 数据类,存储风险来源、建议操作、置信度,序列化到 `Decision.rounds`
2. **执行协调阶段**
- 若风险回合给出回滚指令,则 `execution_round` 应记录“冻结执行”或“调整仓位”而非直接落地。
- 将风险回合结论写入 `risk_events` 表(或新建),供 UI 与监控使用。
3. **日志与监控**
- 在 `risk_round` 中附加 `annotations` 字段:`{"breach_metrics": {...}, "actions": [...]}`。
- 通过补数回调和执行回调,将风险事件推送到监控指标,如“复核触发率”“回滚成功率”。
## 实施里程碑
1. **原型阶段**
- ✅ 重构 `RiskAgent` 以返回策略建议(持仓调整、止损触发)。
- ✅ 扩展 `Decision` 结构,增加 `risk_assessment` 字段。
- ✅ 更新 `ProtocolHost` 将风险建议纳入 `risk_round.notes`
2. **集成验证**
- 在回测环境构造冲突样例,验证“冲突→风险建议→执行调整”链路。
- 新增测试:模拟停牌、仓位超限、黑名单事件,确认风险代理逻辑。
3. **实盘准备**
- 接入实时告警渠道(如风控系统/合规接口)。
- 监控接入:统计复核频率、回滚动作、失败告警。
4. **上线迭代**
- 影子运行记录风险建议与实际执行差异。
- 评估模型表现,迭代规则或引入强化学习风险控制策略。
## 未决事项
- 风险代理与执行模块的数据同步接口(数据库/消息队列)。
- 与合规团队确认风险阈值与回滚条件。
- 定义UI展示风险回合的建议、证据引用、执行结果。

View File

@ -27,6 +27,75 @@
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比
- 完善 `BacktestEngine` 的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端,形成无人值守的自动闭环。
- 新增多智能体多轮逻辑博弈线路:把投资场景建成重复博弈,设定主持、预测、风险等角色分轮交换信号、提出假设、相互反驳,再由执行代理落地操作;通过信念修正规则与合规约束驱动策略共识,而非依赖优化搜索。
### 3.2 多智能体多轮逻辑博弈实施方案
1. 角色与知识库建模
- 定义主持(议程控制)、预测(市场观点)、风险(合规阈值)、执行(指令生成)等核心代理,并划分共享/私有信息。
- 为每个代理建立信号源与知识库接口(行情、新闻、风险限额),设计可信度权重与更新频率。
- 预设重复博弈、信号博弈、贝叶斯博弈等多种结构模板,支持根据市场场景切换或组合,保障架构拓展性。
2. 对话协议与消息格式
- 设计多轮流程:议程发布→观点陈述→证据提交→反驳与驳回→共识决议→执行提示。
- 约定消息 schema动作类型、置信度、引用证据及主持代理的轮次控制逻辑确保所有代理均可追踪历史发言。
3. 信念修正与逻辑推理引擎
- 实现迭代信念修正模块,结合可信度权重与显式逻辑规则更新各代理的市场信念。
- 引入论证框架Argumentation Framework或模态逻辑规则库支持观点检验、冲突解决与风险约束自动否决。
4. 模拟与测试环境
- 搭建历史回测驱动的多轮博弈仿真环境,让代理在真实行情序列上互动。
- 编写回放脚本记录每轮对话、信念轨迹与最终行动,产出可视化报告用于策略审查。
5. 执行映射与风控集成
- 将共识决议映射到具体操作(换仓、对冲、仓位调节),并与 `BacktestEngine`/实盘执行接口打通。
- 风险代理负责实时校验 VaR、仓位集中度等阈值违规时触发主持代理重新议程或引入驳回指令。
6. 监控与评估指标
- 设计协作质量指标(轮次收敛时间、冲突率、执行合规度)和业绩指标(收益、回撤、超额收益稳定性)。
- 将指标纳入监控面板,与现有回测、实盘监控体系统一展示。
7. 增量迭代计划
- 从线下模块联调 → 回测闭环 → 影子运行 → 小资金试点的节奏推进,每阶段设定成功退出标准。
- 收集各阶段的失败案例与风险事件,迭代规则库、信念修正规则与角色职责。
### 3.3 代码改造计划(多轮博弈适配)
1. 架构基线评估
- ⏳ 绘制代理/部门/回测调用图,补充日志字段(缺数告警、补数来源、议程标识)并形成诊断报告。
- ✅ 定义多轮博弈上下文结构(消息历史、信念状态、引用证据),输出数据类与通信协议草稿。
- ✅ 在 `app/agents/protocols.py` 基础上补充主持/执行状态管理,实现 `DialogueTrace` 与部门上下文的对接路径。
- ✅ 扩展 `Decision.rounds``RoundSummary` 采集策略,用于串联部门结论与多轮议程结果。
- ✅ 基于 `ProtocolHost` 设计主持驱动的议程模板,明确多结构策略扩展方式。
- ✅ 扩展主持对风险复核、执行总结等议程的支持,规划风控代理与回滚逻辑。
- ✅ 落地风险议程触发条件(冲突阈值、风控代理信号、外部警报),同步更新回测策略。
- ✅ 结合 `docs/RISK_AGENT_PLAN.md` 明确 RiskAgent 升级里程碑、数据接口及前端展示需求。
- ✅ 定义 `risk_assessment` 数据结构(状态、原因、建议动作),在决策输出与存储层保持一致。
- ✅ 将风险建议与执行回合对齐,执行阶段识别 `risk_adjusted` 并记录原始动作。
2. 数据与因子重构
- ✅ 拆分 `DataBroker` 查询层(`BrokerQueryEngine`),补数逻辑独立于查询管道。
- ⏳ 按主题拆分因子模块,存储缺口/异常标签,改写 `load_market_data()` 为“缺失即说明”。
- ⏳ 维护博弈结构 → 数据 scope 映射,支持角色按结构加载差异化字段。
- ✅ 基于 `_RefreshCoordinator` 落地刷新队列与监控事件,拆分查询与补数路径。
- ✅ 暴露 `DataBroker.register_refresh_callback()` 钩子,结合监控系统记录补数进度与失败重试。
- ⏳ 统一补数回调日志格式(`LOG_EXTRA.stage=data_broker`),为后续指标预留数据源。
3. 多轮博弈框架
- ✅ 在 `app/agents/game.py` 抽象 `GameProtocol` 接口,扩展 `Decision` 记录多轮对话。
- ✅ 实现主持调度器驱动议程(信息→陈述→反驳→共识→执行),挂载风险复核机制。
- ⏳ 引入信念修正规则与论证框架,支持证据引用和冲突裁决。
4. 执行与回测集成
- ✅ 将回测循环改造成“每日多轮→执行摘要”,完成风控校验与冲突重议流程。
- ⏳ 擦合订单映射层,明确多轮结果对应目标仓位、执行节奏、异常回滚策略。
- ✅ 回测执行根据 `risk_assessment` 调整动作与目标权重,记录执行状态。
5. 监控与可视化
- ⏳ 在数据库/UI 增补轮次日志、信念轨迹、冲突原因与执行结果的可视化。
- ⏳ 定义多结构绩效指标并纳入监控面板(回测视图已提供风险事件统计与分布图)。
- ⏳ 补充补数回调监控告警策略(成功率、排队时长、重试次数)。
- ✅ 回测视图展示 `_risk_assessment``execution_status`,后续扩展到实时监控面板。
- ⏳ 完善风险告警(`alerts.backtest_risk`)阈值与通知策略并扩展至实盘。
6. 测试与迭代治理
- ⏳ 构建历史行情驱动的多轮仿真回放,覆盖结构切换与异常场景。
- 新增单元/集成测试:议程流程、信念修正、补数竞态、执行映射。
- ✅ 风险执行集成测试完成(`tests/test_decision_risk_integration.py`),验证风险调节与执行日志。
- 规划上线节奏:线下 PoC → 回测闭环 → 影子运行 → 小资金试点,记录失败案例反馈改进。
- 在测试计划中涵盖风险/执行议程的触发条件与回滚路径,确保主持扩展的行为可验证。
- 新增针对风险复核和执行议程的集成测试场景,覆盖“冲突→复核→执行/回滚”链路。
- 对照 `docs/RISK_AGENT_PLAN.md` 拆解阶段性验收项:数据校验、监控指标、影子运行报告。
- 为 `risk_assessment` 输出设计单元测试(冲突触发、无冲突、外部告警模拟),确保状态与原因正确覆盖。
- 编写执行阶段测试,验证 `risk_adjusted` 状态下的目标仓位与日志记录。
### 3.1 实施步骤(建议顺序)
1. 环境重构:扩展 `DecisionEnv` 支持逐日状态/动作/奖励,完善 `BacktestEngine` 的状态保存与恢复接口,并补充必要的数据库读写钩子。

View File

@ -0,0 +1,119 @@
"""Integration-style tests for risk-aware decision execution."""
from __future__ import annotations
from datetime import date
import pytest
from app.agents.base import AgentAction, AgentContext
from app.agents.game import decide
from app.agents.registry import default_agents
from app.agents.risk import RiskAgent, RiskRecommendation
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
def _make_context(features: dict, *, alerts: list[str] | None = None) -> AgentContext:
raw = {
"scope_values": {"daily.close": 100.0},
}
if alerts:
raw["risk_alerts"] = alerts
return AgentContext(
ts_code="000001.SZ",
trade_date="2025-01-10",
features=features,
market_snapshot={},
raw=raw,
)
class _StubRiskAgent(RiskAgent):
def __init__(self, recommendation: RiskRecommendation) -> None:
super().__init__()
self._recommendation = recommendation
def assess(
self,
context: AgentContext,
decision_action: AgentAction,
conflict_flag: bool,
) -> RiskRecommendation:
return self._recommendation
def test_decide_adjusts_execution_on_risk_recommendation(monkeypatch):
agents = default_agents()
recommendation = RiskRecommendation(
status="blocked",
reason="risk_penalty_extreme",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": 0.95},
)
stub_risk = _StubRiskAgent(recommendation)
agents = [stub_risk if isinstance(agent, RiskAgent) else agent for agent in agents]
context = _make_context({"risk_penalty": 0.95})
decision = decide(
context,
agents,
weights={agent.name: 1.0 for agent in agents},
department_manager=None,
)
assert decision.requires_review is True
assert decision.risk_assessment
assert decision.risk_assessment.status == "blocked"
assert decision.risk_assessment.recommended_action == AgentAction.HOLD
execution_rounds = [round for round in decision.rounds if round.agenda == "execution_summary"]
assert execution_rounds
execution_notes = execution_rounds[0].notes
assert execution_notes.get("execution_status") == "risk_adjusted"
assert execution_rounds[0].outcome == AgentAction.HOLD.value
def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch):
cfg = BtConfig(
id="risk-test",
name="risk-test",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
engine = BacktestEngine(cfg)
state = PortfolioState(cash=100_000.0)
result = BacktestResult()
context = _make_context({"risk_penalty": 0.95})
recommendation = RiskRecommendation(
status="blocked",
reason="risk_penalty_extreme",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": 0.95},
)
agents = [
_StubRiskAgent(recommendation) if isinstance(agent, RiskAgent) else agent
for agent in engine.agents
]
engine.agents = agents
engine.department_manager = None
decision = decide(
context,
engine.agents,
engine.weights,
department_manager=None,
)
engine._apply_portfolio_updates(
date(2025, 1, 10),
state,
[("000001.SZ", context, decision)],
result,
)
assert not state.holdings
assert not result.trades
assert result.nav_series[0]["nav"] == pytest.approx(100_000.0)

63
tests/test_risk_agent.py Normal file
View File

@ -0,0 +1,63 @@
"""Tests for RiskAgent assessment and risk evaluation pipeline."""
from __future__ import annotations
from app.agents.base import AgentAction, AgentContext
from app.agents.game import _evaluate_risk
from app.agents.risk import RiskAgent, RiskRecommendation
def _make_context(**features: float | bool) -> AgentContext:
return AgentContext(
ts_code="000001.SZ",
trade_date="2025-01-01",
features=features,
market_snapshot={},
raw={},
)
def test_risk_agent_ok_status() -> None:
agent = RiskAgent()
context = _make_context(risk_penalty=0.1)
recommendation = agent.assess(context, AgentAction.BUY_S, conflict_flag=False)
assert recommendation.status == "ok"
assert recommendation.reason == "clear"
def test_risk_agent_blocked_on_limit_up() -> None:
agent = RiskAgent()
context = _make_context(limit_up=True)
recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False)
assert recommendation.status == "blocked"
assert recommendation.reason == "limit_up"
assert recommendation.recommended_action == AgentAction.HOLD
def test_risk_agent_pending_on_conflict() -> None:
agent = RiskAgent()
context = _make_context()
recommendation = agent.assess(context, AgentAction.HOLD, conflict_flag=True)
assert recommendation.status == "pending_review"
assert recommendation.reason == "conflict_threshold"
def test_evaluate_risk_external_alerts() -> None:
agent = RiskAgent()
context = AgentContext(
ts_code="000002.SZ",
trade_date="2025-01-01",
features={"risk_penalty": 0.1},
market_snapshot={},
raw={"risk_alerts": ["sudden_news"]},
)
assessment = _evaluate_risk(
context=context,
action=AgentAction.HOLD,
department_votes={"buy": 0.6},
conflict_flag=False,
risk_agent=agent,
)
assert assessment.status == "pending_review"
assert assessment.reason == "external_alert"
assert assessment.recommended_action == AgentAction.HOLD
assert "external_alerts" in assessment.notes