add risk assessment and protocol tracking to game engine
This commit is contained in:
parent
721d59d1cf
commit
3c15d443d3
@ -8,6 +8,15 @@ from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
||||
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
|
||||
from .registry import weight_map
|
||||
from .risk import RiskAgent, RiskRecommendation
|
||||
from .protocols import (
|
||||
DialogueMessage,
|
||||
DialogueRole,
|
||||
GameStructure,
|
||||
MessageType,
|
||||
ProtocolHost,
|
||||
RoundSummary,
|
||||
)
|
||||
|
||||
|
||||
ACTIONS: Tuple[AgentAction, ...] = (
|
||||
@ -23,6 +32,30 @@ def _clamp(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeliefUpdate:
|
||||
belief: Dict[str, object]
|
||||
rationale: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskAssessment:
|
||||
status: str
|
||||
reason: str
|
||||
recommended_action: Optional[AgentAction] = None
|
||||
notes: Dict[str, object] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
payload: Dict[str, object] = {
|
||||
"status": self.status,
|
||||
"reason": self.reason,
|
||||
"notes": dict(self.notes),
|
||||
}
|
||||
if self.recommended_action is not None:
|
||||
payload["recommended_action"] = self.recommended_action.value
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class Decision:
|
||||
action: AgentAction
|
||||
@ -33,6 +66,9 @@ class Decision:
|
||||
department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict)
|
||||
department_votes: Dict[str, float] = field(default_factory=dict)
|
||||
requires_review: bool = False
|
||||
rounds: List[RoundSummary] = field(default_factory=list)
|
||||
risk_assessment: Optional[RiskAssessment] = None
|
||||
belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict)
|
||||
|
||||
|
||||
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
|
||||
@ -132,6 +168,16 @@ def decide(
|
||||
raw_weights = dict(weights)
|
||||
department_decisions: Dict[str, DepartmentDecision] = {}
|
||||
department_votes: Dict[str, float] = {}
|
||||
host = ProtocolHost()
|
||||
host_trace = host.bootstrap_trace(
|
||||
session_id=f"{context.ts_code}:{context.trade_date}",
|
||||
ts_code=context.ts_code,
|
||||
trade_date=context.trade_date,
|
||||
)
|
||||
department_round: Optional[RoundSummary] = None
|
||||
risk_round: Optional[RoundSummary] = None
|
||||
execution_round: Optional[RoundSummary] = None
|
||||
belief_updates: Dict[str, BeliefUpdate] = {}
|
||||
|
||||
if department_manager:
|
||||
dept_context = department_context
|
||||
@ -144,6 +190,12 @@ def decide(
|
||||
raw=dict(getattr(context, "raw", {}) or {}),
|
||||
)
|
||||
department_decisions = department_manager.evaluate(dept_context)
|
||||
if department_decisions:
|
||||
department_round = host.start_round(
|
||||
host_trace,
|
||||
agenda="department_consensus",
|
||||
structure=GameStructure.REPEATED,
|
||||
)
|
||||
for code, decision in department_decisions.items():
|
||||
agent_key = f"dept_{code}"
|
||||
dept_agent = department_manager.agents.get(code)
|
||||
@ -155,6 +207,17 @@ def decide(
|
||||
bucket = _department_vote_bucket(decision.action)
|
||||
if bucket:
|
||||
department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence
|
||||
if department_round:
|
||||
message = _department_message(code, decision)
|
||||
host.handle_message(department_round, message)
|
||||
belief_updates[code] = BeliefUpdate(
|
||||
belief={
|
||||
"action": decision.action.value,
|
||||
"confidence": decision.confidence,
|
||||
"signals": decision.signals,
|
||||
},
|
||||
rationale=decision.summary,
|
||||
)
|
||||
|
||||
filtered_utilities = {action: utilities[action] for action in feas_actions}
|
||||
hold_scores = utilities.get(AgentAction.HOLD, {})
|
||||
@ -168,7 +231,111 @@ def decide(
|
||||
action, confidence = vote(filtered_utilities, norm_weights)
|
||||
|
||||
weight = target_weight_for_action(action)
|
||||
requires_review = _department_conflict_flag(department_votes)
|
||||
conflict_flag = _department_conflict_flag(department_votes)
|
||||
|
||||
risk_agent = _find_risk_agent(agent_list)
|
||||
risk_assessment = _evaluate_risk(
|
||||
context,
|
||||
action,
|
||||
department_votes,
|
||||
conflict_flag,
|
||||
risk_agent,
|
||||
)
|
||||
requires_review = risk_assessment.status != "ok"
|
||||
|
||||
if department_round:
|
||||
department_round.notes.setdefault("department_votes", dict(department_votes))
|
||||
department_round.outcome = action.value
|
||||
host.finalize_round(department_round)
|
||||
|
||||
if requires_review:
|
||||
risk_round = host.ensure_round(
|
||||
host_trace,
|
||||
agenda="risk_review",
|
||||
structure=GameStructure.CUSTOM,
|
||||
)
|
||||
review_message = DialogueMessage(
|
||||
sender="risk_guard",
|
||||
role=DialogueRole.RISK,
|
||||
message_type=MessageType.COUNTER,
|
||||
content=_risk_review_message(risk_assessment.reason),
|
||||
confidence=1.0,
|
||||
references=list(department_votes.keys()),
|
||||
annotations={
|
||||
"department_votes": dict(department_votes),
|
||||
"risk_reason": risk_assessment.reason,
|
||||
"recommended_action": (
|
||||
risk_assessment.recommended_action.value
|
||||
if risk_assessment.recommended_action
|
||||
else None
|
||||
),
|
||||
"notes": dict(risk_assessment.notes),
|
||||
},
|
||||
)
|
||||
host.handle_message(risk_round, review_message)
|
||||
risk_round.notes.setdefault("status", risk_assessment.status)
|
||||
risk_round.notes.setdefault("reason", risk_assessment.reason)
|
||||
if risk_assessment.recommended_action:
|
||||
risk_round.notes.setdefault(
|
||||
"recommended_action",
|
||||
risk_assessment.recommended_action.value,
|
||||
)
|
||||
risk_round.outcome = "REVIEW"
|
||||
host.finalize_round(risk_round)
|
||||
belief_updates["risk_guard"] = BeliefUpdate(
|
||||
belief={
|
||||
"status": risk_assessment.status,
|
||||
"reason": risk_assessment.reason,
|
||||
"recommended_action": (
|
||||
risk_assessment.recommended_action.value
|
||||
if risk_assessment.recommended_action
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
execution_round = host.ensure_round(
|
||||
host_trace,
|
||||
agenda="execution_summary",
|
||||
structure=GameStructure.REPEATED,
|
||||
)
|
||||
exec_action = action
|
||||
exec_weight = weight
|
||||
exec_status = "normal"
|
||||
if requires_review and risk_assessment.recommended_action:
|
||||
exec_action = risk_assessment.recommended_action
|
||||
exec_status = "risk_adjusted"
|
||||
exec_weight = target_weight_for_action(exec_action)
|
||||
execution_message = DialogueMessage(
|
||||
sender="execution_engine",
|
||||
role=DialogueRole.EXECUTION,
|
||||
message_type=MessageType.DIRECTIVE,
|
||||
content=f"执行操作 {exec_action.value}",
|
||||
confidence=1.0,
|
||||
annotations={
|
||||
"target_weight": exec_weight,
|
||||
"requires_review": requires_review,
|
||||
"execution_status": exec_status,
|
||||
},
|
||||
)
|
||||
host.handle_message(execution_round, execution_message)
|
||||
execution_round.outcome = exec_action.value
|
||||
execution_round.notes.setdefault("execution_status", exec_status)
|
||||
if exec_action is not action:
|
||||
execution_round.notes.setdefault("original_action", action.value)
|
||||
belief_updates["execution"] = BeliefUpdate(
|
||||
belief={
|
||||
"execution_status": exec_status,
|
||||
"action": exec_action.value,
|
||||
"target_weight": exec_weight,
|
||||
},
|
||||
)
|
||||
host.finalize_round(execution_round)
|
||||
host.close(host_trace)
|
||||
rounds = host_trace.rounds if host_trace.rounds else _build_round_summaries(
|
||||
department_decisions,
|
||||
action,
|
||||
department_votes,
|
||||
)
|
||||
return Decision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
@ -178,6 +345,9 @@ def decide(
|
||||
department_decisions=department_decisions,
|
||||
department_votes=department_votes,
|
||||
requires_review=requires_review,
|
||||
rounds=rounds,
|
||||
risk_assessment=risk_assessment,
|
||||
belief_updates=belief_updates,
|
||||
)
|
||||
|
||||
|
||||
@ -231,3 +401,145 @@ def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
|
||||
if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage:
|
||||
content = decision.summary or decision.raw_response or decision.action.value
|
||||
references = decision.signals or []
|
||||
annotations: Dict[str, object] = {
|
||||
"risks": decision.risks,
|
||||
"supplements": decision.supplements,
|
||||
}
|
||||
if decision.dialogue:
|
||||
annotations["dialogue"] = decision.dialogue
|
||||
if decision.telemetry:
|
||||
annotations["telemetry"] = decision.telemetry
|
||||
return DialogueMessage(
|
||||
sender=code,
|
||||
role=DialogueRole.PREDICTION,
|
||||
message_type=MessageType.DECISION,
|
||||
content=content,
|
||||
confidence=decision.confidence,
|
||||
references=references,
|
||||
annotations=annotations,
|
||||
)
|
||||
|
||||
|
||||
def _evaluate_risk(
|
||||
context: AgentContext,
|
||||
action: AgentAction,
|
||||
department_votes: Mapping[str, float],
|
||||
conflict_flag: bool,
|
||||
risk_agent: Optional[RiskAgent],
|
||||
) -> RiskAssessment:
|
||||
external_alerts = []
|
||||
if getattr(context, "raw", None):
|
||||
alerts = context.raw.get("risk_alerts", [])
|
||||
if alerts:
|
||||
external_alerts = list(alerts)
|
||||
|
||||
if risk_agent:
|
||||
recommendation = risk_agent.assess(context, action, conflict_flag)
|
||||
notes = dict(recommendation.notes)
|
||||
notes.setdefault("department_votes", dict(department_votes))
|
||||
if external_alerts:
|
||||
notes.setdefault("external_alerts", external_alerts)
|
||||
if recommendation.status == "ok":
|
||||
recommendation = RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="external_alert",
|
||||
recommended_action=recommendation.recommended_action or AgentAction.HOLD,
|
||||
notes=notes,
|
||||
)
|
||||
else:
|
||||
recommendation.notes = notes
|
||||
return RiskAssessment(
|
||||
status=recommendation.status,
|
||||
reason=recommendation.reason,
|
||||
recommended_action=recommendation.recommended_action,
|
||||
notes=recommendation.notes,
|
||||
)
|
||||
|
||||
notes: Dict[str, object] = {
|
||||
"conflict": conflict_flag,
|
||||
"department_votes": dict(department_votes),
|
||||
}
|
||||
if external_alerts:
|
||||
notes["external_alerts"] = external_alerts
|
||||
return RiskAssessment(
|
||||
status="pending_review",
|
||||
reason="external_alert",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes=notes,
|
||||
)
|
||||
if conflict_flag:
|
||||
return RiskAssessment(
|
||||
status="pending_review",
|
||||
reason="conflict_threshold",
|
||||
notes=notes,
|
||||
)
|
||||
return RiskAssessment(status="ok", reason="clear", notes=notes)
|
||||
|
||||
|
||||
def _find_risk_agent(agents: Iterable[Agent]) -> Optional[RiskAgent]:
|
||||
for agent in agents:
|
||||
if isinstance(agent, RiskAgent):
|
||||
return agent
|
||||
return None
|
||||
|
||||
|
||||
def _risk_review_message(reason: str) -> str:
|
||||
mapping = {
|
||||
"conflict_threshold": "部门意见分歧,触发风险复核",
|
||||
"suspended": "标的停牌,需冻结执行",
|
||||
"limit_up": "标的涨停,执行需调整",
|
||||
"position_limit": "仓位限制已触发,需调整目标",
|
||||
"risk_penalty_extreme": "风险评分极高,建议暂停加仓",
|
||||
"risk_penalty_high": "风险评分偏高,建议复核",
|
||||
"external_alert": "外部风险告警触发复核",
|
||||
}
|
||||
return mapping.get(reason, "触发风险复核,需人工确认")
|
||||
|
||||
|
||||
def _build_round_summaries(
|
||||
department_decisions: Mapping[str, DepartmentDecision],
|
||||
final_action: AgentAction,
|
||||
department_votes: Mapping[str, float],
|
||||
) -> List[RoundSummary]:
|
||||
if not department_decisions:
|
||||
return []
|
||||
messages: List[DialogueMessage] = []
|
||||
for code, decision in department_decisions.items():
|
||||
content = decision.summary or decision.raw_response or decision.action.value
|
||||
references = decision.signals or []
|
||||
annotations: Dict[str, object] = {
|
||||
"risks": decision.risks,
|
||||
"supplements": decision.supplements,
|
||||
}
|
||||
if decision.dialogue:
|
||||
annotations["dialogue"] = decision.dialogue
|
||||
if decision.telemetry:
|
||||
annotations["telemetry"] = decision.telemetry
|
||||
message = DialogueMessage(
|
||||
sender=code,
|
||||
role=DialogueRole.PREDICTION,
|
||||
message_type=MessageType.DECISION,
|
||||
content=content,
|
||||
confidence=decision.confidence,
|
||||
references=references,
|
||||
annotations=annotations,
|
||||
)
|
||||
messages.append(message)
|
||||
notes: Dict[str, object] = {
|
||||
"department_votes": dict(department_votes),
|
||||
}
|
||||
summary = RoundSummary(
|
||||
index=0,
|
||||
agenda="department_consensus",
|
||||
structure=GameStructure.REPEATED,
|
||||
resolved=True,
|
||||
outcome=final_action.value,
|
||||
messages=messages,
|
||||
notes=notes,
|
||||
)
|
||||
return [summary]
|
||||
|
||||
270
app/agents/protocols.py
Normal file
270
app/agents/protocols.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Protocols and data structures for multi-round multi-agent games."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
|
||||
class GameStructure(str, Enum):
|
||||
"""Supported multi-agent game topologies."""
|
||||
|
||||
REPEATED = "repeated"
|
||||
SIGNALING = "signaling"
|
||||
BAYESIAN = "bayesian"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class DialogueRole(str, Enum):
|
||||
"""Roles participating in the negotiation agenda."""
|
||||
|
||||
HOST = "host"
|
||||
PREDICTION = "prediction"
|
||||
RISK = "risk"
|
||||
EXECUTION = "execution"
|
||||
OBSERVER = "observer"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""High-level classification of dialogue intents."""
|
||||
|
||||
HYPOTHESIS = "hypothesis"
|
||||
EVIDENCE = "evidence"
|
||||
COUNTER = "counter"
|
||||
DECISION = "decision"
|
||||
DIRECTIVE = "directive"
|
||||
META = "meta"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueMessage:
|
||||
"""Single utterance in the multi-round dialogue."""
|
||||
|
||||
sender: str
|
||||
role: DialogueRole
|
||||
message_type: MessageType
|
||||
content: str
|
||||
confidence: float = 0.0
|
||||
references: List[str] = field(default_factory=list)
|
||||
timestamp: Optional[str] = None
|
||||
annotations: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"sender": self.sender,
|
||||
"role": self.role.value,
|
||||
"message_type": self.message_type.value,
|
||||
"content": self.content,
|
||||
"confidence": self.confidence,
|
||||
"references": list(self.references),
|
||||
"timestamp": self.timestamp,
|
||||
"annotations": dict(self.annotations),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeliefSnapshot:
|
||||
"""Belief state emitted by an agent before or after revision."""
|
||||
|
||||
agent: str
|
||||
role: DialogueRole
|
||||
belief: Dict[str, Any]
|
||||
confidence: float
|
||||
rationale: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"agent": self.agent,
|
||||
"role": self.role.value,
|
||||
"belief": dict(self.belief),
|
||||
"confidence": self.confidence,
|
||||
"rationale": self.rationale,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeliefRevision:
|
||||
"""Tracks belief updates triggered during a round."""
|
||||
|
||||
before: BeliefSnapshot
|
||||
after: BeliefSnapshot
|
||||
justification: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"before": self.before.to_dict(),
|
||||
"after": self.after.to_dict(),
|
||||
"justification": self.justification,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoundSummary:
|
||||
"""Aggregated view of a single negotiation round."""
|
||||
|
||||
index: int
|
||||
agenda: str
|
||||
structure: GameStructure
|
||||
resolved: bool
|
||||
outcome: Optional[str] = None
|
||||
messages: List[DialogueMessage] = field(default_factory=list)
|
||||
revisions: List[BeliefRevision] = field(default_factory=list)
|
||||
constraints_triggered: List[str] = field(default_factory=list)
|
||||
notes: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"index": self.index,
|
||||
"agenda": self.agenda,
|
||||
"structure": self.structure.value,
|
||||
"resolved": self.resolved,
|
||||
"outcome": self.outcome,
|
||||
"messages": [message.to_dict() for message in self.messages],
|
||||
"revisions": [revision.to_dict() for revision in self.revisions],
|
||||
"constraints_triggered": list(self.constraints_triggered),
|
||||
"notes": dict(self.notes),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueTrace:
|
||||
"""Ordered collection of round summaries for auditing."""
|
||||
|
||||
session_id: str
|
||||
ts_code: str
|
||||
trade_date: str
|
||||
rounds: List[RoundSummary] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"ts_code": self.ts_code,
|
||||
"trade_date": self.trade_date,
|
||||
"rounds": [summary.to_dict() for summary in self.rounds],
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
|
||||
class GameProtocol(Protocol):
|
||||
"""Extension point for hosting multi-round agent negotiations."""
|
||||
|
||||
def bootstrap(self, trace: DialogueTrace) -> None:
|
||||
"""Prepare the protocol-specific state before rounds begin."""
|
||||
|
||||
def start_round(self, trace: DialogueTrace, agenda: str, structure: GameStructure) -> RoundSummary:
|
||||
"""Open a new round with the given agenda and structure descriptor."""
|
||||
|
||||
def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None:
|
||||
"""Process a single dialogue message emitted by an agent."""
|
||||
|
||||
def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None:
|
||||
"""Register a belief revision triggered by debate or new evidence."""
|
||||
|
||||
def finalize_round(self, summary: RoundSummary) -> None:
|
||||
"""Mark the round as resolved and perform protocol-specific bookkeeping."""
|
||||
|
||||
def close(self, trace: DialogueTrace) -> None:
|
||||
"""Finish the negotiation session and emit protocol artifacts."""
|
||||
|
||||
|
||||
class ProtocolHost(GameProtocol):
|
||||
"""Base implementation for agenda-driven negotiation protocols."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._current_round: Optional[RoundSummary] = None
|
||||
self._trace: Optional[DialogueTrace] = None
|
||||
self._round_index: int = 0
|
||||
|
||||
def bootstrap(self, trace: DialogueTrace) -> None:
|
||||
trace.metadata.setdefault("host_started", True)
|
||||
self._trace = trace
|
||||
self._round_index = len(trace.rounds)
|
||||
|
||||
def bootstrap_trace(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DialogueTrace:
|
||||
trace = DialogueTrace(
|
||||
session_id=session_id,
|
||||
ts_code=ts_code,
|
||||
trade_date=trade_date,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
self.bootstrap(trace)
|
||||
return trace
|
||||
|
||||
def start_round(
|
||||
self,
|
||||
trace: DialogueTrace,
|
||||
agenda: str,
|
||||
structure: GameStructure,
|
||||
) -> RoundSummary:
|
||||
index = self._round_index
|
||||
summary = RoundSummary(
|
||||
index=index,
|
||||
agenda=agenda,
|
||||
structure=structure,
|
||||
resolved=False,
|
||||
)
|
||||
trace.rounds.append(summary)
|
||||
self._current_round = summary
|
||||
self._round_index += 1
|
||||
return summary
|
||||
|
||||
def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None:
|
||||
summary.messages.append(message)
|
||||
|
||||
def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None:
|
||||
summary.revisions.append(revision)
|
||||
|
||||
def finalize_round(self, summary: RoundSummary) -> None:
|
||||
summary.resolved = True
|
||||
summary.notes.setdefault("message_count", len(summary.messages))
|
||||
self._current_round = None
|
||||
|
||||
def close(self, trace: DialogueTrace) -> None:
|
||||
trace.metadata.setdefault("host_finished", True)
|
||||
self._trace = trace
|
||||
|
||||
def current_round(self) -> Optional[RoundSummary]:
|
||||
return self._current_round
|
||||
|
||||
@property
|
||||
def trace(self) -> Optional[DialogueTrace]:
|
||||
return self._trace
|
||||
|
||||
def ensure_round(
|
||||
self,
|
||||
trace: DialogueTrace,
|
||||
agenda: str,
|
||||
structure: GameStructure,
|
||||
) -> RoundSummary:
|
||||
if self._current_round and not self._current_round.resolved:
|
||||
return self._current_round
|
||||
return self.start_round(trace, agenda, structure)
|
||||
|
||||
|
||||
def round_to_dict(summary: RoundSummary) -> Dict[str, Any]:
|
||||
"""Serialize a round summary for persistence layers."""
|
||||
|
||||
return summary.to_dict()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GameStructure",
|
||||
"DialogueRole",
|
||||
"MessageType",
|
||||
"DialogueMessage",
|
||||
"BeliefSnapshot",
|
||||
"BeliefRevision",
|
||||
"RoundSummary",
|
||||
"DialogueTrace",
|
||||
"GameProtocol",
|
||||
"ProtocolHost",
|
||||
"round_to_dict",
|
||||
]
|
||||
@ -4,6 +4,35 @@ from __future__ import annotations
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
class RiskRecommendation:
|
||||
"""Represents structured recommendation from the risk agent."""
|
||||
|
||||
__slots__ = ("status", "reason", "recommended_action", "notes")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status: str,
|
||||
reason: str,
|
||||
recommended_action: AgentAction | None = None,
|
||||
notes: dict | None = None,
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.reason = reason
|
||||
self.recommended_action = recommended_action
|
||||
self.notes = notes or {}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
payload = {
|
||||
"status": self.status,
|
||||
"reason": self.reason,
|
||||
"notes": dict(self.notes),
|
||||
}
|
||||
if self.recommended_action is not None:
|
||||
payload["recommended_action"] = self.recommended_action.value
|
||||
return payload
|
||||
|
||||
|
||||
class RiskAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="A_risk")
|
||||
@ -27,3 +56,74 @@ class RiskAgent(Agent):
|
||||
if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L):
|
||||
return False
|
||||
return True
|
||||
|
||||
def assess(
|
||||
self,
|
||||
context: AgentContext,
|
||||
decision_action: AgentAction,
|
||||
conflict_flag: bool,
|
||||
) -> RiskRecommendation:
|
||||
features = dict(context.features or {})
|
||||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||||
|
||||
if bool(features.get("is_suspended")):
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="suspended",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "is_suspended"},
|
||||
)
|
||||
|
||||
if bool(features.get("limit_up")) and decision_action in {
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="limit_up",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "limit_up"},
|
||||
)
|
||||
|
||||
if bool(features.get("position_limit")) and decision_action in {
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="position_limit",
|
||||
recommended_action=AgentAction.BUY_S,
|
||||
notes={"trigger": "position_limit"},
|
||||
)
|
||||
|
||||
if risk_penalty >= 0.9 and decision_action in {
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="risk_penalty_extreme",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
)
|
||||
if risk_penalty >= 0.7 and decision_action in {
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="risk_penalty_high",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
)
|
||||
|
||||
if conflict_flag:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="conflict_threshold",
|
||||
)
|
||||
|
||||
return RiskRecommendation(status="ok", reason="clear")
|
||||
|
||||
@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.departments import DepartmentManager
|
||||
from app.agents.game import Decision, decide
|
||||
from app.agents.game import Decision, decide, target_weight_for_action
|
||||
from app.agents.protocols import round_to_dict
|
||||
from app.llm.metrics import record_decision as metrics_record_decision
|
||||
from app.agents.registry import default_agents
|
||||
from app.data.schema import initialize_database
|
||||
@ -16,6 +17,7 @@ from app.utils.data_access import DataBroker
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils import alerts
|
||||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||
|
||||
|
||||
@ -436,6 +438,7 @@ class BacktestEngine:
|
||||
)
|
||||
)
|
||||
|
||||
round_payload = [round_to_dict(summary) for summary in decision.rounds]
|
||||
global_payload = {
|
||||
"_confidence": decision.confidence,
|
||||
"_target_weight": decision.target_weight,
|
||||
@ -459,6 +462,12 @@ class BacktestEngine:
|
||||
for code, dept in decision.department_decisions.items()
|
||||
if dept.telemetry
|
||||
},
|
||||
"_rounds": round_payload,
|
||||
"_risk_assessment": (
|
||||
decision.risk_assessment.to_dict()
|
||||
if decision.risk_assessment
|
||||
else None
|
||||
),
|
||||
}
|
||||
rows.append(
|
||||
(
|
||||
@ -543,18 +552,42 @@ class BacktestEngine:
|
||||
executed_trades: List[Dict[str, Any]] = []
|
||||
risk_events: List[Dict[str, Any]] = []
|
||||
|
||||
def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _record_risk(
|
||||
ts_code: str,
|
||||
reason: str,
|
||||
decision: Decision,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
action_override: Optional[AgentAction] = None,
|
||||
target_weight_override: Optional[float] = None,
|
||||
) -> None:
|
||||
payload = {
|
||||
"trade_date": trade_date_str,
|
||||
"ts_code": ts_code,
|
||||
"action": decision.action.value,
|
||||
"target_weight": decision.target_weight,
|
||||
"action": (action_override or decision.action).value,
|
||||
"target_weight": (
|
||||
target_weight_override
|
||||
if target_weight_override is not None
|
||||
else decision.target_weight
|
||||
),
|
||||
"confidence": decision.confidence,
|
||||
"reason": reason,
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
risk_events.append(payload)
|
||||
risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None
|
||||
status = None
|
||||
if isinstance(risk_meta, dict):
|
||||
status = risk_meta.get("status")
|
||||
if status == "blocked":
|
||||
try:
|
||||
alerts.add_warning(
|
||||
"backtest_risk",
|
||||
f"{ts_code} 风险阻断: {reason}",
|
||||
detail=json.dumps(payload, ensure_ascii=False),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA)
|
||||
|
||||
for ts_code, decision in decisions_map.items():
|
||||
price = price_map.get(ts_code)
|
||||
@ -569,35 +602,54 @@ class BacktestEngine:
|
||||
limit_down = bool(features.get("limit_down"))
|
||||
position_limit = bool(features.get("position_limit"))
|
||||
|
||||
risk = decision.risk_assessment
|
||||
effective_action = decision.action
|
||||
effective_weight = decision.target_weight
|
||||
if risk:
|
||||
risk_payload = risk.to_dict()
|
||||
risk_payload.setdefault("applied_action", effective_action.value)
|
||||
if risk.recommended_action:
|
||||
effective_action = risk.recommended_action
|
||||
risk_payload["applied_action"] = effective_action.value
|
||||
effective_weight = target_weight_for_action(effective_action)
|
||||
if risk.status != "ok":
|
||||
_record_risk(
|
||||
ts_code,
|
||||
risk.reason,
|
||||
decision,
|
||||
extra={"risk_assessment": risk_payload},
|
||||
action_override=effective_action,
|
||||
target_weight_override=effective_weight,
|
||||
)
|
||||
if risk.status == "blocked":
|
||||
continue
|
||||
|
||||
if is_suspended:
|
||||
_record_risk(ts_code, "suspended", decision)
|
||||
continue
|
||||
if decision.action in self._buy_actions:
|
||||
if effective_action in self._buy_actions:
|
||||
if limit_up:
|
||||
_record_risk(ts_code, "limit_up", decision)
|
||||
_record_risk(ts_code, "limit_up", decision, action_override=effective_action)
|
||||
continue
|
||||
if position_limit:
|
||||
_record_risk(ts_code, "position_limit", decision)
|
||||
_record_risk(ts_code, "position_limit", decision, action_override=effective_action)
|
||||
continue
|
||||
if risk_penalty >= 0.95:
|
||||
_record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty})
|
||||
continue
|
||||
if decision.action in self._sell_actions and limit_down:
|
||||
_record_risk(ts_code, "limit_down", decision)
|
||||
if effective_action in self._sell_actions and limit_down:
|
||||
_record_risk(ts_code, "limit_down", decision, action_override=effective_action)
|
||||
continue
|
||||
|
||||
effective_weight = max(decision.target_weight, 0.0)
|
||||
if decision.action in self._buy_actions:
|
||||
effective_weight_value = max(effective_weight, 0.0)
|
||||
if effective_action in self._buy_actions:
|
||||
capped_weight = min(effective_weight, self.risk_params["max_position_weight"])
|
||||
effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty)
|
||||
elif decision.action in self._sell_actions:
|
||||
effective_weight = 0.0
|
||||
effective_weight_value = capped_weight * max(0.0, 1.0 - risk_penalty)
|
||||
elif effective_action in self._sell_actions:
|
||||
effective_weight_value = 0.0
|
||||
|
||||
desired_qty = current_qty
|
||||
if decision.action in self._sell_actions:
|
||||
if effective_action in self._sell_actions:
|
||||
desired_qty = 0.0
|
||||
elif decision.action in self._buy_actions or effective_weight >= 0.0:
|
||||
desired_value = max(effective_weight, 0.0) * portfolio_value_before
|
||||
elif effective_action in self._buy_actions or effective_weight_value > 0.0:
|
||||
desired_value = max(effective_weight_value, 0.0) * portfolio_value_before
|
||||
desired_qty = desired_value / price if price > 0 else current_qty
|
||||
|
||||
delta = desired_qty - current_qty
|
||||
@ -654,7 +706,8 @@ class BacktestEngine:
|
||||
"slippage": trade_price - price,
|
||||
"confidence": decision.confidence,
|
||||
"target_weight": decision.target_weight,
|
||||
"effective_weight": effective_weight,
|
||||
"effective_weight": effective_weight_value,
|
||||
"effective_action": effective_action.value,
|
||||
"risk_penalty": risk_penalty,
|
||||
"liquidity_score": liquidity_score,
|
||||
"status": "executed",
|
||||
@ -694,7 +747,8 @@ class BacktestEngine:
|
||||
"slippage": price - trade_price,
|
||||
"confidence": decision.confidence,
|
||||
"target_weight": decision.target_weight,
|
||||
"effective_weight": effective_weight,
|
||||
"effective_weight": effective_weight_value,
|
||||
"effective_action": effective_action.value,
|
||||
"risk_penalty": risk_penalty,
|
||||
"liquidity_score": liquidity_score,
|
||||
"realized_pnl": realized,
|
||||
|
||||
@ -201,6 +201,7 @@ def render_backtest_review() -> None:
|
||||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
risk_df = pd.DataFrame()
|
||||
if selected_ids:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
@ -214,11 +215,20 @@ def render_backtest_review() -> None:
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
risk_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata "
|
||||
"FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||
st.error("读取回测结果失败")
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
risk_df = pd.DataFrame()
|
||||
start_filter: Optional[date] = None
|
||||
end_filter: Optional[date] = None
|
||||
if not nav_df.empty:
|
||||
try:
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||
@ -274,6 +284,9 @@ def render_backtest_review() -> None:
|
||||
"交易数": summary.get("trade_count"),
|
||||
"平均换手": summary.get("avg_turnover"),
|
||||
"风险事件": summary.get("risk_events"),
|
||||
"风险分布": json.dumps(summary.get("risk_breakdown"), ensure_ascii=False)
|
||||
if summary.get("risk_breakdown")
|
||||
else None,
|
||||
}
|
||||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||
if metrics_rows:
|
||||
@ -291,6 +304,70 @@ def render_backtest_review() -> None:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||
if not risk_df.empty:
|
||||
try:
|
||||
risk_df["trade_date"] = pd.to_datetime(risk_df["trade_date"], errors="coerce")
|
||||
risk_df = risk_df.dropna(subset=["trade_date"])
|
||||
if start_filter is None or end_filter is None:
|
||||
start_filter = pd.to_datetime(risk_df["trade_date"].min()).date()
|
||||
end_filter = pd.to_datetime(risk_df["trade_date"].max()).date()
|
||||
risk_df = risk_df[
|
||||
(risk_df["trade_date"].dt.date >= start_filter)
|
||||
& (risk_df["trade_date"].dt.date <= end_filter)
|
||||
]
|
||||
parsed_cols: List[Dict[str, object]] = []
|
||||
for _, row in risk_df.iterrows():
|
||||
try:
|
||||
metadata = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else (row["metadata"] or {})
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
assessment = metadata.get("risk_assessment") or {}
|
||||
parsed_cols.append(
|
||||
{
|
||||
"cfg_id": row["cfg_id"],
|
||||
"trade_date": row["trade_date"].date().isoformat(),
|
||||
"ts_code": row["ts_code"],
|
||||
"reason": row["reason"],
|
||||
"action": row["action"],
|
||||
"target_weight": row["target_weight"],
|
||||
"confidence": row["confidence"],
|
||||
"risk_status": assessment.get("status"),
|
||||
"recommended_action": assessment.get("recommended_action"),
|
||||
"execution_status": metadata.get("execution_status"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
risk_detail_df = pd.DataFrame(parsed_cols)
|
||||
with st.expander("风险事件明细", expanded=False):
|
||||
st.dataframe(risk_detail_df.drop(columns=["metadata"], errors="ignore"), hide_index=True, width='stretch')
|
||||
try:
|
||||
st.download_button(
|
||||
"下载风险事件(CSV)",
|
||||
data=risk_detail_df.to_csv(index=False),
|
||||
file_name="bt_risk_events.csv",
|
||||
mime="text/csv",
|
||||
key="dl_risk_events",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
agg = risk_detail_df.groupby(["cfg_id", "reason", "risk_status"], dropna=False).size().reset_index(name="count")
|
||||
st.dataframe(agg, hide_index=True, width='stretch')
|
||||
try:
|
||||
if not agg.empty:
|
||||
agg_fig = px.bar(
|
||||
agg,
|
||||
x="reason",
|
||||
y="count",
|
||||
color="risk_status",
|
||||
facet_col="cfg_id",
|
||||
title="风险事件分布",
|
||||
)
|
||||
agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20))
|
||||
st.plotly_chart(agg_fig, use_container_width=True)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染风险事件失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
st.info("请选择至少一个配置进行对比。")
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from .config import get_config
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||
from app.utils.db_query import BrokerQueryEngine
|
||||
|
||||
# 延迟导入,避免循环依赖
|
||||
collect_data_coverage = None
|
||||
@ -60,6 +61,39 @@ def _safe_split(path: str) -> Tuple[str, str] | None:
|
||||
return table, column
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RefreshCoordinator:
|
||||
"""Orchestrates background refresh requests for the broker."""
|
||||
|
||||
broker: "DataBroker"
|
||||
|
||||
def ensure_for_latest(self, trade_date: str, fields: Iterable[str]) -> None:
|
||||
parsed_date = _parse_trade_date(trade_date)
|
||||
if not parsed_date:
|
||||
return
|
||||
normalized = parsed_date.strftime("%Y%m%d")
|
||||
tables = self._collect_tables(fields)
|
||||
if tables and self.broker.check_data_availability(normalized, tables):
|
||||
self.broker._trigger_background_refresh(normalized)
|
||||
|
||||
def ensure_for_series(self, end_date: str, table: str) -> None:
|
||||
parsed_date = _parse_trade_date(end_date)
|
||||
if not parsed_date:
|
||||
return
|
||||
normalized = parsed_date.strftime("%Y%m%d")
|
||||
if self.broker.check_data_availability(normalized, {table}):
|
||||
self.broker._trigger_background_refresh(normalized)
|
||||
|
||||
def _collect_tables(self, fields: Iterable[str]) -> Set[str]:
|
||||
tables: Set[str] = set()
|
||||
for field_name in fields:
|
||||
resolved = self.broker.resolve_field(field_name)
|
||||
if resolved:
|
||||
table, _ = resolved
|
||||
tables.add(table)
|
||||
return tables
|
||||
|
||||
|
||||
def parse_field_path(path: str) -> Tuple[str, str] | None:
|
||||
"""Validate and split a `table.column` field expression."""
|
||||
|
||||
@ -87,6 +121,17 @@ def _end_of_day(dt: datetime) -> str:
|
||||
return dt.strftime("%Y-%m-%d 23:59:59")
|
||||
|
||||
|
||||
def _coerce_date(value: object) -> Optional[date]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, date):
|
||||
return value
|
||||
parsed = _parse_trade_date(value)
|
||||
if parsed:
|
||||
return parsed.date()
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataBroker:
|
||||
"""Lightweight data access helper with automated data fetching capabilities."""
|
||||
@ -130,6 +175,8 @@ class DataBroker:
|
||||
_refresh_in_progress: Dict[str, bool] = field(init=False, repr=False)
|
||||
_refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False)
|
||||
_coverage_cache: Dict[str, Dict] = field(init=False, repr=False)
|
||||
_refresh: _RefreshCoordinator = field(init=False, repr=False)
|
||||
_query_engine: BrokerQueryEngine = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._latest_cache = OrderedDict()
|
||||
@ -139,6 +186,8 @@ class DataBroker:
|
||||
self._refresh_in_progress = {}
|
||||
self._refresh_callbacks = {}
|
||||
self._coverage_cache = {}
|
||||
self._refresh = _RefreshCoordinator(self)
|
||||
self._query_engine = BrokerQueryEngine(db_session)
|
||||
if initialize_database is not None:
|
||||
initialize_database() # 确保数据库已初始化
|
||||
else:
|
||||
@ -169,22 +218,7 @@ class DataBroker:
|
||||
|
||||
# 检查是否需要自动补数
|
||||
if auto_refresh:
|
||||
# 解析交易日以确定是否需要补数
|
||||
parsed_date = _parse_trade_date(trade_date)
|
||||
if parsed_date:
|
||||
# 检查最近交易日的数据是否存在
|
||||
recent_trade_date = parsed_date.strftime('%Y%m%d')
|
||||
# 对涉及的表进行数据可用性检查
|
||||
tables = set()
|
||||
for field_name in field_list:
|
||||
resolved = self.resolve_field(field_name)
|
||||
if resolved:
|
||||
table, _ = resolved
|
||||
tables.add(table)
|
||||
|
||||
if tables and self.check_data_availability(recent_trade_date, tables):
|
||||
# 数据不足,触发后台补数
|
||||
self._trigger_background_refresh(recent_trade_date)
|
||||
self._refresh.ensure_for_latest(trade_date, field_list)
|
||||
|
||||
grouped: Dict[str, List[str]] = {}
|
||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||
@ -208,59 +242,41 @@ class DataBroker:
|
||||
grouped[table].append(column)
|
||||
field_map.setdefault((table, column), []).append(field_name)
|
||||
|
||||
if not grouped:
|
||||
if cache_key is not None and results:
|
||||
self._cache_store(
|
||||
self._latest_cache,
|
||||
cache_key,
|
||||
deepcopy(results),
|
||||
self.latest_cache_size,
|
||||
)
|
||||
return results
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
for table, columns in grouped.items():
|
||||
joined_cols = ", ".join(columns)
|
||||
query = (
|
||||
f"SELECT trade_date, {joined_cols} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT 1"
|
||||
)
|
||||
try:
|
||||
row = conn.execute(query, (ts_code, trade_date)).fetchone()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"查询失败 table=%s fields=%s err=%s",
|
||||
table,
|
||||
columns,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if not row:
|
||||
continue
|
||||
for column in columns:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
continue
|
||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||
try:
|
||||
results[original] = float(value)
|
||||
except (TypeError, ValueError):
|
||||
results[original] = value
|
||||
except sqlite3.OperationalError as exc:
|
||||
LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA)
|
||||
if cache_key is not None:
|
||||
cached = self._cache_lookup(self._latest_cache, cache_key)
|
||||
if cached is not None:
|
||||
if grouped:
|
||||
for table, columns in grouped.items():
|
||||
try:
|
||||
row = self._query_engine.fetch_latest(table, ts_code, trade_date, columns)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"使用缓存结果 ts_code=%s trade_date=%s",
|
||||
ts_code,
|
||||
trade_date,
|
||||
"查询失败 table=%s fields=%s err=%s",
|
||||
table,
|
||||
columns,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return deepcopy(cached)
|
||||
continue
|
||||
if not row:
|
||||
continue
|
||||
for column in columns:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
continue
|
||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||
try:
|
||||
results[original] = float(value)
|
||||
except (TypeError, ValueError):
|
||||
results[original] = value
|
||||
|
||||
if cache_key is not None and not results:
|
||||
cached = self._cache_lookup(self._latest_cache, cache_key)
|
||||
if cached is not None:
|
||||
LOGGER.debug(
|
||||
"使用缓存结果 ts_code=%s trade_date=%s",
|
||||
ts_code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return deepcopy(cached)
|
||||
if cache_key is not None and results:
|
||||
self._cache_store(
|
||||
self._latest_cache,
|
||||
@ -306,9 +322,7 @@ class DataBroker:
|
||||
|
||||
# 检查是否需要自动补数
|
||||
if auto_refresh:
|
||||
parsed_date = _parse_trade_date(end_date)
|
||||
if parsed_date and self.check_data_availability(end_date, {table}):
|
||||
self._trigger_background_refresh(end_date)
|
||||
self._refresh.ensure_for_series(end_date, table)
|
||||
|
||||
cache_key: Optional[Tuple[Any, ...]] = None
|
||||
if self.enable_cache:
|
||||
@ -323,21 +337,10 @@ class DataBroker:
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"时间序列查询失败 table=%s column=%s err=%s",
|
||||
table,
|
||||
column,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
except sqlite3.OperationalError as exc:
|
||||
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"时间序列连接失败 table=%s column=%s err=%s",
|
||||
"时间序列查询失败 table=%s column=%s err=%s",
|
||||
table,
|
||||
column,
|
||||
exc,
|
||||
@ -358,9 +361,13 @@ class DataBroker:
|
||||
series: List[Tuple[str, float]] = []
|
||||
for row in rows:
|
||||
value = row[resolved]
|
||||
if value is None:
|
||||
trade_dt = row["trade_date"]
|
||||
if value is None or trade_dt is None:
|
||||
continue
|
||||
try:
|
||||
series.append((trade_dt, float(value)))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
series.append((row["trade_date"], float(value)))
|
||||
if cache_key is not None and series:
|
||||
self._cache_store(
|
||||
self._series_cache,
|
||||
@ -370,6 +377,32 @@ class DataBroker:
|
||||
)
|
||||
return series
|
||||
|
||||
def register_refresh_callback(
|
||||
self,
|
||||
start: date | str,
|
||||
end: date | str,
|
||||
callback: Callable[[], None],
|
||||
) -> None:
|
||||
"""Register a hook invoked after background refresh completes for the window."""
|
||||
|
||||
if callback is None:
|
||||
return
|
||||
start_date = _coerce_date(start)
|
||||
end_date = _coerce_date(end)
|
||||
if not start_date or not end_date:
|
||||
LOGGER.debug(
|
||||
"忽略无效补数回调窗口 start=%s end=%s",
|
||||
start,
|
||||
end,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return
|
||||
key = f"{start_date}_{end_date}"
|
||||
with self._refresh_lock:
|
||||
bucket = self._refresh_callbacks.setdefault(key, [])
|
||||
if callback not in bucket:
|
||||
bucket.append(callback)
|
||||
|
||||
def fetch_flags(
|
||||
self,
|
||||
table: str,
|
||||
@ -436,48 +469,19 @@ class DataBroker:
|
||||
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
column_list = ", ".join(columns)
|
||||
has_trade_date = "trade_date" in columns
|
||||
if has_trade_date:
|
||||
query = (
|
||||
f"SELECT {column_list} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
params: Tuple[object, ...] = (ts_code, trade_date, window)
|
||||
else:
|
||||
query = (
|
||||
f"SELECT {column_list} FROM {table} "
|
||||
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
||||
)
|
||||
params = (ts_code, window)
|
||||
|
||||
results: List[Dict[str, object]] = []
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"表查询失败 table=%s err=%s",
|
||||
table,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
except sqlite3.OperationalError as exc:
|
||||
LOGGER.debug(
|
||||
"表连接失败 table=%s err=%s",
|
||||
rows = self._query_engine.fetch_table(
|
||||
table,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
columns,
|
||||
ts_code,
|
||||
trade_date if "trade_date" in columns else None,
|
||||
window,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug("表查询失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
for row in rows:
|
||||
record = {col: row[col] for col in columns}
|
||||
results.append(record)
|
||||
return results
|
||||
return [{col: row[col] for col in columns} for row in rows]
|
||||
|
||||
def _resolve_derived_field(
|
||||
self,
|
||||
@ -924,13 +928,20 @@ class DataBroker:
|
||||
with self._refresh_lock:
|
||||
callbacks = self._refresh_callbacks.pop(refresh_key, [])
|
||||
self._refresh_in_progress[refresh_key] = False
|
||||
|
||||
|
||||
if callbacks:
|
||||
LOGGER.info(
|
||||
"执行补数回调 count=%s key=%s",
|
||||
len(callbacks),
|
||||
refresh_key,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as exc:
|
||||
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA)
|
||||
with self._refresh_lock:
|
||||
|
||||
73
app/utils/db_query.py
Normal file
73
app/utils/db_query.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Shared read-only query helpers for database access."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, List, Mapping, Optional, Sequence
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrokerQueryEngine:
|
||||
"""Lightweight wrapper around standard query patterns."""
|
||||
|
||||
session_factory: Callable[..., object]
|
||||
|
||||
def fetch_latest(
|
||||
self,
|
||||
table: str,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
columns: Sequence[str],
|
||||
) -> Optional[Mapping[str, object]]:
|
||||
if not columns:
|
||||
return None
|
||||
joined_cols = ", ".join(columns)
|
||||
query = (
|
||||
f"SELECT trade_date, {joined_cols} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT 1"
|
||||
)
|
||||
with self.session_factory(read_only=True) as conn:
|
||||
return conn.execute(query, (ts_code, trade_date)).fetchone()
|
||||
|
||||
def fetch_series(
|
||||
self,
|
||||
table: str,
|
||||
column: str,
|
||||
ts_code: str,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
) -> List[Mapping[str, object]]:
|
||||
query = (
|
||||
f"SELECT trade_date, {column} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
with self.session_factory(read_only=True) as conn:
|
||||
rows = conn.execute(query, (ts_code, end_date, limit)).fetchall()
|
||||
return list(rows)
|
||||
|
||||
def fetch_table(
|
||||
self,
|
||||
table: str,
|
||||
columns: Iterable[str],
|
||||
ts_code: str,
|
||||
trade_date: Optional[str],
|
||||
limit: int,
|
||||
) -> List[Mapping[str, object]]:
|
||||
cols = ", ".join(columns)
|
||||
if trade_date is None:
|
||||
query = (
|
||||
f"SELECT {cols} FROM {table} "
|
||||
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
||||
)
|
||||
params: Sequence[object] = (ts_code, limit)
|
||||
else:
|
||||
query = (
|
||||
f"SELECT {cols} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
params = (ts_code, trade_date, limit)
|
||||
with self.session_factory(read_only=True) as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return list(rows)
|
||||
41
docs/RISK_AGENT_PLAN.md
Normal file
41
docs/RISK_AGENT_PLAN.md
Normal file
@ -0,0 +1,41 @@
|
||||
# 风险代理集成规划
|
||||
|
||||
## 目标
|
||||
- 将 `risk_guard` 回合由占位消息升级为可执行决策:低置信度的部门共识需自动进入复核流程,根据风险因子、仓位约束、外部告警调整最终指令。
|
||||
- 兼容历史回测与实时监控,确保补数、验证和告警闭环共享风险上下文。
|
||||
|
||||
## 数据与信号需求
|
||||
- **输入特征**:`risk_penalty`、`position_limit`、`is_suspended` 等现有布尔/得分信号;新增日内波动、VaR、行业集中度、外部事件标签(停牌预警、合规黑名单)。
|
||||
- **实时事件**:补数触发的异常、执行失败回报、仓位越界日志;需要通过 `DataBroker.register_refresh_callback()` 和交易执行层的回调管理。
|
||||
- **上下文结构**:在 `AgentContext.raw` 中携带 `risk_flags`、`compliance_notes`,供主持器和前端展示。
|
||||
|
||||
## 决策流程改造
|
||||
1. **风险评估阶段**
|
||||
- 在部门回合后触发 `risk_round`,由 `RiskAgent` 复写 `Decision` 的 `requires_review` 与 `target_weight`,必要时生成“回滚/减仓”建议。
|
||||
- 引入 `RiskAssessment` 数据类,存储风险来源、建议操作、置信度,序列化到 `Decision.rounds`。
|
||||
2. **执行协调阶段**
|
||||
- 若风险回合给出回滚指令,则 `execution_round` 应记录“冻结执行”或“调整仓位”而非直接落地。
|
||||
- 将风险回合结论写入 `risk_events` 表(或新建),供 UI 与监控使用。
|
||||
3. **日志与监控**
|
||||
- 在 `risk_round` 中附加 `annotations` 字段:`{"breach_metrics": {...}, "actions": [...]}`。
|
||||
- 通过补数回调和执行回调,将风险事件推送到监控指标,如“复核触发率”“回滚成功率”。
|
||||
|
||||
## 实施里程碑
|
||||
1. **原型阶段**
|
||||
- ✅ 重构 `RiskAgent` 以返回策略建议(持仓调整、止损触发)。
|
||||
- ✅ 扩展 `Decision` 结构,增加 `risk_assessment` 字段。
|
||||
- ✅ 更新 `ProtocolHost` 将风险建议纳入 `risk_round.notes`。
|
||||
2. **集成验证**
|
||||
- 在回测环境构造冲突样例,验证“冲突→风险建议→执行调整”链路。
|
||||
- 新增测试:模拟停牌、仓位超限、黑名单事件,确认风险代理逻辑。
|
||||
3. **实盘准备**
|
||||
- 接入实时告警渠道(如风控系统/合规接口)。
|
||||
- 监控接入:统计复核频率、回滚动作、失败告警。
|
||||
4. **上线迭代**
|
||||
- 影子运行记录风险建议与实际执行差异。
|
||||
- 评估模型表现,迭代规则或引入强化学习风险控制策略。
|
||||
|
||||
## 未决事项
|
||||
- 风险代理与执行模块的数据同步接口(数据库/消息队列)。
|
||||
- 与合规团队确认风险阈值与回滚条件。
|
||||
- 定义UI展示:风险回合的建议、证据引用、执行结果。
|
||||
69
docs/TODO.md
69
docs/TODO.md
@ -27,6 +27,75 @@
|
||||
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
|
||||
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。
|
||||
- 完善 `BacktestEngine` 的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端,形成无人值守的自动闭环。
|
||||
- 新增多智能体多轮逻辑博弈线路:把投资场景建成重复博弈,设定主持、预测、风险等角色分轮交换信号、提出假设、相互反驳,再由执行代理落地操作;通过信念修正规则与合规约束驱动策略共识,而非依赖优化搜索。
|
||||
|
||||
### 3.2 多智能体多轮逻辑博弈实施方案
|
||||
1. 角色与知识库建模
|
||||
- 定义主持(议程控制)、预测(市场观点)、风险(合规阈值)、执行(指令生成)等核心代理,并划分共享/私有信息。
|
||||
- 为每个代理建立信号源与知识库接口(行情、新闻、风险限额),设计可信度权重与更新频率。
|
||||
- 预设重复博弈、信号博弈、贝叶斯博弈等多种结构模板,支持根据市场场景切换或组合,保障架构拓展性。
|
||||
2. 对话协议与消息格式
|
||||
- 设计多轮流程:议程发布→观点陈述→证据提交→反驳与驳回→共识决议→执行提示。
|
||||
- 约定消息 schema(动作类型、置信度、引用证据)及主持代理的轮次控制逻辑,确保所有代理均可追踪历史发言。
|
||||
3. 信念修正与逻辑推理引擎
|
||||
- 实现迭代信念修正模块,结合可信度权重与显式逻辑规则更新各代理的市场信念。
|
||||
- 引入论证框架(Argumentation Framework)或模态逻辑规则库,支持观点检验、冲突解决与风险约束自动否决。
|
||||
4. 模拟与测试环境
|
||||
- 搭建历史回测驱动的多轮博弈仿真环境,让代理在真实行情序列上互动。
|
||||
- 编写回放脚本记录每轮对话、信念轨迹与最终行动,产出可视化报告用于策略审查。
|
||||
5. 执行映射与风控集成
|
||||
- 将共识决议映射到具体操作(换仓、对冲、仓位调节),并与 `BacktestEngine`/实盘执行接口打通。
|
||||
- 风险代理负责实时校验 VaR、仓位集中度等阈值,违规时触发主持代理重新议程或引入驳回指令。
|
||||
6. 监控与评估指标
|
||||
- 设计协作质量指标(轮次收敛时间、冲突率、执行合规度)和业绩指标(收益、回撤、超额收益稳定性)。
|
||||
- 将指标纳入监控面板,与现有回测、实盘监控体系统一展示。
|
||||
7. 增量迭代计划
|
||||
- 从线下模块联调 → 回测闭环 → 影子运行 → 小资金试点的节奏推进,每阶段设定成功退出标准。
|
||||
- 收集各阶段的失败案例与风险事件,迭代规则库、信念修正规则与角色职责。
|
||||
|
||||
### 3.3 代码改造计划(多轮博弈适配)
|
||||
1. 架构基线评估
|
||||
- ⏳ 绘制代理/部门/回测调用图,补充日志字段(缺数告警、补数来源、议程标识)并形成诊断报告。
|
||||
- ✅ 定义多轮博弈上下文结构(消息历史、信念状态、引用证据),输出数据类与通信协议草稿。
|
||||
- ✅ 在 `app/agents/protocols.py` 基础上补充主持/执行状态管理,实现 `DialogueTrace` 与部门上下文的对接路径。
|
||||
- ✅ 扩展 `Decision.rounds` 与 `RoundSummary` 采集策略,用于串联部门结论与多轮议程结果。
|
||||
- ✅ 基于 `ProtocolHost` 设计主持驱动的议程模板,明确多结构策略扩展方式。
|
||||
- ✅ 扩展主持对风险复核、执行总结等议程的支持,规划风控代理与回滚逻辑。
|
||||
- ✅ 落地风险议程触发条件(冲突阈值、风控代理信号、外部警报),同步更新回测策略。
|
||||
- ✅ 结合 `docs/RISK_AGENT_PLAN.md` 明确 RiskAgent 升级里程碑、数据接口及前端展示需求。
|
||||
- ✅ 定义 `risk_assessment` 数据结构(状态、原因、建议动作),在决策输出与存储层保持一致。
|
||||
- ✅ 将风险建议与执行回合对齐,执行阶段识别 `risk_adjusted` 并记录原始动作。
|
||||
2. 数据与因子重构
|
||||
- ✅ 拆分 `DataBroker` 查询层(`BrokerQueryEngine`),补数逻辑独立于查询管道。
|
||||
- ⏳ 按主题拆分因子模块,存储缺口/异常标签,改写 `load_market_data()` 为“缺失即说明”。
|
||||
- ⏳ 维护博弈结构 → 数据 scope 映射,支持角色按结构加载差异化字段。
|
||||
- ✅ 基于 `_RefreshCoordinator` 落地刷新队列与监控事件,拆分查询与补数路径。
|
||||
- ✅ 暴露 `DataBroker.register_refresh_callback()` 钩子,结合监控系统记录补数进度与失败重试。
|
||||
- ⏳ 统一补数回调日志格式(`LOG_EXTRA.stage=data_broker`),为后续指标预留数据源。
|
||||
3. 多轮博弈框架
|
||||
- ✅ 在 `app/agents/game.py` 抽象 `GameProtocol` 接口,扩展 `Decision` 记录多轮对话。
|
||||
- ✅ 实现主持调度器驱动议程(信息→陈述→反驳→共识→执行),挂载风险复核机制。
|
||||
- ⏳ 引入信念修正规则与论证框架,支持证据引用和冲突裁决。
|
||||
4. 执行与回测集成
|
||||
- ✅ 将回测循环改造成“每日多轮→执行摘要”,完成风控校验与冲突重议流程。
|
||||
- ⏳ 擦合订单映射层,明确多轮结果对应目标仓位、执行节奏、异常回滚策略。
|
||||
- ✅ 回测执行根据 `risk_assessment` 调整动作与目标权重,记录执行状态。
|
||||
5. 监控与可视化
|
||||
- ⏳ 在数据库/UI 增补轮次日志、信念轨迹、冲突原因与执行结果的可视化。
|
||||
- ⏳ 定义多结构绩效指标并纳入监控面板(回测视图已提供风险事件统计与分布图)。
|
||||
- ⏳ 补充补数回调监控告警策略(成功率、排队时长、重试次数)。
|
||||
- ✅ 回测视图展示 `_risk_assessment` 与 `execution_status`,后续扩展到实时监控面板。
|
||||
- ⏳ 完善风险告警(`alerts.backtest_risk`)阈值与通知策略并扩展至实盘。
|
||||
6. 测试与迭代治理
|
||||
- ⏳ 构建历史行情驱动的多轮仿真回放,覆盖结构切换与异常场景。
|
||||
- 新增单元/集成测试:议程流程、信念修正、补数竞态、执行映射。
|
||||
- ✅ 风险执行集成测试完成(`tests/test_decision_risk_integration.py`),验证风险调节与执行日志。
|
||||
- 规划上线节奏:线下 PoC → 回测闭环 → 影子运行 → 小资金试点,记录失败案例反馈改进。
|
||||
- 在测试计划中涵盖风险/执行议程的触发条件与回滚路径,确保主持扩展的行为可验证。
|
||||
- 新增针对风险复核和执行议程的集成测试场景,覆盖“冲突→复核→执行/回滚”链路。
|
||||
- 对照 `docs/RISK_AGENT_PLAN.md` 拆解阶段性验收项:数据校验、监控指标、影子运行报告。
|
||||
- 为 `risk_assessment` 输出设计单元测试(冲突触发、无冲突、外部告警模拟),确保状态与原因正确覆盖。
|
||||
- 编写执行阶段测试,验证 `risk_adjusted` 状态下的目标仓位与日志记录。
|
||||
|
||||
### 3.1 实施步骤(建议顺序)
|
||||
1. 环境重构:扩展 `DecisionEnv` 支持逐日状态/动作/奖励,完善 `BacktestEngine` 的状态保存与恢复接口,并补充必要的数据库读写钩子。
|
||||
|
||||
119
tests/test_decision_risk_integration.py
Normal file
119
tests/test_decision_risk_integration.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""Integration-style tests for risk-aware decision execution."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.game import decide
|
||||
from app.agents.registry import default_agents
|
||||
from app.agents.risk import RiskAgent, RiskRecommendation
|
||||
from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState
|
||||
|
||||
|
||||
def _make_context(features: dict, *, alerts: list[str] | None = None) -> AgentContext:
|
||||
raw = {
|
||||
"scope_values": {"daily.close": 100.0},
|
||||
}
|
||||
if alerts:
|
||||
raw["risk_alerts"] = alerts
|
||||
return AgentContext(
|
||||
ts_code="000001.SZ",
|
||||
trade_date="2025-01-10",
|
||||
features=features,
|
||||
market_snapshot={},
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
|
||||
class _StubRiskAgent(RiskAgent):
|
||||
def __init__(self, recommendation: RiskRecommendation) -> None:
|
||||
super().__init__()
|
||||
self._recommendation = recommendation
|
||||
|
||||
def assess(
|
||||
self,
|
||||
context: AgentContext,
|
||||
decision_action: AgentAction,
|
||||
conflict_flag: bool,
|
||||
) -> RiskRecommendation:
|
||||
return self._recommendation
|
||||
|
||||
|
||||
def test_decide_adjusts_execution_on_risk_recommendation(monkeypatch):
|
||||
agents = default_agents()
|
||||
recommendation = RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="risk_penalty_extreme",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": 0.95},
|
||||
)
|
||||
stub_risk = _StubRiskAgent(recommendation)
|
||||
agents = [stub_risk if isinstance(agent, RiskAgent) else agent for agent in agents]
|
||||
|
||||
context = _make_context({"risk_penalty": 0.95})
|
||||
|
||||
decision = decide(
|
||||
context,
|
||||
agents,
|
||||
weights={agent.name: 1.0 for agent in agents},
|
||||
department_manager=None,
|
||||
)
|
||||
assert decision.requires_review is True
|
||||
assert decision.risk_assessment
|
||||
assert decision.risk_assessment.status == "blocked"
|
||||
assert decision.risk_assessment.recommended_action == AgentAction.HOLD
|
||||
|
||||
execution_rounds = [round for round in decision.rounds if round.agenda == "execution_summary"]
|
||||
assert execution_rounds
|
||||
execution_notes = execution_rounds[0].notes
|
||||
assert execution_notes.get("execution_status") == "risk_adjusted"
|
||||
assert execution_rounds[0].outcome == AgentAction.HOLD.value
|
||||
|
||||
|
||||
def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="risk-test",
|
||||
name="risk-test",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState(cash=100_000.0)
|
||||
result = BacktestResult()
|
||||
|
||||
context = _make_context({"risk_penalty": 0.95})
|
||||
recommendation = RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="risk_penalty_extreme",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": 0.95},
|
||||
)
|
||||
|
||||
agents = [
|
||||
_StubRiskAgent(recommendation) if isinstance(agent, RiskAgent) else agent
|
||||
for agent in engine.agents
|
||||
]
|
||||
engine.agents = agents
|
||||
engine.department_manager = None
|
||||
|
||||
decision = decide(
|
||||
context,
|
||||
engine.agents,
|
||||
engine.weights,
|
||||
department_manager=None,
|
||||
)
|
||||
|
||||
engine._apply_portfolio_updates(
|
||||
date(2025, 1, 10),
|
||||
state,
|
||||
[("000001.SZ", context, decision)],
|
||||
result,
|
||||
)
|
||||
|
||||
assert not state.holdings
|
||||
assert not result.trades
|
||||
assert result.nav_series[0]["nav"] == pytest.approx(100_000.0)
|
||||
63
tests/test_risk_agent.py
Normal file
63
tests/test_risk_agent.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Tests for RiskAgent assessment and risk evaluation pipeline."""
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.game import _evaluate_risk
|
||||
from app.agents.risk import RiskAgent, RiskRecommendation
|
||||
|
||||
|
||||
def _make_context(**features: float | bool) -> AgentContext:
|
||||
return AgentContext(
|
||||
ts_code="000001.SZ",
|
||||
trade_date="2025-01-01",
|
||||
features=features,
|
||||
market_snapshot={},
|
||||
raw={},
|
||||
)
|
||||
|
||||
|
||||
def test_risk_agent_ok_status() -> None:
|
||||
agent = RiskAgent()
|
||||
context = _make_context(risk_penalty=0.1)
|
||||
recommendation = agent.assess(context, AgentAction.BUY_S, conflict_flag=False)
|
||||
assert recommendation.status == "ok"
|
||||
assert recommendation.reason == "clear"
|
||||
|
||||
|
||||
def test_risk_agent_blocked_on_limit_up() -> None:
|
||||
agent = RiskAgent()
|
||||
context = _make_context(limit_up=True)
|
||||
recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False)
|
||||
assert recommendation.status == "blocked"
|
||||
assert recommendation.reason == "limit_up"
|
||||
assert recommendation.recommended_action == AgentAction.HOLD
|
||||
|
||||
|
||||
def test_risk_agent_pending_on_conflict() -> None:
|
||||
agent = RiskAgent()
|
||||
context = _make_context()
|
||||
recommendation = agent.assess(context, AgentAction.HOLD, conflict_flag=True)
|
||||
assert recommendation.status == "pending_review"
|
||||
assert recommendation.reason == "conflict_threshold"
|
||||
|
||||
|
||||
def test_evaluate_risk_external_alerts() -> None:
|
||||
agent = RiskAgent()
|
||||
context = AgentContext(
|
||||
ts_code="000002.SZ",
|
||||
trade_date="2025-01-01",
|
||||
features={"risk_penalty": 0.1},
|
||||
market_snapshot={},
|
||||
raw={"risk_alerts": ["sudden_news"]},
|
||||
)
|
||||
assessment = _evaluate_risk(
|
||||
context=context,
|
||||
action=AgentAction.HOLD,
|
||||
department_votes={"buy": 0.6},
|
||||
conflict_flag=False,
|
||||
risk_agent=agent,
|
||||
)
|
||||
assert assessment.status == "pending_review"
|
||||
assert assessment.reason == "external_alert"
|
||||
assert assessment.recommended_action == AgentAction.HOLD
|
||||
assert "external_alerts" in assessment.notes
|
||||
Loading…
Reference in New Issue
Block a user