add briefing rounds and enhance backtest comparison view

This commit is contained in:
sam 2025-10-15 21:19:27 +08:00
parent f9f8ca887f
commit f6c11867d2
11 changed files with 866 additions and 181 deletions

View File

@ -176,6 +176,13 @@ def decide(
ts_code=context.ts_code, ts_code=context.ts_code,
trade_date=context.trade_date, trade_date=context.trade_date,
) )
briefing_round = host.start_round(
host_trace,
agenda="situation_briefing",
structure=GameStructure.SIGNALING,
)
host.handle_message(briefing_round, _host_briefing_message(context))
host.finalize_round(briefing_round)
department_round: Optional[RoundSummary] = None department_round: Optional[RoundSummary] = None
risk_round: Optional[RoundSummary] = None risk_round: Optional[RoundSummary] = None
execution_round: Optional[RoundSummary] = None execution_round: Optional[RoundSummary] = None
@ -224,6 +231,19 @@ def decide(
filtered_utilities = {action: utilities[action] for action in feas_actions} filtered_utilities = {action: utilities[action] for action in feas_actions}
hold_scores = utilities.get(AgentAction.HOLD, {}) hold_scores = utilities.get(AgentAction.HOLD, {})
norm_weights = weight_map(raw_weights) norm_weights = weight_map(raw_weights)
prediction_round = host.start_round(
host_trace,
agenda="prediction_alignment",
structure=GameStructure.REPEATED,
)
prediction_message, prediction_summary = _prediction_summary_message(filtered_utilities, norm_weights)
host.handle_message(prediction_round, prediction_message)
host.finalize_round(prediction_round)
if prediction_summary:
belief_updates["prediction_summary"] = BeliefUpdate(
belief=prediction_summary,
rationale="Aggregated utilities shared during alignment round.",
)
if method == "vote": if method == "vote":
action, confidence = vote(filtered_utilities, norm_weights) action, confidence = vote(filtered_utilities, norm_weights)
@ -339,6 +359,22 @@ def decide(
department_votes, department_votes,
) )
belief_revision = revise_beliefs(belief_updates, exec_action) belief_revision = revise_beliefs(belief_updates, exec_action)
if belief_revision.conflicts:
risk_round = host.ensure_round(
host_trace,
agenda="conflict_resolution",
structure=GameStructure.CUSTOM,
)
conflict_message = DialogueMessage(
sender="protocol_host",
role=DialogueRole.HOST,
message_type=MessageType.COUNTER,
content="检测到关键冲突,需要后续回合复核。",
annotations={"conflicts": belief_revision.conflicts},
)
host.handle_message(risk_round, conflict_message)
risk_round.notes.setdefault("conflicts", belief_revision.conflicts)
host.finalize_round(risk_round)
execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value) execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value)
execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence) execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence)
if belief_revision.conflicts: if belief_revision.conflicts:
@ -413,6 +449,73 @@ def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
return False return False
def _host_briefing_message(context: AgentContext) -> DialogueMessage:
features = getattr(context, "features", {}) or {}
close = features.get("close") or features.get("daily.close")
pct_chg = features.get("pct_chg") or features.get("daily.pct_chg")
snapshot = getattr(context, "market_snapshot", {}) or {}
index_brief = snapshot.get("index_change")
lines = [
f"标的 {context.ts_code}",
f"交易日 {context.trade_date}",
]
if close is not None:
lines.append(f"最新收盘价:{close}")
if pct_chg is not None:
lines.append(f"涨跌幅:{pct_chg}")
if index_brief:
lines.append(f"市场概览:{index_brief}")
content = "".join(str(line) for line in lines)
return DialogueMessage(
sender="protocol_host",
role=DialogueRole.HOST,
message_type=MessageType.META,
content=content,
)
def _prediction_summary_message(
utilities: Mapping[AgentAction, Mapping[str, float]],
weights: Mapping[str, float],
) -> Tuple[DialogueMessage, Dict[str, object]]:
if not utilities:
message = DialogueMessage(
sender="protocol_host",
role=DialogueRole.PREDICTION,
message_type=MessageType.META,
content="暂无可用的部门或代理评分,默认进入 HOLD 讨论。",
)
return message, {}
aggregates: Dict[AgentAction, float] = {}
for action, agent_scores in utilities.items():
aggregates[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
ranked = sorted(aggregates.items(), key=lambda item: item[1], reverse=True)
summary_lines = []
for action, score in ranked[:3]:
summary_lines.append(f"{action.value}: {score:.3f}")
content = "预测合意度:" + " ".join(summary_lines)
total_score = sum(max(score, 0.0) for _, score in ranked)
confidence = 0.0
if total_score > 0 and ranked:
confidence = max(ranked[0][1], 0.0) / total_score
annotations = {
"aggregates": {action.value: score for action, score in aggregates.items()},
}
message = DialogueMessage(
sender="protocol_host",
role=DialogueRole.PREDICTION,
message_type=MessageType.META,
content=content,
confidence=confidence,
annotations=annotations,
)
summary = {
"aggregates": {action.value: aggregates[action] for action in ACTIONS if action in aggregates},
"confidence": confidence,
}
return message, summary
def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage: def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage:
content = decision.summary or decision.raw_response or decision.action.value content = decision.summary or decision.raw_response or decision.action.value
references = decision.signals or [] references = decision.signals or []

View File

@ -47,6 +47,8 @@ class EpisodeMetrics:
total_return: float total_return: float
max_drawdown: float max_drawdown: float
volatility: float volatility: float
sharpe_like: float
calmar_like: float
nav_series: List[Dict[str, float]] nav_series: List[Dict[str, float]]
trades: List[Dict[str, object]] trades: List[Dict[str, object]]
turnover: float turnover: float
@ -55,12 +57,6 @@ class EpisodeMetrics:
risk_count: int risk_count: int
risk_breakdown: Dict[str, int] risk_breakdown: Dict[str, int]
@property
def sharpe_like(self) -> float:
if self.volatility <= 1e-9:
return 0.0
return self.total_return / self.volatility
class DecisionEnv: class DecisionEnv:
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest.""" """Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
@ -123,6 +119,7 @@ class DecisionEnv:
"volatility": 0.0, "volatility": 0.0,
"turnover": 0.0, "turnover": 0.0,
"sharpe_like": 0.0, "sharpe_like": 0.0,
"calmar_like": 0.0,
"trade_count": 0.0, "trade_count": 0.0,
"risk_count": 0.0, "risk_count": 0.0,
} }
@ -370,6 +367,8 @@ class DecisionEnv:
total_return=0.0, total_return=0.0,
max_drawdown=0.0, max_drawdown=0.0,
volatility=0.0, volatility=0.0,
sharpe_like=0.0,
calmar_like=0.0,
nav_series=[], nav_series=[],
trades=trades or [], trades=trades or [],
turnover=0.0, turnover=0.0,
@ -403,6 +402,8 @@ class DecisionEnv:
volatility = math.sqrt(variance) / base_nav volatility = math.sqrt(variance) / base_nav
else: else:
volatility = 0.0 volatility = 0.0
sharpe_like = total_return / volatility if abs(volatility) > 1e-9 else 0.0
calmar_like = total_return / max_drawdown if max_drawdown > 1e-6 else total_return
turnover_value = 0.0 turnover_value = 0.0
turnover_ratios: List[float] = [] turnover_ratios: List[float] = []
@ -433,6 +434,8 @@ class DecisionEnv:
total_return=float(total_return), total_return=float(total_return),
max_drawdown=float(max_drawdown), max_drawdown=float(max_drawdown),
volatility=volatility, volatility=volatility,
sharpe_like=float(sharpe_like),
calmar_like=float(calmar_like),
nav_series=nav_series, nav_series=nav_series,
trades=trades or [], trades=trades or [],
turnover=float(avg_turnover_ratio), turnover=float(avg_turnover_ratio),
@ -446,8 +449,9 @@ class DecisionEnv:
def _default_reward(metrics: EpisodeMetrics) -> float: def _default_reward(metrics: EpisodeMetrics) -> float:
risk_penalty = 0.05 * metrics.risk_count risk_penalty = 0.05 * metrics.risk_count
turnover_penalty = 0.1 * metrics.turnover turnover_penalty = 0.1 * metrics.turnover
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty drawdown_penalty = 0.5 * metrics.max_drawdown
return metrics.total_return - penalty bonus = 0.1 * metrics.sharpe_like + 0.05 * metrics.calmar_like
return metrics.total_return + bonus - (drawdown_penalty + risk_penalty + turnover_penalty)
def _build_observation( def _build_observation(
self, self,
@ -461,6 +465,7 @@ class DecisionEnv:
"max_drawdown": metrics.max_drawdown, "max_drawdown": metrics.max_drawdown,
"volatility": metrics.volatility, "volatility": metrics.volatility,
"sharpe_like": metrics.sharpe_like, "sharpe_like": metrics.sharpe_like,
"calmar_like": metrics.calmar_like,
"turnover": metrics.turnover, "turnover": metrics.turnover,
"trade_count": float(metrics.trade_count), "trade_count": float(metrics.trade_count),
"risk_count": float(metrics.risk_count), "risk_count": float(metrics.risk_count),
@ -627,6 +632,8 @@ class DecisionEnv:
total_return=0.0, total_return=0.0,
max_drawdown=0.0, max_drawdown=0.0,
volatility=0.0, volatility=0.0,
sharpe_like=0.0,
calmar_like=0.0,
nav_series=nav_series, nav_series=nav_series,
trades=trades, trades=trades,
turnover=0.0, turnover=0.0,

View File

@ -22,7 +22,7 @@ from app.utils.config import (
LLMEndpoint, LLMEndpoint,
get_config, get_config,
) )
from app.llm.metrics import record_call from app.llm.metrics import record_call, record_template_usage
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -332,10 +332,12 @@ def run_llm(
context = None context = None
# Apply template if specified # Apply template if specified
applied_template_version: Optional[str] = None
if template_id: if template_id:
template = TemplateRegistry.get(template_id) template = TemplateRegistry.get(template_id)
if not template: if not template:
raise ValueError(f"Template {template_id} not found") raise ValueError(f"Template {template_id} not found")
applied_template_version = TemplateRegistry.get_active_version(template_id)
vars_dict = template_vars or {} vars_dict = template_vars or {}
if isinstance(prompt, str): if isinstance(prompt, str):
vars_dict["prompt"] = prompt vars_dict["prompt"] = prompt
@ -356,6 +358,11 @@ def run_llm(
if context: if context:
context.add_message(Message(role="assistant", content=response)) context.add_message(Message(role="assistant", content=response))
if template_id:
record_template_usage(
template_id,
version=applied_template_version,
)
return response return response

View File

@ -1,6 +1,7 @@
"""Simple runtime metrics collector for LLM calls.""" """Simple runtime metrics collector for LLM calls."""
from __future__ import annotations from __future__ import annotations
import copy
import logging import logging
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -19,6 +20,7 @@ class _Metrics:
decision_action_counts: Dict[str, int] = field(default_factory=dict) decision_action_counts: Dict[str, int] = field(default_factory=dict)
total_latency: float = 0.0 total_latency: float = 0.0
latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200)) latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200))
template_usage: Dict[str, Dict[str, object]] = field(default_factory=dict)
_METRICS = _Metrics() _METRICS = _Metrics()
@ -78,6 +80,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
else 0.0 else 0.0
), ),
"latency_samples": list(_METRICS.latency_samples), "latency_samples": list(_METRICS.latency_samples),
"template_usage": copy.deepcopy(_METRICS.template_usage),
} }
if reset: if reset:
_METRICS.total_calls = 0 _METRICS.total_calls = 0
@ -89,6 +92,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
_METRICS.decisions.clear() _METRICS.decisions.clear()
_METRICS.total_latency = 0.0 _METRICS.total_latency = 0.0
_METRICS.latency_samples.clear() _METRICS.latency_samples.clear()
_METRICS.template_usage.clear()
return data return data
@ -128,6 +132,38 @@ def record_decision(
_notify_listeners() _notify_listeners()
def record_template_usage(
template_id: str,
*,
version: Optional[str],
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
) -> None:
"""Record usage statistics for a specific prompt template."""
if not template_id:
return
label = template_id.strip()
version_label = version or "active"
with _LOCK:
entry = _METRICS.template_usage.setdefault(
label,
{"total_calls": 0, "versions": {}},
)
entry["total_calls"] = int(entry.get("total_calls", 0)) + 1
versions = entry.setdefault("versions", {})
version_entry = versions.setdefault(
version_label,
{"calls": 0, "prompt_tokens": 0, "completion_tokens": 0},
)
version_entry["calls"] = int(version_entry.get("calls", 0)) + 1
if prompt_tokens:
version_entry["prompt_tokens"] = int(version_entry.get("prompt_tokens", 0)) + int(prompt_tokens)
if completion_tokens:
version_entry["completion_tokens"] = int(version_entry.get("completion_tokens", 0)) + int(completion_tokens)
_notify_listeners()
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]: def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
"""Return the most recent decisions up to limit.""" """Return the most recent decisions up to limit."""

