llm-quant/app/agents/game.py

556 lines
20 KiB
Python

"""Multi-agent decision game implementation."""
from __future__ import annotations
from dataclasses import dataclass, field
from math import log
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map
from .beliefs import BeliefRevisionResult, revise_beliefs
from .risk import RiskAgent, RiskRecommendation
from .protocols import (
DialogueMessage,
DialogueRole,
GameStructure,
MessageType,
ProtocolHost,
RoundSummary,
)
ACTIONS: Tuple[AgentAction, ...] = (
AgentAction.SELL,
AgentAction.HOLD,
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
)
def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
@dataclass
class BeliefUpdate:
belief: Dict[str, object]
rationale: Optional[str] = None
@dataclass
class RiskAssessment:
status: str
reason: str
recommended_action: Optional[AgentAction] = None
notes: Dict[str, object] = field(default_factory=dict)
def to_dict(self) -> Dict[str, object]:
payload: Dict[str, object] = {
"status": self.status,
"reason": self.reason,
"notes": dict(self.notes),
}
if self.recommended_action is not None:
payload["recommended_action"] = self.recommended_action.value
return payload
@dataclass
class Decision:
action: AgentAction
confidence: float
target_weight: float
feasible_actions: List[AgentAction]
utilities: UtilityMatrix
department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict)
department_votes: Dict[str, float] = field(default_factory=dict)
requires_review: bool = False
rounds: List[RoundSummary] = field(default_factory=list)
risk_assessment: Optional[RiskAssessment] = None
belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict)
belief_revision: Optional[BeliefRevisionResult] = None
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
utilities: UtilityMatrix = {}
for action in ACTIONS:
utilities[action] = {}
for agent in agents:
score = _clamp(agent.score(context, action))
utilities[action][agent.name] = score
return utilities
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
feas: List[AgentAction] = []
for action in ACTIONS:
if all(agent.feasible(context, action) for agent in agents):
feas.append(action)
return feas
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
best_action = AgentAction.HOLD
best_score = float("-inf")
for action, agent_scores in utilities.items():
if action not in utilities:
continue
log_product = 0.0
valid = True
for agent_name, score in agent_scores.items():
w = weights.get(agent_name, 0.0)
if w == 0:
continue
gap = score - disagreement.get(agent_name, 0.0)
if gap <= 0:
valid = False
break
log_product += w * log(gap)
if not valid:
continue
if log_product > best_score:
best_score = log_product
best_action = action
if best_score == float("-inf"):
return AgentAction.HOLD, 0.0
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
scores: Dict[AgentAction, float] = {}
for action, agent_scores in utilities.items():
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
best_action = max(scores, key=scores.get)
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
total = sum(weights.values())
if total <= 0:
return 0.0
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
return weighted / total
def target_weight_for_action(action: AgentAction) -> float:
mapping = {
AgentAction.SELL: -1.0,
AgentAction.HOLD: 0.0,
AgentAction.BUY_S: 0.01,
AgentAction.BUY_M: 0.02,
AgentAction.BUY_L: 0.03,
}
return mapping[action]
def decide(
context: AgentContext,
agents: Iterable[Agent],
weights: Mapping[str, float],
method: str = "nash",
department_manager: Optional[DepartmentManager] = None,
department_context: Optional[DepartmentContext] = None,
) -> Decision:
agent_list = list(agents)
utilities = compute_utilities(agent_list, context)
feas_actions = feasible_actions(agent_list, context)
if not feas_actions:
return Decision(
action=AgentAction.HOLD,
confidence=0.0,
target_weight=0.0,
feasible_actions=[],
utilities=utilities,
)
raw_weights = dict(weights)
department_decisions: Dict[str, DepartmentDecision] = {}
department_votes: Dict[str, float] = {}
host = ProtocolHost()
host_trace = host.bootstrap_trace(
session_id=f"{context.ts_code}:{context.trade_date}",
ts_code=context.ts_code,
trade_date=context.trade_date,
)
department_round: Optional[RoundSummary] = None
risk_round: Optional[RoundSummary] = None
execution_round: Optional[RoundSummary] = None
belief_updates: Dict[str, BeliefUpdate] = {}
if department_manager:
dept_context = department_context
if dept_context is None:
dept_context = DepartmentContext(
ts_code=context.ts_code,
trade_date=context.trade_date,
features=dict(context.features),
market_snapshot=dict(getattr(context, "market_snapshot", {}) or {}),
raw=dict(getattr(context, "raw", {}) or {}),
)
department_decisions = department_manager.evaluate(dept_context)
if department_decisions:
department_round = host.start_round(
host_trace,
agenda="department_consensus",
structure=GameStructure.REPEATED,
)
for code, decision in department_decisions.items():
agent_key = f"dept_{code}"
dept_agent = department_manager.agents.get(code)
weight = dept_agent.settings.weight if dept_agent else 1.0
raw_weights[agent_key] = weight
scores = _department_scores(decision)
for action in ACTIONS:
utilities.setdefault(action, {})[agent_key] = scores[action]
bucket = _department_vote_bucket(decision.action)
if bucket:
department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence
if department_round:
message = _department_message(code, decision)
host.handle_message(department_round, message)
belief_updates[code] = BeliefUpdate(
belief={
"action": decision.action.value,
"confidence": decision.confidence,
"signals": decision.signals,
},
rationale=decision.summary,
)
filtered_utilities = {action: utilities[action] for action in feas_actions}
hold_scores = utilities.get(AgentAction.HOLD, {})
norm_weights = weight_map(raw_weights)
if method == "vote":
action, confidence = vote(filtered_utilities, norm_weights)
else:
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
if action not in feas_actions:
action, confidence = vote(filtered_utilities, norm_weights)
weight = target_weight_for_action(action)
conflict_flag = _department_conflict_flag(department_votes)
risk_agent = _find_risk_agent(agent_list)
risk_assessment = _evaluate_risk(
context,
action,
department_votes,
conflict_flag,
risk_agent,
)
requires_review = risk_assessment.status != "ok"
if department_round:
department_round.notes.setdefault("department_votes", dict(department_votes))
department_round.outcome = action.value
host.finalize_round(department_round)
if requires_review:
risk_round = host.ensure_round(
host_trace,
agenda="risk_review",
structure=GameStructure.CUSTOM,
)
review_message = DialogueMessage(
sender="risk_guard",
role=DialogueRole.RISK,
message_type=MessageType.COUNTER,
content=_risk_review_message(risk_assessment.reason),
confidence=1.0,
references=list(department_votes.keys()),
annotations={
"department_votes": dict(department_votes),
"risk_reason": risk_assessment.reason,
"recommended_action": (
risk_assessment.recommended_action.value
if risk_assessment.recommended_action
else None
),
"notes": dict(risk_assessment.notes),
},
)
host.handle_message(risk_round, review_message)
risk_round.notes.setdefault("status", risk_assessment.status)
risk_round.notes.setdefault("reason", risk_assessment.reason)
if risk_assessment.recommended_action:
risk_round.notes.setdefault(
"recommended_action",
risk_assessment.recommended_action.value,
)
risk_round.outcome = "REVIEW"
host.finalize_round(risk_round)
belief_updates["risk_guard"] = BeliefUpdate(
belief={
"status": risk_assessment.status,
"reason": risk_assessment.reason,
"recommended_action": (
risk_assessment.recommended_action.value
if risk_assessment.recommended_action
else None
),
},
)
execution_round = host.ensure_round(
host_trace,
agenda="execution_summary",
structure=GameStructure.REPEATED,
)
exec_action = action
exec_weight = weight
exec_status = "normal"
if requires_review and risk_assessment.recommended_action:
exec_action = risk_assessment.recommended_action
exec_status = "risk_adjusted"
exec_weight = target_weight_for_action(exec_action)
execution_message = DialogueMessage(
sender="execution_engine",
role=DialogueRole.EXECUTION,
message_type=MessageType.DIRECTIVE,
content=f"执行操作 {exec_action.value}",
confidence=1.0,
annotations={
"target_weight": exec_weight,
"requires_review": requires_review,
"execution_status": exec_status,
},
)
host.handle_message(execution_round, execution_message)
execution_round.outcome = exec_action.value
execution_round.notes.setdefault("execution_status", exec_status)
if exec_action is not action:
execution_round.notes.setdefault("original_action", action.value)
belief_updates["execution"] = BeliefUpdate(
belief={
"execution_status": exec_status,
"action": exec_action.value,
"target_weight": exec_weight,
},
)
host.finalize_round(execution_round)
host.close(host_trace)
rounds = host_trace.rounds if host_trace.rounds else _build_round_summaries(
department_decisions,
action,
department_votes,
)
belief_revision = revise_beliefs(belief_updates, exec_action)
execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value)
execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence)
if belief_revision.conflicts:
execution_round.notes.setdefault("conflicts", belief_revision.conflicts)
if belief_revision.notes:
execution_round.notes.setdefault("belief_notes", belief_revision.notes)
return Decision(
action=action,
confidence=confidence,
target_weight=weight,
feasible_actions=feas_actions,
utilities=utilities,
department_decisions=department_decisions,
department_votes=department_votes,
requires_review=requires_review,
rounds=rounds,
risk_assessment=risk_assessment,
belief_updates=belief_updates,
belief_revision=belief_revision,
)
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
conf = _clamp(decision.confidence)
scores: Dict[AgentAction, float] = {action: 0.2 for action in ACTIONS}
if decision.action is AgentAction.SELL:
scores[AgentAction.SELL] = 0.7 + 0.3 * conf
scores[AgentAction.HOLD] = 0.4 * (1 - conf)
scores[AgentAction.BUY_S] = 0.2 * (1 - conf)
scores[AgentAction.BUY_M] = 0.15 * (1 - conf)
scores[AgentAction.BUY_L] = 0.1 * (1 - conf)
elif decision.action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
for action in (AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L):
if action is decision.action:
scores[action] = 0.6 + 0.4 * conf
else:
scores[action] = 0.3 + 0.3 * conf
scores[AgentAction.HOLD] = 0.3 * (1 - conf) + 0.25
scores[AgentAction.SELL] = 0.15 * (1 - conf)
else: # HOLD 或未知
scores[AgentAction.HOLD] = 0.6 + 0.4 * conf
scores[AgentAction.SELL] = 0.3 * (1 - conf)
scores[AgentAction.BUY_S] = 0.3 * (1 - conf)
scores[AgentAction.BUY_M] = 0.3 * (1 - conf)
scores[AgentAction.BUY_L] = 0.3 * (1 - conf)
return {action: _clamp(score) for action, score in scores.items()}
def _department_vote_bucket(action: AgentAction) -> str:
if action is AgentAction.SELL:
return "sell"
if action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
return "buy"
if action is AgentAction.HOLD:
return "hold"
return ""
def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
if not votes:
return False
total = sum(votes.values())
if total <= 0:
return True
top = max(votes.values())
if top < total * 0.45:
return True
if len(votes) > 1:
sorted_votes = sorted(votes.values(), reverse=True)
if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1:
return True
return False
def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage:
content = decision.summary or decision.raw_response or decision.action.value
references = decision.signals or []
annotations: Dict[str, object] = {
"risks": decision.risks,
"supplements": decision.supplements,
}
if decision.dialogue:
annotations["dialogue"] = decision.dialogue
if decision.telemetry:
annotations["telemetry"] = decision.telemetry
return DialogueMessage(
sender=code,
role=DialogueRole.PREDICTION,
message_type=MessageType.DECISION,
content=content,
confidence=decision.confidence,
references=references,
annotations=annotations,
)
def _evaluate_risk(
context: AgentContext,
action: AgentAction,
department_votes: Mapping[str, float],
conflict_flag: bool,
risk_agent: Optional[RiskAgent],
) -> RiskAssessment:
external_alerts = []
if getattr(context, "raw", None):
alerts = context.raw.get("risk_alerts", [])
if alerts:
external_alerts = list(alerts)
if risk_agent:
recommendation = risk_agent.assess(context, action, conflict_flag)
notes = dict(recommendation.notes)
notes.setdefault("department_votes", dict(department_votes))
if external_alerts:
notes.setdefault("external_alerts", external_alerts)
if recommendation.status == "ok":
recommendation = RiskRecommendation(
status="pending_review",
reason="external_alert",
recommended_action=recommendation.recommended_action or AgentAction.HOLD,
notes=notes,
)
else:
recommendation.notes = notes
return RiskAssessment(
status=recommendation.status,
reason=recommendation.reason,
recommended_action=recommendation.recommended_action,
notes=recommendation.notes,
)
notes: Dict[str, object] = {
"conflict": conflict_flag,
"department_votes": dict(department_votes),
}
if external_alerts:
notes["external_alerts"] = external_alerts
return RiskAssessment(
status="pending_review",
reason="external_alert",
recommended_action=AgentAction.HOLD,
notes=notes,
)
if conflict_flag:
return RiskAssessment(
status="pending_review",
reason="conflict_threshold",
notes=notes,
)
return RiskAssessment(status="ok", reason="clear", notes=notes)
def _find_risk_agent(agents: Iterable[Agent]) -> Optional[RiskAgent]:
for agent in agents:
if isinstance(agent, RiskAgent):
return agent
return None
def _risk_review_message(reason: str) -> str:
mapping = {
"conflict_threshold": "部门意见分歧,触发风险复核",
"suspended": "标的停牌,需冻结执行",
"limit_up": "标的涨停,执行需调整",
"position_limit": "仓位限制已触发,需调整目标",
"risk_penalty_extreme": "风险评分极高,建议暂停加仓",
"risk_penalty_high": "风险评分偏高,建议复核",
"external_alert": "外部风险告警触发复核",
}
return mapping.get(reason, "触发风险复核,需人工确认")
def _build_round_summaries(
department_decisions: Mapping[str, DepartmentDecision],
final_action: AgentAction,
department_votes: Mapping[str, float],
) -> List[RoundSummary]:
if not department_decisions:
return []
messages: List[DialogueMessage] = []
for code, decision in department_decisions.items():
content = decision.summary or decision.raw_response or decision.action.value
references = decision.signals or []
annotations: Dict[str, object] = {
"risks": decision.risks,
"supplements": decision.supplements,
}
if decision.dialogue:
annotations["dialogue"] = decision.dialogue
if decision.telemetry:
annotations["telemetry"] = decision.telemetry
message = DialogueMessage(
sender=code,
role=DialogueRole.PREDICTION,
message_type=MessageType.DECISION,
content=content,
confidence=decision.confidence,
references=references,
annotations=annotations,
)
messages.append(message)
notes: Dict[str, object] = {
"department_votes": dict(department_votes),
}
summary = RoundSummary(
index=0,
agenda="department_consensus",
structure=GameStructure.REPEATED,
resolved=True,
outcome=final_action.value,
messages=messages,
notes=notes,
)
return [summary]