update
This commit is contained in:
parent
ca7b249c2c
commit
199cb484b9
@ -35,11 +35,10 @@
|
|||||||
|
|
||||||
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
||||||
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
||||||
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
|
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。
|
||||||
4. **函数式工具调用**:DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具,LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result,再继续对话生成最终意见,避免字段错误与手写 JSON 校验。
|
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||||
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
|
||||||
|
|
||||||
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。
|
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。
|
||||||
|
|
||||||
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
||||||
|
|
||||||
|
|||||||
@ -222,37 +222,126 @@ class DepartmentAgent:
|
|||||||
ts_code = context.ts_code
|
ts_code = context.ts_code
|
||||||
trade_date = self._normalize_trade_date(context.trade_date)
|
trade_date = self._normalize_trade_date(context.trade_date)
|
||||||
|
|
||||||
latest_fields: List[str] = []
|
latest_groups: Dict[str, List[str]] = {}
|
||||||
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
|
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
|
||||||
|
values_map, db_alias_map, series_map = _build_context_lookup(context)
|
||||||
|
|
||||||
for req in requests:
|
for req in requests:
|
||||||
field = req.field.strip()
|
field = req.field.strip()
|
||||||
if not field:
|
if not field:
|
||||||
continue
|
continue
|
||||||
resolved = self._broker.resolve_field(field)
|
window = req.window
|
||||||
if not resolved:
|
resolved: Optional[Tuple[str, str]] = None
|
||||||
lines.append(f"- {field}: 字段不存在或不可用")
|
if "." in field:
|
||||||
continue
|
resolved = self._broker.resolve_field(field)
|
||||||
if req.window <= 1:
|
elif field in db_alias_map:
|
||||||
if field not in latest_fields:
|
resolved = db_alias_map[field]
|
||||||
latest_fields.append(field)
|
|
||||||
delivered.add((field, 1))
|
|
||||||
continue
|
|
||||||
series_requests.append((req, resolved))
|
|
||||||
delivered.add((field, req.window))
|
|
||||||
|
|
||||||
if latest_fields:
|
if resolved:
|
||||||
latest_values = self._broker.fetch_latest(ts_code, trade_date, latest_fields)
|
table, column = resolved
|
||||||
for field in latest_fields:
|
canonical = f"{table}.{column}"
|
||||||
value = latest_values.get(field)
|
if window <= 1:
|
||||||
if value is None:
|
latest_groups.setdefault(canonical, []).append(field)
|
||||||
lines.append(f"- {field}: (数据缺失)")
|
delivered.add((field, 1))
|
||||||
|
delivered.add((canonical, 1))
|
||||||
else:
|
else:
|
||||||
lines.append(f"- {field}: {value}")
|
series_requests.append((req, resolved))
|
||||||
payload.append({"field": field, "window": 1, "values": value})
|
delivered.add((field, window))
|
||||||
|
delivered.add((canonical, window))
|
||||||
|
continue
|
||||||
|
|
||||||
for req, parsed in series_requests:
|
if field in values_map:
|
||||||
table, column = parsed
|
value = values_map[field]
|
||||||
|
if window <= 1:
|
||||||
|
payload.append(
|
||||||
|
{
|
||||||
|
"field": field,
|
||||||
|
"window": 1,
|
||||||
|
"source": "context",
|
||||||
|
"values": [
|
||||||
|
{
|
||||||
|
"trade_date": context.trade_date,
|
||||||
|
"value": value,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
lines.append(f"- {field}: {value} (来自上下文)")
|
||||||
|
else:
|
||||||
|
series = series_map.get(field)
|
||||||
|
if series:
|
||||||
|
trimmed = series[: window]
|
||||||
|
payload.append(
|
||||||
|
{
|
||||||
|
"field": field,
|
||||||
|
"window": window,
|
||||||
|
"source": "context_series",
|
||||||
|
"values": [
|
||||||
|
{"trade_date": dt, "value": val}
|
||||||
|
for dt, val in trimmed
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
preview = ", ".join(
|
||||||
|
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f"- {field} (window={window} 来自上下文序列): {preview}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
payload.append(
|
||||||
|
{
|
||||||
|
"field": field,
|
||||||
|
"window": window,
|
||||||
|
"source": "context",
|
||||||
|
"values": [
|
||||||
|
{
|
||||||
|
"trade_date": context.trade_date,
|
||||||
|
"value": value,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"warning": "仅提供当前值,缺少历史序列",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
|
||||||
|
)
|
||||||
|
delivered.add((field, window))
|
||||||
|
if field in db_alias_map:
|
||||||
|
resolved = db_alias_map[field]
|
||||||
|
canonical = f"{resolved[0]}.{resolved[1]}"
|
||||||
|
delivered.add((canonical, window))
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"- {field}: 字段不存在或不可用")
|
||||||
|
|
||||||
|
if latest_groups:
|
||||||
|
latest_values = self._broker.fetch_latest(
|
||||||
|
ts_code, trade_date, list(latest_groups.keys())
|
||||||
|
)
|
||||||
|
for canonical, aliases in latest_groups.items():
|
||||||
|
value = latest_values.get(canonical)
|
||||||
|
if value is None:
|
||||||
|
lines.append(f"- {canonical}: (数据缺失)")
|
||||||
|
else:
|
||||||
|
lines.append(f"- {canonical}: {value}")
|
||||||
|
for alias in aliases:
|
||||||
|
payload.append(
|
||||||
|
{
|
||||||
|
"field": alias,
|
||||||
|
"window": 1,
|
||||||
|
"source": "database",
|
||||||
|
"values": [
|
||||||
|
{
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"value": value,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for req, resolved in series_requests:
|
||||||
|
table, column = resolved
|
||||||
series = self._broker.fetch_series(
|
series = self._broker.fetch_series(
|
||||||
table,
|
table,
|
||||||
column,
|
column,
|
||||||
@ -276,6 +365,7 @@ class DepartmentAgent:
|
|||||||
{
|
{
|
||||||
"field": req.field,
|
"field": req.field,
|
||||||
"window": req.window,
|
"window": req.window,
|
||||||
|
"source": "database",
|
||||||
"values": [
|
"values": [
|
||||||
{"trade_date": dt, "value": val}
|
{"trade_date": dt, "value": val}
|
||||||
for dt, val in series
|
for dt, val in series
|
||||||
@ -546,6 +636,40 @@ def _extract_message_content(message: Mapping[str, Any]) -> str:
|
|||||||
return json.dumps(message, ensure_ascii=False)
|
return json.dumps(message, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_context_lookup(
|
||||||
|
context: DepartmentContext,
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
|
||||||
|
values: Dict[str, Any] = {}
|
||||||
|
db_alias: Dict[str, Tuple[str, str]] = {}
|
||||||
|
series_map: Dict[str, List[Tuple[str, float]]] = {}
|
||||||
|
|
||||||
|
for source in (context.features or {}, context.market_snapshot or {}):
|
||||||
|
for key, value in source.items():
|
||||||
|
values[str(key)] = value
|
||||||
|
|
||||||
|
scope_values = context.raw.get("scope_values") or {}
|
||||||
|
for key, value in scope_values.items():
|
||||||
|
key_str = str(key)
|
||||||
|
values[key_str] = value
|
||||||
|
if "." in key_str:
|
||||||
|
table, column = key_str.split(".", 1)
|
||||||
|
db_alias.setdefault(column, (table, column))
|
||||||
|
db_alias.setdefault(key_str, (table, column))
|
||||||
|
values.setdefault(column, value)
|
||||||
|
|
||||||
|
close_series = context.raw.get("close_series") or []
|
||||||
|
if isinstance(close_series, list) and close_series:
|
||||||
|
series_map["close"] = close_series
|
||||||
|
series_map["daily.close"] = close_series
|
||||||
|
|
||||||
|
turnover_series = context.raw.get("turnover_series") or []
|
||||||
|
if isinstance(turnover_series, list) and turnover_series:
|
||||||
|
series_map["turnover_rate"] = turnover_series
|
||||||
|
series_map["daily_basic.turnover_rate"] = turnover_series
|
||||||
|
|
||||||
|
return values, db_alias, series_map
|
||||||
|
|
||||||
|
|
||||||
class DepartmentManager:
|
class DepartmentManager:
|
||||||
"""Orchestrates all departments defined in configuration."""
|
"""Orchestrates all departments defined in configuration."""
|
||||||
|
|
||||||
|
|||||||
@ -62,13 +62,8 @@ def department_prompt(
|
|||||||
"risks": ["风险点", "..."]
|
"risks": ["风险点", "..."]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
如需额外数据,请在同一 JSON 中添加可选字段 `"data_requests"`,其取值为数组,例如:
|
如需额外数据,请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组(元素包含 `field` 以及可选的 `window`);`field` 必须符合【可用数据范围】,`window` 默认为 1。
|
||||||
"data_requests": [
|
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
||||||
{{"field": "daily.close", "window": 5}},
|
|
||||||
{{"field": "daily_basic.pe"}}
|
|
||||||
]
|
|
||||||
其中 `field` 必须属于【可用数据范围】或明确说明新增需求;`window` 表示希望返回的最近数据点数量,省略时默认为 1。
|
|
||||||
如果不需要更多数据,请不要返回 `data_requests`。
|
|
||||||
|
|
||||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -213,6 +213,7 @@ def render_today_plan() -> None:
|
|||||||
|
|
||||||
global_info = None
|
global_info = None
|
||||||
dept_records: List[Dict[str, object]] = []
|
dept_records: List[Dict[str, object]] = []
|
||||||
|
dept_details: Dict[str, Dict[str, object]] = {}
|
||||||
agent_records: List[Dict[str, object]] = []
|
agent_records: List[Dict[str, object]] = []
|
||||||
|
|
||||||
for item in rows:
|
for item in rows:
|
||||||
@ -231,6 +232,11 @@ def render_today_plan() -> None:
|
|||||||
"target_weight": float(utils.get("_target_weight", 0.0)),
|
"target_weight": float(utils.get("_target_weight", 0.0)),
|
||||||
"department_votes": utils.get("_department_votes", {}),
|
"department_votes": utils.get("_department_votes", {}),
|
||||||
"requires_review": bool(utils.get("_requires_review", False)),
|
"requires_review": bool(utils.get("_requires_review", False)),
|
||||||
|
"scope_values": utils.get("_scope_values", {}),
|
||||||
|
"close_series": utils.get("_close_series", []),
|
||||||
|
"turnover_series": utils.get("_turnover_series", []),
|
||||||
|
"department_supplements": utils.get("_department_supplements", {}),
|
||||||
|
"department_dialogue": utils.get("_department_dialogue", {}),
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -238,6 +244,8 @@ def render_today_plan() -> None:
|
|||||||
code = agent_name.split("dept_", 1)[-1]
|
code = agent_name.split("dept_", 1)[-1]
|
||||||
signals = utils.get("_signals", [])
|
signals = utils.get("_signals", [])
|
||||||
risks = utils.get("_risks", [])
|
risks = utils.get("_risks", [])
|
||||||
|
supplements = utils.get("_supplements", [])
|
||||||
|
dialogue = utils.get("_dialogue", [])
|
||||||
dept_records.append(
|
dept_records.append(
|
||||||
{
|
{
|
||||||
"部门": code,
|
"部门": code,
|
||||||
@ -247,8 +255,16 @@ def render_today_plan() -> None:
|
|||||||
"摘要": utils.get("_summary", ""),
|
"摘要": utils.get("_summary", ""),
|
||||||
"核心信号": ";".join(signals) if isinstance(signals, list) else signals,
|
"核心信号": ";".join(signals) if isinstance(signals, list) else signals,
|
||||||
"风险提示": ";".join(risks) if isinstance(risks, list) else risks,
|
"风险提示": ";".join(risks) if isinstance(risks, list) else risks,
|
||||||
|
"补充次数": len(supplements) if isinstance(supplements, list) else 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
dept_details[code] = {
|
||||||
|
"supplements": supplements if isinstance(supplements, list) else [],
|
||||||
|
"dialogue": dialogue if isinstance(dialogue, list) else [],
|
||||||
|
"summary": utils.get("_summary", ""),
|
||||||
|
"signals": signals,
|
||||||
|
"risks": risks,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
score_map = {
|
score_map = {
|
||||||
key: float(val)
|
key: float(val)
|
||||||
@ -281,6 +297,26 @@ def render_today_plan() -> None:
|
|||||||
st.json(global_info["department_votes"])
|
st.json(global_info["department_votes"])
|
||||||
if global_info["requires_review"]:
|
if global_info["requires_review"]:
|
||||||
st.warning("部门分歧较大,已标记为需人工复核。")
|
st.warning("部门分歧较大,已标记为需人工复核。")
|
||||||
|
with st.expander("基础上下文数据", expanded=False):
|
||||||
|
if global_info.get("scope_values"):
|
||||||
|
st.write("最新字段:")
|
||||||
|
st.json(global_info["scope_values"])
|
||||||
|
if global_info.get("close_series"):
|
||||||
|
st.write("收盘价时间序列 (最近窗口):")
|
||||||
|
st.json(global_info["close_series"])
|
||||||
|
if global_info.get("turnover_series"):
|
||||||
|
st.write("换手率时间序列 (最近窗口):")
|
||||||
|
st.json(global_info["turnover_series"])
|
||||||
|
dept_sup = global_info.get("department_supplements") or {}
|
||||||
|
dept_dialogue = global_info.get("department_dialogue") or {}
|
||||||
|
if dept_sup or dept_dialogue:
|
||||||
|
with st.expander("部门补数与对话记录", expanded=False):
|
||||||
|
if dept_sup:
|
||||||
|
st.write("补充数据:")
|
||||||
|
st.json(dept_sup)
|
||||||
|
if dept_dialogue:
|
||||||
|
st.write("对话片段:")
|
||||||
|
st.json(dept_dialogue)
|
||||||
else:
|
else:
|
||||||
st.info("暂未写入全局策略摘要。")
|
st.info("暂未写入全局策略摘要。")
|
||||||
|
|
||||||
@ -288,6 +324,21 @@ def render_today_plan() -> None:
|
|||||||
if dept_records:
|
if dept_records:
|
||||||
dept_df = pd.DataFrame(dept_records)
|
dept_df = pd.DataFrame(dept_records)
|
||||||
st.dataframe(dept_df, use_container_width=True, hide_index=True)
|
st.dataframe(dept_df, use_container_width=True, hide_index=True)
|
||||||
|
for code, details in dept_details.items():
|
||||||
|
with st.expander(f"{code} 补充详情", expanded=False):
|
||||||
|
supplements = details.get("supplements", [])
|
||||||
|
dialogue = details.get("dialogue", [])
|
||||||
|
if supplements:
|
||||||
|
st.write("补充数据:")
|
||||||
|
st.json(supplements)
|
||||||
|
else:
|
||||||
|
st.caption("无补充数据请求。")
|
||||||
|
if dialogue:
|
||||||
|
st.write("对话记录:")
|
||||||
|
for idx, line in enumerate(dialogue, start=1):
|
||||||
|
st.markdown(f"**回合 {idx}:** {line}")
|
||||||
|
else:
|
||||||
|
st.caption("无额外对话。")
|
||||||
else:
|
else:
|
||||||
st.info("暂无部门记录。")
|
st.info("暂无部门记录。")
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
from typing import ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@ -42,7 +42,7 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
|
|||||||
class DataBroker:
|
class DataBroker:
|
||||||
"""Lightweight data access helper for agent/LLM consumption."""
|
"""Lightweight data access helper for agent/LLM consumption."""
|
||||||
|
|
||||||
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
|
FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = {
|
||||||
"daily": {
|
"daily": {
|
||||||
"volume": "vol",
|
"volume": "vol",
|
||||||
"vol": "vol",
|
"vol": "vol",
|
||||||
@ -57,13 +57,14 @@ class DataBroker:
|
|||||||
"pb": "pb",
|
"pb": "pb",
|
||||||
"ps": "ps",
|
"ps": "ps",
|
||||||
"ps_ttm": "ps_ttm",
|
"ps_ttm": "ps_ttm",
|
||||||
|
"dividend_yield": "dv_ratio",
|
||||||
},
|
},
|
||||||
"stk_limit": {
|
"stk_limit": {
|
||||||
"up": "up_limit",
|
"up": "up_limit",
|
||||||
"down": "down_limit",
|
"down": "down_limit",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
MAX_WINDOW: int = 120
|
MAX_WINDOW: ClassVar[int] = 120
|
||||||
|
|
||||||
def fetch_latest(
|
def fetch_latest(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -120,7 +120,12 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|||||||
"""返回指定名称的 logger,确保全局配置已就绪。"""
|
"""返回指定名称的 logger,确保全局配置已就绪。"""
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
return logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
# Quiet noisy third-party loggers when default level is DEBUG
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("requests.packages.urllib3").setLevel(logging.WARNING)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
# 默认在模块导入时完成配置,适配现有调用方式。
|
# 默认在模块导入时完成配置,适配现有调用方式。
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user