diff --git a/app/agents/game.py b/app/agents/game.py index d8b8869..4a9cdfc 100644 --- a/app/agents/game.py +++ b/app/agents/game.py @@ -8,6 +8,15 @@ from typing import Dict, Iterable, List, Mapping, Optional, Tuple from .base import Agent, AgentAction, AgentContext, UtilityMatrix from .departments import DepartmentContext, DepartmentDecision, DepartmentManager from .registry import weight_map +from .risk import RiskAgent, RiskRecommendation +from .protocols import ( + DialogueMessage, + DialogueRole, + GameStructure, + MessageType, + ProtocolHost, + RoundSummary, +) ACTIONS: Tuple[AgentAction, ...] = ( @@ -23,6 +32,30 @@ def _clamp(value: float) -> float: return max(0.0, min(1.0, value)) +@dataclass +class BeliefUpdate: + belief: Dict[str, object] + rationale: Optional[str] = None + + +@dataclass +class RiskAssessment: + status: str + reason: str + recommended_action: Optional[AgentAction] = None + notes: Dict[str, object] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, object]: + payload: Dict[str, object] = { + "status": self.status, + "reason": self.reason, + "notes": dict(self.notes), + } + if self.recommended_action is not None: + payload["recommended_action"] = self.recommended_action.value + return payload + + @dataclass class Decision: action: AgentAction @@ -33,6 +66,9 @@ class Decision: department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict) department_votes: Dict[str, float] = field(default_factory=dict) requires_review: bool = False + rounds: List[RoundSummary] = field(default_factory=list) + risk_assessment: Optional[RiskAssessment] = None + belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict) def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix: @@ -132,6 +168,16 @@ def decide( raw_weights = dict(weights) department_decisions: Dict[str, DepartmentDecision] = {} department_votes: Dict[str, float] = {} + host = ProtocolHost() + host_trace = host.bootstrap_trace( + session_id=f"{context.ts_code}:{context.trade_date}", + ts_code=context.ts_code, + trade_date=context.trade_date, + ) + department_round: Optional[RoundSummary] = None + risk_round: Optional[RoundSummary] = None + execution_round: Optional[RoundSummary] = None + belief_updates: Dict[str, BeliefUpdate] = {} if department_manager: dept_context = department_context @@ -144,6 +190,12 @@ def decide( raw=dict(getattr(context, "raw", {}) or {}), ) department_decisions = department_manager.evaluate(dept_context) + if department_decisions: + department_round = host.start_round( + host_trace, + agenda="department_consensus", + structure=GameStructure.REPEATED, + ) for code, decision in department_decisions.items(): agent_key = f"dept_{code}" dept_agent = department_manager.agents.get(code) @@ -155,6 +207,17 @@ def decide( bucket = _department_vote_bucket(decision.action) if bucket: department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence + if department_round: + message = _department_message(code, decision) + host.handle_message(department_round, message) + belief_updates[code] = BeliefUpdate( + belief={ + "action": decision.action.value, + "confidence": decision.confidence, + "signals": decision.signals, + }, + rationale=decision.summary, + ) filtered_utilities = {action: utilities[action] for action in feas_actions} hold_scores = utilities.get(AgentAction.HOLD, {}) @@ -168,7 +231,111 @@ def decide( action, confidence = vote(filtered_utilities, norm_weights) weight = target_weight_for_action(action) - requires_review = _department_conflict_flag(department_votes) + conflict_flag = _department_conflict_flag(department_votes) + + risk_agent = _find_risk_agent(agent_list) + risk_assessment = _evaluate_risk( + context, + action, + department_votes, + conflict_flag, + risk_agent, + ) + requires_review = risk_assessment.status != "ok" + + if department_round: + department_round.notes.setdefault("department_votes", dict(department_votes)) + department_round.outcome = action.value + host.finalize_round(department_round) + + if requires_review: + risk_round = host.ensure_round( + host_trace, + agenda="risk_review", + structure=GameStructure.CUSTOM, + ) + review_message = DialogueMessage( + sender="risk_guard", + role=DialogueRole.RISK, + message_type=MessageType.COUNTER, + content=_risk_review_message(risk_assessment.reason), + confidence=1.0, + references=list(department_votes.keys()), + annotations={ + "department_votes": dict(department_votes), + "risk_reason": risk_assessment.reason, + "recommended_action": ( + risk_assessment.recommended_action.value + if risk_assessment.recommended_action + else None + ), + "notes": dict(risk_assessment.notes), + }, + ) + host.handle_message(risk_round, review_message) + risk_round.notes.setdefault("status", risk_assessment.status) + risk_round.notes.setdefault("reason", risk_assessment.reason) + if risk_assessment.recommended_action: + risk_round.notes.setdefault( + "recommended_action", + risk_assessment.recommended_action.value, + ) + risk_round.outcome = "REVIEW" + host.finalize_round(risk_round) + belief_updates["risk_guard"] = BeliefUpdate( + belief={ + "status": risk_assessment.status, + "reason": risk_assessment.reason, + "recommended_action": ( + risk_assessment.recommended_action.value + if risk_assessment.recommended_action + else None + ), + }, + ) + execution_round = host.ensure_round( + host_trace, + agenda="execution_summary", + structure=GameStructure.REPEATED, + ) + exec_action = action + exec_weight = weight + exec_status = "normal" + if requires_review and risk_assessment.recommended_action: + exec_action = risk_assessment.recommended_action + exec_status = "risk_adjusted" + exec_weight = target_weight_for_action(exec_action) + execution_message = DialogueMessage( + sender="execution_engine", + role=DialogueRole.EXECUTION, + message_type=MessageType.DIRECTIVE, + content=f"执行操作 {exec_action.value}", + confidence=1.0, + annotations={ + "target_weight": exec_weight, + "requires_review": requires_review, + "execution_status": exec_status, + }, + ) + host.handle_message(execution_round, execution_message) + execution_round.outcome = exec_action.value + execution_round.notes.setdefault("execution_status", exec_status) + if exec_action is not action: + execution_round.notes.setdefault("original_action", action.value) + belief_updates["execution"] = BeliefUpdate( + belief={ + "execution_status": exec_status, + "action": exec_action.value, + "target_weight": exec_weight, + }, + ) + host.finalize_round(execution_round) + host.close(host_trace) + rounds = host_trace.rounds if host_trace.rounds else _build_round_summaries( + department_decisions, + action, + department_votes, + ) return Decision( action=action, confidence=confidence, @@ -178,6 +345,9 @@ def decide( department_decisions=department_decisions, department_votes=department_votes, requires_review=requires_review, + rounds=rounds, + risk_assessment=risk_assessment, + belief_updates=belief_updates, ) @@ -231,3 +401,145 @@ def _department_conflict_flag(votes: Mapping[str, float]) -> bool: if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1: return True return False + + +def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage: + content = decision.summary or decision.raw_response or decision.action.value + references = decision.signals or [] + annotations: Dict[str, object] = { + "risks": decision.risks, + "supplements": decision.supplements, + } + if decision.dialogue: + annotations["dialogue"] = decision.dialogue + if decision.telemetry: + annotations["telemetry"] = decision.telemetry + return DialogueMessage( + sender=code, + role=DialogueRole.PREDICTION, + message_type=MessageType.DECISION, + content=content, + confidence=decision.confidence, + references=references, + annotations=annotations, + ) + + +def _evaluate_risk( + context: AgentContext, + action: AgentAction, + department_votes: Mapping[str, float], + conflict_flag: bool, + risk_agent: Optional[RiskAgent], +) -> RiskAssessment: + external_alerts = [] + if getattr(context, "raw", None): + alerts = context.raw.get("risk_alerts", []) + if alerts: + external_alerts = list(alerts) + + if risk_agent: + recommendation = risk_agent.assess(context, action, conflict_flag) + notes = dict(recommendation.notes) + notes.setdefault("department_votes", dict(department_votes)) + if external_alerts: + notes.setdefault("external_alerts", external_alerts) + if recommendation.status == "ok": + recommendation = RiskRecommendation( + status="pending_review", + reason="external_alert", + recommended_action=recommendation.recommended_action or AgentAction.HOLD, + notes=notes, + ) + else: + recommendation.notes = notes + return RiskAssessment( + status=recommendation.status, + reason=recommendation.reason, + recommended_action=recommendation.recommended_action, + notes=recommendation.notes, + ) + + notes: Dict[str, object] = { + "conflict": conflict_flag, + "department_votes": dict(department_votes), + } + if external_alerts: + notes["external_alerts"] = external_alerts + return RiskAssessment( + status="pending_review", + reason="external_alert", + recommended_action=AgentAction.HOLD, + notes=notes, + ) + if conflict_flag: + return RiskAssessment( + status="pending_review", + reason="conflict_threshold", + notes=notes, + ) + return RiskAssessment(status="ok", reason="clear", notes=notes) + + +def _find_risk_agent(agents: Iterable[Agent]) -> Optional[RiskAgent]: + for agent in agents: + if isinstance(agent, RiskAgent): + return agent + return None + + +def _risk_review_message(reason: str) -> str: + mapping = { + "conflict_threshold": "部门意见分歧,触发风险复核", + "suspended": "标的停牌,需冻结执行", + "limit_up": "标的涨停,执行需调整", + "position_limit": "仓位限制已触发,需调整目标", + "risk_penalty_extreme": "风险评分极高,建议暂停加仓", + "risk_penalty_high": "风险评分偏高,建议复核", + "external_alert": "外部风险告警触发复核", + } + return mapping.get(reason, "触发风险复核,需人工确认") + + +def _build_round_summaries( + department_decisions: Mapping[str, DepartmentDecision], + final_action: AgentAction, + department_votes: Mapping[str, float], +) -> List[RoundSummary]: + if not department_decisions: + return [] + messages: List[DialogueMessage] = [] + for code, decision in department_decisions.items(): + content = decision.summary or decision.raw_response or decision.action.value + references = decision.signals or [] + annotations: Dict[str, object] = { + "risks": decision.risks, + "supplements": decision.supplements, + } + if decision.dialogue: + annotations["dialogue"] = decision.dialogue + if decision.telemetry: + annotations["telemetry"] = decision.telemetry + message = DialogueMessage( + sender=code, + role=DialogueRole.PREDICTION, + message_type=MessageType.DECISION, + content=content, + confidence=decision.confidence, + references=references, + annotations=annotations, + ) + messages.append(message) + notes: Dict[str, object] = { + "department_votes": dict(department_votes), + } + summary = RoundSummary( + index=0, + agenda="department_consensus", + structure=GameStructure.REPEATED, + resolved=True, + outcome=final_action.value, + messages=messages, + notes=notes, + ) + return [summary] diff --git a/app/agents/protocols.py b/app/agents/protocols.py new file mode 100644 index 0000000..f6f9b47 --- /dev/null +++ b/app/agents/protocols.py @@ -0,0 +1,270 @@ +"""Protocols and data structures for multi-round multi-agent games.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol + + +class GameStructure(str, Enum): + """Supported multi-agent game topologies.""" + + REPEATED = "repeated" + SIGNALING = "signaling" + BAYESIAN = "bayesian" + CUSTOM = "custom" + + +class DialogueRole(str, Enum): + """Roles participating in the negotiation agenda.""" + + HOST = "host" + PREDICTION = "prediction" + RISK = "risk" + EXECUTION = "execution" + OBSERVER = "observer" + + +class MessageType(str, Enum): + """High-level classification of dialogue intents.""" + + HYPOTHESIS = "hypothesis" + EVIDENCE = "evidence" + COUNTER = "counter" + DECISION = "decision" + DIRECTIVE = "directive" + META = "meta" + + +@dataclass +class DialogueMessage: + """Single utterance in the multi-round dialogue.""" + + sender: str + role: DialogueRole + message_type: MessageType + content: str + confidence: float = 0.0 + references: List[str] = field(default_factory=list) + timestamp: Optional[str] = None + annotations: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "sender": self.sender, + "role": self.role.value, + "message_type": self.message_type.value, + "content": self.content, + "confidence": self.confidence, + "references": list(self.references), + "timestamp": self.timestamp, + "annotations": dict(self.annotations), + } + + +@dataclass +class BeliefSnapshot: + """Belief state emitted by an agent before or after revision.""" + + agent: str + role: DialogueRole + belief: Dict[str, Any] + confidence: float + rationale: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "agent": self.agent, + "role": self.role.value, + "belief": dict(self.belief), + "confidence": self.confidence, + "rationale": self.rationale, + } + + +@dataclass +class BeliefRevision: + """Tracks belief updates triggered during a round.""" + + before: BeliefSnapshot + after: BeliefSnapshot + justification: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "before": self.before.to_dict(), + "after": self.after.to_dict(), + "justification": self.justification, + } + + +@dataclass +class RoundSummary: + """Aggregated view of a single negotiation round.""" + + index: int + agenda: str + structure: GameStructure + resolved: bool + outcome: Optional[str] = None + messages: List[DialogueMessage] = field(default_factory=list) + revisions: List[BeliefRevision] = field(default_factory=list) + constraints_triggered: List[str] = field(default_factory=list) + notes: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "index": self.index, + "agenda": self.agenda, + "structure": self.structure.value, + "resolved": self.resolved, + "outcome": self.outcome, + "messages": [message.to_dict() for message in self.messages], + "revisions": [revision.to_dict() for revision in self.revisions], + "constraints_triggered": list(self.constraints_triggered), + "notes": dict(self.notes), + } + + +@dataclass +class DialogueTrace: + """Ordered collection of round summaries for auditing.""" + + session_id: str + ts_code: str + trade_date: str + rounds: List[RoundSummary] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "session_id": self.session_id, + "ts_code": self.ts_code, + "trade_date": self.trade_date, + "rounds": [summary.to_dict() for summary in self.rounds], + "metadata": dict(self.metadata), + } + + +class GameProtocol(Protocol): + """Extension point for hosting multi-round agent negotiations.""" + + def bootstrap(self, trace: DialogueTrace) -> None: + """Prepare the protocol-specific state before rounds begin.""" + + def start_round(self, trace: DialogueTrace, agenda: str, structure: GameStructure) -> RoundSummary: + """Open a new round with the given agenda and structure descriptor.""" + + def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None: + """Process a single dialogue message emitted by an agent.""" + + def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None: + """Register a belief revision triggered by debate or new evidence.""" + + def finalize_round(self, summary: RoundSummary) -> None: + """Mark the round as resolved and perform protocol-specific bookkeeping.""" + + def close(self, trace: DialogueTrace) -> None: + """Finish the negotiation session and emit protocol artifacts.""" + + +class ProtocolHost(GameProtocol): + """Base implementation for agenda-driven negotiation protocols.""" + + def __init__(self) -> None: + self._current_round: Optional[RoundSummary] = None + self._trace: Optional[DialogueTrace] = None + self._round_index: int = 0 + + def bootstrap(self, trace: DialogueTrace) -> None: + trace.metadata.setdefault("host_started", True) + self._trace = trace + self._round_index = len(trace.rounds) + + def bootstrap_trace( + self, + *, + session_id: str, + ts_code: str, + trade_date: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> DialogueTrace: + trace = DialogueTrace( + session_id=session_id, + ts_code=ts_code, + trade_date=trade_date, + metadata=dict(metadata or {}), + ) + self.bootstrap(trace) + return trace + + def start_round( + self, + trace: DialogueTrace, + agenda: str, + structure: GameStructure, + ) -> RoundSummary: + index = self._round_index + summary = RoundSummary( + index=index, + agenda=agenda, + structure=structure, + resolved=False, + ) + trace.rounds.append(summary) + self._current_round = summary + self._round_index += 1 + return summary + + def handle_message(self, summary: RoundSummary, message: DialogueMessage) -> None: + summary.messages.append(message) + + def apply_revision(self, summary: RoundSummary, revision: BeliefRevision) -> None: + summary.revisions.append(revision) + + def finalize_round(self, summary: RoundSummary) -> None: + summary.resolved = True + summary.notes.setdefault("message_count", len(summary.messages)) + self._current_round = None + + def close(self, trace: DialogueTrace) -> None: + trace.metadata.setdefault("host_finished", True) + self._trace = trace + + def current_round(self) -> Optional[RoundSummary]: + return self._current_round + + @property + def trace(self) -> Optional[DialogueTrace]: + return self._trace + + def ensure_round( + self, + trace: DialogueTrace, + agenda: str, + structure: GameStructure, + ) -> RoundSummary: + if self._current_round and not self._current_round.resolved: + return self._current_round + return self.start_round(trace, agenda, structure) + + +def round_to_dict(summary: RoundSummary) -> Dict[str, Any]: + """Serialize a round summary for persistence layers.""" + + return summary.to_dict() + + +__all__ = [ + "GameStructure", + "DialogueRole", + "MessageType", + "DialogueMessage", + "BeliefSnapshot", + "BeliefRevision", + "RoundSummary", + "DialogueTrace", + "GameProtocol", + "ProtocolHost", + "round_to_dict", +] diff --git a/app/agents/risk.py b/app/agents/risk.py index 57497c9..e78f71d 100644 --- a/app/agents/risk.py +++ b/app/agents/risk.py @@ -4,6 +4,35 @@ from __future__ import annotations from .base import Agent, AgentAction, AgentContext +class RiskRecommendation: + """Represents structured recommendation from the risk agent.""" + + __slots__ = ("status", "reason", "recommended_action", "notes") + + def __init__( + self, + *, + status: str, + reason: str, + recommended_action: AgentAction | None = None, + notes: dict | None = None, + ) -> None: + self.status = status + self.reason = reason + self.recommended_action = recommended_action + self.notes = notes or {} + + def to_dict(self) -> dict: + payload = { + "status": self.status, + "reason": self.reason, + "notes": dict(self.notes), + } + if self.recommended_action is not None: + payload["recommended_action"] = self.recommended_action.value + return payload + + class RiskAgent(Agent): def __init__(self) -> None: super().__init__(name="A_risk") @@ -27,3 +56,74 @@ class RiskAgent(Agent): if context.features.get("position_limit", False) and action in (AgentAction.BUY_M, AgentAction.BUY_L): return False return True + + def assess( + self, + context: AgentContext, + decision_action: AgentAction, + conflict_flag: bool, + ) -> RiskRecommendation: + features = dict(context.features or {}) + risk_penalty = float(features.get("risk_penalty") or 0.0) + + if bool(features.get("is_suspended")): + return RiskRecommendation( + status="blocked", + reason="suspended", + recommended_action=AgentAction.HOLD, + notes={"trigger": "is_suspended"}, + ) + + if bool(features.get("limit_up")) and decision_action in { + AgentAction.BUY_S, + AgentAction.BUY_M, + AgentAction.BUY_L, + }: + return RiskRecommendation( + status="blocked", + reason="limit_up", + recommended_action=AgentAction.HOLD, + notes={"trigger": "limit_up"}, + ) + + if bool(features.get("position_limit")) and decision_action in { + AgentAction.BUY_M, + AgentAction.BUY_L, + }: + return RiskRecommendation( + status="pending_review", + reason="position_limit", + recommended_action=AgentAction.BUY_S, + notes={"trigger": "position_limit"}, + ) + + if risk_penalty >= 0.9 and decision_action in { + AgentAction.BUY_S, + AgentAction.BUY_M, + AgentAction.BUY_L, + }: + return RiskRecommendation( + status="blocked", + reason="risk_penalty_extreme", + recommended_action=AgentAction.HOLD, + notes={"risk_penalty": risk_penalty}, + ) + if risk_penalty >= 0.7 and decision_action in { + AgentAction.BUY_S, + AgentAction.BUY_M, + AgentAction.BUY_L, + }: + return RiskRecommendation( + status="pending_review", + reason="risk_penalty_high", + recommended_action=AgentAction.HOLD, + notes={"risk_penalty": risk_penalty}, + ) + + if conflict_flag: + return RiskRecommendation( + status="pending_review", + reason="conflict_threshold", + ) + + return RiskRecommendation(status="ok", reason="clear") diff --git a/app/backtest/engine.py b/app/backtest/engine.py index d9271e4..ddcb98a 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -8,7 +8,8 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from app.agents.base import AgentAction, AgentContext from app.agents.departments import DepartmentManager -from app.agents.game import Decision, decide +from app.agents.game import Decision, decide, target_weight_for_action +from app.agents.protocols import round_to_dict from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents from app.data.schema import initialize_database @@ -16,6 +17,7 @@ from app.utils.data_access import DataBroker from app.utils.config import get_config from app.utils.db import db_session from app.utils.logging import get_logger +from app.utils import alerts from app.core.indicators import momentum, normalize, rolling_mean, volatility @@ -436,6 +438,7 @@ class BacktestEngine: ) ) + round_payload = [round_to_dict(summary) for summary in decision.rounds] global_payload = { "_confidence": decision.confidence, "_target_weight": decision.target_weight, @@ -459,6 +462,12 @@ class BacktestEngine: for code, dept in decision.department_decisions.items() if dept.telemetry }, + "_rounds": round_payload, + "_risk_assessment": ( + decision.risk_assessment.to_dict() + if decision.risk_assessment + else None + ), } rows.append( ( @@ -543,18 +552,42 @@ class BacktestEngine: executed_trades: List[Dict[str, Any]] = [] risk_events: List[Dict[str, Any]] = [] - def _record_risk(ts_code: str, reason: str, decision: Decision, extra: Optional[Dict[str, Any]] = None) -> None: + def _record_risk( + ts_code: str, + reason: str, + decision: Decision, + extra: Optional[Dict[str, Any]] = None, + action_override: Optional[AgentAction] = None, + target_weight_override: Optional[float] = None, + ) -> None: payload = { "trade_date": trade_date_str, "ts_code": ts_code, - "action": decision.action.value, - "target_weight": decision.target_weight, + "action": (action_override or decision.action).value, + "target_weight": ( + target_weight_override + if target_weight_override is not None + else decision.target_weight + ), "confidence": decision.confidence, "reason": reason, } if extra: payload.update(extra) risk_events.append(payload) + risk_meta = payload.get("risk_assessment") if isinstance(payload.get("risk_assessment"), dict) else extra.get("risk_assessment") if extra else None + status = None + if isinstance(risk_meta, dict): + status = risk_meta.get("status") + if status == "blocked": + try: + alerts.add_warning( + "backtest_risk", + f"{ts_code} 风险阻断: {reason}", + detail=json.dumps(payload, ensure_ascii=False), + ) + except Exception: # noqa: BLE001 + LOGGER.debug("记录风险告警失败", extra=LOG_EXTRA) for ts_code, decision in decisions_map.items(): price = price_map.get(ts_code) @@ -569,35 +602,54 @@ class BacktestEngine: limit_down = bool(features.get("limit_down")) position_limit = bool(features.get("position_limit")) + risk = decision.risk_assessment + effective_action = decision.action + effective_weight = decision.target_weight + if risk: + risk_payload = risk.to_dict() + risk_payload.setdefault("applied_action", effective_action.value) + if risk.recommended_action: + effective_action = risk.recommended_action + risk_payload["applied_action"] = effective_action.value + effective_weight = target_weight_for_action(effective_action) + if risk.status != "ok": + _record_risk( + ts_code, + risk.reason, + decision, + extra={"risk_assessment": risk_payload}, + action_override=effective_action, + target_weight_override=effective_weight, + ) + if risk.status == "blocked": + continue + if is_suspended: _record_risk(ts_code, "suspended", decision) continue - if decision.action in self._buy_actions: + if effective_action in self._buy_actions: if limit_up: - _record_risk(ts_code, "limit_up", decision) + _record_risk(ts_code, "limit_up", decision, action_override=effective_action) continue if position_limit: - _record_risk(ts_code, "position_limit", decision) + _record_risk(ts_code, "position_limit", decision, action_override=effective_action) continue - if risk_penalty >= 0.95: - _record_risk(ts_code, "risk_penalty", decision, {"risk_penalty": risk_penalty}) - continue - if decision.action in self._sell_actions and limit_down: - _record_risk(ts_code, "limit_down", decision) + if effective_action in self._sell_actions and limit_down: + _record_risk(ts_code, "limit_down", decision, action_override=effective_action) continue - effective_weight = max(decision.target_weight, 0.0) - if decision.action in self._buy_actions: + effective_weight_value = max(effective_weight, 0.0) + if effective_action in self._buy_actions: capped_weight = min(effective_weight, self.risk_params["max_position_weight"]) - effective_weight = capped_weight * max(0.0, 1.0 - risk_penalty) - elif decision.action in self._sell_actions: - effective_weight = 0.0 + effective_weight_value = capped_weight * max(0.0, 1.0 - risk_penalty) + elif effective_action in self._sell_actions: + effective_weight_value = 0.0 desired_qty = current_qty - if decision.action in self._sell_actions: + if effective_action in self._sell_actions: desired_qty = 0.0 - elif decision.action in self._buy_actions or effective_weight >= 0.0: - desired_value = max(effective_weight, 0.0) * portfolio_value_before + elif effective_action in self._buy_actions or effective_weight_value > 0.0: + desired_value = max(effective_weight_value, 0.0) * portfolio_value_before desired_qty = desired_value / price if price > 0 else current_qty delta = desired_qty - current_qty @@ -654,7 +706,8 @@ class BacktestEngine: "slippage": trade_price - price, "confidence": decision.confidence, "target_weight": decision.target_weight, - "effective_weight": effective_weight, + "effective_weight": effective_weight_value, + "effective_action": effective_action.value, "risk_penalty": risk_penalty, "liquidity_score": liquidity_score, "status": "executed", @@ -694,7 +747,8 @@ class BacktestEngine: "slippage": price - trade_price, "confidence": decision.confidence, "target_weight": decision.target_weight, - "effective_weight": effective_weight, + "effective_weight": effective_weight_value, + "effective_action": effective_action.value, "risk_penalty": risk_penalty, "liquidity_score": liquidity_score, "realized_pnl": realized, diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 8290aac..1636b00 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -201,6 +201,7 @@ def render_backtest_review() -> None: selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] nav_df = pd.DataFrame() rpt_df = pd.DataFrame() + risk_df = pd.DataFrame() if selected_ids: try: with db_session(read_only=True) as conn: @@ -214,11 +215,20 @@ def render_backtest_review() -> None: conn, params=tuple(selected_ids), ) + risk_df = pd.read_sql_query( + "SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata " + "FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), + conn, + params=tuple(selected_ids), + ) except Exception: # noqa: BLE001 LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) st.error("读取回测结果失败") nav_df = pd.DataFrame() rpt_df = pd.DataFrame() + risk_df = pd.DataFrame() + start_filter: Optional[date] = None + end_filter: Optional[date] = None if not nav_df.empty: try: nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") @@ -274,6 +284,9 @@ def render_backtest_review() -> None: "交易数": summary.get("trade_count"), "平均换手": summary.get("avg_turnover"), "风险事件": summary.get("risk_events"), + "风险分布": json.dumps(summary.get("risk_breakdown"), ensure_ascii=False) + if summary.get("risk_breakdown") + else None, } metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) if metrics_rows: @@ -291,6 +304,70 @@ def render_backtest_review() -> None: pass except Exception: # noqa: BLE001 LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) + if not risk_df.empty: + try: + risk_df["trade_date"] = pd.to_datetime(risk_df["trade_date"], errors="coerce") + risk_df = risk_df.dropna(subset=["trade_date"]) + if start_filter is None or end_filter is None: + start_filter = pd.to_datetime(risk_df["trade_date"].min()).date() + end_filter = pd.to_datetime(risk_df["trade_date"].max()).date() + risk_df = risk_df[ + (risk_df["trade_date"].dt.date >= start_filter) + & (risk_df["trade_date"].dt.date <= end_filter) + ] + parsed_cols: List[Dict[str, object]] = [] + for _, row in risk_df.iterrows(): + try: + metadata = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else (row["metadata"] or {}) + except json.JSONDecodeError: + metadata = {} + assessment = metadata.get("risk_assessment") or {} + parsed_cols.append( + { + "cfg_id": row["cfg_id"], + "trade_date": row["trade_date"].date().isoformat(), + "ts_code": row["ts_code"], + "reason": row["reason"], + "action": row["action"], + "target_weight": row["target_weight"], + "confidence": row["confidence"], + "risk_status": assessment.get("status"), + "recommended_action": assessment.get("recommended_action"), + "execution_status": metadata.get("execution_status"), + "metadata": metadata, + } + ) + risk_detail_df = pd.DataFrame(parsed_cols) + with st.expander("风险事件明细", expanded=False): + st.dataframe(risk_detail_df.drop(columns=["metadata"], errors="ignore"), hide_index=True, width='stretch') + try: + st.download_button( + "下载风险事件(CSV)", + data=risk_detail_df.to_csv(index=False), + file_name="bt_risk_events.csv", + mime="text/csv", + key="dl_risk_events", + ) + except Exception: + pass + agg = risk_detail_df.groupby(["cfg_id", "reason", "risk_status"], dropna=False).size().reset_index(name="count") + st.dataframe(agg, hide_index=True, width='stretch') + try: + if not agg.empty: + agg_fig = px.bar( + agg, + x="reason", + y="count", + color="risk_status", + facet_col="cfg_id", + title="风险事件分布", + ) + agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20)) + st.plotly_chart(agg_fig, use_container_width=True) + except Exception: # noqa: BLE001 + LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA) + except Exception: # noqa: BLE001 + LOGGER.debug("渲染风险事件失败", extra=LOG_EXTRA) else: st.info("请选择至少一个配置进行对比。") diff --git a/app/utils/data_access.py b/app/utils/data_access.py index cae6cc1..eebf09d 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -14,6 +14,7 @@ from .config import get_config from .db import db_session from .logging import get_logger from app.core.indicators import momentum, normalize, rolling_mean, volatility +from app.utils.db_query import BrokerQueryEngine # 延迟导入,避免循环依赖 collect_data_coverage = None @@ -60,6 +61,39 @@ def _safe_split(path: str) -> Tuple[str, str] | None: return table, column +@dataclass +class _RefreshCoordinator: + """Orchestrates background refresh requests for the broker.""" + + broker: "DataBroker" + + def ensure_for_latest(self, trade_date: str, fields: Iterable[str]) -> None: + parsed_date = _parse_trade_date(trade_date) + if not parsed_date: + return + normalized = parsed_date.strftime("%Y%m%d") + tables = self._collect_tables(fields) + if tables and self.broker.check_data_availability(normalized, tables): + self.broker._trigger_background_refresh(normalized) + + def ensure_for_series(self, end_date: str, table: str) -> None: + parsed_date = _parse_trade_date(end_date) + if not parsed_date: + return + normalized = parsed_date.strftime("%Y%m%d") + if self.broker.check_data_availability(normalized, {table}): + self.broker._trigger_background_refresh(normalized) + + def _collect_tables(self, fields: Iterable[str]) -> Set[str]: + tables: Set[str] = set() + for field_name in fields: + resolved = self.broker.resolve_field(field_name) + if resolved: + table, _ = resolved + tables.add(table) + return tables + + def parse_field_path(path: str) -> Tuple[str, str] | None: """Validate and split a `table.column` field expression.""" @@ -87,6 +121,17 @@ def _end_of_day(dt: datetime) -> str: return dt.strftime("%Y-%m-%d 23:59:59") +def _coerce_date(value: object) -> Optional[date]: + if value is None: + return None + if isinstance(value, date): + return value + parsed = _parse_trade_date(value) + if parsed: + return parsed.date() + return None + + @dataclass class DataBroker: """Lightweight data access helper with automated data fetching capabilities.""" @@ -130,6 +175,8 @@ class DataBroker: _refresh_in_progress: Dict[str, bool] = field(init=False, repr=False) _refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False) _coverage_cache: Dict[str, Dict] = field(init=False, repr=False) + _refresh: _RefreshCoordinator = field(init=False, repr=False) + _query_engine: BrokerQueryEngine = field(init=False, repr=False) def __post_init__(self) -> None: self._latest_cache = OrderedDict() @@ -139,6 +186,8 @@ class DataBroker: self._refresh_in_progress = {} self._refresh_callbacks = {} self._coverage_cache = {} + self._refresh = _RefreshCoordinator(self) + self._query_engine = BrokerQueryEngine(db_session) if initialize_database is not None: initialize_database() # 确保数据库已初始化 else: @@ -169,22 +218,7 @@ class DataBroker: # 检查是否需要自动补数 if auto_refresh: - # 解析交易日以确定是否需要补数 - parsed_date = _parse_trade_date(trade_date) - if parsed_date: - # 检查最近交易日的数据是否存在 - recent_trade_date = parsed_date.strftime('%Y%m%d') - # 对涉及的表进行数据可用性检查 - tables = set() - for field_name in field_list: - resolved = self.resolve_field(field_name) - if resolved: - table, _ = resolved - tables.add(table) - - if tables and self.check_data_availability(recent_trade_date, tables): - # 数据不足,触发后台补数 - self._trigger_background_refresh(recent_trade_date) + self._refresh.ensure_for_latest(trade_date, field_list) grouped: Dict[str, List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {} @@ -208,59 +242,41 @@ class DataBroker: grouped[table].append(column) field_map.setdefault((table, column), []).append(field_name) - if not grouped: - if cache_key is not None and results: - self._cache_store( - self._latest_cache, - cache_key, - deepcopy(results), - self.latest_cache_size, - ) - return results - - try: - with db_session(read_only=True) as conn: - for table, columns in grouped.items(): - joined_cols = ", ".join(columns) - query = ( - f"SELECT trade_date, {joined_cols} FROM {table} " - "WHERE ts_code = ? AND trade_date <= ? " - "ORDER BY trade_date DESC LIMIT 1" - ) - try: - row = conn.execute(query, (ts_code, trade_date)).fetchone() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "查询失败 table=%s fields=%s err=%s", - table, - columns, - exc, - extra=LOG_EXTRA, - ) - continue - if not row: - continue - for column in columns: - value = row[column] - if value is None: - continue - for original in field_map.get((table, column), [f"{table}.{column}"]): - try: - results[original] = float(value) - except (TypeError, ValueError): - results[original] = value - except sqlite3.OperationalError as exc: - LOGGER.debug("数据库只读连接失败:%s", exc, extra=LOG_EXTRA) - if cache_key is not None: - cached = self._cache_lookup(self._latest_cache, cache_key) - if cached is not None: + if grouped: + for table, columns in grouped.items(): + try: + row = self._query_engine.fetch_latest(table, ts_code, trade_date, columns) + except Exception as exc: # noqa: BLE001 LOGGER.debug( - "使用缓存结果 ts_code=%s trade_date=%s", - ts_code, - trade_date, + "查询失败 table=%s fields=%s err=%s", + table, + columns, + exc, extra=LOG_EXTRA, ) - return deepcopy(cached) + continue + if not row: + continue + for column in columns: + value = row[column] + if value is None: + continue + for original in field_map.get((table, column), [f"{table}.{column}"]): + try: + results[original] = float(value) + except (TypeError, ValueError): + results[original] = value + + if cache_key is not None and not results: + cached = self._cache_lookup(self._latest_cache, cache_key) + if cached is not None: + LOGGER.debug( + "使用缓存结果 ts_code=%s trade_date=%s", + ts_code, + trade_date, + extra=LOG_EXTRA, + ) + return deepcopy(cached) if cache_key is not None and results: self._cache_store( self._latest_cache, @@ -306,9 +322,7 @@ class DataBroker: # 检查是否需要自动补数 if auto_refresh: - parsed_date = _parse_trade_date(end_date) - if parsed_date and self.check_data_availability(end_date, {table}): - self._trigger_background_refresh(end_date) + self._refresh.ensure_for_series(end_date, table) cache_key: Optional[Tuple[Any, ...]] = None if self.enable_cache: @@ -323,21 +337,10 @@ class DataBroker: "ORDER BY trade_date DESC LIMIT ?" ) try: - with db_session(read_only=True) as conn: - try: - rows = conn.execute(query, (ts_code, end_date, window)).fetchall() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "时间序列查询失败 table=%s column=%s err=%s", - table, - column, - exc, - extra=LOG_EXTRA, - ) - return [] - except sqlite3.OperationalError as exc: + rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window) + except Exception as exc: # noqa: BLE001 LOGGER.debug( - "时间序列连接失败 table=%s column=%s err=%s", + "时间序列查询失败 table=%s column=%s err=%s", table, column, exc, @@ -358,9 +361,13 @@ class DataBroker: series: List[Tuple[str, float]] = [] for row in rows: value = row[resolved] - if value is None: + trade_dt = row["trade_date"] + if value is None or trade_dt is None: + continue + try: + series.append((trade_dt, float(value))) + except (TypeError, ValueError): continue - series.append((row["trade_date"], float(value))) if cache_key is not None and series: self._cache_store( self._series_cache, @@ -370,6 +377,32 @@ class DataBroker: ) return series + def register_refresh_callback( + self, + start: date | str, + end: date | str, + callback: Callable[[], None], + ) -> None: + """Register a hook invoked after background refresh completes for the window.""" + + if callback is None: + return + start_date = _coerce_date(start) + end_date = _coerce_date(end) + if not start_date or not end_date: + LOGGER.debug( + "忽略无效补数回调窗口 start=%s end=%s", + start, + end, + extra=LOG_EXTRA, + ) + return + key = f"{start_date}_{end_date}" + with self._refresh_lock: + bucket = self._refresh_callbacks.setdefault(key, []) + if callback not in bucket: + bucket.append(callback) + def fetch_flags( self, table: str, @@ -436,48 +469,19 @@ class DataBroker: LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA) return [] - column_list = ", ".join(columns) - has_trade_date = "trade_date" in columns - if has_trade_date: - query = ( - f"SELECT {column_list} FROM {table} " - "WHERE ts_code = ? AND trade_date <= ? " - "ORDER BY trade_date DESC LIMIT ?" - ) - params: Tuple[object, ...] = (ts_code, trade_date, window) - else: - query = ( - f"SELECT {column_list} FROM {table} " - "WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?" - ) - params = (ts_code, window) - - results: List[Dict[str, object]] = [] try: - with db_session(read_only=True) as conn: - try: - rows = conn.execute(query, params).fetchall() - except Exception as exc: # noqa: BLE001 - LOGGER.debug( - "表查询失败 table=%s err=%s", - table, - exc, - extra=LOG_EXTRA, - ) - return [] - except sqlite3.OperationalError as exc: - LOGGER.debug( - "表连接失败 table=%s err=%s", + rows = self._query_engine.fetch_table( table, - exc, - extra=LOG_EXTRA, + columns, + ts_code, + trade_date if "trade_date" in columns else None, + window, ) + except Exception as exc: # noqa: BLE001 + LOGGER.debug("表查询失败 table=%s err=%s", table, exc, extra=LOG_EXTRA) return [] - for row in rows: - record = {col: row[col] for col in columns} - results.append(record) - return results + return [{col: row[col] for col in columns} for row in rows] def _resolve_derived_field( self, @@ -924,13 +928,20 @@ class DataBroker: with self._refresh_lock: callbacks = self._refresh_callbacks.pop(refresh_key, []) self._refresh_in_progress[refresh_key] = False - + + if callbacks: + LOGGER.info( + "执行补数回调 count=%s key=%s", + len(callbacks), + refresh_key, + extra=LOG_EXTRA, + ) for callback in callbacks: try: callback() except Exception as exc: LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA) - + except Exception as exc: LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA) with self._refresh_lock: diff --git a/app/utils/db_query.py b/app/utils/db_query.py new file mode 100644 index 0000000..860c291 --- /dev/null +++ b/app/utils/db_query.py @@ -0,0 +1,73 @@ +"""Shared read-only query helpers for database access.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, List, Mapping, Optional, Sequence + + +@dataclass +class BrokerQueryEngine: + """Lightweight wrapper around standard query patterns.""" + + session_factory: Callable[..., object] + + def fetch_latest( + self, + table: str, + ts_code: str, + trade_date: str, + columns: Sequence[str], + ) -> Optional[Mapping[str, object]]: + if not columns: + return None + joined_cols = ", ".join(columns) + query = ( + f"SELECT trade_date, {joined_cols} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT 1" + ) + with self.session_factory(read_only=True) as conn: + return conn.execute(query, (ts_code, trade_date)).fetchone() + + def fetch_series( + self, + table: str, + column: str, + ts_code: str, + end_date: str, + limit: int, + ) -> List[Mapping[str, object]]: + query = ( + f"SELECT trade_date, {column} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT ?" + ) + with self.session_factory(read_only=True) as conn: + rows = conn.execute(query, (ts_code, end_date, limit)).fetchall() + return list(rows) + + def fetch_table( + self, + table: str, + columns: Iterable[str], + ts_code: str, + trade_date: Optional[str], + limit: int, + ) -> List[Mapping[str, object]]: + cols = ", ".join(columns) + if trade_date is None: + query = ( + f"SELECT {cols} FROM {table} " + "WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?" + ) + params: Sequence[object] = (ts_code, limit) + else: + query = ( + f"SELECT {cols} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT ?" + ) + params = (ts_code, trade_date, limit) + with self.session_factory(read_only=True) as conn: + rows = conn.execute(query, params).fetchall() + return list(rows) diff --git a/docs/RISK_AGENT_PLAN.md b/docs/RISK_AGENT_PLAN.md new file mode 100644 index 0000000..83c6f21 --- /dev/null +++ b/docs/RISK_AGENT_PLAN.md @@ -0,0 +1,41 @@ +# 风险代理集成规划 + +## 目标 +- 将 `risk_guard` 回合由占位消息升级为可执行决策:低置信度的部门共识需自动进入复核流程,根据风险因子、仓位约束、外部告警调整最终指令。 +- 兼容历史回测与实时监控,确保补数、验证和告警闭环共享风险上下文。 + +## 数据与信号需求 +- **输入特征**:`risk_penalty`、`position_limit`、`is_suspended` 等现有布尔/得分信号;新增日内波动、VaR、行业集中度、外部事件标签(停牌预警、合规黑名单)。 +- **实时事件**:补数触发的异常、执行失败回报、仓位越界日志;需要通过 `DataBroker.register_refresh_callback()` 和交易执行层的回调管理。 +- **上下文结构**:在 `AgentContext.raw` 中携带 `risk_flags`、`compliance_notes`,供主持器和前端展示。 + +## 决策流程改造 +1. **风险评估阶段** + - 在部门回合后触发 `risk_round`,由 `RiskAgent` 复写 `Decision` 的 `requires_review` 与 `target_weight`,必要时生成“回滚/减仓”建议。 + - 引入 `RiskAssessment` 数据类,存储风险来源、建议操作、置信度,序列化到 `Decision.rounds`。 +2. **执行协调阶段** + - 若风险回合给出回滚指令,则 `execution_round` 应记录“冻结执行”或“调整仓位”而非直接落地。 + - 将风险回合结论写入 `risk_events` 表(或新建),供 UI 与监控使用。 +3. **日志与监控** + - 在 `risk_round` 中附加 `annotations` 字段:`{"breach_metrics": {...}, "actions": [...]}`。 + - 通过补数回调和执行回调,将风险事件推送到监控指标,如“复核触发率”“回滚成功率”。 + +## 实施里程碑 +1. **原型阶段** + - ✅ 重构 `RiskAgent` 以返回策略建议(持仓调整、止损触发)。 + - ✅ 扩展 `Decision` 结构,增加 `risk_assessment` 字段。 + - ✅ 更新 `ProtocolHost` 将风险建议纳入 `risk_round.notes`。 +2. **集成验证** + - 在回测环境构造冲突样例,验证“冲突→风险建议→执行调整”链路。 + - 新增测试:模拟停牌、仓位超限、黑名单事件,确认风险代理逻辑。 +3. **实盘准备** + - 接入实时告警渠道(如风控系统/合规接口)。 + - 监控接入:统计复核频率、回滚动作、失败告警。 +4. **上线迭代** + - 影子运行记录风险建议与实际执行差异。 + - 评估模型表现,迭代规则或引入强化学习风险控制策略。 + +## 未决事项 +- 风险代理与执行模块的数据同步接口(数据库/消息队列)。 +- 与合规团队确认风险阈值与回滚条件。 +- 定义UI展示:风险回合的建议、证据引用、执行结果。 diff --git a/docs/TODO.md b/docs/TODO.md index bebfb6e..7f36eec 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -27,6 +27,75 @@ - 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。 - 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。 - 完善 `BacktestEngine` 的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端,形成无人值守的自动闭环。 +- 新增多智能体多轮逻辑博弈线路:把投资场景建成重复博弈,设定主持、预测、风险等角色分轮交换信号、提出假设、相互反驳,再由执行代理落地操作;通过信念修正规则与合规约束驱动策略共识,而非依赖优化搜索。 + +### 3.2 多智能体多轮逻辑博弈实施方案 +1. 角色与知识库建模 + - 定义主持(议程控制)、预测(市场观点)、风险(合规阈值)、执行(指令生成)等核心代理,并划分共享/私有信息。 + - 为每个代理建立信号源与知识库接口(行情、新闻、风险限额),设计可信度权重与更新频率。 + - 预设重复博弈、信号博弈、贝叶斯博弈等多种结构模板,支持根据市场场景切换或组合,保障架构拓展性。 +2. 对话协议与消息格式 + - 设计多轮流程:议程发布→观点陈述→证据提交→反驳与驳回→共识决议→执行提示。 + - 约定消息 schema(动作类型、置信度、引用证据)及主持代理的轮次控制逻辑,确保所有代理均可追踪历史发言。 +3. 信念修正与逻辑推理引擎 + - 实现迭代信念修正模块,结合可信度权重与显式逻辑规则更新各代理的市场信念。 + - 引入论证框架(Argumentation Framework)或模态逻辑规则库,支持观点检验、冲突解决与风险约束自动否决。 +4. 模拟与测试环境 + - 搭建历史回测驱动的多轮博弈仿真环境,让代理在真实行情序列上互动。 + - 编写回放脚本记录每轮对话、信念轨迹与最终行动,产出可视化报告用于策略审查。 +5. 执行映射与风控集成 + - 将共识决议映射到具体操作(换仓、对冲、仓位调节),并与 `BacktestEngine`/实盘执行接口打通。 + - 风险代理负责实时校验 VaR、仓位集中度等阈值,违规时触发主持代理重新议程或引入驳回指令。 +6. 监控与评估指标 + - 设计协作质量指标(轮次收敛时间、冲突率、执行合规度)和业绩指标(收益、回撤、超额收益稳定性)。 + - 将指标纳入监控面板,与现有回测、实盘监控体系统一展示。 +7. 增量迭代计划 + - 从线下模块联调 → 回测闭环 → 影子运行 → 小资金试点的节奏推进,每阶段设定成功退出标准。 + - 收集各阶段的失败案例与风险事件,迭代规则库、信念修正规则与角色职责。 + +### 3.3 代码改造计划(多轮博弈适配) +1. 架构基线评估 + - ⏳ 绘制代理/部门/回测调用图,补充日志字段(缺数告警、补数来源、议程标识)并形成诊断报告。 + - ✅ 定义多轮博弈上下文结构(消息历史、信念状态、引用证据),输出数据类与通信协议草稿。 + - ✅ 在 `app/agents/protocols.py` 基础上补充主持/执行状态管理,实现 `DialogueTrace` 与部门上下文的对接路径。 + - ✅ 扩展 `Decision.rounds` 与 `RoundSummary` 采集策略,用于串联部门结论与多轮议程结果。 + - ✅ 基于 `ProtocolHost` 设计主持驱动的议程模板,明确多结构策略扩展方式。 + - ✅ 扩展主持对风险复核、执行总结等议程的支持,规划风控代理与回滚逻辑。 + - ✅ 落地风险议程触发条件(冲突阈值、风控代理信号、外部警报),同步更新回测策略。 + - ✅ 结合 `docs/RISK_AGENT_PLAN.md` 明确 RiskAgent 升级里程碑、数据接口及前端展示需求。 + - ✅ 定义 `risk_assessment` 数据结构(状态、原因、建议动作),在决策输出与存储层保持一致。 + - ✅ 将风险建议与执行回合对齐,执行阶段识别 `risk_adjusted` 并记录原始动作。 +2. 数据与因子重构 + - ✅ 拆分 `DataBroker` 查询层(`BrokerQueryEngine`),补数逻辑独立于查询管道。 + - ⏳ 按主题拆分因子模块,存储缺口/异常标签,改写 `load_market_data()` 为“缺失即说明”。 + - ⏳ 维护博弈结构 → 数据 scope 映射,支持角色按结构加载差异化字段。 + - ✅ 基于 `_RefreshCoordinator` 落地刷新队列与监控事件,拆分查询与补数路径。 + - ✅ 暴露 `DataBroker.register_refresh_callback()` 钩子,结合监控系统记录补数进度与失败重试。 + - ⏳ 统一补数回调日志格式(`LOG_EXTRA.stage=data_broker`),为后续指标预留数据源。 +3. 多轮博弈框架 + - ✅ 在 `app/agents/game.py` 抽象 `GameProtocol` 接口,扩展 `Decision` 记录多轮对话。 + - ✅ 实现主持调度器驱动议程(信息→陈述→反驳→共识→执行),挂载风险复核机制。 + - ⏳ 引入信念修正规则与论证框架,支持证据引用和冲突裁决。 +4. 执行与回测集成 + - ✅ 将回测循环改造成“每日多轮→执行摘要”,完成风控校验与冲突重议流程。 + - ⏳ 擦合订单映射层,明确多轮结果对应目标仓位、执行节奏、异常回滚策略。 + - ✅ 回测执行根据 `risk_assessment` 调整动作与目标权重,记录执行状态。 +5. 监控与可视化 + - ⏳ 在数据库/UI 增补轮次日志、信念轨迹、冲突原因与执行结果的可视化。 + - ⏳ 定义多结构绩效指标并纳入监控面板(回测视图已提供风险事件统计与分布图)。 + - ⏳ 补充补数回调监控告警策略(成功率、排队时长、重试次数)。 + - ✅ 回测视图展示 `_risk_assessment` 与 `execution_status`,后续扩展到实时监控面板。 + - ⏳ 完善风险告警(`alerts.backtest_risk`)阈值与通知策略并扩展至实盘。 +6. 测试与迭代治理 + - ⏳ 构建历史行情驱动的多轮仿真回放,覆盖结构切换与异常场景。 + - 新增单元/集成测试:议程流程、信念修正、补数竞态、执行映射。 + - ✅ 风险执行集成测试完成(`tests/test_decision_risk_integration.py`),验证风险调节与执行日志。 + - 规划上线节奏:线下 PoC → 回测闭环 → 影子运行 → 小资金试点,记录失败案例反馈改进。 + - 在测试计划中涵盖风险/执行议程的触发条件与回滚路径,确保主持扩展的行为可验证。 + - 新增针对风险复核和执行议程的集成测试场景,覆盖“冲突→复核→执行/回滚”链路。 + - 对照 `docs/RISK_AGENT_PLAN.md` 拆解阶段性验收项:数据校验、监控指标、影子运行报告。 + - 为 `risk_assessment` 输出设计单元测试(冲突触发、无冲突、外部告警模拟),确保状态与原因正确覆盖。 + - 编写执行阶段测试,验证 `risk_adjusted` 状态下的目标仓位与日志记录。 ### 3.1 实施步骤(建议顺序) 1. 环境重构:扩展 `DecisionEnv` 支持逐日状态/动作/奖励,完善 `BacktestEngine` 的状态保存与恢复接口,并补充必要的数据库读写钩子。 diff --git a/tests/test_decision_risk_integration.py b/tests/test_decision_risk_integration.py new file mode 100644 index 0000000..ed12a5f --- /dev/null +++ b/tests/test_decision_risk_integration.py @@ -0,0 +1,119 @@ +"""Integration-style tests for risk-aware decision execution.""" +from __future__ import annotations + +from datetime import date + +import pytest + +from app.agents.base import AgentAction, AgentContext +from app.agents.game import decide +from app.agents.registry import default_agents +from app.agents.risk import RiskAgent, RiskRecommendation +from app.backtest.engine import BacktestEngine, BacktestResult, BtConfig, PortfolioState + + +def _make_context(features: dict, *, alerts: list[str] | None = None) -> AgentContext: + raw = { + "scope_values": {"daily.close": 100.0}, + } + if alerts: + raw["risk_alerts"] = alerts + return AgentContext( + ts_code="000001.SZ", + trade_date="2025-01-10", + features=features, + market_snapshot={}, + raw=raw, + ) + + +class _StubRiskAgent(RiskAgent): + def __init__(self, recommendation: RiskRecommendation) -> None: + super().__init__() + self._recommendation = recommendation + + def assess( + self, + context: AgentContext, + decision_action: AgentAction, + conflict_flag: bool, + ) -> RiskRecommendation: + return self._recommendation + + +def test_decide_adjusts_execution_on_risk_recommendation(monkeypatch): + agents = default_agents() + recommendation = RiskRecommendation( + status="blocked", + reason="risk_penalty_extreme", + recommended_action=AgentAction.HOLD, + notes={"risk_penalty": 0.95}, + ) + stub_risk = _StubRiskAgent(recommendation) + agents = [stub_risk if isinstance(agent, RiskAgent) else agent for agent in agents] + + context = _make_context({"risk_penalty": 0.95}) + + decision = decide( + context, + agents, + weights={agent.name: 1.0 for agent in agents}, + department_manager=None, + ) + assert decision.requires_review is True + assert decision.risk_assessment + assert decision.risk_assessment.status == "blocked" + assert decision.risk_assessment.recommended_action == AgentAction.HOLD + + execution_rounds = [round for round in decision.rounds if round.agenda == "execution_summary"] + assert execution_rounds + execution_notes = execution_rounds[0].notes + assert execution_notes.get("execution_status") == "risk_adjusted" + assert execution_rounds[0].outcome == AgentAction.HOLD.value + + +def test_backtest_engine_applies_risk_adjusted_execution(monkeypatch): + cfg = BtConfig( + id="risk-test", + name="risk-test", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState(cash=100_000.0) + result = BacktestResult() + + context = _make_context({"risk_penalty": 0.95}) + recommendation = RiskRecommendation( + status="blocked", + reason="risk_penalty_extreme", + recommended_action=AgentAction.HOLD, + notes={"risk_penalty": 0.95}, + ) + + agents = [ + _StubRiskAgent(recommendation) if isinstance(agent, RiskAgent) else agent + for agent in engine.agents + ] + engine.agents = agents + engine.department_manager = None + + decision = decide( + context, + engine.agents, + engine.weights, + department_manager=None, + ) + + engine._apply_portfolio_updates( + date(2025, 1, 10), + state, + [("000001.SZ", context, decision)], + result, + ) + + assert not state.holdings + assert not result.trades + assert result.nav_series[0]["nav"] == pytest.approx(100_000.0) diff --git a/tests/test_risk_agent.py b/tests/test_risk_agent.py new file mode 100644 index 0000000..019aa78 --- /dev/null +++ b/tests/test_risk_agent.py @@ -0,0 +1,63 @@ +"""Tests for RiskAgent assessment and risk evaluation pipeline.""" +from __future__ import annotations + +from app.agents.base import AgentAction, AgentContext +from app.agents.game import _evaluate_risk +from app.agents.risk import RiskAgent, RiskRecommendation + + +def _make_context(**features: float | bool) -> AgentContext: + return AgentContext( + ts_code="000001.SZ", + trade_date="2025-01-01", + features=features, + market_snapshot={}, + raw={}, + ) + + +def test_risk_agent_ok_status() -> None: + agent = RiskAgent() + context = _make_context(risk_penalty=0.1) + recommendation = agent.assess(context, AgentAction.BUY_S, conflict_flag=False) + assert recommendation.status == "ok" + assert recommendation.reason == "clear" + + +def test_risk_agent_blocked_on_limit_up() -> None: + agent = RiskAgent() + context = _make_context(limit_up=True) + recommendation = agent.assess(context, AgentAction.BUY_M, conflict_flag=False) + assert recommendation.status == "blocked" + assert recommendation.reason == "limit_up" + assert recommendation.recommended_action == AgentAction.HOLD + + +def test_risk_agent_pending_on_conflict() -> None: + agent = RiskAgent() + context = _make_context() + recommendation = agent.assess(context, AgentAction.HOLD, conflict_flag=True) + assert recommendation.status == "pending_review" + assert recommendation.reason == "conflict_threshold" + + +def test_evaluate_risk_external_alerts() -> None: + agent = RiskAgent() + context = AgentContext( + ts_code="000002.SZ", + trade_date="2025-01-01", + features={"risk_penalty": 0.1}, + market_snapshot={}, + raw={"risk_alerts": ["sudden_news"]}, + ) + assessment = _evaluate_risk( + context=context, + action=AgentAction.HOLD, + department_votes={"buy": 0.6}, + conflict_flag=False, + risk_agent=agent, + ) + assert assessment.status == "pending_review" + assert assessment.reason == "external_alert" + assert assessment.recommended_action == AgentAction.HOLD + assert "external_alerts" in assessment.notes