update
This commit is contained in:
parent
20bbcbb898
commit
2e98e81715
@ -1,12 +1,18 @@
|
||||
"""Department-level LLM agents coordinating multi-model decisions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from app.agents.base import AgentAction
|
||||
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
||||
from app.llm.client import (
|
||||
call_endpoint_with_messages,
|
||||
resolve_endpoint,
|
||||
run_llm_with_config,
|
||||
LLMError,
|
||||
)
|
||||
from app.llm.prompts import department_prompt
|
||||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||
from app.utils.logging import get_logger, get_conversation_logger
|
||||
@ -48,6 +54,7 @@ class DepartmentDecision:
|
||||
risks: List[str] = field(default_factory=list)
|
||||
supplements: List[Dict[str, Any]] = field(default_factory=list)
|
||||
dialogue: List[str] = field(default_factory=list)
|
||||
telemetry: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
@ -60,6 +67,7 @@ class DepartmentDecision:
|
||||
"raw_response": self.raw_response,
|
||||
"supplements": self.supplements,
|
||||
"dialogue": self.dialogue,
|
||||
"telemetry": self.telemetry,
|
||||
}
|
||||
|
||||
|
||||
@ -102,18 +110,30 @@ class DepartmentAgent:
|
||||
messages: List[Dict[str, object]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": department_prompt(self.settings, mutable_context),
|
||||
}
|
||||
)
|
||||
prompt_body = department_prompt(self.settings, mutable_context)
|
||||
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
|
||||
prompt_preview = prompt_body[:240]
|
||||
messages.append({"role": "user", "content": prompt_body})
|
||||
|
||||
transcript: List[str] = []
|
||||
delivered_requests: set[Tuple[str, int, str]] = set()
|
||||
|
||||
primary_endpoint = llm_cfg.primary
|
||||
try:
|
||||
resolved_primary = resolve_endpoint(primary_endpoint)
|
||||
except LLMError as exc:
|
||||
LOGGER.warning(
|
||||
"部门 %s 无法解析 LLM 端点,回退传统提示:%s",
|
||||
self.settings.code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self._analyze_legacy(mutable_context, system_prompt)
|
||||
|
||||
final_message: Optional[Dict[str, Any]] = None
|
||||
usage_records: List[Dict[str, Any]] = []
|
||||
tool_call_records: List[Dict[str, Any]] = []
|
||||
rounds_executed = 0
|
||||
CONV_LOGGER.info(
|
||||
"dept=%s ts_code=%s trade_date=%s start",
|
||||
self.settings.code,
|
||||
@ -138,6 +158,14 @@ class DepartmentAgent:
|
||||
)
|
||||
return self._analyze_legacy(mutable_context, system_prompt)
|
||||
|
||||
rounds_executed = round_idx + 1
|
||||
|
||||
usage = response.get("usage") if isinstance(response, Mapping) else None
|
||||
if isinstance(usage, Mapping):
|
||||
usage_payload = {"round": round_idx + 1}
|
||||
usage_payload.update(dict(usage))
|
||||
usage_records.append(usage_payload)
|
||||
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
transcript.append(_message_to_text(message))
|
||||
@ -159,12 +187,36 @@ class DepartmentAgent:
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
function_block = call.get("function") or {}
|
||||
tool_response, delivered = self._handle_tool_call(
|
||||
mutable_context,
|
||||
call,
|
||||
delivered_requests,
|
||||
round_idx,
|
||||
)
|
||||
tables_summary: List[Dict[str, Any]] = []
|
||||
for item in tool_response.get("results") or []:
|
||||
if isinstance(item, Mapping):
|
||||
tables_summary.append(
|
||||
{
|
||||
"table": item.get("table"),
|
||||
"window": item.get("window"),
|
||||
"trade_date": item.get("trade_date"),
|
||||
"row_count": len(item.get("rows") or []),
|
||||
}
|
||||
)
|
||||
tool_call_records.append(
|
||||
{
|
||||
"round": round_idx + 1,
|
||||
"id": call.get("id"),
|
||||
"name": function_block.get("name"),
|
||||
"arguments": function_block.get("arguments"),
|
||||
"status": tool_response.get("status"),
|
||||
"results": len(tool_response.get("results") or []),
|
||||
"tables": tables_summary,
|
||||
"skipped": list(tool_response.get("skipped") or []),
|
||||
}
|
||||
)
|
||||
transcript.append(
|
||||
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
|
||||
)
|
||||
@ -216,6 +268,64 @@ class DepartmentAgent:
|
||||
if isinstance(risks, str):
|
||||
risks = [risks]
|
||||
|
||||
def _safe_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError): # noqa: PERF203 - clarity
|
||||
return 0
|
||||
|
||||
prompt_tokens_total = 0
|
||||
completion_tokens_total = 0
|
||||
total_tokens_reported = 0
|
||||
for usage_payload in usage_records:
|
||||
prompt_tokens_total += _safe_int(
|
||||
usage_payload.get("prompt_tokens")
|
||||
or usage_payload.get("prompt_tokens_total")
|
||||
)
|
||||
completion_tokens_total += _safe_int(
|
||||
usage_payload.get("completion_tokens")
|
||||
or usage_payload.get("completion_tokens_total")
|
||||
)
|
||||
reported_total = _safe_int(
|
||||
usage_payload.get("total_tokens")
|
||||
or usage_payload.get("total_tokens_total")
|
||||
)
|
||||
if reported_total:
|
||||
total_tokens_reported += reported_total
|
||||
|
||||
total_tokens = (
|
||||
total_tokens_reported
|
||||
if total_tokens_reported
|
||||
else prompt_tokens_total + completion_tokens_total
|
||||
)
|
||||
|
||||
telemetry: Dict[str, Any] = {
|
||||
"provider": resolved_primary.get("provider_key"),
|
||||
"model": resolved_primary.get("model"),
|
||||
"temperature": resolved_primary.get("temperature"),
|
||||
"timeout": resolved_primary.get("timeout"),
|
||||
"endpoint_prompt_template": resolved_primary.get("prompt_template"),
|
||||
"rounds": rounds_executed,
|
||||
"tool_call_count": len(tool_call_records),
|
||||
"tool_trace": tool_call_records,
|
||||
"usage_by_round": usage_records,
|
||||
"tokens": {
|
||||
"prompt": prompt_tokens_total,
|
||||
"completion": completion_tokens_total,
|
||||
"total": total_tokens,
|
||||
},
|
||||
"prompt": {
|
||||
"checksum": prompt_checksum,
|
||||
"length": len(prompt_body),
|
||||
"preview": prompt_preview,
|
||||
"role_description": self.settings.description,
|
||||
"instruction": self.settings.prompt,
|
||||
"system": system_prompt,
|
||||
},
|
||||
"messages_exchanged": len(messages),
|
||||
"supplement_rounds": len(tool_call_records),
|
||||
}
|
||||
|
||||
decision = DepartmentDecision(
|
||||
department=self.settings.code,
|
||||
action=action,
|
||||
@ -226,6 +336,7 @@ class DepartmentAgent:
|
||||
raw_response=content_text,
|
||||
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
||||
dialogue=list(transcript),
|
||||
telemetry=telemetry,
|
||||
)
|
||||
LOGGER.debug(
|
||||
"部门 %s 决策:action=%s confidence=%.2f",
|
||||
@ -241,6 +352,11 @@ class DepartmentAgent:
|
||||
decision.confidence,
|
||||
summary or "",
|
||||
)
|
||||
CONV_LOGGER.info(
|
||||
"dept=%s telemetry=%s",
|
||||
self.settings.code,
|
||||
json.dumps(telemetry, ensure_ascii=False),
|
||||
)
|
||||
return decision
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -324,6 +324,8 @@ class BacktestEngine:
|
||||
metadata["_supplements"] = dept_decision.supplements
|
||||
if dept_decision.dialogue:
|
||||
metadata["_dialogue"] = dept_decision.dialogue
|
||||
if dept_decision.telemetry:
|
||||
metadata["_telemetry"] = dept_decision.telemetry
|
||||
payload_json = {**action_scores, **metadata}
|
||||
rows.append(
|
||||
(
|
||||
@ -355,6 +357,11 @@ class BacktestEngine:
|
||||
for code, dept in decision.department_decisions.items()
|
||||
if dept.dialogue
|
||||
},
|
||||
"_department_telemetry": {
|
||||
code: dept.telemetry
|
||||
for code, dept in decision.department_decisions.items()
|
||||
if dept.telemetry
|
||||
},
|
||||
}
|
||||
rows.append(
|
||||
(
|
||||
|
||||
@ -105,7 +105,7 @@ def _request_openai_chat(
|
||||
return response.json()
|
||||
|
||||
|
||||
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
cfg = get_config()
|
||||
provider_key = (endpoint.provider or "ollama").lower()
|
||||
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||
@ -152,7 +152,7 @@ def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
|
||||
|
||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||
resolved = _prepare_endpoint(endpoint)
|
||||
resolved = resolve_endpoint(endpoint)
|
||||
provider_key = resolved["provider_key"]
|
||||
mode = resolved["mode"]
|
||||
prompt_template = resolved["prompt_template"]
|
||||
@ -188,7 +188,7 @@ def call_endpoint_with_messages(
|
||||
tools: Optional[List[Dict[str, object]]] = None,
|
||||
tool_choice: Optional[object] = None,
|
||||
) -> Dict[str, object]:
|
||||
resolved = _prepare_endpoint(endpoint)
|
||||
resolved = resolve_endpoint(endpoint)
|
||||
provider_key = resolved["provider_key"]
|
||||
mode = resolved["mode"]
|
||||
base_url = resolved["base_url"]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
"""Simple runtime metrics collector for LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict, List, Optional
|
||||
from typing import Callable, Deque, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -20,6 +21,9 @@ class _Metrics:
|
||||
|
||||
_METRICS = _Metrics()
|
||||
_LOCK = Lock()
|
||||
_LISTENERS: List[Callable[[Dict[str, object]], None]] = []
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def record_call(
|
||||
@ -45,6 +49,7 @@ def record_call(
|
||||
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
||||
if completion_tokens:
|
||||
_METRICS.total_completion_tokens += int(completion_tokens)
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
@ -75,6 +80,7 @@ def reset() -> None:
|
||||
"""Reset all collected metrics."""
|
||||
|
||||
snapshot(reset=True)
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def record_decision(
|
||||
@ -103,6 +109,7 @@ def record_decision(
|
||||
_METRICS.decision_action_counts[action] = (
|
||||
_METRICS.decision_action_counts.get(action, 0) + 1
|
||||
)
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||
@ -112,3 +119,42 @@ def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||
if limit <= 0:
|
||||
return []
|
||||
return list(_METRICS.decisions)[-limit:]
|
||||
|
||||
|
||||
def register_listener(callback: Callable[[Dict[str, object]], None]) -> None:
|
||||
"""Register a callback invoked whenever metrics change."""
|
||||
|
||||
if not callable(callback):
|
||||
return
|
||||
with _LOCK:
|
||||
if callback in _LISTENERS:
|
||||
should_invoke = False
|
||||
else:
|
||||
_LISTENERS.append(callback)
|
||||
should_invoke = True
|
||||
if should_invoke:
|
||||
try:
|
||||
callback(snapshot())
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("Metrics listener failed on initial callback")
|
||||
|
||||
|
||||
def unregister_listener(callback: Callable[[Dict[str, object]], None]) -> None:
|
||||
"""Remove a previously registered metrics callback."""
|
||||
|
||||
with _LOCK:
|
||||
if callback in _LISTENERS:
|
||||
_LISTENERS.remove(callback)
|
||||
|
||||
|
||||
def _notify_listeners() -> None:
|
||||
with _LOCK:
|
||||
listeners = list(_LISTENERS)
|
||||
if not listeners:
|
||||
return
|
||||
data = snapshot()
|
||||
for callback in listeners:
|
||||
try:
|
||||
callback(data)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("Metrics listener execution failed")
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
@ -28,9 +29,10 @@ from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.llm.client import llm_config_snapshot, run_llm
|
||||
from app.llm.metrics import (
|
||||
recent_decisions as llm_recent_decisions,
|
||||
register_listener as register_llm_metrics_listener,
|
||||
reset as reset_llm_metrics,
|
||||
snapshot as snapshot_llm_metrics,
|
||||
recent_decisions as llm_recent_decisions,
|
||||
)
|
||||
from app.utils.config import (
|
||||
ALLOWED_LLM_STRATEGIES,
|
||||
@ -49,6 +51,11 @@ from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "ui"}
|
||||
_SIDEBAR_THROTTLE_SECONDS = 0.75
|
||||
|
||||
|
||||
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
||||
_update_dashboard_sidebar(metrics, throttled=True)
|
||||
|
||||
|
||||
def render_global_dashboard() -> None:
|
||||
@ -56,54 +63,118 @@ def render_global_dashboard() -> None:
|
||||
|
||||
metrics_container = st.sidebar.container()
|
||||
decisions_container = st.sidebar.container()
|
||||
st.session_state["dashboard_placeholders"] = (metrics_container, decisions_container)
|
||||
st.session_state["dashboard_containers"] = (metrics_container, decisions_container)
|
||||
_ensure_dashboard_elements(metrics_container, decisions_container)
|
||||
if not st.session_state.get("dashboard_listener_registered"):
|
||||
register_llm_metrics_listener(_sidebar_metrics_listener)
|
||||
st.session_state["dashboard_listener_registered"] = True
|
||||
_update_dashboard_sidebar()
|
||||
|
||||
|
||||
def _update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None:
|
||||
placeholders = st.session_state.get("dashboard_placeholders")
|
||||
if not placeholders:
|
||||
def _update_dashboard_sidebar(
|
||||
metrics: Optional[Dict[str, object]] = None,
|
||||
*,
|
||||
throttled: bool = False,
|
||||
) -> None:
|
||||
containers = st.session_state.get("dashboard_containers")
|
||||
if not containers:
|
||||
return
|
||||
metrics_container, decisions_container = placeholders
|
||||
metrics_container, decisions_container = containers
|
||||
elements = st.session_state.get("dashboard_elements")
|
||||
if elements is None:
|
||||
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||
|
||||
if throttled:
|
||||
now = time.monotonic()
|
||||
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
|
||||
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
|
||||
if metrics is not None:
|
||||
st.session_state["dashboard_pending_metrics"] = metrics
|
||||
return
|
||||
st.session_state["dashboard_last_update_ts"] = now
|
||||
else:
|
||||
st.session_state["dashboard_last_update_ts"] = time.monotonic()
|
||||
|
||||
if metrics is None:
|
||||
metrics = st.session_state.pop("dashboard_pending_metrics", None)
|
||||
if metrics is None:
|
||||
metrics = snapshot_llm_metrics()
|
||||
else:
|
||||
st.session_state.pop("dashboard_pending_metrics", None)
|
||||
|
||||
metrics = metrics or snapshot_llm_metrics()
|
||||
|
||||
metrics_container.empty()
|
||||
with metrics_container.container():
|
||||
st.header("系统监控")
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
col_a.metric("LLM 调用", metrics.get("total_calls", 0))
|
||||
col_b.metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
||||
col_c.metric("Completion Tokens", metrics.get("total_completion_tokens", 0))
|
||||
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
|
||||
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
||||
elements["metrics_completion"].metric(
|
||||
"Completion Tokens", metrics.get("total_completion_tokens", 0)
|
||||
)
|
||||
|
||||
provider_calls = metrics.get("provider_calls", {})
|
||||
model_calls = metrics.get("model_calls", {})
|
||||
if provider_calls or model_calls:
|
||||
with st.expander("调用分布", expanded=False):
|
||||
provider_placeholder = elements["provider_distribution"]
|
||||
provider_placeholder.empty()
|
||||
if provider_calls:
|
||||
st.write("按 Provider:")
|
||||
st.json(provider_calls)
|
||||
if model_calls:
|
||||
st.write("按模型:")
|
||||
st.json(model_calls)
|
||||
provider_placeholder.json(provider_calls)
|
||||
else:
|
||||
provider_placeholder.info("暂无 Provider 分布数据。")
|
||||
|
||||
model_placeholder = elements["model_distribution"]
|
||||
model_placeholder.empty()
|
||||
if model_calls:
|
||||
model_placeholder.json(model_calls)
|
||||
else:
|
||||
model_placeholder.info("暂无模型分布数据。")
|
||||
|
||||
decisions_container.empty()
|
||||
with decisions_container.container():
|
||||
st.subheader("最新决策")
|
||||
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
|
||||
if decisions:
|
||||
lines = []
|
||||
for record in reversed(decisions[-10:]):
|
||||
ts_code = record.get("ts_code")
|
||||
trade_date = record.get("trade_date")
|
||||
action = record.get("action")
|
||||
confidence = record.get("confidence", 0.0)
|
||||
summary = record.get("summary")
|
||||
st.markdown(
|
||||
f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
|
||||
)
|
||||
line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
|
||||
if summary:
|
||||
st.caption(summary)
|
||||
line += f"\n<small>{summary}</small>"
|
||||
lines.append(line)
|
||||
decisions_placeholder = elements["decisions_list"]
|
||||
decisions_placeholder.empty()
|
||||
decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True)
|
||||
else:
|
||||
st.caption("暂无决策记录。执行回测或实时评估后可在此查看。")
|
||||
decisions_placeholder = elements["decisions_list"]
|
||||
decisions_placeholder.empty()
|
||||
decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。")
|
||||
|
||||
|
||||
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
|
||||
elements = st.session_state.get("dashboard_elements")
|
||||
if elements:
|
||||
return elements
|
||||
|
||||
metrics_container.header("系统监控")
|
||||
col_a, col_b, col_c = metrics_container.columns(3)
|
||||
metrics_calls = col_a.empty()
|
||||
metrics_prompt = col_b.empty()
|
||||
metrics_completion = col_c.empty()
|
||||
distribution_expander = metrics_container.expander("调用分布", expanded=False)
|
||||
provider_distribution = distribution_expander.empty()
|
||||
model_distribution = distribution_expander.empty()
|
||||
|
||||
decisions_container.subheader("最新决策")
|
||||
decisions_list = decisions_container.empty()
|
||||
|
||||
elements = {
|
||||
"metrics_calls": metrics_calls,
|
||||
"metrics_prompt": metrics_prompt,
|
||||
"metrics_completion": metrics_completion,
|
||||
"provider_distribution": provider_distribution,
|
||||
"model_distribution": model_distribution,
|
||||
"decisions_list": decisions_list,
|
||||
}
|
||||
st.session_state["dashboard_elements"] = elements
|
||||
return elements
|
||||
|
||||
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||
"""Attempt to query provider API and return available model ids."""
|
||||
@ -335,6 +406,7 @@ def render_today_plan() -> None:
|
||||
"turnover_series": utils.get("_turnover_series", []),
|
||||
"department_supplements": utils.get("_department_supplements", {}),
|
||||
"department_dialogue": utils.get("_department_dialogue", {}),
|
||||
"department_telemetry": utils.get("_department_telemetry", {}),
|
||||
}
|
||||
continue
|
||||
|
||||
@ -344,6 +416,7 @@ def render_today_plan() -> None:
|
||||
risks = utils.get("_risks", [])
|
||||
supplements = utils.get("_supplements", [])
|
||||
dialogue = utils.get("_dialogue", [])
|
||||
telemetry = utils.get("_telemetry", {})
|
||||
dept_records.append(
|
||||
{
|
||||
"部门": code,
|
||||
@ -362,6 +435,7 @@ def render_today_plan() -> None:
|
||||
"summary": utils.get("_summary", ""),
|
||||
"signals": signals,
|
||||
"risks": risks,
|
||||
"telemetry": telemetry if isinstance(telemetry, dict) else {},
|
||||
}
|
||||
else:
|
||||
score_map = {
|
||||
@ -407,6 +481,7 @@ def render_today_plan() -> None:
|
||||
st.json(global_info["turnover_series"])
|
||||
dept_sup = global_info.get("department_supplements") or {}
|
||||
dept_dialogue = global_info.get("department_dialogue") or {}
|
||||
dept_telemetry = global_info.get("department_telemetry") or {}
|
||||
if dept_sup or dept_dialogue:
|
||||
with st.expander("部门补数与对话记录", expanded=False):
|
||||
if dept_sup:
|
||||
@ -415,6 +490,9 @@ def render_today_plan() -> None:
|
||||
if dept_dialogue:
|
||||
st.write("对话片段:")
|
||||
st.json(dept_dialogue)
|
||||
if dept_telemetry:
|
||||
with st.expander("部门 LLM 元数据", expanded=False):
|
||||
st.json(dept_telemetry)
|
||||
else:
|
||||
st.info("暂未写入全局策略摘要。")
|
||||
|
||||
@ -437,6 +515,10 @@ def render_today_plan() -> None:
|
||||
st.markdown(f"**回合 {idx}:** {line}")
|
||||
else:
|
||||
st.caption("无额外对话。")
|
||||
telemetry = details.get("telemetry") or {}
|
||||
if telemetry:
|
||||
st.write("LLM 元数据:")
|
||||
st.json(telemetry)
|
||||
else:
|
||||
st.info("暂无部门记录。")
|
||||
|
||||
|
||||
36
docs/decision_optimization_notes.md
Normal file
36
docs/decision_optimization_notes.md
Normal file
@ -0,0 +1,36 @@
|
||||
# 决策优化讨论记录
|
||||
|
||||
## 核心目标
|
||||
- 在 `app/agents` 现有规则型与部门 LLM 协同框架上,定义“收益最大化 + 风险约束”的统一指标,用于评估提示词、温度、多智能体博弈和调用策略的优劣。
|
||||
- 借助 `app/backtest/engine.py` 的日频回测环境,构建可重复的仿真闭环,为强化学习或策略搜索提供离线训练与验证数据。
|
||||
- 通过优化高层策略(Prompt 模板、对话轮次、function 调用模式),而不是直接干预底层行情预测,使 RL 学习关注“决策流程调度”。
|
||||
|
||||
## 关键挑战
|
||||
- 状态空间:每轮博弈涉及市场因子、代理信号、历史行为与日志状态(`app/data/logs`);需压缩为 RL 训练可用的紧凑向量或摘要。
|
||||
- 奖励设计:直接使用策略收益会导致梯度稀疏,可考虑收益、回撤、成交约束、信心一致性等多目标加权。
|
||||
- 数据效率:真实交易可探索次数有限,必须依赖离线仿真 + 反事实评估,或引入模型驱动的世界模型缓解分布偏移。
|
||||
- 多智能体非定常性:规则代理 + LLM 的组合策略随参数调整而漂移,需要稳定 Nash 聚合或引入对手建模。
|
||||
|
||||
## 建议路径
|
||||
- **数据记录**:扩展 `app/agent_utils` 写入 Prompt 版本、温度、function 调用次数、部门信号等元数据,形成 RL 可用轨迹;同时让日志文件(如 `app/data/logs/agent_20250929_0754.log`)对齐这些字段。
|
||||
- **环境封装**:在 `app/backtest/engine.py` 外层定义 `DecisionEnv`,让“动作”映射到策略配置(Prompt 选择、温度、投票权重),`step()` 调用现有引擎完成一日博弈并返回奖励。
|
||||
- **层级策略**:先做 Bandit 或 CMA-ES 等黑箱优化调参(温度、权重)建立基线,再过渡到 PPO/SAC 等连续动作 RL;提示词可以编码成可学习 embedding 或有限候选集合。
|
||||
- **博弈协同**:针对部门间权重、否决策略,引入 centralized training, decentralized execution (CTDE) 思路,共享一个 critic 评估全局奖励,actor 负责单部门参数。
|
||||
- **安全约束**:用 penalty 方法或 Lagrangian 处理仓位/风控约束,确保训练过程中遵守 `A_risk` 设定的停牌、涨跌停逻辑。
|
||||
|
||||
## 强化学习框架
|
||||
- 状态构建:拼接市场特征(因子矩阵降维)、历史动作、日志中 LLM 置信度;必要时用 LSTM/Transformer 编码。
|
||||
- 动作定义:连续动作控制 `temperature`/`top_p`/投票权重,离散动作选择 Prompt 模板/协作模式(majority vs leader);组合动作可分解成参数化策略。
|
||||
- 奖励函数:`收益净值提升 - λ1*回撤 - λ2*成交成本 - λ3*冲突次数`,可把 LLM 置信度一致性或 function 调用成本作为正负项。
|
||||
- 训练流程:循环“采样配置 → 回放到 BacktestEngine → 记入 Replay Buffer → 更新策略”,必要时采用离线 RL(CQL、IQL)以利用历史轨迹。
|
||||
- 评估:对比默认提示/温度设定的收益分布,统计策略稳定性、反事实收益差,并在 `app/ui/streamlit_app.py` 加入实验版本切换和可视化。
|
||||
|
||||
## 下一步
|
||||
- 先补齐日志与数据库字段,确保能完整记录提示参数与决策结果。
|
||||
- 搭建轻量 Bandit 调参实验,验证不同提示与温度组合对回测收益的影响。
|
||||
- 设计 RL 环境接口,与现有 BacktestEngine 集成,规划训练与评估脚本。
|
||||
|
||||
## 已完成的日志改进
|
||||
- `agent_utils` 表新增 `_telemetry` 与 `_department_telemetry` JSON 字段(存于 `utils` 列内部),记录每个部门的 provider、模型、温度、回合数、工具调用列表与 token 统计,可在 Streamlit “部门意见”详情页展开查看。
|
||||
- `app/data/logs/agent_*.log` 会追加 `telemetry` 行,保存每轮函数调用的摘要,方便离线分析提示版本与 LLM 配置对决策的影响。
|
||||
- Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。
|
||||
Loading…
Reference in New Issue
Block a user