add belief revision and data gaps tracking to game engine

This commit is contained in:
sam 2025-10-07 21:42:11 +08:00
parent 3c15d443d3
commit 80b96497fd
10 changed files with 348 additions and 6 deletions

60
app/agents/beliefs.py Normal file
View File

@ -0,0 +1,60 @@
"""Belief revision helpers for multi-round negotiation."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional
from .base import AgentAction
@dataclass
class BeliefRevisionResult:
consensus_action: Optional[AgentAction]
consensus_confidence: float
conflicts: List[str]
notes: Dict[str, object]
def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: AgentAction) -> BeliefRevisionResult:
action_votes: Dict[AgentAction, int] = {}
reasons: Dict[str, object] = {}
for agent, update in belief_updates.items():
belief = getattr(update, "belief", {}) or {}
action_value = belief.get("action") if isinstance(belief, dict) else None
try:
action = AgentAction(action_value) if action_value else None
except ValueError:
action = None
if action:
action_votes[action] = action_votes.get(action, 0) + 1
reasons[agent] = belief
consensus_action = None
consensus_confidence = 0.0
conflicts: List[str] = []
if action_votes:
total_votes = sum(action_votes.values())
consensus_action = max(action_votes.items(), key=lambda kv: kv[1])[0]
consensus_confidence = action_votes[consensus_action] / total_votes if total_votes else 0.0
if len(action_votes) > 1:
conflicts = [action.name for action in action_votes.keys() if action is not consensus_action]
if consensus_action is None:
consensus_action = default_action
notes = {
"votes": {action.value: count for action, count in action_votes.items()},
"reasons": reasons,
}
return BeliefRevisionResult(
consensus_action=consensus_action,
consensus_confidence=consensus_confidence,
conflicts=conflicts,
notes=notes,
)
# avoid circular import typing
from typing import TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .game import BeliefUpdate

View File

@ -8,6 +8,7 @@ from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map
from .beliefs import BeliefRevisionResult, revise_beliefs
from .risk import RiskAgent, RiskRecommendation
from .protocols import (
DialogueMessage,
@ -69,6 +70,7 @@ class Decision:
rounds: List[RoundSummary] = field(default_factory=list)
risk_assessment: Optional[RiskAssessment] = None
belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict)
belief_revision: Optional[BeliefRevisionResult] = None
def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix:
@ -336,6 +338,13 @@ def decide(
action,
department_votes,
)
belief_revision = revise_beliefs(belief_updates, exec_action)
execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value)
execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence)
if belief_revision.conflicts:
execution_round.notes.setdefault("conflicts", belief_revision.conflicts)
if belief_revision.notes:
execution_round.notes.setdefault("belief_notes", belief_revision.notes)
return Decision(
action=action,
confidence=confidence,
@ -348,6 +357,7 @@ def decide(
rounds=rounds,
risk_assessment=risk_assessment,
belief_updates=belief_updates,
belief_revision=belief_revision,
)

51
app/agents/scopes.py Normal file
View File

@ -0,0 +1,51 @@
"""Scope mappings for different game structures."""
from __future__ import annotations
from typing import Dict, Iterable, Set
from .protocols import GameStructure
_GAME_SCOPE_MAP: Dict[GameStructure, Set[str]] = {
GameStructure.REPEATED: {
"daily.close",
"daily.open",
"daily.high",
"daily.low",
"daily_basic.turnover_rate",
"daily_basic.turnover_rate_f",
},
GameStructure.SIGNALING: {
"daily.close",
"daily.high",
"daily_basic.turnover_rate",
"daily_basic.volume_ratio",
"factors.sent_momentum",
"factors.sent_market",
},
GameStructure.BAYESIAN: {
"daily.close",
"daily_basic.turnover_rate",
"factors.mom_20",
"factors.mom_60",
"factors.val_multiscore",
"factors.sent_divergence",
},
GameStructure.CUSTOM: {
"factors.risk_penalty",
"factors.turn_20",
"factors.volat_20",
"daily_basic.turnover_rate",
},
}
def scope_for_structures(structures: Iterable[GameStructure]) -> Set[str]:
scope: Set[str] = set()
for structure in structures:
scope.update(_GAME_SCOPE_MAP.get(structure, set()))
return scope
def registered_structures() -> Dict[GameStructure, Set[str]]:
return {key: set(values) for key, values in _GAME_SCOPE_MAP.items()}

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import json
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import date
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
@ -9,7 +10,8 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from app.agents.base import AgentAction, AgentContext
from app.agents.departments import DepartmentManager
from app.agents.game import Decision, decide, target_weight_for_action
from app.agents.protocols import round_to_dict
from app.agents.protocols import GameStructure, round_to_dict
from app.agents.scopes import scope_for_structures
from app.llm.metrics import record_decision as metrics_record_decision
from app.agents.registry import default_agents
from app.data.schema import initialize_database
@ -56,6 +58,7 @@ class BtConfig:
universe: List[str]
params: Dict[str, float]
method: str = "nash"
game_structures: Optional[List[GameStructure]] = None
@dataclass
@ -72,6 +75,7 @@ class BacktestResult:
nav_series: List[Dict[str, float]] = field(default_factory=list)
trades: List[Dict[str, str]] = field(default_factory=list)
risk_events: List[Dict[str, object]] = field(default_factory=list)
data_gaps: List[Dict[str, object]] = field(default_factory=list)
@dataclass
@ -137,7 +141,19 @@ class BacktestEngine:
"factors.volat_20",
"factors.turn_20",
}
self.required_fields = sorted(base_scope | department_scope)
selected_structures = (
cfg.game_structures
if cfg.game_structures
else [
GameStructure.REPEATED,
GameStructure.SIGNALING,
GameStructure.BAYESIAN,
GameStructure.CUSTOM,
]
)
self.game_structures = list(dict.fromkeys(selected_structures))
structure_scope = scope_for_structures(self.game_structures)
self.required_fields = sorted(base_scope | department_scope | structure_scope)
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]:
"""Load per-stock feature vectors and context slices for the trade date."""
@ -152,6 +168,19 @@ class BacktestEngine:
self.required_fields,
auto_refresh=False # 避免回测时触发自动补数
)
missing_fields = [
field
for field in self.required_fields
if scope_values.get(field) is None
]
derived_fields: List[str] = []
if missing_fields:
LOGGER.debug(
"字段缺失,使用回退或派生数据 ts_code=%s fields=%s",
ts_code,
missing_fields,
extra=LOG_EXTRA,
)
closes = self.data_broker.fetch_series(
"daily",
@ -166,17 +195,21 @@ class BacktestEngine:
mom5 = scope_values.get("factors.mom_5")
if mom5 is None and len(close_values) >= 5:
mom5 = momentum(close_values, 5)
derived_fields.append("factors.mom_5")
mom20 = scope_values.get("factors.mom_20")
if mom20 is None and len(close_values) >= 20:
mom20 = momentum(close_values, 20)
derived_fields.append("factors.mom_20")
mom60 = scope_values.get("factors.mom_60")
if mom60 is None and len(close_values) >= 60:
mom60 = momentum(close_values, 60)
derived_fields.append("factors.mom_60")
volat20 = scope_values.get("factors.volat_20")
if volat20 is None and len(close_values) >= 2:
volat20 = volatility(close_values, 20)
derived_fields.append("factors.volat_20")
turnover_series = self.data_broker.fetch_series(
"daily_basic",
@ -191,9 +224,11 @@ class BacktestEngine:
turn20 = scope_values.get("factors.turn_20")
if turn20 is None and turnover_values:
turn20 = rolling_mean(turnover_values, 20)
derived_fields.append("factors.turn_20")
turn5 = scope_values.get("factors.turn_5")
if turn5 is None and len(turnover_values) >= 5:
turn5 = rolling_mean(turnover_values, 5)
derived_fields.append("factors.turn_5")
if mom20 is None:
mom20 = 0.0
@ -217,14 +252,24 @@ class BacktestEngine:
val_pe = scope_values.get("factors.val_pe_score")
if val_pe is None:
val_pe = _valuation_score(scope_values.get("daily_basic.pe"), scale=12.0)
derived_fields.append("factors.val_pe_score")
val_pb = scope_values.get("factors.val_pb_score")
if val_pb is None:
val_pb = _valuation_score(scope_values.get("daily_basic.pb"), scale=2.5)
derived_fields.append("factors.val_pb_score")
volume_ratio_score = scope_values.get("factors.volume_ratio_score")
if volume_ratio_score is None:
volume_ratio_score = _volume_ratio_score(scope_values.get("daily_basic.volume_ratio"))
derived_fields.append("factors.volume_ratio_score")
if derived_fields:
LOGGER.debug(
"字段派生完成 ts_code=%s derived=%s",
ts_code,
derived_fields,
extra=LOG_EXTRA,
)
sentiment_index = scope_values.get("news.sentiment_index", 0.0)
heat_score = scope_values.get("news.heat_score", 0.0)
@ -316,6 +361,8 @@ class BacktestEngine:
"close_series": closes,
"turnover_series": turnover_series,
"required_fields": self.required_fields,
"missing_fields": missing_fields,
"derived_fields": derived_fields,
}
feature_map[ts_code] = {
@ -512,6 +559,8 @@ class BacktestEngine:
price_map: Dict[str, float] = {}
decisions_map: Dict[str, Decision] = {}
feature_cache: Dict[str, Mapping[str, Any]] = {}
missing_counts: Dict[str, int] = defaultdict(int)
derived_counts: Dict[str, int] = defaultdict(int)
for ts_code, context, decision in records:
features = context.features or {}
if not isinstance(features, Mapping):
@ -520,6 +569,12 @@ class BacktestEngine:
scope_values = context.raw.get("scope_values") if context.raw else {}
if not isinstance(scope_values, Mapping):
scope_values = {}
raw_missing = context.raw.get("missing_fields") if context.raw else []
raw_derived = context.raw.get("derived_fields") if context.raw else []
for field in raw_missing or []:
missing_counts[field] += 1
for field in raw_derived or []:
derived_counts[field] += 1
price = scope_values.get("daily.close") or scope_values.get("close")
if price is None:
continue
@ -756,6 +811,15 @@ class BacktestEngine:
}
)
if missing_counts or derived_counts:
result.data_gaps.append(
{
"trade_date": trade_date_str,
"missing_fields": dict(sorted(missing_counts.items())),
"derived_fields": dict(sorted(derived_counts.items())),
}
)
market_value = 0.0
unrealized_pnl = 0.0
for ts_code, qty in state.holdings.items():
@ -1129,6 +1193,19 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None:
)
summary_payload["risk_breakdown"] = breakdown
if getattr(result, "data_gaps", None):
missing_total: Dict[str, int] = defaultdict(int)
derived_total: Dict[str, int] = defaultdict(int)
for gap in result.data_gaps:
for field, count in (gap.get("missing_fields") or {}).items():
missing_total[field] += int(count)
for field, count in (gap.get("derived_fields") or {}).items():
derived_total[field] += int(count)
if missing_total:
summary_payload["missing_field_counts"] = dict(missing_total)
if derived_total:
summary_payload["derived_field_counts"] = dict(derived_total)
cfg_payload = {
"id": cfg.id,
"name": cfg.name,

View File

@ -17,6 +17,7 @@ import streamlit as st
from app.agents.base import AgentContext
from app.agents.game import Decision
from app.agents.registry import default_agents
from app.agents.protocols import GameStructure
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo
@ -67,6 +68,16 @@ def render_backtest_review() -> None:
target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target")
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
structure_options = [item.value for item in GameStructure]
selected_structure_values = st.multiselect(
"选择博弈框架",
structure_options,
default=structure_options,
key="bt_game_structures",
)
if not selected_structure_values:
selected_structure_values = [GameStructure.REPEATED.value]
selected_structures = [GameStructure(value) for value in selected_structure_values]
LOGGER.debug(
"当前回测表单输入start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
start_date,
@ -148,6 +159,7 @@ def render_backtest_review() -> None:
"stop": stop,
"hold_days": int(hold_days),
},
game_structures=selected_structures,
)
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
LOGGER.info(
@ -287,6 +299,12 @@ def render_backtest_review() -> None:
"风险分布": json.dumps(summary.get("risk_breakdown"), ensure_ascii=False)
if summary.get("risk_breakdown")
else None,
"缺失字段": json.dumps(summary.get("missing_field_counts"), ensure_ascii=False)
if summary.get("missing_field_counts")
else None,
"派生字段": json.dumps(summary.get("derived_field_counts"), ensure_ascii=False)
if summary.get("derived_field_counts")
else None,
}
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
if metrics_rows:
@ -658,6 +676,7 @@ def render_backtest_review() -> None:
"hold_days": int(hold_days),
},
method=app_cfg.decision_method,
game_structures=selected_structures,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
@ -934,6 +953,7 @@ def render_backtest_review() -> None:
"hold_days": int(hold_days),
},
method=app_cfg.decision_method,
game_structures=selected_structures,
)
env = DecisionEnv(
bt_config=bt_cfg_env,

View File

@ -74,6 +74,12 @@ class _RefreshCoordinator:
normalized = parsed_date.strftime("%Y%m%d")
tables = self._collect_tables(fields)
if tables and self.broker.check_data_availability(normalized, tables):
LOGGER.debug(
"触发近端数据刷新 trade_date=%s tables=%s",
normalized,
sorted(tables),
extra=LOG_EXTRA,
)
self.broker._trigger_background_refresh(normalized)
def ensure_for_series(self, end_date: str, table: str) -> None:
@ -82,6 +88,12 @@ class _RefreshCoordinator:
return
normalized = parsed_date.strftime("%Y%m%d")
if self.broker.check_data_availability(normalized, {table}):
LOGGER.debug(
"触发序列刷新 trade_date=%s table=%s",
normalized,
table,
extra=LOG_EXTRA,
)
self.broker._trigger_background_refresh(normalized)
def _collect_tables(self, fields: Iterable[str]) -> Set[str]:

View File

@ -55,7 +55,7 @@
### 3.3 代码改造计划(多轮博弈适配)
1. 架构基线评估
- ⏳ 绘制代理/部门/回测调用图,补充日志字段(缺数告警、补数来源、议程标识)并形成诊断报告
- ✅ 绘制代理/部门/回测调用图并补充日志字段(见 docs/architecture_call_graph.md
- ✅ 定义多轮博弈上下文结构(消息历史、信念状态、引用证据),输出数据类与通信协议草稿。
- ✅ 在 `app/agents/protocols.py` 基础上补充主持/执行状态管理,实现 `DialogueTrace` 与部门上下文的对接路径。
- ✅ 扩展 `Decision.rounds``RoundSummary` 采集策略,用于串联部门结论与多轮议程结果。
@ -67,15 +67,16 @@
- ✅ 将风险建议与执行回合对齐,执行阶段识别 `risk_adjusted` 并记录原始动作。
2. 数据与因子重构
- ✅ 拆分 `DataBroker` 查询层(`BrokerQueryEngine`),补数逻辑独立于查询管道。
- ⏳ 按主题拆分因子模块,存储缺口/异常标签,改写 `load_market_data()` 为“缺失即说明”。
- ⏳ 维护博弈结构 → 数据 scope 映射,支持角色按结构加载差异化字段。
- ⏳ 按主题拆分因子模块,存储缺口/异常标签。
- ✅ `load_market_data()` 标注缺失字段并写入原始日志(`missing_fields`、`derived_fields`)。
- ✅ 维护博弈结构 → 数据 scope 映射(`app/agents/scopes.py``BacktestEngine.required_fields`)。
- ✅ 基于 `_RefreshCoordinator` 落地刷新队列与监控事件,拆分查询与补数路径。
- ✅ 暴露 `DataBroker.register_refresh_callback()` 钩子,结合监控系统记录补数进度与失败重试。
- ⏳ 统一补数回调日志格式(`LOG_EXTRA.stage=data_broker`),为后续指标预留数据源。
3. 多轮博弈框架
- ✅ 在 `app/agents/game.py` 抽象 `GameProtocol` 接口,扩展 `Decision` 记录多轮对话。
- ✅ 实现主持调度器驱动议程(信息→陈述→反驳→共识→执行),挂载风险复核机制。
- ⏳ 引入信念修正规则与论证框架,支持证据引用和冲突裁决
- ✅ 引入基础信念修正规则(`app/agents/beliefs.py`),汇总信念并记录冲突
4. 执行与回测集成
- ✅ 将回测循环改造成“每日多轮→执行摘要”,完成风控校验与冲突重议流程。
- ⏳ 擦合订单映射层,明确多轮结果对应目标仓位、执行节奏、异常回滚策略。

View File

@ -0,0 +1,24 @@
digraph LLMQuantCallGraph {
rankdir=LR;
node [shape=box, style=rounded];
BacktestEngine -> LoadMarketData;
LoadMarketData -> DataBrokerFetch;
DataBrokerFetch -> BrokerQuery [label="db_session"];
LoadMarketData -> FeatureAssembly;
BacktestEngine -> Decide;
Decide -> ProtocolHost;
ProtocolHost -> DepartmentRound;
ProtocolHost -> RiskReview;
ProtocolHost -> ExecutionSummary;
Decide -> BeliefRevision;
ExecutionSummary -> ApplyPortfolio;
ApplyPortfolio -> RiskEvents;
ApplyPortfolio -> Alerts;
ApplyPortfolio -> PersistResults;
PersistResults -> Reports;
PersistResults -> UI;
}

View File

@ -0,0 +1,43 @@
# 多轮博弈决策调用示意
本节概述 `llm_quant` 中多轮博弈执行链路,便于定位关键日志与扩展点。
```
BacktestEngine.simulate_day
└─ load_market_data
├─ DataBroker.fetch_latest
│ ├─ BrokerQueryEngine.fetch_latest
│ └─ 缺失字段 → derived_fields/missing_fields (写入 raw/missing_fields)
├─ DataBroker.fetch_series (同上)
└─ assemble feature_map (features / market_snapshot / raw)
└─ for each symbol → decide (agents.game)
├─ compute_utilities / feasible_actions
├─ DepartmentManager.evaluate (LLM 部门,可带回 risk/rationale)
├─ ProtocolHost (game protocols)
│ ├─ start_round("department_consensus")
│ ├─ risk_review (当 conflict / risk assessment 触发)
│ └─ execution_summary (记录 execution_status)
├─ revise_beliefs (beliefs.py) → consensus/conflict
└─ Decision
├─ rounds (RoundSummary 日志)
├─ risk_assessment (status/reason/recommended_action)
├─ belief_updates / belief_revision (供监控/重播)
└─ department_votes / utilities
└─ _apply_portfolio_updates
├─ 使用 Decision.risk_assessment 调节执行
├─ 执行失败/阻断 → risk_events & alerts.backtest_risk
└─ executed_trades / nav_series / risk_events → bt_* 表
```
## 关键日志
- `LOG_EXTRA = {"stage": "backtest"}`:缺失字段、派生字段、执行阻断。
- `LOG_EXTRA = {"stage": "data_broker"}`:自动补数触发、查询失败回退。
## 拉通数据
- `app/agents/scopes.py` 维护结构 → 字段映射。
- `Decision.raw``missing_fields/derived_fields` 可用于缺口诊断。
## 后续建议
1. 将 `belief_revision``risk_events` 接入监控告警。
2. 结合 `missing_fields` 统计生成数据质量简报。
3. 通过自动化脚本渲染上述流程图/时序图。

View File

@ -0,0 +1,44 @@
"""Render architecture call graph for multi-agent decision flow."""
from __future__ import annotations
from pathlib import Path
DOT_TEMPLATE = """
digraph LLMQuantCallGraph {
rankdir=LR;
node [shape=box, style=rounded];
BacktestEngine -> LoadMarketData;
LoadMarketData -> DataBrokerFetch;
DataBrokerFetch -> BrokerQuery [label="db_session"];
LoadMarketData -> FeatureAssembly;
BacktestEngine -> Decide;
Decide -> ProtocolHost;
ProtocolHost -> DepartmentRound;
ProtocolHost -> RiskReview;
ProtocolHost -> ExecutionSummary;
Decide -> BeliefRevision;
ExecutionSummary -> ApplyPortfolio;
ApplyPortfolio -> RiskEvents;
ApplyPortfolio -> Alerts;
ApplyPortfolio -> PersistResults;
PersistResults -> Reports;
PersistResults -> UI;
}
"""
def render(output: Path) -> None:
output.write_text(DOT_TEMPLATE.strip() + "\n", encoding="utf-8")
def main() -> None:
out_file = Path("docs/architecture_call_graph.dot")
render(out_file)
print(f"dot file written: {out_file}")
if __name__ == "__main__":
main()