View File

@ -199,6 +199,14 @@ class TemplateRegistry:
collected[template_id] = active.template if active else template collected[template_id] = active.template if active else template
return list(collected.values()) return list(collected.values())
@classmethod
def list_template_ids(cls) -> List[str]:
"""Return all known template IDs in sorted order."""
ids = set(cls._templates.keys())
manager = cls._manager()
ids.update(manager.list_template_ids())
return sorted(ids)
@classmethod @classmethod
def list_versions(cls, template_id: str) -> List[str]: def list_versions(cls, template_id: str) -> List[str]:
"""List available version labels for a template.""" """List available version labels for a template."""
@ -206,6 +214,49 @@ class TemplateRegistry:
manager = cls._manager() manager = cls._manager()
return [ver.version for ver in manager.list_versions(template_id)] return [ver.version for ver in manager.list_versions(template_id)]
@classmethod
def list_version_details(cls, template_id: str) -> List[Dict[str, Any]]:
"""Return detailed information for each version of a template."""
manager = cls._manager()
versions = manager.list_version_details(template_id)
details: List[Dict[str, Any]] = []
for entry in versions:
details.append(
{
"version": entry.version,
"created_at": entry.created_at,
"is_active": entry.is_active,
"metadata": dict(entry.metadata or {}),
}
)
if not details and template_id in cls._templates:
details.append(
{
"version": cls._default_version_label,
"created_at": None,
"is_active": True,
"metadata": {},
}
)
return details
@classmethod
def update_version_metadata(cls, template_id: str, version: str, metadata: Dict[str, Any]) -> None:
"""Update metadata for a template version."""
manager = cls._manager()
manager.update_metadata(template_id, version, metadata)
@classmethod
def export_versions(cls, template_id: str) -> Optional[str]:
"""Export template versions for backup."""
manager = cls._manager()
if not manager.list_versions(template_id):
return None
return manager.export_versions(template_id)
@classmethod @classmethod
def load_from_json(cls, json_str: str) -> None: def load_from_json(cls, json_str: str) -> None:
"""Load templates from JSON string.""" """Load templates from JSON string."""

