update
This commit is contained in:
parent
20bbcbb898
commit
2e98e81715
@ -1,12 +1,18 @@
|
|||||||
"""Department-level LLM agents coordinating multi-model decisions."""
|
"""Department-level LLM agents coordinating multi-model decisions."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
|
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from app.agents.base import AgentAction
|
from app.agents.base import AgentAction
|
||||||
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
from app.llm.client import (
|
||||||
|
call_endpoint_with_messages,
|
||||||
|
resolve_endpoint,
|
||||||
|
run_llm_with_config,
|
||||||
|
LLMError,
|
||||||
|
)
|
||||||
from app.llm.prompts import department_prompt
|
from app.llm.prompts import department_prompt
|
||||||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||||
from app.utils.logging import get_logger, get_conversation_logger
|
from app.utils.logging import get_logger, get_conversation_logger
|
||||||
@ -48,6 +54,7 @@ class DepartmentDecision:
|
|||||||
risks: List[str] = field(default_factory=list)
|
risks: List[str] = field(default_factory=list)
|
||||||
supplements: List[Dict[str, Any]] = field(default_factory=list)
|
supplements: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
dialogue: List[str] = field(default_factory=list)
|
dialogue: List[str] = field(default_factory=list)
|
||||||
|
telemetry: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -60,6 +67,7 @@ class DepartmentDecision:
|
|||||||
"raw_response": self.raw_response,
|
"raw_response": self.raw_response,
|
||||||
"supplements": self.supplements,
|
"supplements": self.supplements,
|
||||||
"dialogue": self.dialogue,
|
"dialogue": self.dialogue,
|
||||||
|
"telemetry": self.telemetry,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -102,18 +110,30 @@ class DepartmentAgent:
|
|||||||
messages: List[Dict[str, object]] = []
|
messages: List[Dict[str, object]] = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append(
|
prompt_body = department_prompt(self.settings, mutable_context)
|
||||||
{
|
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
|
||||||
"role": "user",
|
prompt_preview = prompt_body[:240]
|
||||||
"content": department_prompt(self.settings, mutable_context),
|
messages.append({"role": "user", "content": prompt_body})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript: List[str] = []
|
transcript: List[str] = []
|
||||||
delivered_requests: set[Tuple[str, int, str]] = set()
|
delivered_requests: set[Tuple[str, int, str]] = set()
|
||||||
|
|
||||||
primary_endpoint = llm_cfg.primary
|
primary_endpoint = llm_cfg.primary
|
||||||
|
try:
|
||||||
|
resolved_primary = resolve_endpoint(primary_endpoint)
|
||||||
|
except LLMError as exc:
|
||||||
|
LOGGER.warning(
|
||||||
|
"部门 %s 无法解析 LLM 端点,回退传统提示:%s",
|
||||||
|
self.settings.code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return self._analyze_legacy(mutable_context, system_prompt)
|
||||||
|
|
||||||
final_message: Optional[Dict[str, Any]] = None
|
final_message: Optional[Dict[str, Any]] = None
|
||||||
|
usage_records: List[Dict[str, Any]] = []
|
||||||
|
tool_call_records: List[Dict[str, Any]] = []
|
||||||
|
rounds_executed = 0
|
||||||
CONV_LOGGER.info(
|
CONV_LOGGER.info(
|
||||||
"dept=%s ts_code=%s trade_date=%s start",
|
"dept=%s ts_code=%s trade_date=%s start",
|
||||||
self.settings.code,
|
self.settings.code,
|
||||||
@ -138,6 +158,14 @@ class DepartmentAgent:
|
|||||||
)
|
)
|
||||||
return self._analyze_legacy(mutable_context, system_prompt)
|
return self._analyze_legacy(mutable_context, system_prompt)
|
||||||
|
|
||||||
|
rounds_executed = round_idx + 1
|
||||||
|
|
||||||
|
usage = response.get("usage") if isinstance(response, Mapping) else None
|
||||||
|
if isinstance(usage, Mapping):
|
||||||
|
usage_payload = {"round": round_idx + 1}
|
||||||
|
usage_payload.update(dict(usage))
|
||||||
|
usage_records.append(usage_payload)
|
||||||
|
|
||||||
choice = (response.get("choices") or [{}])[0]
|
choice = (response.get("choices") or [{}])[0]
|
||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
transcript.append(_message_to_text(message))
|
transcript.append(_message_to_text(message))
|
||||||
@ -159,12 +187,36 @@ class DepartmentAgent:
|
|||||||
tool_calls = message.get("tool_calls") or []
|
tool_calls = message.get("tool_calls") or []
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
for call in tool_calls:
|
for call in tool_calls:
|
||||||
|
function_block = call.get("function") or {}
|
||||||
tool_response, delivered = self._handle_tool_call(
|
tool_response, delivered = self._handle_tool_call(
|
||||||
mutable_context,
|
mutable_context,
|
||||||
call,
|
call,
|
||||||
delivered_requests,
|
delivered_requests,
|
||||||
round_idx,
|
round_idx,
|
||||||
)
|
)
|
||||||
|
tables_summary: List[Dict[str, Any]] = []
|
||||||
|
for item in tool_response.get("results") or []:
|
||||||
|
if isinstance(item, Mapping):
|
||||||
|
tables_summary.append(
|
||||||
|
{
|
||||||
|
"table": item.get("table"),
|
||||||
|
"window": item.get("window"),
|
||||||
|
"trade_date": item.get("trade_date"),
|
||||||
|
"row_count": len(item.get("rows") or []),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_call_records.append(
|
||||||
|
{
|
||||||
|
"round": round_idx + 1,
|
||||||
|
"id": call.get("id"),
|
||||||
|
"name": function_block.get("name"),
|
||||||
|
"arguments": function_block.get("arguments"),
|
||||||
|
"status": tool_response.get("status"),
|
||||||
|
"results": len(tool_response.get("results") or []),
|
||||||
|
"tables": tables_summary,
|
||||||
|
"skipped": list(tool_response.get("skipped") or []),
|
||||||
|
}
|
||||||
|
)
|
||||||
transcript.append(
|
transcript.append(
|
||||||
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
|
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
@ -216,6 +268,64 @@ class DepartmentAgent:
|
|||||||
if isinstance(risks, str):
|
if isinstance(risks, str):
|
||||||
risks = [risks]
|
risks = [risks]
|
||||||
|
|
||||||
|
def _safe_int(value: Any) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError): # noqa: PERF203 - clarity
|
||||||
|
return 0
|
||||||
|
|
||||||
|
prompt_tokens_total = 0
|
||||||
|
completion_tokens_total = 0
|
||||||
|
total_tokens_reported = 0
|
||||||
|
for usage_payload in usage_records:
|
||||||
|
prompt_tokens_total += _safe_int(
|
||||||
|
usage_payload.get("prompt_tokens")
|
||||||
|
or usage_payload.get("prompt_tokens_total")
|
||||||
|
)
|
||||||
|
completion_tokens_total += _safe_int(
|
||||||
|
usage_payload.get("completion_tokens")
|
||||||
|
or usage_payload.get("completion_tokens_total")
|
||||||
|
)
|
||||||
|
reported_total = _safe_int(
|
||||||
|
usage_payload.get("total_tokens")
|
||||||
|
or usage_payload.get("total_tokens_total")
|
||||||
|
)
|
||||||
|
if reported_total:
|
||||||
|
total_tokens_reported += reported_total
|
||||||
|
|
||||||
|
total_tokens = (
|
||||||
|
total_tokens_reported
|
||||||
|
if total_tokens_reported
|
||||||
|
else prompt_tokens_total + completion_tokens_total
|
||||||
|
)
|
||||||
|
|
||||||
|
telemetry: Dict[str, Any] = {
|
||||||
|
"provider": resolved_primary.get("provider_key"),
|
||||||
|
"model": resolved_primary.get("model"),
|
||||||
|
"temperature": resolved_primary.get("temperature"),
|
||||||
|
"timeout": resolved_primary.get("timeout"),
|
||||||
|
"endpoint_prompt_template": resolved_primary.get("prompt_template"),
|
||||||
|
"rounds": rounds_executed,
|
||||||
|
"tool_call_count": len(tool_call_records),
|
||||||
|
"tool_trace": tool_call_records,
|
||||||
|
"usage_by_round": usage_records,
|
||||||
|
"tokens": {
|
||||||
|
"prompt": prompt_tokens_total,
|
||||||
|
"completion": completion_tokens_total,
|
||||||
|
"total": total_tokens,
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"checksum": prompt_checksum,
|
||||||
|
"length": len(prompt_body),
|
||||||
|
"preview": prompt_preview,
|
||||||
|
"role_description": self.settings.description,
|
||||||
|
"instruction": self.settings.prompt,
|
||||||
|
"system": system_prompt,
|
||||||
|
},
|
||||||
|
"messages_exchanged": len(messages),
|
||||||
|
"supplement_rounds": len(tool_call_records),
|
||||||
|
}
|
||||||
|
|
||||||
decision = DepartmentDecision(
|
decision = DepartmentDecision(
|
||||||
department=self.settings.code,
|
department=self.settings.code,
|
||||||
action=action,
|
action=action,
|
||||||
@ -226,6 +336,7 @@ class DepartmentAgent:
|
|||||||
raw_response=content_text,
|
raw_response=content_text,
|
||||||
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
||||||
dialogue=list(transcript),
|
dialogue=list(transcript),
|
||||||
|
telemetry=telemetry,
|
||||||
)
|
)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"部门 %s 决策:action=%s confidence=%.2f",
|
"部门 %s 决策:action=%s confidence=%.2f",
|
||||||
@ -241,6 +352,11 @@ class DepartmentAgent:
|
|||||||
decision.confidence,
|
decision.confidence,
|
||||||
summary or "",
|
summary or "",
|
||||||
)
|
)
|
||||||
|
CONV_LOGGER.info(
|
||||||
|
"dept=%s telemetry=%s",
|
||||||
|
self.settings.code,
|
||||||
|
json.dumps(telemetry, ensure_ascii=False),
|
||||||
|
)
|
||||||
return decision
|
return decision
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -324,6 +324,8 @@ class BacktestEngine:
|
|||||||
metadata["_supplements"] = dept_decision.supplements
|
metadata["_supplements"] = dept_decision.supplements
|
||||||
if dept_decision.dialogue:
|
if dept_decision.dialogue:
|
||||||
metadata["_dialogue"] = dept_decision.dialogue
|
metadata["_dialogue"] = dept_decision.dialogue
|
||||||
|
if dept_decision.telemetry:
|
||||||
|
metadata["_telemetry"] = dept_decision.telemetry
|
||||||
payload_json = {**action_scores, **metadata}
|
payload_json = {**action_scores, **metadata}
|
||||||
rows.append(
|
rows.append(
|
||||||
(
|
(
|
||||||
@ -355,6 +357,11 @@ class BacktestEngine:
|
|||||||
for code, dept in decision.department_decisions.items()
|
for code, dept in decision.department_decisions.items()
|
||||||
if dept.dialogue
|
if dept.dialogue
|
||||||
},
|
},
|
||||||
|
"_department_telemetry": {
|
||||||
|
code: dept.telemetry
|
||||||
|
for code, dept in decision.department_decisions.items()
|
||||||
|
if dept.telemetry
|
||||||
|
},
|
||||||
}
|
}
|
||||||
rows.append(
|
rows.append(
|
||||||
(
|
(
|
||||||
|
|||||||
@ -105,7 +105,7 @@ def _request_openai_chat(
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
provider_key = (endpoint.provider or "ollama").lower()
|
provider_key = (endpoint.provider or "ollama").lower()
|
||||||
provider_cfg = cfg.llm_providers.get(provider_key)
|
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||||
@ -152,7 +152,7 @@ def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
|||||||
|
|
||||||
|
|
||||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||||
resolved = _prepare_endpoint(endpoint)
|
resolved = resolve_endpoint(endpoint)
|
||||||
provider_key = resolved["provider_key"]
|
provider_key = resolved["provider_key"]
|
||||||
mode = resolved["mode"]
|
mode = resolved["mode"]
|
||||||
prompt_template = resolved["prompt_template"]
|
prompt_template = resolved["prompt_template"]
|
||||||
@ -188,7 +188,7 @@ def call_endpoint_with_messages(
|
|||||||
tools: Optional[List[Dict[str, object]]] = None,
|
tools: Optional[List[Dict[str, object]]] = None,
|
||||||
tool_choice: Optional[object] = None,
|
tool_choice: Optional[object] = None,
|
||||||
) -> Dict[str, object]:
|
) -> Dict[str, object]:
|
||||||
resolved = _prepare_endpoint(endpoint)
|
resolved = resolve_endpoint(endpoint)
|
||||||
provider_key = resolved["provider_key"]
|
provider_key = resolved["provider_key"]
|
||||||
mode = resolved["mode"]
|
mode = resolved["mode"]
|
||||||
base_url = resolved["base_url"]
|
base_url = resolved["base_url"]
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
"""Simple runtime metrics collector for LLM calls."""
|
"""Simple runtime metrics collector for LLM calls."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Deque, Dict, List, Optional
|
from typing import Callable, Deque, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -20,6 +21,9 @@ class _Metrics:
|
|||||||
|
|
||||||
_METRICS = _Metrics()
|
_METRICS = _Metrics()
|
||||||
_LOCK = Lock()
|
_LOCK = Lock()
|
||||||
|
_LISTENERS: List[Callable[[Dict[str, object]], None]] = []
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def record_call(
|
def record_call(
|
||||||
@ -45,6 +49,7 @@ def record_call(
|
|||||||
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
||||||
if completion_tokens:
|
if completion_tokens:
|
||||||
_METRICS.total_completion_tokens += int(completion_tokens)
|
_METRICS.total_completion_tokens += int(completion_tokens)
|
||||||
|
_notify_listeners()
|
||||||
|
|
||||||
|
|
||||||
def snapshot(reset: bool = False) -> Dict[str, object]:
|
def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||||
@ -75,6 +80,7 @@ def reset() -> None:
|
|||||||
"""Reset all collected metrics."""
|
"""Reset all collected metrics."""
|
||||||
|
|
||||||
snapshot(reset=True)
|
snapshot(reset=True)
|
||||||
|
_notify_listeners()
|
||||||
|
|
||||||
|
|
||||||
def record_decision(
|
def record_decision(
|
||||||
@ -103,6 +109,7 @@ def record_decision(
|
|||||||
_METRICS.decision_action_counts[action] = (
|
_METRICS.decision_action_counts[action] = (
|
||||||
_METRICS.decision_action_counts.get(action, 0) + 1
|
_METRICS.decision_action_counts.get(action, 0) + 1
|
||||||
)
|
)
|
||||||
|
_notify_listeners()
|
||||||
|
|
||||||
|
|
||||||
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||||
@ -112,3 +119,42 @@ def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
|||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
return []
|
return []
|
||||||
return list(_METRICS.decisions)[-limit:]
|
return list(_METRICS.decisions)[-limit:]
|
||||||
|
|
||||||
|
|
||||||
|
def register_listener(callback: Callable[[Dict[str, object]], None]) -> None:
|
||||||
|
"""Register a callback invoked whenever metrics change."""
|
||||||
|
|
||||||
|
if not callable(callback):
|
||||||
|
return
|
||||||
|
with _LOCK:
|
||||||
|
if callback in _LISTENERS:
|
||||||
|
should_invoke = False
|
||||||
|
else:
|
||||||
|
_LISTENERS.append(callback)
|
||||||
|
should_invoke = True
|
||||||
|
if should_invoke:
|
||||||
|
try:
|
||||||
|
callback(snapshot())
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("Metrics listener failed on initial callback")
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_listener(callback: Callable[[Dict[str, object]], None]) -> None:
|
||||||
|
"""Remove a previously registered metrics callback."""
|
||||||
|
|
||||||
|
with _LOCK:
|
||||||
|
if callback in _LISTENERS:
|
||||||
|
_LISTENERS.remove(callback)
|
||||||
|
|
||||||
|
|
||||||
|
def _notify_listeners() -> None:
|
||||||
|
with _LOCK:
|
||||||
|
listeners = list(_LISTENERS)
|
||||||
|
if not listeners:
|
||||||
|
return
|
||||||
|
data = snapshot()
|
||||||
|
for callback in listeners:
|
||||||
|
try:
|
||||||
|
callback(data)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("Metrics listener execution failed")
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -28,9 +29,10 @@ from app.ingest.checker import run_boot_check
|
|||||||
from app.ingest.tushare import FetchJob, run_ingestion
|
from app.ingest.tushare import FetchJob, run_ingestion
|
||||||
from app.llm.client import llm_config_snapshot, run_llm
|
from app.llm.client import llm_config_snapshot, run_llm
|
||||||
from app.llm.metrics import (
|
from app.llm.metrics import (
|
||||||
|
recent_decisions as llm_recent_decisions,
|
||||||
|
register_listener as register_llm_metrics_listener,
|
||||||
reset as reset_llm_metrics,
|
reset as reset_llm_metrics,
|
||||||
snapshot as snapshot_llm_metrics,
|
snapshot as snapshot_llm_metrics,
|
||||||
recent_decisions as llm_recent_decisions,
|
|
||||||
)
|
)
|
||||||
from app.utils.config import (
|
from app.utils.config import (
|
||||||
ALLOWED_LLM_STRATEGIES,
|
ALLOWED_LLM_STRATEGIES,
|
||||||
@ -49,6 +51,11 @@ from app.utils.logging import get_logger
|
|||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "ui"}
|
LOG_EXTRA = {"stage": "ui"}
|
||||||
|
_SIDEBAR_THROTTLE_SECONDS = 0.75
|
||||||
|
|
||||||
|
|
||||||
|
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
||||||
|
_update_dashboard_sidebar(metrics, throttled=True)
|
||||||
|
|
||||||
|
|
||||||
def render_global_dashboard() -> None:
|
def render_global_dashboard() -> None:
|
||||||
@ -56,54 +63,118 @@ def render_global_dashboard() -> None:
|
|||||||
|
|
||||||
metrics_container = st.sidebar.container()
|
metrics_container = st.sidebar.container()
|
||||||
decisions_container = st.sidebar.container()
|
decisions_container = st.sidebar.container()
|
||||||
st.session_state["dashboard_placeholders"] = (metrics_container, decisions_container)
|
st.session_state["dashboard_containers"] = (metrics_container, decisions_container)
|
||||||
|
_ensure_dashboard_elements(metrics_container, decisions_container)
|
||||||
|
if not st.session_state.get("dashboard_listener_registered"):
|
||||||
|
register_llm_metrics_listener(_sidebar_metrics_listener)
|
||||||
|
st.session_state["dashboard_listener_registered"] = True
|
||||||
_update_dashboard_sidebar()
|
_update_dashboard_sidebar()
|
||||||
|
|
||||||
|
|
||||||
def _update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None:
|
def _update_dashboard_sidebar(
|
||||||
placeholders = st.session_state.get("dashboard_placeholders")
|
metrics: Optional[Dict[str, object]] = None,
|
||||||
if not placeholders:
|
*,
|
||||||
|
throttled: bool = False,
|
||||||
|
) -> None:
|
||||||
|
containers = st.session_state.get("dashboard_containers")
|
||||||
|
if not containers:
|
||||||
return
|
return
|
||||||
metrics_container, decisions_container = placeholders
|
metrics_container, decisions_container = containers
|
||||||
|
elements = st.session_state.get("dashboard_elements")
|
||||||
|
if elements is None:
|
||||||
|
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||||
|
|
||||||
|
if throttled:
|
||||||
|
now = time.monotonic()
|
||||||
|
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
|
||||||
|
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
|
||||||
|
if metrics is not None:
|
||||||
|
st.session_state["dashboard_pending_metrics"] = metrics
|
||||||
|
return
|
||||||
|
st.session_state["dashboard_last_update_ts"] = now
|
||||||
|
else:
|
||||||
|
st.session_state["dashboard_last_update_ts"] = time.monotonic()
|
||||||
|
|
||||||
|
if metrics is None:
|
||||||
|
metrics = st.session_state.pop("dashboard_pending_metrics", None)
|
||||||
|
if metrics is None:
|
||||||
|
metrics = snapshot_llm_metrics()
|
||||||
|
else:
|
||||||
|
st.session_state.pop("dashboard_pending_metrics", None)
|
||||||
|
|
||||||
metrics = metrics or snapshot_llm_metrics()
|
metrics = metrics or snapshot_llm_metrics()
|
||||||
|
|
||||||
metrics_container.empty()
|
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
|
||||||
with metrics_container.container():
|
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
||||||
st.header("系统监控")
|
elements["metrics_completion"].metric(
|
||||||
col_a, col_b, col_c = st.columns(3)
|
"Completion Tokens", metrics.get("total_completion_tokens", 0)
|
||||||
col_a.metric("LLM 调用", metrics.get("total_calls", 0))
|
)
|
||||||
col_b.metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
|
||||||
col_c.metric("Completion Tokens", metrics.get("total_completion_tokens", 0))
|
|
||||||
|
|
||||||
provider_calls = metrics.get("provider_calls", {})
|
provider_calls = metrics.get("provider_calls", {})
|
||||||
model_calls = metrics.get("model_calls", {})
|
model_calls = metrics.get("model_calls", {})
|
||||||
if provider_calls or model_calls:
|
provider_placeholder = elements["provider_distribution"]
|
||||||
with st.expander("调用分布", expanded=False):
|
provider_placeholder.empty()
|
||||||
if provider_calls:
|
if provider_calls:
|
||||||
st.write("按 Provider:")
|
provider_placeholder.json(provider_calls)
|
||||||
st.json(provider_calls)
|
else:
|
||||||
if model_calls:
|
provider_placeholder.info("暂无 Provider 分布数据。")
|
||||||
st.write("按模型:")
|
|
||||||
st.json(model_calls)
|
|
||||||
|
|
||||||
decisions_container.empty()
|
model_placeholder = elements["model_distribution"]
|
||||||
with decisions_container.container():
|
model_placeholder.empty()
|
||||||
st.subheader("最新决策")
|
if model_calls:
|
||||||
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
|
model_placeholder.json(model_calls)
|
||||||
if decisions:
|
else:
|
||||||
for record in reversed(decisions[-10:]):
|
model_placeholder.info("暂无模型分布数据。")
|
||||||
ts_code = record.get("ts_code")
|
|
||||||
trade_date = record.get("trade_date")
|
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
|
||||||
action = record.get("action")
|
if decisions:
|
||||||
confidence = record.get("confidence", 0.0)
|
lines = []
|
||||||
summary = record.get("summary")
|
for record in reversed(decisions[-10:]):
|
||||||
st.markdown(
|
ts_code = record.get("ts_code")
|
||||||
f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
|
trade_date = record.get("trade_date")
|
||||||
)
|
action = record.get("action")
|
||||||
if summary:
|
confidence = record.get("confidence", 0.0)
|
||||||
st.caption(summary)
|
summary = record.get("summary")
|
||||||
else:
|
line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
|
||||||
st.caption("暂无决策记录。执行回测或实时评估后可在此查看。")
|
if summary:
|
||||||
|
line += f"\n<small>{summary}</small>"
|
||||||
|
lines.append(line)
|
||||||
|
decisions_placeholder = elements["decisions_list"]
|
||||||
|
decisions_placeholder.empty()
|
||||||
|
decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True)
|
||||||
|
else:
|
||||||
|
decisions_placeholder = elements["decisions_list"]
|
||||||
|
decisions_placeholder.empty()
|
||||||
|
decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
|
||||||
|
elements = st.session_state.get("dashboard_elements")
|
||||||
|
if elements:
|
||||||
|
return elements
|
||||||
|
|
||||||
|
metrics_container.header("系统监控")
|
||||||
|
col_a, col_b, col_c = metrics_container.columns(3)
|
||||||
|
metrics_calls = col_a.empty()
|
||||||
|
metrics_prompt = col_b.empty()
|
||||||
|
metrics_completion = col_c.empty()
|
||||||
|
distribution_expander = metrics_container.expander("调用分布", expanded=False)
|
||||||
|
provider_distribution = distribution_expander.empty()
|
||||||
|
model_distribution = distribution_expander.empty()
|
||||||
|
|
||||||
|
decisions_container.subheader("最新决策")
|
||||||
|
decisions_list = decisions_container.empty()
|
||||||
|
|
||||||
|
elements = {
|
||||||
|
"metrics_calls": metrics_calls,
|
||||||
|
"metrics_prompt": metrics_prompt,
|
||||||
|
"metrics_completion": metrics_completion,
|
||||||
|
"provider_distribution": provider_distribution,
|
||||||
|
"model_distribution": model_distribution,
|
||||||
|
"decisions_list": decisions_list,
|
||||||
|
}
|
||||||
|
st.session_state["dashboard_elements"] = elements
|
||||||
|
return elements
|
||||||
|
|
||||||
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||||
"""Attempt to query provider API and return available model ids."""
|
"""Attempt to query provider API and return available model ids."""
|
||||||
@ -335,6 +406,7 @@ def render_today_plan() -> None:
|
|||||||
"turnover_series": utils.get("_turnover_series", []),
|
"turnover_series": utils.get("_turnover_series", []),
|
||||||
"department_supplements": utils.get("_department_supplements", {}),
|
"department_supplements": utils.get("_department_supplements", {}),
|
||||||
"department_dialogue": utils.get("_department_dialogue", {}),
|
"department_dialogue": utils.get("_department_dialogue", {}),
|
||||||
|
"department_telemetry": utils.get("_department_telemetry", {}),
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -344,6 +416,7 @@ def render_today_plan() -> None:
|
|||||||
risks = utils.get("_risks", [])
|
risks = utils.get("_risks", [])
|
||||||
supplements = utils.get("_supplements", [])
|
supplements = utils.get("_supplements", [])
|
||||||
dialogue = utils.get("_dialogue", [])
|
dialogue = utils.get("_dialogue", [])
|
||||||
|
telemetry = utils.get("_telemetry", {})
|
||||||
dept_records.append(
|
dept_records.append(
|
||||||
{
|
{
|
||||||
"部门": code,
|
"部门": code,
|
||||||
@ -362,6 +435,7 @@ def render_today_plan() -> None:
|
|||||||
"summary": utils.get("_summary", ""),
|
"summary": utils.get("_summary", ""),
|
||||||
"signals": signals,
|
"signals": signals,
|
||||||
"risks": risks,
|
"risks": risks,
|
||||||
|
"telemetry": telemetry if isinstance(telemetry, dict) else {},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
score_map = {
|
score_map = {
|
||||||
@ -407,6 +481,7 @@ def render_today_plan() -> None:
|
|||||||
st.json(global_info["turnover_series"])
|
st.json(global_info["turnover_series"])
|
||||||
dept_sup = global_info.get("department_supplements") or {}
|
dept_sup = global_info.get("department_supplements") or {}
|
||||||
dept_dialogue = global_info.get("department_dialogue") or {}
|
dept_dialogue = global_info.get("department_dialogue") or {}
|
||||||
|
dept_telemetry = global_info.get("department_telemetry") or {}
|
||||||
if dept_sup or dept_dialogue:
|
if dept_sup or dept_dialogue:
|
||||||
with st.expander("部门补数与对话记录", expanded=False):
|
with st.expander("部门补数与对话记录", expanded=False):
|
||||||
if dept_sup:
|
if dept_sup:
|
||||||
@ -415,6 +490,9 @@ def render_today_plan() -> None:
|
|||||||
if dept_dialogue:
|
if dept_dialogue:
|
||||||
st.write("对话片段:")
|
st.write("对话片段:")
|
||||||
st.json(dept_dialogue)
|
st.json(dept_dialogue)
|
||||||
|
if dept_telemetry:
|
||||||
|
with st.expander("部门 LLM 元数据", expanded=False):
|
||||||
|
st.json(dept_telemetry)
|
||||||
else:
|
else:
|
||||||
st.info("暂未写入全局策略摘要。")
|
st.info("暂未写入全局策略摘要。")
|
||||||
|
|
||||||
@ -437,6 +515,10 @@ def render_today_plan() -> None:
|
|||||||
st.markdown(f"**回合 {idx}:** {line}")
|
st.markdown(f"**回合 {idx}:** {line}")
|
||||||
else:
|
else:
|
||||||
st.caption("无额外对话。")
|
st.caption("无额外对话。")
|
||||||
|
telemetry = details.get("telemetry") or {}
|
||||||
|
if telemetry:
|
||||||
|
st.write("LLM 元数据:")
|
||||||
|
st.json(telemetry)
|
||||||
else:
|
else:
|
||||||
st.info("暂无部门记录。")
|
st.info("暂无部门记录。")
|
||||||
|
|
||||||
|
|||||||
36
docs/decision_optimization_notes.md
Normal file
36
docs/decision_optimization_notes.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# 决策优化讨论记录
|
||||||
|
|
||||||
|
## 核心目标
|
||||||
|
- 在 `app/agents` 现有规则型与部门 LLM 协同框架上,定义“收益最大化 + 风险约束”的统一指标,用于评估提示词、温度、多智能体博弈和调用策略的优劣。
|
||||||
|
- 借助 `app/backtest/engine.py` 的日频回测环境,构建可重复的仿真闭环,为强化学习或策略搜索提供离线训练与验证数据。
|
||||||
|
- 通过优化高层策略(Prompt 模板、对话轮次、function 调用模式),而不是直接干预底层行情预测,使 RL 学习关注“决策流程调度”。
|
||||||
|
|
||||||
|
## 关键挑战
|
||||||
|
- 状态空间:每轮博弈涉及市场因子、代理信号、历史行为与日志状态(`app/data/logs`);需压缩为 RL 训练可用的紧凑向量或摘要。
|
||||||
|
- 奖励设计:直接使用策略收益会导致梯度稀疏,可考虑收益、回撤、成交约束、信心一致性等多目标加权。
|
||||||
|
- 数据效率:真实交易可探索次数有限,必须依赖离线仿真 + 反事实评估,或引入模型驱动的世界模型缓解分布偏移。
|
||||||
|
- 多智能体非定常性:规则代理 + LLM 的组合策略随参数调整而漂移,需要稳定 Nash 聚合或引入对手建模。
|
||||||
|
|
||||||
|
## 建议路径
|
||||||
|
- **数据记录**:扩展 `app/agent_utils` 写入 Prompt 版本、温度、function 调用次数、部门信号等元数据,形成 RL 可用轨迹;同时让日志文件(如 `app/data/logs/agent_20250929_0754.log`)对齐这些字段。
|
||||||
|
- **环境封装**:在 `app/backtest/engine.py` 外层定义 `DecisionEnv`,让“动作”映射到策略配置(Prompt 选择、温度、投票权重),`step()` 调用现有引擎完成一日博弈并返回奖励。
|
||||||
|
- **层级策略**:先做 Bandit 或 CMA-ES 等黑箱优化调参(温度、权重)建立基线,再过渡到 PPO/SAC 等连续动作 RL;提示词可以编码成可学习 embedding 或有限候选集合。
|
||||||
|
- **博弈协同**:针对部门间权重、否决策略,引入 centralized training, decentralized execution (CTDE) 思路,共享一个 critic 评估全局奖励,actor 负责单部门参数。
|
||||||
|
- **安全约束**:用 penalty 方法或 Lagrangian 处理仓位/风控约束,确保训练过程中遵守 `A_risk` 设定的停牌、涨跌停逻辑。
|
||||||
|
|
||||||
|
## 强化学习框架
|
||||||
|
- 状态构建:拼接市场特征(因子矩阵降维)、历史动作、日志中 LLM 置信度;必要时用 LSTM/Transformer 编码。
|
||||||
|
- 动作定义:连续动作控制 `temperature`/`top_p`/投票权重,离散动作选择 Prompt 模板/协作模式(majority vs leader);组合动作可分解成参数化策略。
|
||||||
|
- 奖励函数:`收益净值提升 - λ1*回撤 - λ2*成交成本 - λ3*冲突次数`,可把 LLM 置信度一致性或 function 调用成本作为正负项。
|
||||||
|
- 训练流程:循环“采样配置 → 回放到 BacktestEngine → 记入 Replay Buffer → 更新策略”,必要时采用离线 RL(CQL、IQL)以利用历史轨迹。
|
||||||
|
- 评估:对比默认提示/温度设定的收益分布,统计策略稳定性、反事实收益差,并在 `app/ui/streamlit_app.py` 加入实验版本切换和可视化。
|
||||||
|
|
||||||
|
## 下一步
|
||||||
|
- 先补齐日志与数据库字段,确保能完整记录提示参数与决策结果。
|
||||||
|
- 搭建轻量 Bandit 调参实验,验证不同提示与温度组合对回测收益的影响。
|
||||||
|
- 设计 RL 环境接口,与现有 BacktestEngine 集成,规划训练与评估脚本。
|
||||||
|
|
||||||
|
## 已完成的日志改进
|
||||||
|
- `agent_utils` 表新增 `_telemetry` 与 `_department_telemetry` JSON 字段(存于 `utils` 列内部),记录每个部门的 provider、模型、温度、回合数、工具调用列表与 token 统计,可在 Streamlit “部门意见”详情页展开查看。
|
||||||
|
- `app/data/logs/agent_*.log` 会追加 `telemetry` 行,保存每轮函数调用的摘要,方便离线分析提示版本与 LLM 配置对决策的影响。
|
||||||
|
- Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。
|
||||||
Loading…
Reference in New Issue
Block a user