diff --git a/README.md b/README.md index 4136666..35ac27f 100644 --- a/README.md +++ b/README.md @@ -35,11 +35,10 @@ 1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。 2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。 -3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。 -4. **函数式工具调用**:DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具,LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result,再继续对话生成最终意见,避免字段错误与手写 JSON 校验。 -5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 +3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。 +4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 -目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。 +目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。 上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。 diff --git a/app/agents/departments.py b/app/agents/departments.py index 2c13276..d0789d0 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -222,37 +222,126 @@ class DepartmentAgent: ts_code = context.ts_code trade_date = self._normalize_trade_date(context.trade_date) - latest_fields: List[str] = [] + latest_groups: Dict[str, List[str]] = {} series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = [] + values_map, db_alias_map, series_map = _build_context_lookup(context) for req in requests: field = req.field.strip() if not field: continue - resolved = self._broker.resolve_field(field) - if not resolved: - lines.append(f"- {field}: 字段不存在或不可用") - continue - if req.window <= 1: - if field not in latest_fields: - latest_fields.append(field) - delivered.add((field, 1)) - continue - series_requests.append((req, resolved)) - delivered.add((field, req.window)) + window = req.window + resolved: Optional[Tuple[str, str]] = None + if "." in field: + resolved = self._broker.resolve_field(field) + elif field in db_alias_map: + resolved = db_alias_map[field] - if latest_fields: - latest_values = self._broker.fetch_latest(ts_code, trade_date, latest_fields) - for field in latest_fields: - value = latest_values.get(field) - if value is None: - lines.append(f"- {field}: (数据缺失)") + if resolved: + table, column = resolved + canonical = f"{table}.{column}" + if window <= 1: + latest_groups.setdefault(canonical, []).append(field) + delivered.add((field, 1)) + delivered.add((canonical, 1)) else: - lines.append(f"- {field}: {value}") - payload.append({"field": field, "window": 1, "values": value}) + series_requests.append((req, resolved)) + delivered.add((field, window)) + delivered.add((canonical, window)) + continue - for req, parsed in series_requests: - table, column = parsed + if field in values_map: + value = values_map[field] + if window <= 1: + payload.append( + { + "field": field, + "window": 1, + "source": "context", + "values": [ + { + "trade_date": context.trade_date, + "value": value, + } + ], + } + ) + lines.append(f"- {field}: {value} (来自上下文)") + else: + series = series_map.get(field) + if series: + trimmed = series[: window] + payload.append( + { + "field": field, + "window": window, + "source": "context_series", + "values": [ + {"trade_date": dt, "value": val} + for dt, val in trimmed + ], + } + ) + preview = ", ".join( + f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)] + ) + lines.append( + f"- {field} (window={window} 来自上下文序列): {preview}" + ) + else: + payload.append( + { + "field": field, + "window": window, + "source": "context", + "values": [ + { + "trade_date": context.trade_date, + "value": value, + } + ], + "warning": "仅提供当前值,缺少历史序列", + } + ) + lines.append( + f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列" + ) + delivered.add((field, window)) + if field in db_alias_map: + resolved = db_alias_map[field] + canonical = f"{resolved[0]}.{resolved[1]}" + delivered.add((canonical, window)) + continue + + lines.append(f"- {field}: 字段不存在或不可用") + + if latest_groups: + latest_values = self._broker.fetch_latest( + ts_code, trade_date, list(latest_groups.keys()) + ) + for canonical, aliases in latest_groups.items(): + value = latest_values.get(canonical) + if value is None: + lines.append(f"- {canonical}: (数据缺失)") + else: + lines.append(f"- {canonical}: {value}") + for alias in aliases: + payload.append( + { + "field": alias, + "window": 1, + "source": "database", + "values": [ + { + "trade_date": trade_date, + "value": value, + } + ], + } + ) + + for req, resolved in series_requests: + table, column = resolved series = self._broker.fetch_series( table, column, @@ -276,6 +365,7 @@ class DepartmentAgent: { "field": req.field, "window": req.window, + "source": "database", "values": [ {"trade_date": dt, "value": val} for dt, val in series @@ -546,6 +636,40 @@ def _extract_message_content(message: Mapping[str, Any]) -> str: return json.dumps(message, ensure_ascii=False) +def _build_context_lookup( + context: DepartmentContext, +) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]: + values: Dict[str, Any] = {} + db_alias: Dict[str, Tuple[str, str]] = {} + series_map: Dict[str, List[Tuple[str, float]]] = {} + + for source in (context.features or {}, context.market_snapshot or {}): + for key, value in source.items(): + values[str(key)] = value + + scope_values = context.raw.get("scope_values") or {} + for key, value in scope_values.items(): + key_str = str(key) + values[key_str] = value + if "." in key_str: + table, column = key_str.split(".", 1) + db_alias.setdefault(column, (table, column)) + db_alias.setdefault(key_str, (table, column)) + values.setdefault(column, value) + + close_series = context.raw.get("close_series") or [] + if isinstance(close_series, list) and close_series: + series_map["close"] = close_series + series_map["daily.close"] = close_series + + turnover_series = context.raw.get("turnover_series") or [] + if isinstance(turnover_series, list) and turnover_series: + series_map["turnover_rate"] = turnover_series + series_map["daily_basic.turnover_rate"] = turnover_series + + return values, db_alias, series_map + + class DepartmentManager: """Orchestrates all departments defined in configuration.""" diff --git a/app/llm/prompts.py b/app/llm/prompts.py index 8dcdd4a..86340af 100644 --- a/app/llm/prompts.py +++ b/app/llm/prompts.py @@ -62,13 +62,8 @@ def department_prompt( "risks": ["风险点", "..."] }} -如需额外数据,请在同一 JSON 中添加可选字段 `"data_requests"`,其取值为数组,例如: -"data_requests": [ - {{"field": "daily.close", "window": 5}}, - {{"field": "daily_basic.pe"}} -] -其中 `field` 必须属于【可用数据范围】或明确说明新增需求;`window` 表示希望返回的最近数据点数量,省略时默认为 1。 -如果不需要更多数据,请不要返回 `data_requests`。 +如需额外数据,请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组(元素包含 `field` 以及可选的 `window`);`field` 必须符合【可用数据范围】,`window` 默认为 1。 +工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。 请严格返回单个 JSON 对象,不要添加额外文本。 """ diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 4649ce5..bb6c340 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -213,6 +213,7 @@ def render_today_plan() -> None: global_info = None dept_records: List[Dict[str, object]] = [] + dept_details: Dict[str, Dict[str, object]] = {} agent_records: List[Dict[str, object]] = [] for item in rows: @@ -231,6 +232,11 @@ def render_today_plan() -> None: "target_weight": float(utils.get("_target_weight", 0.0)), "department_votes": utils.get("_department_votes", {}), "requires_review": bool(utils.get("_requires_review", False)), + "scope_values": utils.get("_scope_values", {}), + "close_series": utils.get("_close_series", []), + "turnover_series": utils.get("_turnover_series", []), + "department_supplements": utils.get("_department_supplements", {}), + "department_dialogue": utils.get("_department_dialogue", {}), } continue @@ -238,6 +244,8 @@ def render_today_plan() -> None: code = agent_name.split("dept_", 1)[-1] signals = utils.get("_signals", []) risks = utils.get("_risks", []) + supplements = utils.get("_supplements", []) + dialogue = utils.get("_dialogue", []) dept_records.append( { "部门": code, @@ -247,8 +255,16 @@ def render_today_plan() -> None: "摘要": utils.get("_summary", ""), "核心信号": ";".join(signals) if isinstance(signals, list) else signals, "风险提示": ";".join(risks) if isinstance(risks, list) else risks, + "补充次数": len(supplements) if isinstance(supplements, list) else 0, } ) + dept_details[code] = { + "supplements": supplements if isinstance(supplements, list) else [], + "dialogue": dialogue if isinstance(dialogue, list) else [], + "summary": utils.get("_summary", ""), + "signals": signals, + "risks": risks, + } else: score_map = { key: float(val) @@ -281,6 +297,26 @@ def render_today_plan() -> None: st.json(global_info["department_votes"]) if global_info["requires_review"]: st.warning("部门分歧较大,已标记为需人工复核。") + with st.expander("基础上下文数据", expanded=False): + if global_info.get("scope_values"): + st.write("最新字段:") + st.json(global_info["scope_values"]) + if global_info.get("close_series"): + st.write("收盘价时间序列 (最近窗口):") + st.json(global_info["close_series"]) + if global_info.get("turnover_series"): + st.write("换手率时间序列 (最近窗口):") + st.json(global_info["turnover_series"]) + dept_sup = global_info.get("department_supplements") or {} + dept_dialogue = global_info.get("department_dialogue") or {} + if dept_sup or dept_dialogue: + with st.expander("部门补数与对话记录", expanded=False): + if dept_sup: + st.write("补充数据:") + st.json(dept_sup) + if dept_dialogue: + st.write("对话片段:") + st.json(dept_dialogue) else: st.info("暂未写入全局策略摘要。") @@ -288,6 +324,21 @@ def render_today_plan() -> None: if dept_records: dept_df = pd.DataFrame(dept_records) st.dataframe(dept_df, use_container_width=True, hide_index=True) + for code, details in dept_details.items(): + with st.expander(f"{code} 补充详情", expanded=False): + supplements = details.get("supplements", []) + dialogue = details.get("dialogue", []) + if supplements: + st.write("补充数据:") + st.json(supplements) + else: + st.caption("无补充数据请求。") + if dialogue: + st.write("对话记录:") + for idx, line in enumerate(dialogue, start=1): + st.markdown(f"**回合 {idx}:** {line}") + else: + st.caption("无额外对话。") else: st.info("暂无部门记录。") diff --git a/app/utils/data_access.py b/app/utils/data_access.py index a12930f..baf9fb5 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -3,7 +3,7 @@ from __future__ import annotations import re from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple from .db import db_session from .logging import get_logger @@ -42,7 +42,7 @@ def parse_field_path(path: str) -> Tuple[str, str] | None: class DataBroker: """Lightweight data access helper for agent/LLM consumption.""" - FIELD_ALIASES: Dict[str, Dict[str, str]] = { + FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = { "daily": { "volume": "vol", "vol": "vol", @@ -57,13 +57,14 @@ class DataBroker: "pb": "pb", "ps": "ps", "ps_ttm": "ps_ttm", + "dividend_yield": "dv_ratio", }, "stk_limit": { "up": "up_limit", "down": "down_limit", }, } - MAX_WINDOW: int = 120 + MAX_WINDOW: ClassVar[int] = 120 def fetch_latest( self, diff --git a/app/utils/logging.py b/app/utils/logging.py index 590d5e3..4d5985e 100644 --- a/app/utils/logging.py +++ b/app/utils/logging.py @@ -120,7 +120,12 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: """返回指定名称的 logger,确保全局配置已就绪。""" setup_logging() - return logging.getLogger(name) + logger = logging.getLogger(name) + # Quiet noisy third-party loggers when default level is DEBUG + logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING) + logging.getLogger("requests.packages.urllib3").setLevel(logging.WARNING) + return logger # 默认在模块导入时完成配置,适配现有调用方式。