View File

@ -80,9 +80,13 @@ class TemplateVersionManager:
self._versions: Dict[str, Dict[str, TemplateVersion]] = {} self._versions: Dict[str, Dict[str, TemplateVersion]] = {}
self._active_versions: Dict[str, str] = {} self._active_versions: Dict[str, str] = {}
def add_version(self, template: PromptTemplate, version: str, def add_version(
self,
template: PromptTemplate,
version: str,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
activate: bool = False) -> TemplateVersion: activate: bool = False,
) -> TemplateVersion:
"""Add a new template version.""" """Add a new template version."""
if template.id not in self._versions: if template.id not in self._versions:
self._versions[template.id] = {} self._versions[template.id] = {}
@ -111,6 +115,14 @@ class TemplateVersionManager:
"""List all versions of a template.""" """List all versions of a template."""
return list(self._versions.get(template_id, {}).values()) return list(self._versions.get(template_id, {}).values())
def list_template_ids(self) -> List[str]:
"""Return template IDs currently tracked by the version manager."""
return list(self._versions.keys())
def list_version_details(self, template_id: str) -> List[TemplateVersion]:
"""Alias for list_versions to expose structured version records."""
return self.list_versions(template_id)
def get_active_version(self, template_id: str) -> Optional[TemplateVersion]: def get_active_version(self, template_id: str) -> Optional[TemplateVersion]:
"""Get the active version of a template.""" """Get the active version of a template."""
active_version = self._active_versions.get(template_id) active_version = self._active_versions.get(template_id)
@ -179,3 +191,10 @@ class TemplateVersionManager:
active_version = data.get("active_version") active_version = data.get("active_version")
if active_version: if active_version:
self.activate_version(template_id, active_version) self.activate_version(template_id, active_version)
def update_metadata(self, template_id: str, version: str, metadata: Dict[str, Any]) -> None:
"""Update metadata for a specific template version."""
version_obj = self.get_version(template_id, version)
if version_obj is None:
raise ValueError(f"Version {version} not found for template {template_id}")
version_obj.metadata = metadata or {}

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import json import json
from datetime import date, datetime from datetime import date, datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import streamlit as st import streamlit as st
import numpy as np
from app.agents.base import AgentContext from app.agents.base import AgentContext
from app.agents.game import Decision from app.agents.game import Decision
@ -41,6 +42,80 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results" _DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results" _DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results"
def _normalize_nav_records(records: List[Dict[str, object]]) -> pd.DataFrame:
"""Return nav dataframe with columns trade_date (datetime) and nav (float)."""
if not records:
return pd.DataFrame(columns=["trade_date", "nav"])
df = pd.DataFrame(records)
if "trade_date" not in df.columns:
if "date" in df.columns:
df = df.rename(columns={"date": "trade_date"})
elif "ts" in df.columns:
df = df.rename(columns={"ts": "trade_date"})
if "nav" not in df.columns:
# fallback: look for value column
candidates = [col for col in df.columns if col not in {"trade_date", "date", "ts"}]
if candidates:
df = df.rename(columns={candidates[0]: "nav"})
if "trade_date" not in df.columns or "nav" not in df.columns:
return pd.DataFrame(columns=["trade_date", "nav"])
df = df.copy()
df["trade_date"] = pd.to_datetime(df["trade_date"], errors="coerce")
df["nav"] = pd.to_numeric(df["nav"], errors="coerce")
df = df.dropna(subset=["trade_date", "nav"]).sort_values("trade_date")
return df[["trade_date", "nav"]]
def _normalize_trade_records(records: List[Dict[str, object]]) -> pd.DataFrame:
if not records:
return pd.DataFrame()
df = pd.DataFrame(records)
if "trade_date" not in df.columns:
for candidate in ("date", "ts", "timestamp"):
if candidate in df.columns:
df = df.rename(columns={candidate: "trade_date"})
break
if "trade_date" in df.columns:
df["trade_date"] = pd.to_datetime(df["trade_date"], errors="coerce")
return df
def _compute_nav_metrics(nav_df: pd.DataFrame, trades_df: pd.DataFrame) -> Dict[str, object]:
if nav_df.empty:
return {
"total_return": None,
"max_drawdown": None,
"trade_count": len(trades_df),
"avg_turnover": None,
"risk_events": None,
}
values = nav_df["nav"].astype(float).values
if values.size == 0:
return {
"total_return": None,
"max_drawdown": None,
"trade_count": len(trades_df),
"avg_turnover": None,
"risk_events": None,
}
total_return = float(values[-1] / values[0] - 1.0) if values[0] != 0 else None
cumulative_max = np.maximum.accumulate(values)
drawdowns = (values - cumulative_max) / cumulative_max
max_drawdown = float(drawdowns.min()) if drawdowns.size else None
summary = {
"total_return": total_return,
"max_drawdown": max_drawdown,
"trade_count": int(len(trades_df)),
"avg_turnover": trades_df["turnover"].mean() if "turnover" in trades_df.columns and not trades_df.empty else None,
"risk_events": None,
}
return summary
def _session_compare_store() -> Dict[str, Dict[str, object]]:
return st.session_state.setdefault("bt_compare_runs", {})
def render_backtest_review() -> None: def render_backtest_review() -> None:
"""渲染回测执行、调参与结果复盘页面。""" """渲染回测执行、调参与结果复盘页面。"""
st.header("回测与复盘") st.header("回测与复盘")
@ -220,8 +295,24 @@ def render_backtest_review() -> None:
} }
) )
update_dashboard_sidebar(metrics) update_dashboard_sidebar(metrics)
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades} nav_df = _normalize_nav_records(result.nav_series)
st.json(st.session_state["backtest_last_result"]) trades_df = _normalize_trade_records(result.trades)
summary = _compute_nav_metrics(nav_df, trades_df)
summary["risk_events"] = len(result.risk_events or [])
st.session_state["backtest_last_result"] = {
"nav_records": nav_df.to_dict(orient="records"),
"trades": trades_df.to_dict(orient="records"),
"risk_events": result.risk_events or [],
}
st.session_state["backtest_last_summary"] = summary
st.session_state["backtest_last_config"] = {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"universe": universe,
"params": dict(backtest_params),
"structures": selected_structures,
"name": backtest_cfg.name,
}
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
LOGGER.exception("回测执行失败", extra=LOG_EXTRA) LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
status_box.update(label="回测执行失败", state="error") status_box.update(label="回测执行失败", state="error")
@ -229,16 +320,81 @@ def render_backtest_review() -> None:
last_result = st.session_state.get("backtest_last_result") last_result = st.session_state.get("backtest_last_result")
if last_result: if last_result:
last_summary = st.session_state.get("backtest_last_summary", {})
last_config = st.session_state.get("backtest_last_config", {})
st.markdown("#### 最近回测输出") st.markdown("#### 最近回测输出")
st.json(last_result) nav_preview = _normalize_nav_records(last_result.get("nav_records", []))
if not nav_preview.empty:
import plotly.graph_objects as go
fig_last = go.Figure()
fig_last.add_trace(
go.Scatter(
x=nav_preview["trade_date"],
y=nav_preview["nav"],
mode="lines",
name="NAV",
)
)
fig_last.update_layout(height=260, margin=dict(l=10, r=10, t=30, b=10))
st.plotly_chart(fig_last, width="stretch")
metric_cols = st.columns(4)
metric_cols[0].metric("总收益", f"{(last_summary.get('total_return') or 0.0)*100:.2f}%", delta=None)
metric_cols[1].metric("最大回撤", f"{(last_summary.get('max_drawdown') or 0.0)*100:.2f}%")
metric_cols[2].metric("交易数", last_summary.get("trade_count", 0))
metric_cols[3].metric("风险事件", last_summary.get("risk_events", 0))
default_label = (
f"{last_config.get('name', '临时实验')} | {last_config.get('start_date', '')}~{last_config.get('end_date', '')}"
).strip(" |~")
save_col, button_col = st.columns([4, 1])
save_label = save_col.text_input(
"保存至实验对比(可编辑标签)",
value=default_label or f"实验_{datetime.now().strftime('%H%M%S')}",
key="bt_save_label",
)
if button_col.button("保存", key="bt_save_button"):
label = save_label.strip() or f"实验_{datetime.now().strftime('%H%M%S')}"
store = _session_compare_store()
store[label] = {
"cfg_id": f"session::{label}",
"nav": nav_preview.to_dict(orient="records"),
"summary": dict(last_summary),
"config": dict(last_config),
"risk_events": last_result.get("risk_events", []),
}
st.success(f"已保存实验:{label}")
with st.expander("最近回测详情", expanded=False):
st.json(
{
"config": last_config,
"summary": last_summary,
"trades": last_result.get("trades", [])[:50],
}
)
st.divider() st.divider()
# ADD: Comparison view for multiple backtest configurations st.caption("从历史回测配置或本页保存的实验中选择多个进行净值曲线与指标对比。")
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize") normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y") use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"] metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics") selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
session_store = _session_compare_store()
if session_store:
with st.expander("会话实验管理", expanded=False):
st.write("会话实验仅保存在当前浏览器窗口中,可选择删除以保持列表精简。")
removal_choices = st.multiselect(
"选择要删除的会话实验",
list(session_store.keys()),
key="bt_cmp_remove_runs",
)
if st.button("删除选中实验", key="bt_cmp_remove_button") and removal_choices:
for label in removal_choices:
session_store.pop(label, None)
st.success("已删除选中的会话实验。")
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
cfg_rows = conn.execute( cfg_rows = conn.execute(
@ -247,89 +403,71 @@ def render_backtest_review() -> None:
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA) LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
cfg_rows = [] cfg_rows = []
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs") option_map: Dict[str, Tuple[str, object]] = {}
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] option_labels: List[str] = []
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame() for label in session_store.keys():
risk_df = pd.DataFrame() option_label = f"[会话] {label}"
if selected_ids: option_labels.append(option_label)
option_map[option_label] = ("session", label)
for row in cfg_rows:
option_label = f"[DB] {row['id']} | {row['name']}"
option_labels.append(option_label)
option_map[option_label] = ("db", row["id"])
if not option_labels:
st.info("暂无可对比的回测实验,请先执行回测或保存实验。")
else:
default_selection = option_labels[:2]
selected_labels = st.multiselect(
"选择实验配置",
option_labels,
default=default_selection,
key="bt_cmp_configs",
)
selected_db_ids = [option_map[label][1] for label in selected_labels if option_map[label][0] == "db"]
selected_session_labels = [option_map[label][1] for label in selected_labels if option_map[label][0] == "session"]
nav_frames: List[pd.DataFrame] = []
metrics_rows: List[Dict[str, object]] = []
risk_frames: List[pd.DataFrame] = []
if selected_db_ids:
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
nav_df = pd.read_sql_query( db_nav = pd.read_sql_query(
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), "SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
conn, conn,
params=tuple(selected_ids), params=tuple(selected_db_ids),
) )
rpt_df = pd.read_sql_query( db_rpt = pd.read_sql_query(
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), "SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
conn, conn,
params=tuple(selected_ids), params=tuple(selected_db_ids),
) )
risk_df = pd.read_sql_query( db_risk = pd.read_sql_query(
"SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata " "SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata "
"FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), "FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
conn, conn,
params=tuple(selected_ids), params=tuple(selected_db_ids),
) )
except Exception: # noqa: BLE001 if not db_nav.empty:
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) db_nav["trade_date"] = pd.to_datetime(db_nav["trade_date"], errors="coerce")
st.error("读取回测结果失败") nav_frames.append(db_nav)
nav_df = pd.DataFrame() if not db_risk.empty:
rpt_df = pd.DataFrame() db_risk["trade_date"] = pd.to_datetime(db_risk["trade_date"], errors="coerce")
risk_df = pd.DataFrame() risk_frames.append(db_risk)
start_filter: Optional[date] = None for _, row in db_rpt.iterrows():
end_filter: Optional[date] = None
if not nav_df.empty:
try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
# ADD: date window filter
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
col_d1, col_d2 = st.columns(2)
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
if start_filter > end_filter:
start_filter, end_filter = end_filter, start_filter
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
nav_df = nav_df.loc[mask]
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
if normalize_to_one:
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
import plotly.graph_objects as go
fig = go.Figure()
for col in pivot.columns:
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
if use_log_y:
fig.update_yaxes(type="log")
st.plotly_chart(fig, width='stretch')
# ADD: export pivot
try:
csv_buf = pivot.reset_index()
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
st.download_button(
"下载曲线(CSV)",
data=csv_buf.to_csv(index=False),
file_name="bt_nav_compare.csv",
mime="text/csv",
key="dl_nav_compare",
)
except Exception:
pass
except Exception: # noqa: BLE001
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
if not rpt_df.empty:
try:
metrics_rows: List[Dict[str, object]] = []
for _, row in rpt_df.iterrows():
cfg_id = row["cfg_id"]
try: try:
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {}) summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
except json.JSONDecodeError: except json.JSONDecodeError:
summary = {} summary = {}
record = { metrics_rows.append(
"cfg_id": cfg_id, {
"cfg_id": row["cfg_id"],
"总收益": summary.get("total_return"), "总收益": summary.get("total_return"),
"最大回撤": summary.get("max_drawdown"), "最大回撤": summary.get("max_drawdown"),
"交易数": summary.get("trade_count"), "交易数": summary.get("trade_count"),
@ -344,89 +482,144 @@ def render_backtest_review() -> None:
"派生字段": json.dumps(summary.get("derived_field_counts"), ensure_ascii=False) "派生字段": json.dumps(summary.get("derived_field_counts"), ensure_ascii=False)
if summary.get("derived_field_counts") if summary.get("derived_field_counts")
else None, else None,
"参数": json.dumps(summary.get("config_params"), ensure_ascii=False)
if summary.get("config_params")
else None,
"备注": summary.get("note"),
} }
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) )
if metrics_rows: except Exception: # noqa: BLE001
dfm = pd.DataFrame(metrics_rows) LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
st.dataframe(dfm, hide_index=True, width='stretch') st.error("读取数据库中的回测结果失败,详见日志。")
for label in selected_session_labels:
data = session_store.get(label)
if not data:
continue
nav_df_session = _normalize_nav_records(data.get("nav", []))
if not nav_df_session.empty:
nav_df_session = nav_df_session.assign(cfg_id=data.get("cfg_id"))
nav_frames.append(nav_df_session)
summary = data.get("summary", {})
metrics_rows.append(
{
"cfg_id": data.get("cfg_id"),
"总收益": summary.get("total_return"),
"最大回撤": summary.get("max_drawdown"),
"交易数": summary.get("trade_count"),
"平均换手": summary.get("avg_turnover"),
"风险事件": summary.get("risk_events"),
"参数": json.dumps(data.get("config", {}).get("params"), ensure_ascii=False)
if data.get("config", {}).get("params")
else None,
"备注": json.dumps(
{
"structures": data.get("config", {}).get("structures"),
"universe_size": len(data.get("config", {}).get("universe", [])),
},
ensure_ascii=False,
),
}
)
risk_events = data.get("risk_events") or []
if risk_events:
risk_df_session = pd.DataFrame(risk_events)
if not risk_df_session.empty:
if "trade_date" in risk_df_session.columns:
risk_df_session["trade_date"] = pd.to_datetime(risk_df_session["trade_date"], errors="coerce")
risk_df_session = risk_df_session.assign(cfg_id=data.get("cfg_id"))
risk_frames.append(risk_df_session)
if not selected_labels:
st.info("请选择至少一个实验进行对比。")
else:
nav_df = pd.concat(nav_frames, ignore_index=True) if nav_frames else pd.DataFrame()
if not nav_df.empty:
nav_df = nav_df.dropna(subset=["trade_date", "nav"])
if not nav_df.empty:
overall_min = nav_df["trade_date"].min().date()
overall_max = nav_df["trade_date"].max().date()
col_d1, col_d2 = st.columns(2)
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
if start_filter > end_filter:
start_filter, end_filter = end_filter, start_filter
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
nav_df = nav_df.loc[mask]
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
if normalize_to_one:
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
import plotly.graph_objects as go
fig = go.Figure()
for col in pivot.columns:
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
fig.update_layout(height=320, margin=dict(l=10, r=10, t=30, b=10))
if use_log_y:
fig.update_yaxes(type="log")
st.plotly_chart(fig, width="stretch")
try:
csv_buf = pivot.reset_index()
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
st.download_button(
"下载曲线(CSV)",
data=csv_buf.to_csv(index=False),
file_name="bt_nav_compare.csv",
mime="text/csv",
key="dl_nav_compare",
)
except Exception:
pass
metric_df = pd.DataFrame(metrics_rows)
if not metric_df.empty:
display_cols = ["cfg_id"] + [col for col in selected_metrics if col in metric_df.columns]
additional_cols = [col for col in metric_df.columns if col not in display_cols]
metric_df = metric_df.loc[:, display_cols + additional_cols]
st.dataframe(metric_df, hide_index=True, width="stretch")
try: try:
st.download_button( st.download_button(
"下载指标(CSV)", "下载指标(CSV)",
data=dfm.to_csv(index=False), data=metric_df.to_csv(index=False),
file_name="bt_metrics_compare.csv", file_name="bt_metrics_compare.csv",
mime="text/csv", mime="text/csv",
key="dl_metrics_compare", key="dl_metrics_compare",
) )
except Exception: except Exception:
pass pass
except Exception: # noqa: BLE001
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) risk_df = pd.concat(risk_frames, ignore_index=True) if risk_frames else pd.DataFrame()
if not risk_df.empty: if not risk_df.empty:
try: st.caption("风险事件详情(按 cfg_id 聚合)。")
risk_df["trade_date"] = pd.to_datetime(risk_df["trade_date"], errors="coerce") st.dataframe(risk_df, hide_index=True, width="stretch")
risk_df = risk_df.dropna(subset=["trade_date"])
if start_filter is None or end_filter is None:
start_filter = pd.to_datetime(risk_df["trade_date"].min()).date()
end_filter = pd.to_datetime(risk_df["trade_date"].max()).date()
risk_df = risk_df[
(risk_df["trade_date"].dt.date >= start_filter)
& (risk_df["trade_date"].dt.date <= end_filter)
]
parsed_cols: List[Dict[str, object]] = []
for _, row in risk_df.iterrows():
try:
metadata = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else (row["metadata"] or {})
except json.JSONDecodeError:
metadata = {}
assessment = metadata.get("risk_assessment") or {}
parsed_cols.append(
{
"cfg_id": row["cfg_id"],
"trade_date": row["trade_date"].date().isoformat(),
"ts_code": row["ts_code"],
"reason": row["reason"],
"action": row["action"],
"target_weight": row["target_weight"],
"confidence": row["confidence"],
"risk_status": assessment.get("status"),
"recommended_action": assessment.get("recommended_action"),
"execution_status": metadata.get("execution_status"),
"metadata": metadata,
}
)
risk_detail_df = pd.DataFrame(parsed_cols)
with st.expander("风险事件明细", expanded=False):
st.dataframe(risk_detail_df.drop(columns=["metadata"], errors="ignore"), hide_index=True, width='stretch')
try: try:
st.download_button( st.download_button(
"下载风险事件(CSV)", "下载风险事件(CSV)",
data=risk_detail_df.to_csv(index=False), data=risk_df.to_csv(index=False),
file_name="bt_risk_events.csv", file_name="bt_risk_events.csv",
mime="text/csv", mime="text/csv",
key="dl_risk_events", key="dl_risk_events",
) )
except Exception: except Exception:
pass pass
agg = risk_detail_df.groupby(["cfg_id", "reason", "risk_status"], dropna=False).size().reset_index(name="count") if {"cfg_id", "reason"}.issubset(risk_df.columns):
st.dataframe(agg, hide_index=True, width='stretch') group_cols = [col for col in ["cfg_id", "reason", "risk_status"] if col in risk_df.columns]
agg = risk_df.groupby(group_cols, dropna=False).size().reset_index(name="count")
st.dataframe(agg, hide_index=True, width="stretch")
try: try:
if not agg.empty: if not agg.empty:
agg_fig = px.bar( agg_fig = px.bar(
agg, agg,
x="reason", x="reason" if "reason" in agg.columns else agg.columns[1],
y="count", y="count",
color="risk_status", color="risk_status" if "risk_status" in agg.columns else None,
facet_col="cfg_id", facet_col="cfg_id" if "cfg_id" in agg.columns else None,
title="风险事件分布", title="风险事件分布",
) )
agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20)) agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20))
st.plotly_chart(agg_fig, width="stretch") st.plotly_chart(agg_fig, width="stretch")
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA) LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA)
except Exception: # noqa: BLE001
LOGGER.debug("渲染风险事件失败", extra=LOG_EXTRA)
else:
st.info("请选择至少一个配置进行对比。")
@ -741,6 +934,8 @@ def render_backtest_review() -> None:
"总收益": episode.metrics.total_return, "总收益": episode.metrics.total_return,
"最大回撤": episode.metrics.max_drawdown, "最大回撤": episode.metrics.max_drawdown,
"波动率": episode.metrics.volatility, "波动率": episode.metrics.volatility,
"Sharpe": episode.metrics.sharpe_like,
"Calmar": episode.metrics.calmar_like,
"权重": json.dumps(episode.weights or {}, ensure_ascii=False), "权重": json.dumps(episode.weights or {}, ensure_ascii=False),
"部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False), "部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False),
} }
@ -756,6 +951,12 @@ def render_backtest_review() -> None:
"action": best_episode.action if best_episode else None, "action": best_episode.action if best_episode else None,
"resolved_action": best_episode.resolved_action if best_episode else None, "resolved_action": best_episode.resolved_action if best_episode else None,
"weights": best_episode.weights if best_episode else None, "weights": best_episode.weights if best_episode else None,
"metrics": {
"total_return": best_episode.metrics.total_return if best_episode else None,
"sharpe_like": best_episode.metrics.sharpe_like if best_episode else None,
"calmar_like": best_episode.metrics.calmar_like if best_episode else None,
"max_drawdown": best_episode.metrics.max_drawdown if best_episode else None,
} if best_episode else None,
"department_controls": best_episode.department_controls if best_episode else None, "department_controls": best_episode.department_controls if best_episode else None,
}, },
"experiment_id": config.experiment_id, "experiment_id": config.experiment_id,
@ -781,6 +982,13 @@ def render_backtest_review() -> None:
col_best1.json(best_payload.get("action") or {}) col_best1.json(best_payload.get("action") or {})
col_best2.write("参数值:") col_best2.write("参数值:")
col_best2.json(best_payload.get("resolved_action") or {}) col_best2.json(best_payload.get("resolved_action") or {})
metrics_payload = best_payload.get("metrics") or {}
if metrics_payload:
col_m1, col_m2, col_m3 = st.columns(3)
col_m1.metric("总收益", f"{metrics_payload.get('total_return', 0.0):+.4f}")
col_m2.metric("Sharpe", f"{metrics_payload.get('sharpe_like', 0.0):.3f}")
col_m3.metric("Calmar", f"{metrics_payload.get('calmar_like', 0.0):.3f}")
st.caption(f"最大回撤:{metrics_payload.get('max_drawdown', 0.0):.3f}")
weights_payload = best_payload.get("weights") or {} weights_payload = best_payload.get("weights") or {}
if weights_payload: if weights_payload:
st.write("对应代理权重:") st.write("对应代理权重:")

