This commit is contained in:
sam 2025-09-28 21:43:08 +08:00
parent ca7b249c2c
commit 199cb484b9
6 changed files with 212 additions and 37 deletions

View File

@ -35,11 +35,10 @@
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
4. **函数式工具调用**DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result再继续对话生成最终意见避免字段错误与手写 JSON 校验。
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。
上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。

View File

@ -222,37 +222,126 @@ class DepartmentAgent:
ts_code = context.ts_code
trade_date = self._normalize_trade_date(context.trade_date)
latest_fields: List[str] = []
latest_groups: Dict[str, List[str]] = {}
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
values_map, db_alias_map, series_map = _build_context_lookup(context)
for req in requests:
field = req.field.strip()
if not field:
continue
window = req.window
resolved: Optional[Tuple[str, str]] = None
if "." in field:
resolved = self._broker.resolve_field(field)
if not resolved:
lines.append(f"- {field}: 字段不存在或不可用")
continue
if req.window <= 1:
if field not in latest_fields:
latest_fields.append(field)
elif field in db_alias_map:
resolved = db_alias_map[field]
if resolved:
table, column = resolved
canonical = f"{table}.{column}"
if window <= 1:
latest_groups.setdefault(canonical, []).append(field)
delivered.add((field, 1))
continue
series_requests.append((req, resolved))
delivered.add((field, req.window))
if latest_fields:
latest_values = self._broker.fetch_latest(ts_code, trade_date, latest_fields)
for field in latest_fields:
value = latest_values.get(field)
if value is None:
lines.append(f"- {field}: (数据缺失)")
delivered.add((canonical, 1))
else:
lines.append(f"- {field}: {value}")
payload.append({"field": field, "window": 1, "values": value})
series_requests.append((req, resolved))
delivered.add((field, window))
delivered.add((canonical, window))
continue
for req, parsed in series_requests:
table, column = parsed
if field in values_map:
value = values_map[field]
if window <= 1:
payload.append(
{
"field": field,
"window": 1,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
}
)
lines.append(f"- {field}: {value} (来自上下文)")
else:
series = series_map.get(field)
if series:
trimmed = series[: window]
payload.append(
{
"field": field,
"window": window,
"source": "context_series",
"values": [
{"trade_date": dt, "value": val}
for dt, val in trimmed
],
}
)
preview = ", ".join(
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
)
lines.append(
f"- {field} (window={window} 来自上下文序列): {preview}"
)
else:
payload.append(
{
"field": field,
"window": window,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
"warning": "仅提供当前值,缺少历史序列",
}
)
lines.append(
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
)
delivered.add((field, window))
if field in db_alias_map:
resolved = db_alias_map[field]
canonical = f"{resolved[0]}.{resolved[1]}"
delivered.add((canonical, window))
continue
lines.append(f"- {field}: 字段不存在或不可用")
if latest_groups:
latest_values = self._broker.fetch_latest(
ts_code, trade_date, list(latest_groups.keys())
)
for canonical, aliases in latest_groups.items():
value = latest_values.get(canonical)
if value is None:
lines.append(f"- {canonical}: (数据缺失)")
else:
lines.append(f"- {canonical}: {value}")
for alias in aliases:
payload.append(
{
"field": alias,
"window": 1,
"source": "database",
"values": [
{
"trade_date": trade_date,
"value": value,
}
],
}
)
for req, resolved in series_requests:
table, column = resolved
series = self._broker.fetch_series(
table,
column,
@ -276,6 +365,7 @@ class DepartmentAgent:
{
"field": req.field,
"window": req.window,
"source": "database",
"values": [
{"trade_date": dt, "value": val}
for dt, val in series
@ -546,6 +636,40 @@ def _extract_message_content(message: Mapping[str, Any]) -> str:
return json.dumps(message, ensure_ascii=False)
def _build_context_lookup(
context: DepartmentContext,
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
values: Dict[str, Any] = {}
db_alias: Dict[str, Tuple[str, str]] = {}
series_map: Dict[str, List[Tuple[str, float]]] = {}
for source in (context.features or {}, context.market_snapshot or {}):
for key, value in source.items():
values[str(key)] = value
scope_values = context.raw.get("scope_values") or {}
for key, value in scope_values.items():
key_str = str(key)
values[key_str] = value
if "." in key_str:
table, column = key_str.split(".", 1)
db_alias.setdefault(column, (table, column))
db_alias.setdefault(key_str, (table, column))
values.setdefault(column, value)
close_series = context.raw.get("close_series") or []
if isinstance(close_series, list) and close_series:
series_map["close"] = close_series
series_map["daily.close"] = close_series
turnover_series = context.raw.get("turnover_series") or []
if isinstance(turnover_series, list) and turnover_series:
series_map["turnover_rate"] = turnover_series
series_map["daily_basic.turnover_rate"] = turnover_series
return values, db_alias, series_map
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""

View File

@ -62,13 +62,8 @@ def department_prompt(
"risks": ["风险点", "..."]
}}
如需额外数据请在同一 JSON 中添加可选字段 `"data_requests"`其取值为数组例如
"data_requests": [
{{"field": "daily.close", "window": 5}},
{{"field": "daily_basic.pe"}}
]
其中 `field` 必须属于可用数据范围或明确说明新增需求`window` 表示希望返回的最近数据点数量省略时默认为 1
如果不需要更多数据请不要返回 `data_requests`
如需额外数据请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组元素包含 `field` 以及可选的 `window``field` 必须符合可用数据范围`window` 默认为 1
工具返回的数据会在后续消息中提供请在获取所有必要信息后再给出最终 JSON 答复
请严格返回单个 JSON 对象不要添加额外文本
"""

View File

@ -213,6 +213,7 @@ def render_today_plan() -> None:
global_info = None
dept_records: List[Dict[str, object]] = []
dept_details: Dict[str, Dict[str, object]] = {}
agent_records: List[Dict[str, object]] = []
for item in rows:
@ -231,6 +232,11 @@ def render_today_plan() -> None:
"target_weight": float(utils.get("_target_weight", 0.0)),
"department_votes": utils.get("_department_votes", {}),
"requires_review": bool(utils.get("_requires_review", False)),
"scope_values": utils.get("_scope_values", {}),
"close_series": utils.get("_close_series", []),
"turnover_series": utils.get("_turnover_series", []),
"department_supplements": utils.get("_department_supplements", {}),
"department_dialogue": utils.get("_department_dialogue", {}),
}
continue
@ -238,6 +244,8 @@ def render_today_plan() -> None:
code = agent_name.split("dept_", 1)[-1]
signals = utils.get("_signals", [])
risks = utils.get("_risks", [])
supplements = utils.get("_supplements", [])
dialogue = utils.get("_dialogue", [])
dept_records.append(
{
"部门": code,
@ -247,8 +255,16 @@ def render_today_plan() -> None:
"摘要": utils.get("_summary", ""),
"核心信号": "".join(signals) if isinstance(signals, list) else signals,
"风险提示": "".join(risks) if isinstance(risks, list) else risks,
"补充次数": len(supplements) if isinstance(supplements, list) else 0,
}
)
dept_details[code] = {
"supplements": supplements if isinstance(supplements, list) else [],
"dialogue": dialogue if isinstance(dialogue, list) else [],
"summary": utils.get("_summary", ""),
"signals": signals,
"risks": risks,
}
else:
score_map = {
key: float(val)
@ -281,6 +297,26 @@ def render_today_plan() -> None:
st.json(global_info["department_votes"])
if global_info["requires_review"]:
st.warning("部门分歧较大,已标记为需人工复核。")
with st.expander("基础上下文数据", expanded=False):
if global_info.get("scope_values"):
st.write("最新字段:")
st.json(global_info["scope_values"])
if global_info.get("close_series"):
st.write("收盘价时间序列 (最近窗口)")
st.json(global_info["close_series"])
if global_info.get("turnover_series"):
st.write("换手率时间序列 (最近窗口)")
st.json(global_info["turnover_series"])
dept_sup = global_info.get("department_supplements") or {}
dept_dialogue = global_info.get("department_dialogue") or {}
if dept_sup or dept_dialogue:
with st.expander("部门补数与对话记录", expanded=False):
if dept_sup:
st.write("补充数据:")
st.json(dept_sup)
if dept_dialogue:
st.write("对话片段:")
st.json(dept_dialogue)
else:
st.info("暂未写入全局策略摘要。")
@ -288,6 +324,21 @@ def render_today_plan() -> None:
if dept_records:
dept_df = pd.DataFrame(dept_records)
st.dataframe(dept_df, use_container_width=True, hide_index=True)
for code, details in dept_details.items():
with st.expander(f"{code} 补充详情", expanded=False):
supplements = details.get("supplements", [])
dialogue = details.get("dialogue", [])
if supplements:
st.write("补充数据:")
st.json(supplements)
else:
st.caption("无补充数据请求。")
if dialogue:
st.write("对话记录:")
for idx, line in enumerate(dialogue, start=1):
st.markdown(f"**回合 {idx}:** {line}")
else:
st.caption("无额外对话。")
else:
st.info("暂无部门记录。")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
from .db import db_session
from .logging import get_logger
@ -42,7 +42,7 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
class DataBroker:
"""Lightweight data access helper for agent/LLM consumption."""
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = {
"daily": {
"volume": "vol",
"vol": "vol",
@ -57,13 +57,14 @@ class DataBroker:
"pb": "pb",
"ps": "ps",
"ps_ttm": "ps_ttm",
"dividend_yield": "dv_ratio",
},
"stk_limit": {
"up": "up_limit",
"down": "down_limit",
},
}
MAX_WINDOW: int = 120
MAX_WINDOW: ClassVar[int] = 120
def fetch_latest(
self,

View File

@ -120,7 +120,12 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
"""返回指定名称的 logger确保全局配置已就绪。"""
setup_logging()
return logging.getLogger(name)
logger = logging.getLogger(name)
# Quiet noisy third-party loggers when default level is DEBUG
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
logging.getLogger("requests.packages.urllib3").setLevel(logging.WARNING)
return logger
# 默认在模块导入时完成配置,适配现有调用方式。