From 0a2742b869ba9a967af464856d23ce3ca7791ff2 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 28 Sep 2025 18:57:05 +0800 Subject: [PATCH] update --- README.md | 6 +- app/agents/departments.py | 230 +++++++++++++++++++++++++++++++++++--- app/backtest/engine.py | 28 ++++- app/llm/prompts.py | 18 ++- app/utils/data_access.py | 6 + 5 files changed, 268 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 91b9264..00e4454 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ ## 架构总览 -- **数据与存储层**:`app/ingest` 封装 TuShare/RSS 拉数与限频处理,`app/data/schema.py` 初始化 SQLite 表结构,所有模块通过 `app/utils/db.py` 的 `db_session` 访问 `app/data/llm_quant.db`。 +- **数据与存储层**:`app/ingest` 封装 TuShare/RSS 拉数与限频处理,`app/data/schema.py` 初始化 SQLite 表结构,所有模块通过 `app/utils/db.py` 的 `db_session` 访问 `app/data/llm_quant.db`,数据抽象层由 `app/utils/data_access.py` 的 `DataBroker` 统一提供字段查询与时间序列切片。 - **工具与配置层**:`app/utils` 聚合配置、日志、交易日历及 provider 管理,`app/utils/config.py` 定义 LLM/部门/代理权重等全局设置。 -- **特征与策略层**:`app/features` 负责信号构建(当前为占位实现),`app/agents` 实现六类规则型代理与部门级 LLM 协同,`app/backtest/engine.py` 运行多智能体博弈并将结果写入 `agent_utils`。 +- **特征与策略层**:`app/features` 负责信号构建(当前为占位实现),`app/agents` 实现六类规则型代理与部门级 LLM 协同,`app/backtest/engine.py` 通过 `DataBroker` 装配特征/上下文后运行多智能体博弈并将结果写入 `agent_utils`。 - **LLM 与协作层**:`app/llm` 提供统一的模型调用与 Prompt 构建,支持 single/majority/leader 策略,部门输出再与规则代理共同决策。 - **可视化层**:`app/ui/streamlit_app.py` 提供今日计划、回测、设置、自检四大页签,实时读取 `agent_utils`、`run_log` 追踪决策链路。 @@ -38,6 +38,8 @@ 3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。 4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 +目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。 + 上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。 ## 核心技术原理 diff --git a/app/agents/departments.py b/app/agents/departments.py index 1b65016..1ad11b9 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -3,18 +3,25 @@ from __future__ import annotations import json from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from app.agents.base import AgentAction from app.llm.client import run_llm_with_config from app.llm.prompts import department_prompt from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.logging import get_logger +from app.utils.data_access import DataBroker, parse_field_path LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "department"} +@dataclass +class DataRequest: + field: str + window: int = 1 + + @dataclass class DepartmentContext: """Structured data fed into a department for decision making.""" @@ -37,6 +44,8 @@ class DepartmentDecision: raw_response: str signals: List[str] = field(default_factory=list) risks: List[str] = field(default_factory=list) + supplements: List[Dict[str, Any]] = field(default_factory=list) + dialogue: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { @@ -47,6 +56,8 @@ class DepartmentDecision: "signals": self.signals, "risks": self.risks, "raw_response": self.raw_response, + "supplements": self.supplements, + "dialogue": self.dialogue, } @@ -60,6 +71,8 @@ class DepartmentAgent: ) -> None: self.settings = settings self._resolver = resolver + self._broker = DataBroker() + self._max_rounds = 3 def _get_llm_config(self) -> LLMConfig: if self._resolver: @@ -67,24 +80,79 @@ class DepartmentAgent: return self.settings.llm def analyze(self, context: DepartmentContext) -> DepartmentDecision: - prompt = department_prompt(self.settings, context) + mutable_context = _ensure_mutable_context(context) system_prompt = ( "你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。" ) llm_cfg = self._get_llm_config() - try: - response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA) - return DepartmentDecision( - department=self.settings.code, - action=AgentAction.HOLD, - confidence=0.0, - summary=f"LLM 调用失败:{exc}", - raw_response=str(exc), - ) + supplement_chunks: List[str] = [] + transcript: List[str] = [] + delivered_requests = { + (field, 1) + for field in (mutable_context.raw.get("scope_values") or {}).keys() + } + + response = "" + decision_data: Dict[str, Any] = {} + for round_idx in range(self._max_rounds): + supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk) + prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text) + try: + response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) + except Exception as exc: # noqa: BLE001 + LOGGER.exception( + "部门 %s 调用 LLM 失败:%s", + self.settings.code, + exc, + extra=LOG_EXTRA, + ) + return DepartmentDecision( + department=self.settings.code, + action=AgentAction.HOLD, + confidence=0.0, + summary=f"LLM 调用失败:{exc}", + raw_response=str(exc), + ) + + transcript.append(response) + decision_data = _parse_department_response(response) + data_requests = _parse_data_requests(decision_data) + filtered_requests = [ + req + for req in data_requests + if (req.field, req.window) not in delivered_requests + ] + + if filtered_requests and round_idx < self._max_rounds - 1: + lines, payload, delivered = self._fulfill_data_requests( + mutable_context, filtered_requests + ) + if payload: + supplement_chunks.append( + f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines) + ) + mutable_context.raw.setdefault("supplement_data", []).extend(payload) + mutable_context.raw.setdefault("supplement_rounds", []).append( + { + "round": round_idx + 1, + "requests": [req.__dict__ for req in filtered_requests], + "data": payload, + } + ) + delivered_requests.update(delivered) + decision_data.pop("data_requests", None) + continue + LOGGER.debug( + "部门 %s 数据请求无结果:%s", + self.settings.code, + filtered_requests, + extra=LOG_EXTRA, + ) + decision_data.pop("data_requests", None) + break + + mutable_context.raw["supplement_transcript"] = list(transcript) - decision_data = _parse_department_response(response) action = _normalize_action(decision_data.get("action")) confidence = _clamp_float(decision_data.get("confidence"), default=0.5) summary = decision_data.get("summary") or decision_data.get("reason") or "" @@ -102,7 +170,9 @@ class DepartmentAgent: summary=summary or "未提供摘要", signals=[str(sig) for sig in signals if sig], risks=[str(risk) for risk in risks if risk], - raw_response=response, + raw_response="\n\n".join(transcript) if transcript else response, + supplements=list(mutable_context.raw.get("supplement_data", [])), + dialogue=list(transcript), ) LOGGER.debug( "部门 %s 决策:action=%s confidence=%.2f", @@ -113,6 +183,124 @@ class DepartmentAgent: ) return decision + @staticmethod + def _normalize_trade_date(value: str) -> str: + if not isinstance(value, str): + return str(value) + return value.replace("-", "") + + def _fulfill_data_requests( + self, + context: DepartmentContext, + requests: Sequence[DataRequest], + ) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]: + lines: List[str] = [] + payload: List[Dict[str, Any]] = [] + delivered: set[Tuple[str, int]] = set() + + ts_code = context.ts_code + trade_date = self._normalize_trade_date(context.trade_date) + + latest_fields: List[str] = [] + series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = [] + + for req in requests: + field = req.field.strip() + if not field: + continue + if req.window <= 1: + if field not in latest_fields: + latest_fields.append(field) + delivered.add((field, 1)) + continue + parsed = parse_field_path(field) + if not parsed: + lines.append(f"- {field}: 字段不合法,已忽略") + continue + series_requests.append((req, parsed)) + delivered.add((field, req.window)) + + if latest_fields: + latest_values = self._broker.fetch_latest(ts_code, trade_date, latest_fields) + for field in latest_fields: + value = latest_values.get(field) + if value is None: + lines.append(f"- {field}: (数据缺失)") + else: + lines.append(f"- {field}: {value}") + payload.append({"field": field, "window": 1, "values": value}) + + for req, parsed in series_requests: + table, column = parsed + series = self._broker.fetch_series( + table, + column, + ts_code, + trade_date, + window=req.window, + ) + if series: + preview = ", ".join( + f"{dt}:{val:.4f}" + for dt, val in series[: min(len(series), 5)] + ) + lines.append( + f"- {req.field} (window={req.window}): {preview}" + ) + else: + lines.append( + f"- {req.field} (window={req.window}): (数据缺失)" + ) + payload.append({"field": req.field, "window": req.window, "values": series}) + + return lines, payload, delivered + + +def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: + if not isinstance(context.features, dict): + context.features = dict(context.features or {}) + if not isinstance(context.market_snapshot, dict): + context.market_snapshot = dict(context.market_snapshot or {}) + raw = dict(context.raw or {}) + scope_values = raw.get("scope_values") + if scope_values is not None and not isinstance(scope_values, dict): + raw["scope_values"] = dict(scope_values) + context.raw = raw + return context + + +def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]: + raw_requests = payload.get("data_requests") + requests: List[DataRequest] = [] + if not isinstance(raw_requests, list): + return requests + seen: set[Tuple[str, int]] = set() + for item in raw_requests: + field = "" + window = 1 + if isinstance(item, str): + field = item.strip() + elif isinstance(item, Mapping): + candidate = item.get("field") + if candidate is None: + continue + field = str(candidate).strip() + try: + window = int(item.get("window", 1)) + except (TypeError, ValueError): + window = 1 + else: + continue + if not field: + continue + window = max(1, window) + key = (field, window) + if key in seen: + continue + seen.add(key) + requests.append(DataRequest(field=field, window=window)) + return requests + class DepartmentManager: """Orchestrates all departments defined in configuration.""" @@ -127,7 +315,17 @@ class DepartmentManager: def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]: results: Dict[str, DepartmentDecision] = {} for code, agent in self.agents.items(): - results[code] = agent.analyze(context) + raw_base = dict(context.raw or {}) + if "scope_values" in raw_base: + raw_base["scope_values"] = dict(raw_base.get("scope_values") or {}) + dept_context = DepartmentContext( + ts_code=context.ts_code, + trade_date=context.trade_date, + features=dict(context.features or {}), + market_snapshot=dict(context.market_snapshot or {}), + raw=raw_base, + ) + results[code] = agent.analyze(dept_context) return results def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig: diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 8f37013..912c6cf 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -170,6 +170,14 @@ class BacktestEngine: if down_limit and latest_close: limit_down = latest_close <= down_limit * 1.001 + is_suspended = self.data_broker.fetch_flags( + "suspend", + ts_code, + trade_date_str, + "suspend_date <= ? AND (resume_date IS NULL OR resume_date > ?)", + (trade_date_str, trade_date_str), + ) + features = { "mom_20": mom20, "mom_60": mom60, @@ -185,7 +193,7 @@ class BacktestEngine: scope_values.get("index.performance_peers", 0.0), ), "risk_penalty": min(1.0, volat20 * 5.0), - "is_suspended": False, + "is_suspended": is_suspended, "limit_up": limit_up, "limit_down": limit_down, "position_limit": False, @@ -205,6 +213,7 @@ class BacktestEngine: "scope_values": scope_values, "close_series": closes, "turnover_series": turnover_series, + "required_fields": self.required_fields, } feature_map[ts_code] = { @@ -285,6 +294,10 @@ class BacktestEngine: "_risks": dept_decision.risks, "_confidence": dept_decision.confidence, } + if dept_decision.supplements: + metadata["_supplements"] = dept_decision.supplements + if dept_decision.dialogue: + metadata["_dialogue"] = dept_decision.dialogue payload_json = {**action_scores, **metadata} rows.append( ( @@ -303,6 +316,19 @@ class BacktestEngine: "_target_weight": decision.target_weight, "_department_votes": decision.department_votes, "_requires_review": decision.requires_review, + "_scope_values": context.raw.get("scope_values", {}), + "_close_series": context.raw.get("close_series", []), + "_turnover_series": context.raw.get("turnover_series", []), + "_department_supplements": { + code: dept.supplements + for code, dept in decision.department_decisions.items() + if dept.supplements + }, + "_department_dialogue": { + code: dept.dialogue + for code, dept in decision.department_decisions.items() + if dept.dialogue + }, } rows.append( ( diff --git a/app/llm/prompts.py b/app/llm/prompts.py index 48a08b5..8dcdd4a 100644 --- a/app/llm/prompts.py +++ b/app/llm/prompts.py @@ -15,7 +15,11 @@ def plan_prompt(data: Dict) -> str: return "你是一个投资助理,请根据提供的数据给出三条要点和两条风险提示。" -def department_prompt(settings: "DepartmentSettings", context: "DepartmentContext") -> str: +def department_prompt( + settings: "DepartmentSettings", + context: "DepartmentContext", + supplements: str = "", +) -> str: """Compose a structured prompt for department-level LLM ensemble.""" feature_lines = "\n".join( @@ -27,6 +31,7 @@ def department_prompt(settings: "DepartmentSettings", context: "DepartmentContex scope_lines = "\n".join(f"- {item}" for item in settings.data_scope) role_description = settings.description.strip() role_instruction = settings.prompt.strip() + supplement_block = supplements.strip() instructions = f""" 部门名称:{settings.title} @@ -45,6 +50,9 @@ def department_prompt(settings: "DepartmentSettings", context: "DepartmentContex 【市场背景】 {market_lines or '- (无)'} +【追加数据】 +{supplement_block or '- 当前无追加数据'} + 请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下: {{ "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", @@ -54,6 +62,14 @@ def department_prompt(settings: "DepartmentSettings", context: "DepartmentContex "risks": ["风险点", "..."] }} +如需额外数据,请在同一 JSON 中添加可选字段 `"data_requests"`,其取值为数组,例如: +"data_requests": [ + {{"field": "daily.close", "window": 5}}, + {{"field": "daily_basic.pe"}} +] +其中 `field` 必须属于【可用数据范围】或明确说明新增需求;`window` 表示希望返回的最近数据点数量,省略时默认为 1。 +如果不需要更多数据,请不要返回 `data_requests`。 + 请严格返回单个 JSON 对象,不要添加额外文本。 """ return instructions.strip() diff --git a/app/utils/data_access.py b/app/utils/data_access.py index d5b9f73..c63d345 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -32,6 +32,12 @@ def _safe_split(path: str) -> Tuple[str, str] | None: return table, column +def parse_field_path(path: str) -> Tuple[str, str] | None: + """Validate and split a `table.column` field expression.""" + + return _safe_split(path) + + @dataclass class DataBroker: """Lightweight data access helper for agent/LLM consumption."""