View File

@ -1,6 +1,7 @@
"""系统设置相关视图。""" """系统设置相关视图。"""
from __future__ import annotations from __future__ import annotations
import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -10,6 +11,8 @@ from requests.exceptions import RequestException
import streamlit as st import streamlit as st
from app.llm.client import llm_config_snapshot from app.llm.client import llm_config_snapshot
from app.llm.metrics import snapshot as llm_metrics_snapshot
from app.llm.templates import TemplateRegistry
from app.utils.config import ( from app.utils.config import (
ALLOWED_LLM_STRATEGIES, ALLOWED_LLM_STRATEGIES,
DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_BASE_URLS,
@ -471,6 +474,239 @@ def render_llm_settings() -> None:
else: else:
dept_rows = dept_editor dept_rows = dept_editor
st.divider()
st.markdown("##### 提示模板治理")
template_ids = TemplateRegistry.list_template_ids()
if not template_ids:
st.info("尚未注册任何提示模板。")
else:
template_id = st.selectbox(
"选择模板",
template_ids,
key="prompt_template_select",
)
version_details = TemplateRegistry.list_version_details(template_id)
raw_versions = TemplateRegistry.list_versions(template_id)
active_version = None
if version_details:
for detail in version_details:
if detail.get("is_active"):
active_version = detail["version"]
break
if active_version is None:
active_version = version_details[0]["version"]
usage_snapshot = llm_metrics_snapshot()
template_usage = usage_snapshot.get("template_usage", {}).get(template_id, {})
table_rows: List[Dict[str, object]] = []
for detail in version_details:
metadata_preview = detail.get("metadata") or {}
table_rows.append(
{
"版本": detail["version"],
"创建时间": detail.get("created_at") or "-",
"激活": "" if detail.get("is_active") else "",
"元数据": json.dumps(metadata_preview, ensure_ascii=False, default=str) if metadata_preview else "{}",
}
)
if table_rows:
st.dataframe(pd.DataFrame(table_rows), hide_index=True, width="stretch")
version_options = [row["版本"] for row in table_rows] if table_rows else []
if not version_options:
st.info("当前模板尚未创建版本,建议通过配置文件或 API 注册。")
else:
try:
default_idx = version_options.index(active_version or version_options[0])
except ValueError:
default_idx = 0
selected_version = st.selectbox(
"查看版本",
version_options,
index=default_idx,
key=f"{template_id}_version_select",
)
selected_detail = next(
(detail for detail in version_details if detail["version"] == selected_version),
{"metadata": {}},
)
usage_cols = st.columns(3)
usage_cols[0].metric("累计调用", int(template_usage.get("total_calls", 0)))
version_usage = (template_usage.get("versions") or {}).get(selected_version, {})
usage_cols[1].metric("版本调用", int(version_usage.get("calls", 0)))
usage_cols[2].metric(
"Prompt Tokens",
int(version_usage.get("prompt_tokens", 0)),
)
template_obj = TemplateRegistry.get(template_id, version=selected_version)
if template_obj:
with st.expander("模板内容预览", expanded=False):
st.write(f"名称:{template_obj.name}")
st.write(f"描述:{template_obj.description or '-'}")
st.write(f"变量:{', '.join(template_obj.variables) if template_obj.variables else ''}")
st.code(template_obj.template, language="markdown")
metadata_str = json.dumps(selected_detail.get("metadata") or {}, ensure_ascii=False, indent=2, default=str)
metadata_input = st.text_area(
"版本元数据JSON",
value=metadata_str,
height=200,
key=f"{template_id}_{selected_version}_metadata",
)
meta_buttons = st.columns(3)
enable_version_actions = bool(raw_versions)
if meta_buttons[0].button(
"保存元数据",
key=f"{template_id}_{selected_version}_save_metadata",
disabled=not enable_version_actions,
):
try:
new_metadata = json.loads(metadata_input or "{}")
except json.JSONDecodeError as exc:
st.error(f"元数据格式错误:{exc}")
else:
try:
TemplateRegistry.update_version_metadata(template_id, selected_version, new_metadata)
st.success("元数据已更新。")
st.rerun()
except Exception as exc: # noqa: BLE001
st.error(f"更新元数据失败:{exc}")
if meta_buttons[1].button(
"设为激活版本",
key=f"{template_id}_{selected_version}_activate",
disabled=(selected_version == active_version) or not enable_version_actions,
):
try:
TemplateRegistry.activate_version(template_id, selected_version)
st.success(f"{template_id} 已切换至版本 {selected_version}")
st.rerun()
except Exception as exc: # noqa: BLE001
st.error(f"切换版本失败:{exc}")
export_payload = TemplateRegistry.export_versions(template_id) if enable_version_actions else None
meta_buttons[2].download_button(
"导出版本 JSON",
data=export_payload or "",
file_name=f"{template_id}_versions.json",
mime="application/json",
key=f"{template_id}_download_versions",
disabled=not export_payload,
)
st.divider()
st.markdown("##### 部门遥测可视化")
telemetry_limit = st.slider(
"遥测查询条数",
min_value=50,
max_value=500,
value=200,
step=50,
help="限制查询 agent_utils 表中的最新记录数量。",
key="telemetry_limit",
)
telemetry_rows: List[Dict[str, object]] = []
try:
with db_session(read_only=True) as conn:
raw_rows = conn.execute(
"""
SELECT trade_date, ts_code, agent, utils
FROM agent_utils
ORDER BY trade_date DESC, ts_code
LIMIT ?
""",
(telemetry_limit,),
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取 agent_utils 遥测失败", extra=LOG_EXTRA)
raw_rows = []
for row in raw_rows:
trade_date = row["trade_date"]
ts_code = row["ts_code"]
agent = row["agent"]
try:
utils_payload = json.loads(row["utils"] or "{}")
except json.JSONDecodeError:
utils_payload = {}
if agent == "global":
telemetry_map = utils_payload.get("_department_telemetry") or {}
for dept_code, payload in telemetry_map.items():
if not isinstance(payload, dict):
payload = {"value": payload}
record = {
"trade_date": trade_date,
"ts_code": ts_code,
"agent": agent,
"department": dept_code,
"source": "global",
"telemetry": json.dumps(payload, ensure_ascii=False, default=str),
}
for key, value in payload.items():
if isinstance(value, (int, float, bool, str)):
record.setdefault(key, value)
telemetry_rows.append(record)
elif agent.startswith("dept_"):
dept_code = agent.split("dept_", 1)[-1]
payload = utils_payload.get("_telemetry") or {}
if not isinstance(payload, dict):
payload = {"value": payload}
record = {
"trade_date": trade_date,
"ts_code": ts_code,
"agent": agent,
"department": dept_code,
"source": "department",
"telemetry": json.dumps(payload, ensure_ascii=False, default=str),
}
for key, value in payload.items():
if isinstance(value, (int, float, bool, str)):
record.setdefault(key, value)
telemetry_rows.append(record)
if not telemetry_rows:
st.info("未找到遥测记录,可先运行部门评估流程。")
else:
telemetry_df = pd.DataFrame(telemetry_rows)
telemetry_df["trade_date"] = telemetry_df["trade_date"].astype(str)
trade_dates = sorted(telemetry_df["trade_date"].unique(), reverse=True)
selected_trade_date = st.selectbox(
"交易日",
trade_dates,
index=0,
key="telemetry_trade_date",
)
filtered_df = telemetry_df[telemetry_df["trade_date"] == selected_trade_date]
departments = sorted(filtered_df["department"].dropna().unique())
selected_departments = st.multiselect(
"部门过滤",
departments,
default=departments,
key="telemetry_departments",
)
if selected_departments:
filtered_df = filtered_df[filtered_df["department"].isin(selected_departments)]
numeric_columns = [
col
for col in filtered_df.columns
if col not in {"trade_date", "ts_code", "agent", "department", "source", "telemetry"}
and pd.api.types.is_numeric_dtype(filtered_df[col])
]
metric_cols = st.columns(min(3, max(1, len(numeric_columns))))
for idx, column in enumerate(numeric_columns[: len(metric_cols)]):
column_series = filtered_df[column].dropna()
value = column_series.mean() if not column_series.empty else 0.0
metric_cols[idx].metric(f"{column} 均值", f"{value:.2f}")
st.dataframe(filtered_df, hide_index=True, width="stretch")
st.download_button(
"下载遥测 CSV",
data=filtered_df.to_csv(index=False),
file_name=f"telemetry_{selected_trade_date}.csv",
mime="text/csv",
key="telemetry_download",
)
col_reset, col_save = st.columns([1, 1]) col_reset, col_save = st.columns([1, 1])
if col_save.button("保存部门配置"): if col_save.button("保存部门配置"):

View File

@ -18,9 +18,9 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| DecisionEnv 扩展 | 🔄 | 已支持多步 episode 与部分动作维度,需继续覆盖提示版本、function 策略等。 | | DecisionEnv 扩展 | 🔄 | Episode 指标新增 Sharpe/Calmar奖励函数引入风险惩罚继续覆盖提示版本、function 策略等。 |
| 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 | | 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 |
| 奖励与评估体系 | ⏳ | 需将 `portfolio_trades`/`portfolio_snapshots` 等指标纳入奖励与评估。 | | 奖励与评估体系 | 🔄 | 决策环境奖励已纳入风险/Turnover/Sharpe-Calmar待接入成交与资金曲线指标。 |
| 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 | | 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 |
| 全局参数搜索 | ⏳ | 引入 Bandit、贝叶斯优化或 BOHB 提供权重/参数候选。 | | 全局参数搜索 | ⏳ | 引入 Bandit、贝叶斯优化或 BOHB 提供权重/参数候选。 |
@ -29,9 +29,9 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 | | Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 |
| 提示模板治理 | ⏳ | 待建立模板版本管理、成本监控与性能指标。 | | 提示模板治理 | 🔄 | LLM 设置新增模板版本治理与使用监控,后续补充成本/效果数据。 |
| 部门遥测可视化 | ⏳ | `_telemetry` / `_department_telemetry` 字段需在 UI 中完整展示。 | | 部门遥测可视化 | 🔄 | LLM 设置新增遥测面板,支持分页查看/导出部门 & 全局遥测。 |
| 多轮逻辑博弈框架 | ⏳ | 需实现主持/预测/风险/执行分轮对话、信念修正与冲突解决。 | | 多轮逻辑博弈框架 | 🔄 | 新增主持 briefing、预测对齐及冲突复核轮持续完善信念修正策略。 |
| LLM 稳定性提升 | ⏳ | 持续优化限速、降级、成本控制与缓存策略。 | | LLM 稳定性提升 | ⏳ | 持续优化限速、降级、成本控制与缓存策略。 |
## UI 与监控 ## UI 与监控
@ -39,7 +39,7 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| 一键重评估入口 | ✅ | 今日计划页提供批量/全量重评估入口,待收集反馈再做优化。 | | 一键重评估入口 | ✅ | 今日计划页提供批量/全量重评估入口,待收集反馈再做优化。 |
| 回测实验对比 | ⏳ | 提供提示/温度多版本实验管理与曲线对比。 | | 回测实验对比 | 🔄 | 新增会话实验保存与曲线/指标对比,后续接入更多提示参数。 |
| 实时指标面板 | ✅ | Streamlit 监控页已具备核心实时指标。 | | 实时指标面板 | ✅ | Streamlit 监控页已具备核心实时指标。 |
| 异常日志钻取 | ⏳ | 待补充筛选、定位与历史对比能力。 | | 异常日志钻取 | ⏳ | 待补充筛选、定位与历史对比能力。 |
| 仅监控模式 | ⏳ | 支持“监控不干预”场景的一键复评策略。 | | 仅监控模式 | ⏳ | 支持“监控不干预”场景的一键复评策略。 |

View File

@ -31,10 +31,14 @@ class DummyEnv:
def step(self, action): def step(self, action):
value = float(action[0]) value = float(action[0])
reward = 1.0 - abs(value - 0.7) reward = 1.0 - abs(value - 0.7)
sharpe_like = reward / 0.05 if 0.05 else 0.0
calmar_like = reward / 0.1 if 0.1 else reward
metrics = EpisodeMetrics( metrics = EpisodeMetrics(
total_return=reward, total_return=reward,
max_drawdown=0.1, max_drawdown=0.1,
volatility=0.05, volatility=0.05,
sharpe_like=sharpe_like,
calmar_like=calmar_like,
nav_series=[], nav_series=[],
trades=[], trades=[],
turnover=0.1, turnover=0.1,
@ -49,6 +53,7 @@ class DummyEnv:
"max_drawdown": 0.1, "max_drawdown": 0.1,
"volatility": 0.05, "volatility": 0.05,
"sharpe_like": reward / 0.05, "sharpe_like": reward / 0.05,
"calmar_like": reward / 0.1,
"turnover": 0.1, "turnover": 0.1,
"turnover_value": 1000.0, "turnover_value": 1000.0,
"trade_count": 0.0, "trade_count": 0.0,

View File

@ -139,10 +139,17 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
def test_default_reward_penalizes_metrics(): def test_default_reward_penalizes_metrics():
total_return = 0.1
max_drawdown = 0.2
volatility = 0.05
sharpe_like = total_return / volatility
calmar_like = total_return / max_drawdown
metrics = EpisodeMetrics( metrics = EpisodeMetrics(
total_return=0.1, total_return=total_return,
max_drawdown=0.2, max_drawdown=max_drawdown,
volatility=0.05, volatility=volatility,
sharpe_like=sharpe_like,
calmar_like=calmar_like,
nav_series=[], nav_series=[],
trades=[], trades=[],
turnover=0.3, turnover=0.3,
@ -152,7 +159,13 @@ def test_default_reward_penalizes_metrics():
risk_breakdown={"foo": 2}, risk_breakdown={"foo": 2},
) )
reward = DecisionEnv._default_reward(metrics) reward = DecisionEnv._default_reward(metrics)
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3)) expected = (
total_return
+ 0.1 * sharpe_like
+ 0.05 * calmar_like
- (0.5 * max_drawdown + 0.05 * metrics.risk_count + 0.1 * metrics.turnover)
)
assert reward == pytest.approx(expected)
def test_decision_env_department_controls(monkeypatch): def test_decision_env_department_controls(monkeypatch):