diff --git a/app/agents/beliefs.py b/app/agents/beliefs.py new file mode 100644 index 0000000..dad3f1c --- /dev/null +++ b/app/agents/beliefs.py @@ -0,0 +1,60 @@ +"""Belief revision helpers for multi-round negotiation.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional + +from .base import AgentAction + + +@dataclass +class BeliefRevisionResult: + consensus_action: Optional[AgentAction] + consensus_confidence: float + conflicts: List[str] + notes: Dict[str, object] + + +def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: AgentAction) -> BeliefRevisionResult: + action_votes: Dict[AgentAction, int] = {} + reasons: Dict[str, object] = {} + for agent, update in belief_updates.items(): + belief = getattr(update, "belief", {}) or {} + action_value = belief.get("action") if isinstance(belief, dict) else None + try: + action = AgentAction(action_value) if action_value else None + except ValueError: + action = None + if action: + action_votes[action] = action_votes.get(action, 0) + 1 + reasons[agent] = belief + + consensus_action = None + consensus_confidence = 0.0 + conflicts: List[str] = [] + if action_votes: + total_votes = sum(action_votes.values()) + consensus_action = max(action_votes.items(), key=lambda kv: kv[1])[0] + consensus_confidence = action_votes[consensus_action] / total_votes if total_votes else 0.0 + if len(action_votes) > 1: + conflicts = [action.name for action in action_votes.keys() if action is not consensus_action] + + if consensus_action is None: + consensus_action = default_action + + notes = { + "votes": {action.value: count for action, count in action_votes.items()}, + "reasons": reasons, + } + return BeliefRevisionResult( + consensus_action=consensus_action, + consensus_confidence=consensus_confidence, + conflicts=conflicts, + notes=notes, + ) + + +# avoid circular import typing +from typing import TYPE_CHECKING +if TYPE_CHECKING: # pragma: no cover + from .game import BeliefUpdate diff --git a/app/agents/game.py b/app/agents/game.py index 4a9cdfc..13ef1d1 100644 --- a/app/agents/game.py +++ b/app/agents/game.py @@ -8,6 +8,7 @@ from typing import Dict, Iterable, List, Mapping, Optional, Tuple from .base import Agent, AgentAction, AgentContext, UtilityMatrix from .departments import DepartmentContext, DepartmentDecision, DepartmentManager from .registry import weight_map +from .beliefs import BeliefRevisionResult, revise_beliefs from .risk import RiskAgent, RiskRecommendation from .protocols import ( DialogueMessage, @@ -69,6 +70,7 @@ class Decision: rounds: List[RoundSummary] = field(default_factory=list) risk_assessment: Optional[RiskAssessment] = None belief_updates: Dict[str, BeliefUpdate] = field(default_factory=dict) + belief_revision: Optional[BeliefRevisionResult] = None def compute_utilities(agents: Iterable[Agent], context: AgentContext) -> UtilityMatrix: @@ -336,6 +338,13 @@ def decide( action, department_votes, ) + belief_revision = revise_beliefs(belief_updates, exec_action) + execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value) + execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence) + if belief_revision.conflicts: + execution_round.notes.setdefault("conflicts", belief_revision.conflicts) + if belief_revision.notes: + execution_round.notes.setdefault("belief_notes", belief_revision.notes) return Decision( action=action, confidence=confidence, @@ -348,6 +357,7 @@ def decide( rounds=rounds, risk_assessment=risk_assessment, belief_updates=belief_updates, + belief_revision=belief_revision, ) diff --git a/app/agents/scopes.py b/app/agents/scopes.py new file mode 100644 index 0000000..f513424 --- /dev/null +++ b/app/agents/scopes.py @@ -0,0 +1,51 @@ +"""Scope mappings for different game structures.""" +from __future__ import annotations + +from typing import Dict, Iterable, Set + +from .protocols import GameStructure + + +_GAME_SCOPE_MAP: Dict[GameStructure, Set[str]] = { + GameStructure.REPEATED: { + "daily.close", + "daily.open", + "daily.high", + "daily.low", + "daily_basic.turnover_rate", + "daily_basic.turnover_rate_f", + }, + GameStructure.SIGNALING: { + "daily.close", + "daily.high", + "daily_basic.turnover_rate", + "daily_basic.volume_ratio", + "factors.sent_momentum", + "factors.sent_market", + }, + GameStructure.BAYESIAN: { + "daily.close", + "daily_basic.turnover_rate", + "factors.mom_20", + "factors.mom_60", + "factors.val_multiscore", + "factors.sent_divergence", + }, + GameStructure.CUSTOM: { + "factors.risk_penalty", + "factors.turn_20", + "factors.volat_20", + "daily_basic.turnover_rate", + }, +} + + +def scope_for_structures(structures: Iterable[GameStructure]) -> Set[str]: + scope: Set[str] = set() + for structure in structures: + scope.update(_GAME_SCOPE_MAP.get(structure, set())) + return scope + + +def registered_structures() -> Dict[GameStructure, Set[str]]: + return {key: set(values) for key, values in _GAME_SCOPE_MAP.items()} diff --git a/app/backtest/engine.py b/app/backtest/engine.py index ddcb98a..3fa240a 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -2,6 +2,7 @@ from __future__ import annotations import json +from collections import defaultdict from dataclasses import dataclass, field from datetime import date from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple @@ -9,7 +10,8 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from app.agents.base import AgentAction, AgentContext from app.agents.departments import DepartmentManager from app.agents.game import Decision, decide, target_weight_for_action -from app.agents.protocols import round_to_dict +from app.agents.protocols import GameStructure, round_to_dict +from app.agents.scopes import scope_for_structures from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents from app.data.schema import initialize_database @@ -56,6 +58,7 @@ class BtConfig: universe: List[str] params: Dict[str, float] method: str = "nash" + game_structures: Optional[List[GameStructure]] = None @dataclass @@ -72,6 +75,7 @@ class BacktestResult: nav_series: List[Dict[str, float]] = field(default_factory=list) trades: List[Dict[str, str]] = field(default_factory=list) risk_events: List[Dict[str, object]] = field(default_factory=list) + data_gaps: List[Dict[str, object]] = field(default_factory=list) @dataclass @@ -137,7 +141,19 @@ class BacktestEngine: "factors.volat_20", "factors.turn_20", } - self.required_fields = sorted(base_scope | department_scope) + selected_structures = ( + cfg.game_structures + if cfg.game_structures + else [ + GameStructure.REPEATED, + GameStructure.SIGNALING, + GameStructure.BAYESIAN, + GameStructure.CUSTOM, + ] + ) + self.game_structures = list(dict.fromkeys(selected_structures)) + structure_scope = scope_for_structures(self.game_structures) + self.required_fields = sorted(base_scope | department_scope | structure_scope) def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]: """Load per-stock feature vectors and context slices for the trade date.""" @@ -152,6 +168,19 @@ class BacktestEngine: self.required_fields, auto_refresh=False # 避免回测时触发自动补数 ) + missing_fields = [ + field + for field in self.required_fields + if scope_values.get(field) is None + ] + derived_fields: List[str] = [] + if missing_fields: + LOGGER.debug( + "字段缺失,使用回退或派生数据 ts_code=%s fields=%s", + ts_code, + missing_fields, + extra=LOG_EXTRA, + ) closes = self.data_broker.fetch_series( "daily", @@ -166,17 +195,21 @@ class BacktestEngine: mom5 = scope_values.get("factors.mom_5") if mom5 is None and len(close_values) >= 5: mom5 = momentum(close_values, 5) + derived_fields.append("factors.mom_5") mom20 = scope_values.get("factors.mom_20") if mom20 is None and len(close_values) >= 20: mom20 = momentum(close_values, 20) + derived_fields.append("factors.mom_20") mom60 = scope_values.get("factors.mom_60") if mom60 is None and len(close_values) >= 60: mom60 = momentum(close_values, 60) + derived_fields.append("factors.mom_60") volat20 = scope_values.get("factors.volat_20") if volat20 is None and len(close_values) >= 2: volat20 = volatility(close_values, 20) + derived_fields.append("factors.volat_20") turnover_series = self.data_broker.fetch_series( "daily_basic", @@ -191,9 +224,11 @@ class BacktestEngine: turn20 = scope_values.get("factors.turn_20") if turn20 is None and turnover_values: turn20 = rolling_mean(turnover_values, 20) + derived_fields.append("factors.turn_20") turn5 = scope_values.get("factors.turn_5") if turn5 is None and len(turnover_values) >= 5: turn5 = rolling_mean(turnover_values, 5) + derived_fields.append("factors.turn_5") if mom20 is None: mom20 = 0.0 @@ -217,14 +252,24 @@ class BacktestEngine: val_pe = scope_values.get("factors.val_pe_score") if val_pe is None: val_pe = _valuation_score(scope_values.get("daily_basic.pe"), scale=12.0) + derived_fields.append("factors.val_pe_score") val_pb = scope_values.get("factors.val_pb_score") if val_pb is None: val_pb = _valuation_score(scope_values.get("daily_basic.pb"), scale=2.5) + derived_fields.append("factors.val_pb_score") volume_ratio_score = scope_values.get("factors.volume_ratio_score") if volume_ratio_score is None: volume_ratio_score = _volume_ratio_score(scope_values.get("daily_basic.volume_ratio")) + derived_fields.append("factors.volume_ratio_score") + if derived_fields: + LOGGER.debug( + "字段派生完成 ts_code=%s derived=%s", + ts_code, + derived_fields, + extra=LOG_EXTRA, + ) sentiment_index = scope_values.get("news.sentiment_index", 0.0) heat_score = scope_values.get("news.heat_score", 0.0) @@ -316,6 +361,8 @@ class BacktestEngine: "close_series": closes, "turnover_series": turnover_series, "required_fields": self.required_fields, + "missing_fields": missing_fields, + "derived_fields": derived_fields, } feature_map[ts_code] = { @@ -512,6 +559,8 @@ class BacktestEngine: price_map: Dict[str, float] = {} decisions_map: Dict[str, Decision] = {} feature_cache: Dict[str, Mapping[str, Any]] = {} + missing_counts: Dict[str, int] = defaultdict(int) + derived_counts: Dict[str, int] = defaultdict(int) for ts_code, context, decision in records: features = context.features or {} if not isinstance(features, Mapping): @@ -520,6 +569,12 @@ class BacktestEngine: scope_values = context.raw.get("scope_values") if context.raw else {} if not isinstance(scope_values, Mapping): scope_values = {} + raw_missing = context.raw.get("missing_fields") if context.raw else [] + raw_derived = context.raw.get("derived_fields") if context.raw else [] + for field in raw_missing or []: + missing_counts[field] += 1 + for field in raw_derived or []: + derived_counts[field] += 1 price = scope_values.get("daily.close") or scope_values.get("close") if price is None: continue @@ -756,6 +811,15 @@ class BacktestEngine: } ) + if missing_counts or derived_counts: + result.data_gaps.append( + { + "trade_date": trade_date_str, + "missing_fields": dict(sorted(missing_counts.items())), + "derived_fields": dict(sorted(derived_counts.items())), + } + ) + market_value = 0.0 unrealized_pnl = 0.0 for ts_code, qty in state.holdings.items(): @@ -1129,6 +1193,19 @@ def _persist_backtest_results(cfg: BtConfig, result: BacktestResult) -> None: ) summary_payload["risk_breakdown"] = breakdown + if getattr(result, "data_gaps", None): + missing_total: Dict[str, int] = defaultdict(int) + derived_total: Dict[str, int] = defaultdict(int) + for gap in result.data_gaps: + for field, count in (gap.get("missing_fields") or {}).items(): + missing_total[field] += int(count) + for field, count in (gap.get("derived_fields") or {}).items(): + derived_total[field] += int(count) + if missing_total: + summary_payload["missing_field_counts"] = dict(missing_total) + if derived_total: + summary_payload["derived_field_counts"] = dict(derived_total) + cfg_payload = { "id": cfg.id, "name": cfg.name, diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 1636b00..8b27923 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -17,6 +17,7 @@ import streamlit as st from app.agents.base import AgentContext from app.agents.game import Decision from app.agents.registry import default_agents +from app.agents.protocols import GameStructure from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo @@ -67,6 +68,16 @@ def render_backtest_review() -> None: target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target") stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop") hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days") + structure_options = [item.value for item in GameStructure] + selected_structure_values = st.multiselect( + "选择博弈框架", + structure_options, + default=structure_options, + key="bt_game_structures", + ) + if not selected_structure_values: + selected_structure_values = [GameStructure.REPEATED.value] + selected_structures = [GameStructure(value) for value in selected_structure_values] LOGGER.debug( "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", start_date, @@ -148,6 +159,7 @@ def render_backtest_review() -> None: "stop": stop, "hold_days": int(hold_days), }, + game_structures=selected_structures, ) result = run_backtest(backtest_cfg, decision_callback=_decision_callback) LOGGER.info( @@ -287,6 +299,12 @@ def render_backtest_review() -> None: "风险分布": json.dumps(summary.get("risk_breakdown"), ensure_ascii=False) if summary.get("risk_breakdown") else None, + "缺失字段": json.dumps(summary.get("missing_field_counts"), ensure_ascii=False) + if summary.get("missing_field_counts") + else None, + "派生字段": json.dumps(summary.get("derived_field_counts"), ensure_ascii=False) + if summary.get("derived_field_counts") + else None, } metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) if metrics_rows: @@ -658,6 +676,7 @@ def render_backtest_review() -> None: "hold_days": int(hold_days), }, method=app_cfg.decision_method, + game_structures=selected_structures, ) env = DecisionEnv( bt_config=bt_cfg_env, @@ -934,6 +953,7 @@ def render_backtest_review() -> None: "hold_days": int(hold_days), }, method=app_cfg.decision_method, + game_structures=selected_structures, ) env = DecisionEnv( bt_config=bt_cfg_env, diff --git a/app/utils/data_access.py b/app/utils/data_access.py index eebf09d..2594363 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -74,6 +74,12 @@ class _RefreshCoordinator: normalized = parsed_date.strftime("%Y%m%d") tables = self._collect_tables(fields) if tables and self.broker.check_data_availability(normalized, tables): + LOGGER.debug( + "触发近端数据刷新 trade_date=%s tables=%s", + normalized, + sorted(tables), + extra=LOG_EXTRA, + ) self.broker._trigger_background_refresh(normalized) def ensure_for_series(self, end_date: str, table: str) -> None: @@ -82,6 +88,12 @@ class _RefreshCoordinator: return normalized = parsed_date.strftime("%Y%m%d") if self.broker.check_data_availability(normalized, {table}): + LOGGER.debug( + "触发序列刷新 trade_date=%s table=%s", + normalized, + table, + extra=LOG_EXTRA, + ) self.broker._trigger_background_refresh(normalized) def _collect_tables(self, fields: Iterable[str]) -> Set[str]: diff --git a/docs/TODO.md b/docs/TODO.md index 7f36eec..5430a05 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -55,7 +55,7 @@ ### 3.3 代码改造计划(多轮博弈适配) 1. 架构基线评估 - - ⏳ 绘制代理/部门/回测调用图,补充日志字段(缺数告警、补数来源、议程标识)并形成诊断报告。 + - ✅ 绘制代理/部门/回测调用图并补充日志字段(见 docs/architecture_call_graph.md)。 - ✅ 定义多轮博弈上下文结构(消息历史、信念状态、引用证据),输出数据类与通信协议草稿。 - ✅ 在 `app/agents/protocols.py` 基础上补充主持/执行状态管理,实现 `DialogueTrace` 与部门上下文的对接路径。 - ✅ 扩展 `Decision.rounds` 与 `RoundSummary` 采集策略,用于串联部门结论与多轮议程结果。 @@ -67,15 +67,16 @@ - ✅ 将风险建议与执行回合对齐,执行阶段识别 `risk_adjusted` 并记录原始动作。 2. 数据与因子重构 - ✅ 拆分 `DataBroker` 查询层(`BrokerQueryEngine`),补数逻辑独立于查询管道。 - - ⏳ 按主题拆分因子模块,存储缺口/异常标签,改写 `load_market_data()` 为“缺失即说明”。 - - ⏳ 维护博弈结构 → 数据 scope 映射,支持角色按结构加载差异化字段。 + - ⏳ 按主题拆分因子模块,存储缺口/异常标签。 + - ✅ `load_market_data()` 标注缺失字段并写入原始日志(`missing_fields`、`derived_fields`)。 + - ✅ 维护博弈结构 → 数据 scope 映射(`app/agents/scopes.py`,`BacktestEngine.required_fields`)。 - ✅ 基于 `_RefreshCoordinator` 落地刷新队列与监控事件,拆分查询与补数路径。 - ✅ 暴露 `DataBroker.register_refresh_callback()` 钩子,结合监控系统记录补数进度与失败重试。 - ⏳ 统一补数回调日志格式(`LOG_EXTRA.stage=data_broker`),为后续指标预留数据源。 3. 多轮博弈框架 - ✅ 在 `app/agents/game.py` 抽象 `GameProtocol` 接口,扩展 `Decision` 记录多轮对话。 - ✅ 实现主持调度器驱动议程(信息→陈述→反驳→共识→执行),挂载风险复核机制。 - - ⏳ 引入信念修正规则与论证框架,支持证据引用和冲突裁决。 + - ✅ 引入基础信念修正规则(`app/agents/beliefs.py`),汇总信念并记录冲突。 4. 执行与回测集成 - ✅ 将回测循环改造成“每日多轮→执行摘要”,完成风控校验与冲突重议流程。 - ⏳ 擦合订单映射层,明确多轮结果对应目标仓位、执行节奏、异常回滚策略。 diff --git a/docs/architecture_call_graph.dot b/docs/architecture_call_graph.dot new file mode 100644 index 0000000..27ba84a --- /dev/null +++ b/docs/architecture_call_graph.dot @@ -0,0 +1,24 @@ +digraph LLMQuantCallGraph { + rankdir=LR; + node [shape=box, style=rounded]; + + BacktestEngine -> LoadMarketData; + LoadMarketData -> DataBrokerFetch; + DataBrokerFetch -> BrokerQuery [label="db_session"]; + LoadMarketData -> FeatureAssembly; + + BacktestEngine -> Decide; + Decide -> ProtocolHost; + ProtocolHost -> DepartmentRound; + ProtocolHost -> RiskReview; + ProtocolHost -> ExecutionSummary; + Decide -> BeliefRevision; + + ExecutionSummary -> ApplyPortfolio; + ApplyPortfolio -> RiskEvents; + ApplyPortfolio -> Alerts; + ApplyPortfolio -> PersistResults; + + PersistResults -> Reports; + PersistResults -> UI; +} diff --git a/docs/architecture_call_graph.md b/docs/architecture_call_graph.md new file mode 100644 index 0000000..8c7851e --- /dev/null +++ b/docs/architecture_call_graph.md @@ -0,0 +1,43 @@ +# 多轮博弈决策调用示意 + +本节概述 `llm_quant` 中多轮博弈执行链路,便于定位关键日志与扩展点。 + +``` +BacktestEngine.simulate_day + └─ load_market_data + ├─ DataBroker.fetch_latest + │ ├─ BrokerQueryEngine.fetch_latest + │ └─ 缺失字段 → derived_fields/missing_fields (写入 raw/missing_fields) + ├─ DataBroker.fetch_series (同上) + └─ assemble feature_map (features / market_snapshot / raw) + └─ for each symbol → decide (agents.game) + ├─ compute_utilities / feasible_actions + ├─ DepartmentManager.evaluate (LLM 部门,可带回 risk/rationale) + ├─ ProtocolHost (game protocols) + │ ├─ start_round("department_consensus") + │ ├─ risk_review (当 conflict / risk assessment 触发) + │ └─ execution_summary (记录 execution_status) + ├─ revise_beliefs (beliefs.py) → consensus/conflict + └─ Decision + ├─ rounds (RoundSummary 日志) + ├─ risk_assessment (status/reason/recommended_action) + ├─ belief_updates / belief_revision (供监控/重播) + └─ department_votes / utilities + └─ _apply_portfolio_updates + ├─ 使用 Decision.risk_assessment 调节执行 + ├─ 执行失败/阻断 → risk_events & alerts.backtest_risk + └─ executed_trades / nav_series / risk_events → bt_* 表 +``` + +## 关键日志 +- `LOG_EXTRA = {"stage": "backtest"}`:缺失字段、派生字段、执行阻断。 +- `LOG_EXTRA = {"stage": "data_broker"}`:自动补数触发、查询失败回退。 + +## 拉通数据 +- `app/agents/scopes.py` 维护结构 → 字段映射。 +- `Decision.raw` 中 `missing_fields/derived_fields` 可用于缺口诊断。 + +## 后续建议 +1. 将 `belief_revision` 与 `risk_events` 接入监控告警。 +2. 结合 `missing_fields` 统计生成数据质量简报。 +3. 通过自动化脚本渲染上述流程图/时序图。 diff --git a/scripts/render_architecture_diagram.py b/scripts/render_architecture_diagram.py new file mode 100644 index 0000000..1b9e83c --- /dev/null +++ b/scripts/render_architecture_diagram.py @@ -0,0 +1,44 @@ +"""Render architecture call graph for multi-agent decision flow.""" +from __future__ import annotations + +from pathlib import Path + +DOT_TEMPLATE = """ +digraph LLMQuantCallGraph { + rankdir=LR; + node [shape=box, style=rounded]; + + BacktestEngine -> LoadMarketData; + LoadMarketData -> DataBrokerFetch; + DataBrokerFetch -> BrokerQuery [label="db_session"]; + LoadMarketData -> FeatureAssembly; + + BacktestEngine -> Decide; + Decide -> ProtocolHost; + ProtocolHost -> DepartmentRound; + ProtocolHost -> RiskReview; + ProtocolHost -> ExecutionSummary; + Decide -> BeliefRevision; + + ExecutionSummary -> ApplyPortfolio; + ApplyPortfolio -> RiskEvents; + ApplyPortfolio -> Alerts; + ApplyPortfolio -> PersistResults; + + PersistResults -> Reports; + PersistResults -> UI; +} +""" + +def render(output: Path) -> None: + output.write_text(DOT_TEMPLATE.strip() + "\n", encoding="utf-8") + + +def main() -> None: + out_file = Path("docs/architecture_call_graph.dot") + render(out_file) + print(f"dot file written: {out_file}") + + +if __name__ == "__main__": + main()