llm-quant/app/agents/game.py
2025-09-28 09:39:48 +08:00

234 lines
8.3 KiB
Python

"""Multi-agent decision game implementation."""
from __future__ import annotations
from dataclasses import dataclass, field
from math import log
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map
ACTIONS: Tuple[AgentAction, ...] = (
AgentAction.SELL,
AgentAction.HOLD,
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
)
def _clamp(value: float) -> float:
return max(0.0, min(1.0, value))
@dataclass
class Decision:
action: AgentAction
confidence: float
target_weight: float
feasible_actions: List[AgentAction]
utilities: UtilityMatrix
department_decisions: Dict[str, DepartmentDecision] = field(default_factory=dict)
department_votes: Dict[str, float] = field(default_factory=dict)
requires_review: bool = False
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
utilities: UtilityMatrix = {}
for action in ACTIONS:
utilities[action] = {}
for agent in agents:
score = _clamp(agent.score(context, action))
utilities[action][agent.name] = score
return utilities
def feasible_actions(agents: Iterable[Agent], context: AgentContext) -> List[AgentAction]:
feas: List[AgentAction] = []
for action in ACTIONS:
if all(agent.feasible(context, action) for agent in agents):
feas.append(action)
return feas
def nash_bargain(utilities: UtilityMatrix, weights: Mapping[str, float], disagreement: Mapping[str, float]) -> Tuple[AgentAction, float]:
best_action = AgentAction.HOLD
best_score = float("-inf")
for action, agent_scores in utilities.items():
if action not in utilities:
continue
log_product = 0.0
valid = True
for agent_name, score in agent_scores.items():
w = weights.get(agent_name, 0.0)
if w == 0:
continue
gap = score - disagreement.get(agent_name, 0.0)
if gap <= 0:
valid = False
break
log_product += w * log(gap)
if not valid:
continue
if log_product > best_score:
best_score = log_product
best_action = action
if best_score == float("-inf"):
return AgentAction.HOLD, 0.0
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def vote(utilities: UtilityMatrix, weights: Mapping[str, float]) -> Tuple[AgentAction, float]:
scores: Dict[AgentAction, float] = {}
for action, agent_scores in utilities.items():
scores[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
best_action = max(scores, key=scores.get)
confidence = _aggregate_confidence(utilities[best_action], weights)
return best_action, confidence
def _aggregate_confidence(agent_scores: Mapping[str, float], weights: Mapping[str, float]) -> float:
total = sum(weights.values())
if total <= 0:
return 0.0
weighted = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
return weighted / total
def target_weight_for_action(action: AgentAction) -> float:
mapping = {
AgentAction.SELL: -1.0,
AgentAction.HOLD: 0.0,
AgentAction.BUY_S: 0.01,
AgentAction.BUY_M: 0.02,
AgentAction.BUY_L: 0.03,
}
return mapping[action]
def decide(
context: AgentContext,
agents: Iterable[Agent],
weights: Mapping[str, float],
method: str = "nash",
department_manager: Optional[DepartmentManager] = None,
department_context: Optional[DepartmentContext] = None,
) -> Decision:
agent_list = list(agents)
utilities = compute_utilities(agent_list, context)
feas_actions = feasible_actions(agent_list, context)
if not feas_actions:
return Decision(
action=AgentAction.HOLD,
confidence=0.0,
target_weight=0.0,
feasible_actions=[],
utilities=utilities,
)
raw_weights = dict(weights)
department_decisions: Dict[str, DepartmentDecision] = {}
department_votes: Dict[str, float] = {}
if department_manager:
dept_context = department_context
if dept_context is None:
dept_context = DepartmentContext(
ts_code=context.ts_code,
trade_date=context.trade_date,
features=dict(context.features),
market_snapshot=dict(getattr(context, "market_snapshot", {}) or {}),
raw=dict(getattr(context, "raw", {}) or {}),
)
department_decisions = department_manager.evaluate(dept_context)
for code, decision in department_decisions.items():
agent_key = f"dept_{code}"
dept_agent = department_manager.agents.get(code)
weight = dept_agent.settings.weight if dept_agent else 1.0
raw_weights[agent_key] = weight
scores = _department_scores(decision)
for action in ACTIONS:
utilities.setdefault(action, {})[agent_key] = scores[action]
bucket = _department_vote_bucket(decision.action)
if bucket:
department_votes[bucket] = department_votes.get(bucket, 0.0) + weight * decision.confidence
filtered_utilities = {action: utilities[action] for action in feas_actions}
hold_scores = utilities.get(AgentAction.HOLD, {})
norm_weights = weight_map(raw_weights)
if method == "vote":
action, confidence = vote(filtered_utilities, norm_weights)
else:
action, confidence = nash_bargain(filtered_utilities, norm_weights, hold_scores)
if action not in feas_actions:
action, confidence = vote(filtered_utilities, norm_weights)
weight = target_weight_for_action(action)
requires_review = _department_conflict_flag(department_votes)
return Decision(
action=action,
confidence=confidence,
target_weight=weight,
feasible_actions=feas_actions,
utilities=utilities,
department_decisions=department_decisions,
department_votes=department_votes,
requires_review=requires_review,
)
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
conf = _clamp(decision.confidence)
scores: Dict[AgentAction, float] = {action: 0.2 for action in ACTIONS}
if decision.action is AgentAction.SELL:
scores[AgentAction.SELL] = 0.7 + 0.3 * conf
scores[AgentAction.HOLD] = 0.4 * (1 - conf)
scores[AgentAction.BUY_S] = 0.2 * (1 - conf)
scores[AgentAction.BUY_M] = 0.15 * (1 - conf)
scores[AgentAction.BUY_L] = 0.1 * (1 - conf)
elif decision.action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
for action in (AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L):
if action is decision.action:
scores[action] = 0.6 + 0.4 * conf
else:
scores[action] = 0.3 + 0.3 * conf
scores[AgentAction.HOLD] = 0.3 * (1 - conf) + 0.25
scores[AgentAction.SELL] = 0.15 * (1 - conf)
else: # HOLD 或未知
scores[AgentAction.HOLD] = 0.6 + 0.4 * conf
scores[AgentAction.SELL] = 0.3 * (1 - conf)
scores[AgentAction.BUY_S] = 0.3 * (1 - conf)
scores[AgentAction.BUY_M] = 0.3 * (1 - conf)
scores[AgentAction.BUY_L] = 0.3 * (1 - conf)
return {action: _clamp(score) for action, score in scores.items()}
def _department_vote_bucket(action: AgentAction) -> str:
if action is AgentAction.SELL:
return "sell"
if action in {AgentAction.BUY_S, AgentAction.BUY_M, AgentAction.BUY_L}:
return "buy"
if action is AgentAction.HOLD:
return "hold"
return ""
def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
if not votes:
return False
total = sum(votes.values())
if total <= 0:
return True
top = max(votes.values())
if top < total * 0.45:
return True
if len(votes) > 1:
sorted_votes = sorted(votes.values(), reverse=True)
if len(sorted_votes) >= 2 and (sorted_votes[0] - sorted_votes[1]) < total * 0.1:
return True
return False