234 lines
8.3 KiB
Python
234 lines
8.3 KiB
Python
"""Multi-agent decision game implementation."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from math import log
|
|
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
|
|
|
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
|
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
|
|
from .registry import weight_map
|
|
|
|
|
|
ACTIONS: Tuple[AgentAction, ...] = (
|
|
AgentAction.SELL,
|
|
AgentAction.HOLD,
|
|
AgentAction.BUY_S,
|
|
AgentAction.BUY_M,
|
|
AgentAction.BUY_L,
|
|
)
|
|
|
|
|
|
def _clamp(value: float) -> float:
|
|
return max(0.0, min(1.0, value))
|
|
|
|
|
|
@dataclass
|
|
class Decision:
|
|
action: AgentAction
|
|
confidence: float
|
|
target_weight: float
|
|
feasible_actions: List[AgentAction]
|
|
utilities: UtilityMatrix
|
|
department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict)
|
|
department_votes: Dict[str, float] = field(default_factory=dict)
|
|
requires_review: bool = False
|
|
|
|
|
|
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
|
|
utilities: UtilityMatrix = {}
|
|
for action in ACTIONS:
|
|
utilities[action] = {}
|
|
for agent in agents:
|
|
score = _clamp(agent.score(context, action))
|
|
utilities[action][agent.name] = score
|
|
return utilities
|
|
|
|
|
|
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
|
|
feas: List[AgentAction] = []
|
|
for action in ACTIONS:
|
|
if all(agent.feasible(context, action) for agent in agents):
|
|
feas.append(action)
|
|
return feas
|
|
|
|
|
|
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
|
best_action = AgentAction.HOLD
|
|
best_score = float("-inf")
|
|
for action, agent_scores in utilities.items():
|
|
if action not in utilities:
|
|
continue
|
|
log_product = 0.0
|
|
valid = True
|
|
for agent_name, score in agent_scores.items():
|
|
w = weights.get(agent_name, 0.0)
|
|
if w == 0:
|
|
continue
|
|
gap = score - disagreement.get(agent_name, 0.0)
|
|
if gap <= 0:
|
|
valid = False
|
|
break
|
|
log_product += w * log(gap)
|
|
if not valid:
|
|
continue
|
|
if log_product > best_score:
|
|
best_score = log_product
|
|
best_action = action
|
|
if best_score == float("-inf"):
|
|
return AgentAction.HOLD, 0.0
|
|
confidence = _aggregate_confidence(utilities[best_action], weights)
|
|
return best_action, confidence
|
|
|
|
|
|
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
|
|
scores: Dict[AgentAction, float] = {}
|
|
for action, agent_scores in utilities.items():
|
|
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
|
best_action = max(scores, key=scores.get)
|
|
confidence = _aggregate_confidence(utilities[best_action], weights)
|
|
return best_action, confidence
|
|
|
|
|
|
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
|
|
total = sum(weights.values())
|
|
if total <= 0:
|
|
return 0.0
|
|
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
|
return weighted / total
|
|
|
|
|
|
def target_weight_for_action(action: AgentAction) -> float:
|
|
mapping = {
|
|
AgentAction.SELL: -1.0,
|
|
AgentAction.HOLD: 0.0,
|
|
AgentAction.BUY_S: 0.01,
|
|
AgentAction.BUY_M: 0.02,
|
|
AgentAction.BUY_L: 0.03,
|
|
}
|
|
return mapping[action]
|
|
|
|
|
|
def decide(
|
|
context: AgentContext,
|
|
agents: Iterable[Agent],
|
|
weights: Mapping[str, float],
|
|
method: str = "nash",
|
|
department_manager: Optional[DepartmentManager] = None,
|
|
department_context: Optional[DepartmentContext] = None,
|
|
) -> Decision:
|
|
agent_list = list(agents)
|
|
utilities = compute_utilities(agent_list, context)
|
|
feas_actions = feasible_actions(agent_list, context)
|
|
if not feas_actions:
|
|
return Decision(
|
|
action=AgentAction.HOLD,
|
|
confidence=0.0,
|
|
target_weight=0.0,
|
|
feasible_actions=[],
|
|
utilities=utilities,
|
|
)
|
|
|
|
raw_weights = dict(weights)
|
|
department_decisions: Dict[str, DepartmentDecision] = {}
|
|
department_votes: Dict[str, float] = {}
|
|
|
|
if department_manager:
|
|
dept_context = department_context
|
|
if dept_context is None:
|
|
dept_context = DepartmentContext(
|
|
ts_code=context.ts_code,
|
|
trade_date=context.trade_date,
|
|
features=dict(context.features),
|
|
market_snapshot=dict(getattr(context, "market_snapshot", {}) or {}),
|
|
raw=dict(getattr(context, "raw", {}) or {}),
|
|
)
|
|
department_decisions = department_manager.evaluate(dept_context)
|
|
for code, decision in department_decisions.items():
|
|
agent_key = f"dept_{code}"
|
|
dept_agent = department_manager.agents.get(code)
|
|
weight = dept_agent.settings.weight if dept_agent else 1.0
|
|
raw_weights[agent_key] = weight
|
|
scores = _department_scores(decision)
|
|
for action in ACTIONS:
|
|
utilities.setdefault(action, {})[agent_key] = scores[action]
|
|
bucket = _department_vote_bucket(decision.action)
|
|
if bucket:
|
|
department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence
|
|
|
|
filtered_utilities = {action: utilities[action] for action in feas_actions}
|
|
hold_scores = utilities.get(AgentAction.HOLD, {})
|
|
norm_weights = weight_map(raw_weights)
|
|
|
|
if method == "vote":
|
|
action, confidence = vote(filtered_utilities, norm_weights)
|
|
else:
|
|
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
|
|
if action not in feas_actions:
|
|
action, confidence = vote(filtered_utilities, norm_weights)
|
|
|
|
weight = target_weight_for_action(action)
|
|
requires_review = _department_conflict_flag(department_votes)
|
|
return Decision(
|
|
action=action,
|
|
confidence=confidence,
|
|
target_weight=weight,
|
|
feasible_actions=feas_actions,
|
|
utilities=utilities,
|
|
department_decisions=department_decisions,
|
|
department_votes=department_votes,
|
|
requires_review=requires_review,
|
|
)
|
|
|
|
|
|
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
|
|
conf = _clamp(decision.confidence)
|
|
scores: Dict[AgentAction, float] = {action: 0.2 for action in ACTIONS}
|
|
if decision.action is AgentAction.SELL:
|
|
scores[AgentAction.SELL] = 0.7 + 0.3 * conf
|
|
scores[AgentAction.HOLD] = 0.4 * (1 - conf)
|
|
scores[AgentAction.BUY_S] = 0.2 * (1 - conf)
|
|
scores[AgentAction.BUY_M] = 0.15 * (1 - conf)
|
|
scores[AgentAction.BUY_L] = 0.1 * (1 - conf)
|
|
elif decision.action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
|
|
for action in (AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L):
|
|
if action is decision.action:
|
|
scores[action] = 0.6 + 0.4 * conf
|
|
else:
|
|
scores[action] = 0.3 + 0.3 * conf
|
|
scores[AgentAction.HOLD] = 0.3 * (1 - conf) + 0.25
|
|
scores[AgentAction.SELL] = 0.15 * (1 - conf)
|
|
else: # HOLD 或未知
|
|
scores[AgentAction.HOLD] = 0.6 + 0.4 * conf
|
|
scores[AgentAction.SELL] = 0.3 * (1 - conf)
|
|
scores[AgentAction.BUY_S] = 0.3 * (1 - conf)
|
|
scores[AgentAction.BUY_M] = 0.3 * (1 - conf)
|
|
scores[AgentAction.BUY_L] = 0.3 * (1 - conf)
|
|
return {action: _clamp(score) for action, score in scores.items()}
|
|
|
|
|
|
def _department_vote_bucket(action: AgentAction) -> str:
|
|
if action is AgentAction.SELL:
|
|
return "sell"
|
|
if action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
|
|
return "buy"
|
|
if action is AgentAction.HOLD:
|
|
return "hold"
|
|
return ""
|
|
|
|
|
|
def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
|
|
if not votes:
|
|
return False
|
|
total = sum(votes.values())
|
|
if total <= 0:
|
|
return True
|
|
top = max(votes.values())
|
|
if top < total * 0.45:
|
|
return True
|
|
if len(votes) > 1:
|
|
sorted_votes = sorted(votes.values(), reverse=True)
|
|
if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1:
|
|
return True
|
|
return False
|