add briefing rounds and enhance backtest comparison view
This commit is contained in:
parent
f9f8ca887f
commit
f6c11867d2
@ -176,6 +176,13 @@ def decide(
|
||||
ts_code=context.ts_code,
|
||||
trade_date=context.trade_date,
|
||||
)
|
||||
briefing_round = host.start_round(
|
||||
host_trace,
|
||||
agenda="situation_briefing",
|
||||
structure=GameStructure.SIGNALING,
|
||||
)
|
||||
host.handle_message(briefing_round, _host_briefing_message(context))
|
||||
host.finalize_round(briefing_round)
|
||||
department_round: Optional[RoundSummary] = None
|
||||
risk_round: Optional[RoundSummary] = None
|
||||
execution_round: Optional[RoundSummary] = None
|
||||
@ -224,6 +231,19 @@ def decide(
|
||||
filtered_utilities = {action: utilities[action] for action in feas_actions}
|
||||
hold_scores = utilities.get(AgentAction.HOLD, {})
|
||||
norm_weights = weight_map(raw_weights)
|
||||
prediction_round = host.start_round(
|
||||
host_trace,
|
||||
agenda="prediction_alignment",
|
||||
structure=GameStructure.REPEATED,
|
||||
)
|
||||
prediction_message, prediction_summary = _prediction_summary_message(filtered_utilities, norm_weights)
|
||||
host.handle_message(prediction_round, prediction_message)
|
||||
host.finalize_round(prediction_round)
|
||||
if prediction_summary:
|
||||
belief_updates["prediction_summary"] = BeliefUpdate(
|
||||
belief=prediction_summary,
|
||||
rationale="Aggregated utilities shared during alignment round.",
|
||||
)
|
||||
|
||||
if method == "vote":
|
||||
action, confidence = vote(filtered_utilities, norm_weights)
|
||||
@ -339,6 +359,22 @@ def decide(
|
||||
department_votes,
|
||||
)
|
||||
belief_revision = revise_beliefs(belief_updates, exec_action)
|
||||
if belief_revision.conflicts:
|
||||
risk_round = host.ensure_round(
|
||||
host_trace,
|
||||
agenda="conflict_resolution",
|
||||
structure=GameStructure.CUSTOM,
|
||||
)
|
||||
conflict_message = DialogueMessage(
|
||||
sender="protocol_host",
|
||||
role=DialogueRole.HOST,
|
||||
message_type=MessageType.COUNTER,
|
||||
content="检测到关键冲突,需要后续回合复核。",
|
||||
annotations={"conflicts": belief_revision.conflicts},
|
||||
)
|
||||
host.handle_message(risk_round, conflict_message)
|
||||
risk_round.notes.setdefault("conflicts", belief_revision.conflicts)
|
||||
host.finalize_round(risk_round)
|
||||
execution_round.notes.setdefault("consensus_action", belief_revision.consensus_action.value)
|
||||
execution_round.notes.setdefault("consensus_confidence", belief_revision.consensus_confidence)
|
||||
if belief_revision.conflicts:
|
||||
@ -413,6 +449,73 @@ def _department_conflict_flag(votes: Mapping[str, float]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _host_briefing_message(context: AgentContext) -> DialogueMessage:
|
||||
features = getattr(context, "features", {}) or {}
|
||||
close = features.get("close") or features.get("daily.close")
|
||||
pct_chg = features.get("pct_chg") or features.get("daily.pct_chg")
|
||||
snapshot = getattr(context, "market_snapshot", {}) or {}
|
||||
index_brief = snapshot.get("index_change")
|
||||
lines = [
|
||||
f"标的 {context.ts_code}",
|
||||
f"交易日 {context.trade_date}",
|
||||
]
|
||||
if close is not None:
|
||||
lines.append(f"最新收盘价:{close}")
|
||||
if pct_chg is not None:
|
||||
lines.append(f"涨跌幅:{pct_chg}")
|
||||
if index_brief:
|
||||
lines.append(f"市场概览:{index_brief}")
|
||||
content = ";".join(str(line) for line in lines)
|
||||
return DialogueMessage(
|
||||
sender="protocol_host",
|
||||
role=DialogueRole.HOST,
|
||||
message_type=MessageType.META,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _prediction_summary_message(
|
||||
utilities: Mapping[AgentAction, Mapping[str, float]],
|
||||
weights: Mapping[str, float],
|
||||
) -> Tuple[DialogueMessage, Dict[str, object]]:
|
||||
if not utilities:
|
||||
message = DialogueMessage(
|
||||
sender="protocol_host",
|
||||
role=DialogueRole.PREDICTION,
|
||||
message_type=MessageType.META,
|
||||
content="暂无可用的部门或代理评分,默认进入 HOLD 讨论。",
|
||||
)
|
||||
return message, {}
|
||||
aggregates: Dict[AgentAction, float] = {}
|
||||
for action, agent_scores in utilities.items():
|
||||
aggregates[action] = sum(weights.get(agent, 0.0) * score for agent, score in agent_scores.items())
|
||||
ranked = sorted(aggregates.items(), key=lambda item: item[1], reverse=True)
|
||||
summary_lines = []
|
||||
for action, score in ranked[:3]:
|
||||
summary_lines.append(f"{action.value}: {score:.3f}")
|
||||
content = "预测合意度:" + " | ".join(summary_lines)
|
||||
total_score = sum(max(score, 0.0) for _, score in ranked)
|
||||
confidence = 0.0
|
||||
if total_score > 0 and ranked:
|
||||
confidence = max(ranked[0][1], 0.0) / total_score
|
||||
annotations = {
|
||||
"aggregates": {action.value: score for action, score in aggregates.items()},
|
||||
}
|
||||
message = DialogueMessage(
|
||||
sender="protocol_host",
|
||||
role=DialogueRole.PREDICTION,
|
||||
message_type=MessageType.META,
|
||||
content=content,
|
||||
confidence=confidence,
|
||||
annotations=annotations,
|
||||
)
|
||||
summary = {
|
||||
"aggregates": {action.value: aggregates[action] for action in ACTIONS if action in aggregates},
|
||||
"confidence": confidence,
|
||||
}
|
||||
return message, summary
|
||||
|
||||
|
||||
def _department_message(code: str, decision: DepartmentDecision) -> DialogueMessage:
|
||||
content = decision.summary or decision.raw_response or decision.action.value
|
||||
references = decision.signals or []
|
||||
|
||||
@ -47,6 +47,8 @@ class EpisodeMetrics:
|
||||
total_return: float
|
||||
max_drawdown: float
|
||||
volatility: float
|
||||
sharpe_like: float
|
||||
calmar_like: float
|
||||
nav_series: List[Dict[str, float]]
|
||||
trades: List[Dict[str, object]]
|
||||
turnover: float
|
||||
@ -55,12 +57,6 @@ class EpisodeMetrics:
|
||||
risk_count: int
|
||||
risk_breakdown: Dict[str, int]
|
||||
|
||||
@property
|
||||
def sharpe_like(self) -> float:
|
||||
if self.volatility <= 1e-9:
|
||||
return 0.0
|
||||
return self.total_return / self.volatility
|
||||
|
||||
|
||||
class DecisionEnv:
|
||||
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
|
||||
@ -123,6 +119,7 @@ class DecisionEnv:
|
||||
"volatility": 0.0,
|
||||
"turnover": 0.0,
|
||||
"sharpe_like": 0.0,
|
||||
"calmar_like": 0.0,
|
||||
"trade_count": 0.0,
|
||||
"risk_count": 0.0,
|
||||
}
|
||||
@ -370,6 +367,8 @@ class DecisionEnv:
|
||||
total_return=0.0,
|
||||
max_drawdown=0.0,
|
||||
volatility=0.0,
|
||||
sharpe_like=0.0,
|
||||
calmar_like=0.0,
|
||||
nav_series=[],
|
||||
trades=trades or [],
|
||||
turnover=0.0,
|
||||
@ -403,6 +402,8 @@ class DecisionEnv:
|
||||
volatility = math.sqrt(variance) / base_nav
|
||||
else:
|
||||
volatility = 0.0
|
||||
sharpe_like = total_return / volatility if abs(volatility) > 1e-9 else 0.0
|
||||
calmar_like = total_return / max_drawdown if max_drawdown > 1e-6 else total_return
|
||||
|
||||
turnover_value = 0.0
|
||||
turnover_ratios: List[float] = []
|
||||
@ -433,6 +434,8 @@ class DecisionEnv:
|
||||
total_return=float(total_return),
|
||||
max_drawdown=float(max_drawdown),
|
||||
volatility=volatility,
|
||||
sharpe_like=float(sharpe_like),
|
||||
calmar_like=float(calmar_like),
|
||||
nav_series=nav_series,
|
||||
trades=trades or [],
|
||||
turnover=float(avg_turnover_ratio),
|
||||
@ -446,8 +449,9 @@ class DecisionEnv:
|
||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||
risk_penalty = 0.05 * metrics.risk_count
|
||||
turnover_penalty = 0.1 * metrics.turnover
|
||||
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||
return metrics.total_return - penalty
|
||||
drawdown_penalty = 0.5 * metrics.max_drawdown
|
||||
bonus = 0.1 * metrics.sharpe_like + 0.05 * metrics.calmar_like
|
||||
return metrics.total_return + bonus - (drawdown_penalty + risk_penalty + turnover_penalty)
|
||||
|
||||
def _build_observation(
|
||||
self,
|
||||
@ -461,6 +465,7 @@ class DecisionEnv:
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"calmar_like": metrics.calmar_like,
|
||||
"turnover": metrics.turnover,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
@ -627,6 +632,8 @@ class DecisionEnv:
|
||||
total_return=0.0,
|
||||
max_drawdown=0.0,
|
||||
volatility=0.0,
|
||||
sharpe_like=0.0,
|
||||
calmar_like=0.0,
|
||||
nav_series=nav_series,
|
||||
trades=trades,
|
||||
turnover=0.0,
|
||||
|
||||
@ -22,7 +22,7 @@ from app.utils.config import (
|
||||
LLMEndpoint,
|
||||
get_config,
|
||||
)
|
||||
from app.llm.metrics import record_call
|
||||
from app.llm.metrics import record_call, record_template_usage
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -332,10 +332,12 @@ def run_llm(
|
||||
context = None
|
||||
|
||||
# Apply template if specified
|
||||
applied_template_version: Optional[str] = None
|
||||
if template_id:
|
||||
template = TemplateRegistry.get(template_id)
|
||||
if not template:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
applied_template_version = TemplateRegistry.get_active_version(template_id)
|
||||
vars_dict = template_vars or {}
|
||||
if isinstance(prompt, str):
|
||||
vars_dict["prompt"] = prompt
|
||||
@ -356,6 +358,11 @@ def run_llm(
|
||||
if context:
|
||||
context.add_message(Message(role="assistant", content=response))
|
||||
|
||||
if template_id:
|
||||
record_template_usage(
|
||||
template_id,
|
||||
version=applied_template_version,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Simple runtime metrics collector for LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
@ -19,6 +20,7 @@ class _Metrics:
|
||||
decision_action_counts: Dict[str, int] = field(default_factory=dict)
|
||||
total_latency: float = 0.0
|
||||
latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200))
|
||||
template_usage: Dict[str, Dict[str, object]] = field(default_factory=dict)
|
||||
|
||||
|
||||
_METRICS = _Metrics()
|
||||
@ -78,6 +80,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
else 0.0
|
||||
),
|
||||
"latency_samples": list(_METRICS.latency_samples),
|
||||
"template_usage": copy.deepcopy(_METRICS.template_usage),
|
||||
}
|
||||
if reset:
|
||||
_METRICS.total_calls = 0
|
||||
@ -89,6 +92,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
_METRICS.decisions.clear()
|
||||
_METRICS.total_latency = 0.0
|
||||
_METRICS.latency_samples.clear()
|
||||
_METRICS.template_usage.clear()
|
||||
return data
|
||||
|
||||
|
||||
@ -128,6 +132,38 @@ def record_decision(
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def record_template_usage(
|
||||
template_id: str,
|
||||
*,
|
||||
version: Optional[str],
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Record usage statistics for a specific prompt template."""
|
||||
|
||||
if not template_id:
|
||||
return
|
||||
label = template_id.strip()
|
||||
version_label = version or "active"
|
||||
with _LOCK:
|
||||
entry = _METRICS.template_usage.setdefault(
|
||||
label,
|
||||
{"total_calls": 0, "versions": {}},
|
||||
)
|
||||
entry["total_calls"] = int(entry.get("total_calls", 0)) + 1
|
||||
versions = entry.setdefault("versions", {})
|
||||
version_entry = versions.setdefault(
|
||||
version_label,
|
||||
{"calls": 0, "prompt_tokens": 0, "completion_tokens": 0},
|
||||
)
|
||||
version_entry["calls"] = int(version_entry.get("calls", 0)) + 1
|
||||
if prompt_tokens:
|
||||
version_entry["prompt_tokens"] = int(version_entry.get("prompt_tokens", 0)) + int(prompt_tokens)
|
||||
if completion_tokens:
|
||||
version_entry["completion_tokens"] = int(version_entry.get("completion_tokens", 0)) + int(completion_tokens)
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||
"""Return the most recent decisions up to limit."""
|
||||
|
||||
|
||||
@ -199,6 +199,14 @@ class TemplateRegistry:
|
||||
collected[template_id] = active.template if active else template
|
||||
return list(collected.values())
|
||||
|
||||
@classmethod
|
||||
def list_template_ids(cls) -> List[str]:
|
||||
"""Return all known template IDs in sorted order."""
|
||||
ids = set(cls._templates.keys())
|
||||
manager = cls._manager()
|
||||
ids.update(manager.list_template_ids())
|
||||
return sorted(ids)
|
||||
|
||||
@classmethod
|
||||
def list_versions(cls, template_id: str) -> List[str]:
|
||||
"""List available version labels for a template."""
|
||||
@ -206,6 +214,49 @@ class TemplateRegistry:
|
||||
manager = cls._manager()
|
||||
return [ver.version for ver in manager.list_versions(template_id)]
|
||||
|
||||
@classmethod
|
||||
def list_version_details(cls, template_id: str) -> List[Dict[str, Any]]:
|
||||
"""Return detailed information for each version of a template."""
|
||||
|
||||
manager = cls._manager()
|
||||
versions = manager.list_version_details(template_id)
|
||||
details: List[Dict[str, Any]] = []
|
||||
for entry in versions:
|
||||
details.append(
|
||||
{
|
||||
"version": entry.version,
|
||||
"created_at": entry.created_at,
|
||||
"is_active": entry.is_active,
|
||||
"metadata": dict(entry.metadata or {}),
|
||||
}
|
||||
)
|
||||
if not details and template_id in cls._templates:
|
||||
details.append(
|
||||
{
|
||||
"version": cls._default_version_label,
|
||||
"created_at": None,
|
||||
"is_active": True,
|
||||
"metadata": {},
|
||||
}
|
||||
)
|
||||
return details
|
||||
|
||||
@classmethod
|
||||
def update_version_metadata(cls, template_id: str, version: str, metadata: Dict[str, Any]) -> None:
|
||||
"""Update metadata for a template version."""
|
||||
|
||||
manager = cls._manager()
|
||||
manager.update_metadata(template_id, version, metadata)
|
||||
|
||||
@classmethod
|
||||
def export_versions(cls, template_id: str) -> Optional[str]:
|
||||
"""Export template versions for backup."""
|
||||
|
||||
manager = cls._manager()
|
||||
if not manager.list_versions(template_id):
|
||||
return None
|
||||
return manager.export_versions(template_id)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_str: str) -> None:
|
||||
"""Load templates from JSON string."""
|
||||
|
||||
@ -80,9 +80,13 @@ class TemplateVersionManager:
|
||||
self._versions: Dict[str, Dict[str, TemplateVersion]] = {}
|
||||
self._active_versions: Dict[str, str] = {}
|
||||
|
||||
def add_version(self, template: PromptTemplate, version: str,
|
||||
def add_version(
|
||||
self,
|
||||
template: PromptTemplate,
|
||||
version: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
activate: bool = False) -> TemplateVersion:
|
||||
activate: bool = False,
|
||||
) -> TemplateVersion:
|
||||
"""Add a new template version."""
|
||||
if template.id not in self._versions:
|
||||
self._versions[template.id] = {}
|
||||
@ -111,6 +115,14 @@ class TemplateVersionManager:
|
||||
"""List all versions of a template."""
|
||||
return list(self._versions.get(template_id, {}).values())
|
||||
|
||||
def list_template_ids(self) -> List[str]:
|
||||
"""Return template IDs currently tracked by the version manager."""
|
||||
return list(self._versions.keys())
|
||||
|
||||
def list_version_details(self, template_id: str) -> List[TemplateVersion]:
|
||||
"""Alias for list_versions to expose structured version records."""
|
||||
return self.list_versions(template_id)
|
||||
|
||||
def get_active_version(self, template_id: str) -> Optional[TemplateVersion]:
|
||||
"""Get the active version of a template."""
|
||||
active_version = self._active_versions.get(template_id)
|
||||
@ -179,3 +191,10 @@ class TemplateVersionManager:
|
||||
active_version = data.get("active_version")
|
||||
if active_version:
|
||||
self.activate_version(template_id, active_version)
|
||||
|
||||
def update_metadata(self, template_id: str, version: str, metadata: Dict[str, Any]) -> None:
|
||||
"""Update metadata for a specific template version."""
|
||||
version_obj = self.get_version(template_id, version)
|
||||
if version_obj is None:
|
||||
raise ValueError(f"Version {version} not found for template {template_id}")
|
||||
version_obj.metadata = metadata or {}
|
||||
|
||||
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
import numpy as np
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision
|
||||
@ -41,6 +42,80 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
|
||||
_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results"
|
||||
|
||||
|
||||
def _normalize_nav_records(records: List[Dict[str, object]]) -> pd.DataFrame:
|
||||
"""Return nav dataframe with columns trade_date (datetime) and nav (float)."""
|
||||
if not records:
|
||||
return pd.DataFrame(columns=["trade_date", "nav"])
|
||||
df = pd.DataFrame(records)
|
||||
if "trade_date" not in df.columns:
|
||||
if "date" in df.columns:
|
||||
df = df.rename(columns={"date": "trade_date"})
|
||||
elif "ts" in df.columns:
|
||||
df = df.rename(columns={"ts": "trade_date"})
|
||||
if "nav" not in df.columns:
|
||||
# fallback: look for value column
|
||||
candidates = [col for col in df.columns if col not in {"trade_date", "date", "ts"}]
|
||||
if candidates:
|
||||
df = df.rename(columns={candidates[0]: "nav"})
|
||||
if "trade_date" not in df.columns or "nav" not in df.columns:
|
||||
return pd.DataFrame(columns=["trade_date", "nav"])
|
||||
df = df.copy()
|
||||
df["trade_date"] = pd.to_datetime(df["trade_date"], errors="coerce")
|
||||
df["nav"] = pd.to_numeric(df["nav"], errors="coerce")
|
||||
df = df.dropna(subset=["trade_date", "nav"]).sort_values("trade_date")
|
||||
return df[["trade_date", "nav"]]
|
||||
|
||||
|
||||
def _normalize_trade_records(records: List[Dict[str, object]]) -> pd.DataFrame:
|
||||
if not records:
|
||||
return pd.DataFrame()
|
||||
df = pd.DataFrame(records)
|
||||
if "trade_date" not in df.columns:
|
||||
for candidate in ("date", "ts", "timestamp"):
|
||||
if candidate in df.columns:
|
||||
df = df.rename(columns={candidate: "trade_date"})
|
||||
break
|
||||
if "trade_date" in df.columns:
|
||||
df["trade_date"] = pd.to_datetime(df["trade_date"], errors="coerce")
|
||||
return df
|
||||
|
||||
|
||||
def _compute_nav_metrics(nav_df: pd.DataFrame, trades_df: pd.DataFrame) -> Dict[str, object]:
|
||||
if nav_df.empty:
|
||||
return {
|
||||
"total_return": None,
|
||||
"max_drawdown": None,
|
||||
"trade_count": len(trades_df),
|
||||
"avg_turnover": None,
|
||||
"risk_events": None,
|
||||
}
|
||||
values = nav_df["nav"].astype(float).values
|
||||
if values.size == 0:
|
||||
return {
|
||||
"total_return": None,
|
||||
"max_drawdown": None,
|
||||
"trade_count": len(trades_df),
|
||||
"avg_turnover": None,
|
||||
"risk_events": None,
|
||||
}
|
||||
total_return = float(values[-1] / values[0] - 1.0) if values[0] != 0 else None
|
||||
cumulative_max = np.maximum.accumulate(values)
|
||||
drawdowns = (values - cumulative_max) / cumulative_max
|
||||
max_drawdown = float(drawdowns.min()) if drawdowns.size else None
|
||||
summary = {
|
||||
"total_return": total_return,
|
||||
"max_drawdown": max_drawdown,
|
||||
"trade_count": int(len(trades_df)),
|
||||
"avg_turnover": trades_df["turnover"].mean() if "turnover" in trades_df.columns and not trades_df.empty else None,
|
||||
"risk_events": None,
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def _session_compare_store() -> Dict[str, Dict[str, object]]:
|
||||
return st.session_state.setdefault("bt_compare_runs", {})
|
||||
|
||||
def render_backtest_review() -> None:
|
||||
"""渲染回测执行、调参与结果复盘页面。"""
|
||||
st.header("回测与复盘")
|
||||
@ -220,8 +295,24 @@ def render_backtest_review() -> None:
|
||||
}
|
||||
)
|
||||
update_dashboard_sidebar(metrics)
|
||||
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades}
|
||||
st.json(st.session_state["backtest_last_result"])
|
||||
nav_df = _normalize_nav_records(result.nav_series)
|
||||
trades_df = _normalize_trade_records(result.trades)
|
||||
summary = _compute_nav_metrics(nav_df, trades_df)
|
||||
summary["risk_events"] = len(result.risk_events or [])
|
||||
st.session_state["backtest_last_result"] = {
|
||||
"nav_records": nav_df.to_dict(orient="records"),
|
||||
"trades": trades_df.to_dict(orient="records"),
|
||||
"risk_events": result.risk_events or [],
|
||||
}
|
||||
st.session_state["backtest_last_summary"] = summary
|
||||
st.session_state["backtest_last_config"] = {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"universe": universe,
|
||||
"params": dict(backtest_params),
|
||||
"structures": selected_structures,
|
||||
"name": backtest_cfg.name,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||
status_box.update(label="回测执行失败", state="error")
|
||||
@ -229,16 +320,81 @@ def render_backtest_review() -> None:
|
||||
|
||||
last_result = st.session_state.get("backtest_last_result")
|
||||
if last_result:
|
||||
last_summary = st.session_state.get("backtest_last_summary", {})
|
||||
last_config = st.session_state.get("backtest_last_config", {})
|
||||
st.markdown("#### 最近回测输出")
|
||||
st.json(last_result)
|
||||
nav_preview = _normalize_nav_records(last_result.get("nav_records", []))
|
||||
if not nav_preview.empty:
|
||||
import plotly.graph_objects as go
|
||||
|
||||
fig_last = go.Figure()
|
||||
fig_last.add_trace(
|
||||
go.Scatter(
|
||||
x=nav_preview["trade_date"],
|
||||
y=nav_preview["nav"],
|
||||
mode="lines",
|
||||
name="NAV",
|
||||
)
|
||||
)
|
||||
fig_last.update_layout(height=260, margin=dict(l=10, r=10, t=30, b=10))
|
||||
st.plotly_chart(fig_last, width="stretch")
|
||||
|
||||
metric_cols = st.columns(4)
|
||||
metric_cols[0].metric("总收益", f"{(last_summary.get('total_return') or 0.0)*100:.2f}%", delta=None)
|
||||
metric_cols[1].metric("最大回撤", f"{(last_summary.get('max_drawdown') or 0.0)*100:.2f}%")
|
||||
metric_cols[2].metric("交易数", last_summary.get("trade_count", 0))
|
||||
metric_cols[3].metric("风险事件", last_summary.get("risk_events", 0))
|
||||
|
||||
default_label = (
|
||||
f"{last_config.get('name', '临时实验')} | {last_config.get('start_date', '')}~{last_config.get('end_date', '')}"
|
||||
).strip(" |~")
|
||||
save_col, button_col = st.columns([4, 1])
|
||||
save_label = save_col.text_input(
|
||||
"保存至实验对比(可编辑标签)",
|
||||
value=default_label or f"实验_{datetime.now().strftime('%H%M%S')}",
|
||||
key="bt_save_label",
|
||||
)
|
||||
if button_col.button("保存", key="bt_save_button"):
|
||||
label = save_label.strip() or f"实验_{datetime.now().strftime('%H%M%S')}"
|
||||
store = _session_compare_store()
|
||||
store[label] = {
|
||||
"cfg_id": f"session::{label}",
|
||||
"nav": nav_preview.to_dict(orient="records"),
|
||||
"summary": dict(last_summary),
|
||||
"config": dict(last_config),
|
||||
"risk_events": last_result.get("risk_events", []),
|
||||
}
|
||||
st.success(f"已保存实验:{label}")
|
||||
with st.expander("最近回测详情", expanded=False):
|
||||
st.json(
|
||||
{
|
||||
"config": last_config,
|
||||
"summary": last_summary,
|
||||
"trades": last_result.get("trades", [])[:50],
|
||||
}
|
||||
)
|
||||
|
||||
st.divider()
|
||||
# ADD: Comparison view for multiple backtest configurations
|
||||
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||||
st.caption("从历史回测配置或本页保存的实验中选择多个进行净值曲线与指标对比。")
|
||||
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
|
||||
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
|
||||
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||||
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
|
||||
|
||||
session_store = _session_compare_store()
|
||||
if session_store:
|
||||
with st.expander("会话实验管理", expanded=False):
|
||||
st.write("会话实验仅保存在当前浏览器窗口中,可选择删除以保持列表精简。")
|
||||
removal_choices = st.multiselect(
|
||||
"选择要删除的会话实验",
|
||||
list(session_store.keys()),
|
||||
key="bt_cmp_remove_runs",
|
||||
)
|
||||
if st.button("删除选中实验", key="bt_cmp_remove_button") and removal_choices:
|
||||
for label in removal_choices:
|
||||
session_store.pop(label, None)
|
||||
st.success("已删除选中的会话实验。")
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
cfg_rows = conn.execute(
|
||||
@ -247,89 +403,71 @@ def render_backtest_review() -> None:
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||||
cfg_rows = []
|
||||
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||||
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs")
|
||||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
risk_df = pd.DataFrame()
|
||||
if selected_ids:
|
||||
|
||||
option_map: Dict[str, Tuple[str, object]] = {}
|
||||
option_labels: List[str] = []
|
||||
|
||||
for label in session_store.keys():
|
||||
option_label = f"[会话] {label}"
|
||||
option_labels.append(option_label)
|
||||
option_map[option_label] = ("session", label)
|
||||
|
||||
for row in cfg_rows:
|
||||
option_label = f"[DB] {row['id']} | {row['name']}"
|
||||
option_labels.append(option_label)
|
||||
option_map[option_label] = ("db", row["id"])
|
||||
|
||||
if not option_labels:
|
||||
st.info("暂无可对比的回测实验,请先执行回测或保存实验。")
|
||||
else:
|
||||
default_selection = option_labels[:2]
|
||||
selected_labels = st.multiselect(
|
||||
"选择实验配置",
|
||||
option_labels,
|
||||
default=default_selection,
|
||||
key="bt_cmp_configs",
|
||||
)
|
||||
|
||||
selected_db_ids = [option_map[label][1] for label in selected_labels if option_map[label][0] == "db"]
|
||||
selected_session_labels = [option_map[label][1] for label in selected_labels if option_map[label][0] == "session"]
|
||||
|
||||
nav_frames: List[pd.DataFrame] = []
|
||||
metrics_rows: List[Dict[str, object]] = []
|
||||
risk_frames: List[pd.DataFrame] = []
|
||||
|
||||
if selected_db_ids:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
nav_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
db_nav = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
params=tuple(selected_db_ids),
|
||||
)
|
||||
rpt_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
db_rpt = pd.read_sql_query(
|
||||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
params=tuple(selected_db_ids),
|
||||
)
|
||||
risk_df = pd.read_sql_query(
|
||||
db_risk = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, ts_code, reason, action, target_weight, confidence, metadata "
|
||||
"FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
"FROM bt_risk_events WHERE cfg_id IN (%s)" % (",".join(["?"] * len(selected_db_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
params=tuple(selected_db_ids),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||
st.error("读取回测结果失败")
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
risk_df = pd.DataFrame()
|
||||
start_filter: Optional[date] = None
|
||||
end_filter: Optional[date] = None
|
||||
if not nav_df.empty:
|
||||
try:
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||
# ADD: date window filter
|
||||
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
|
||||
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
|
||||
col_d1, col_d2 = st.columns(2)
|
||||
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
|
||||
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
|
||||
if start_filter > end_filter:
|
||||
start_filter, end_filter = end_filter, start_filter
|
||||
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
||||
nav_df = nav_df.loc[mask]
|
||||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||
if normalize_to_one:
|
||||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||
import plotly.graph_objects as go
|
||||
fig = go.Figure()
|
||||
for col in pivot.columns:
|
||||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||
if use_log_y:
|
||||
fig.update_yaxes(type="log")
|
||||
st.plotly_chart(fig, width='stretch')
|
||||
# ADD: export pivot
|
||||
try:
|
||||
csv_buf = pivot.reset_index()
|
||||
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
||||
st.download_button(
|
||||
"下载曲线(CSV)",
|
||||
data=csv_buf.to_csv(index=False),
|
||||
file_name="bt_nav_compare.csv",
|
||||
mime="text/csv",
|
||||
key="dl_nav_compare",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||||
if not rpt_df.empty:
|
||||
try:
|
||||
metrics_rows: List[Dict[str, object]] = []
|
||||
for _, row in rpt_df.iterrows():
|
||||
cfg_id = row["cfg_id"]
|
||||
if not db_nav.empty:
|
||||
db_nav["trade_date"] = pd.to_datetime(db_nav["trade_date"], errors="coerce")
|
||||
nav_frames.append(db_nav)
|
||||
if not db_risk.empty:
|
||||
db_risk["trade_date"] = pd.to_datetime(db_risk["trade_date"], errors="coerce")
|
||||
risk_frames.append(db_risk)
|
||||
for _, row in db_rpt.iterrows():
|
||||
try:
|
||||
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||||
except json.JSONDecodeError:
|
||||
summary = {}
|
||||
record = {
|
||||
"cfg_id": cfg_id,
|
||||
metrics_rows.append(
|
||||
{
|
||||
"cfg_id": row["cfg_id"],
|
||||
"总收益": summary.get("total_return"),
|
||||
"最大回撤": summary.get("max_drawdown"),
|
||||
"交易数": summary.get("trade_count"),
|
||||
@ -344,89 +482,144 @@ def render_backtest_review() -> None:
|
||||
"派生字段": json.dumps(summary.get("derived_field_counts"), ensure_ascii=False)
|
||||
if summary.get("derived_field_counts")
|
||||
else None,
|
||||
"参数": json.dumps(summary.get("config_params"), ensure_ascii=False)
|
||||
if summary.get("config_params")
|
||||
else None,
|
||||
"备注": summary.get("note"),
|
||||
}
|
||||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||
if metrics_rows:
|
||||
dfm = pd.DataFrame(metrics_rows)
|
||||
st.dataframe(dfm, hide_index=True, width='stretch')
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||
st.error("读取数据库中的回测结果失败,详见日志。")
|
||||
|
||||
for label in selected_session_labels:
|
||||
data = session_store.get(label)
|
||||
if not data:
|
||||
continue
|
||||
nav_df_session = _normalize_nav_records(data.get("nav", []))
|
||||
if not nav_df_session.empty:
|
||||
nav_df_session = nav_df_session.assign(cfg_id=data.get("cfg_id"))
|
||||
nav_frames.append(nav_df_session)
|
||||
summary = data.get("summary", {})
|
||||
metrics_rows.append(
|
||||
{
|
||||
"cfg_id": data.get("cfg_id"),
|
||||
"总收益": summary.get("total_return"),
|
||||
"最大回撤": summary.get("max_drawdown"),
|
||||
"交易数": summary.get("trade_count"),
|
||||
"平均换手": summary.get("avg_turnover"),
|
||||
"风险事件": summary.get("risk_events"),
|
||||
"参数": json.dumps(data.get("config", {}).get("params"), ensure_ascii=False)
|
||||
if data.get("config", {}).get("params")
|
||||
else None,
|
||||
"备注": json.dumps(
|
||||
{
|
||||
"structures": data.get("config", {}).get("structures"),
|
||||
"universe_size": len(data.get("config", {}).get("universe", [])),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
}
|
||||
)
|
||||
risk_events = data.get("risk_events") or []
|
||||
if risk_events:
|
||||
risk_df_session = pd.DataFrame(risk_events)
|
||||
if not risk_df_session.empty:
|
||||
if "trade_date" in risk_df_session.columns:
|
||||
risk_df_session["trade_date"] = pd.to_datetime(risk_df_session["trade_date"], errors="coerce")
|
||||
risk_df_session = risk_df_session.assign(cfg_id=data.get("cfg_id"))
|
||||
risk_frames.append(risk_df_session)
|
||||
|
||||
if not selected_labels:
|
||||
st.info("请选择至少一个实验进行对比。")
|
||||
else:
|
||||
nav_df = pd.concat(nav_frames, ignore_index=True) if nav_frames else pd.DataFrame()
|
||||
if not nav_df.empty:
|
||||
nav_df = nav_df.dropna(subset=["trade_date", "nav"])
|
||||
if not nav_df.empty:
|
||||
overall_min = nav_df["trade_date"].min().date()
|
||||
overall_max = nav_df["trade_date"].max().date()
|
||||
col_d1, col_d2 = st.columns(2)
|
||||
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
|
||||
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
|
||||
if start_filter > end_filter:
|
||||
start_filter, end_filter = end_filter, start_filter
|
||||
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
||||
nav_df = nav_df.loc[mask]
|
||||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||
if normalize_to_one:
|
||||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||
import plotly.graph_objects as go
|
||||
|
||||
fig = go.Figure()
|
||||
for col in pivot.columns:
|
||||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||
fig.update_layout(height=320, margin=dict(l=10, r=10, t=30, b=10))
|
||||
if use_log_y:
|
||||
fig.update_yaxes(type="log")
|
||||
st.plotly_chart(fig, width="stretch")
|
||||
try:
|
||||
csv_buf = pivot.reset_index()
|
||||
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
||||
st.download_button(
|
||||
"下载曲线(CSV)",
|
||||
data=csv_buf.to_csv(index=False),
|
||||
file_name="bt_nav_compare.csv",
|
||||
mime="text/csv",
|
||||
key="dl_nav_compare",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metric_df = pd.DataFrame(metrics_rows)
|
||||
if not metric_df.empty:
|
||||
display_cols = ["cfg_id"] + [col for col in selected_metrics if col in metric_df.columns]
|
||||
additional_cols = [col for col in metric_df.columns if col not in display_cols]
|
||||
metric_df = metric_df.loc[:, display_cols + additional_cols]
|
||||
st.dataframe(metric_df, hide_index=True, width="stretch")
|
||||
try:
|
||||
st.download_button(
|
||||
"下载指标(CSV)",
|
||||
data=dfm.to_csv(index=False),
|
||||
data=metric_df.to_csv(index=False),
|
||||
file_name="bt_metrics_compare.csv",
|
||||
mime="text/csv",
|
||||
key="dl_metrics_compare",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||
|
||||
risk_df = pd.concat(risk_frames, ignore_index=True) if risk_frames else pd.DataFrame()
|
||||
if not risk_df.empty:
|
||||
try:
|
||||
risk_df["trade_date"] = pd.to_datetime(risk_df["trade_date"], errors="coerce")
|
||||
risk_df = risk_df.dropna(subset=["trade_date"])
|
||||
if start_filter is None or end_filter is None:
|
||||
start_filter = pd.to_datetime(risk_df["trade_date"].min()).date()
|
||||
end_filter = pd.to_datetime(risk_df["trade_date"].max()).date()
|
||||
risk_df = risk_df[
|
||||
(risk_df["trade_date"].dt.date >= start_filter)
|
||||
& (risk_df["trade_date"].dt.date <= end_filter)
|
||||
]
|
||||
parsed_cols: List[Dict[str, object]] = []
|
||||
for _, row in risk_df.iterrows():
|
||||
try:
|
||||
metadata = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else (row["metadata"] or {})
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
assessment = metadata.get("risk_assessment") or {}
|
||||
parsed_cols.append(
|
||||
{
|
||||
"cfg_id": row["cfg_id"],
|
||||
"trade_date": row["trade_date"].date().isoformat(),
|
||||
"ts_code": row["ts_code"],
|
||||
"reason": row["reason"],
|
||||
"action": row["action"],
|
||||
"target_weight": row["target_weight"],
|
||||
"confidence": row["confidence"],
|
||||
"risk_status": assessment.get("status"),
|
||||
"recommended_action": assessment.get("recommended_action"),
|
||||
"execution_status": metadata.get("execution_status"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
risk_detail_df = pd.DataFrame(parsed_cols)
|
||||
with st.expander("风险事件明细", expanded=False):
|
||||
st.dataframe(risk_detail_df.drop(columns=["metadata"], errors="ignore"), hide_index=True, width='stretch')
|
||||
st.caption("风险事件详情(按 cfg_id 聚合)。")
|
||||
st.dataframe(risk_df, hide_index=True, width="stretch")
|
||||
try:
|
||||
st.download_button(
|
||||
"下载风险事件(CSV)",
|
||||
data=risk_detail_df.to_csv(index=False),
|
||||
data=risk_df.to_csv(index=False),
|
||||
file_name="bt_risk_events.csv",
|
||||
mime="text/csv",
|
||||
key="dl_risk_events",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
agg = risk_detail_df.groupby(["cfg_id", "reason", "risk_status"], dropna=False).size().reset_index(name="count")
|
||||
st.dataframe(agg, hide_index=True, width='stretch')
|
||||
if {"cfg_id", "reason"}.issubset(risk_df.columns):
|
||||
group_cols = [col for col in ["cfg_id", "reason", "risk_status"] if col in risk_df.columns]
|
||||
agg = risk_df.groupby(group_cols, dropna=False).size().reset_index(name="count")
|
||||
st.dataframe(agg, hide_index=True, width="stretch")
|
||||
try:
|
||||
if not agg.empty:
|
||||
agg_fig = px.bar(
|
||||
agg,
|
||||
x="reason",
|
||||
x="reason" if "reason" in agg.columns else agg.columns[1],
|
||||
y="count",
|
||||
color="risk_status",
|
||||
facet_col="cfg_id",
|
||||
color="risk_status" if "risk_status" in agg.columns else None,
|
||||
facet_col="cfg_id" if "cfg_id" in agg.columns else None,
|
||||
title="风险事件分布",
|
||||
)
|
||||
agg_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=20))
|
||||
st.plotly_chart(agg_fig, width="stretch")
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("绘制风险事件分布失败", extra=LOG_EXTRA)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染风险事件失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
st.info("请选择至少一个配置进行对比。")
|
||||
|
||||
|
||||
|
||||
@ -741,6 +934,8 @@ def render_backtest_review() -> None:
|
||||
"总收益": episode.metrics.total_return,
|
||||
"最大回撤": episode.metrics.max_drawdown,
|
||||
"波动率": episode.metrics.volatility,
|
||||
"Sharpe": episode.metrics.sharpe_like,
|
||||
"Calmar": episode.metrics.calmar_like,
|
||||
"权重": json.dumps(episode.weights or {}, ensure_ascii=False),
|
||||
"部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False),
|
||||
}
|
||||
@ -756,6 +951,12 @@ def render_backtest_review() -> None:
|
||||
"action": best_episode.action if best_episode else None,
|
||||
"resolved_action": best_episode.resolved_action if best_episode else None,
|
||||
"weights": best_episode.weights if best_episode else None,
|
||||
"metrics": {
|
||||
"total_return": best_episode.metrics.total_return if best_episode else None,
|
||||
"sharpe_like": best_episode.metrics.sharpe_like if best_episode else None,
|
||||
"calmar_like": best_episode.metrics.calmar_like if best_episode else None,
|
||||
"max_drawdown": best_episode.metrics.max_drawdown if best_episode else None,
|
||||
} if best_episode else None,
|
||||
"department_controls": best_episode.department_controls if best_episode else None,
|
||||
},
|
||||
"experiment_id": config.experiment_id,
|
||||
@ -781,6 +982,13 @@ def render_backtest_review() -> None:
|
||||
col_best1.json(best_payload.get("action") or {})
|
||||
col_best2.write("参数值:")
|
||||
col_best2.json(best_payload.get("resolved_action") or {})
|
||||
metrics_payload = best_payload.get("metrics") or {}
|
||||
if metrics_payload:
|
||||
col_m1, col_m2, col_m3 = st.columns(3)
|
||||
col_m1.metric("总收益", f"{metrics_payload.get('total_return', 0.0):+.4f}")
|
||||
col_m2.metric("Sharpe", f"{metrics_payload.get('sharpe_like', 0.0):.3f}")
|
||||
col_m3.metric("Calmar", f"{metrics_payload.get('calmar_like', 0.0):.3f}")
|
||||
st.caption(f"最大回撤:{metrics_payload.get('max_drawdown', 0.0):.3f}")
|
||||
weights_payload = best_payload.get("weights") or {}
|
||||
if weights_payload:
|
||||
st.write("对应代理权重:")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""系统设置相关视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -10,6 +11,8 @@ from requests.exceptions import RequestException
|
||||
import streamlit as st
|
||||
|
||||
from app.llm.client import llm_config_snapshot
|
||||
from app.llm.metrics import snapshot as llm_metrics_snapshot
|
||||
from app.llm.templates import TemplateRegistry
|
||||
from app.utils.config import (
|
||||
ALLOWED_LLM_STRATEGIES,
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
@ -471,6 +474,239 @@ def render_llm_settings() -> None:
|
||||
else:
|
||||
dept_rows = dept_editor
|
||||
|
||||
st.divider()
|
||||
st.markdown("##### 提示模板治理")
|
||||
template_ids = TemplateRegistry.list_template_ids()
|
||||
if not template_ids:
|
||||
st.info("尚未注册任何提示模板。")
|
||||
else:
|
||||
template_id = st.selectbox(
|
||||
"选择模板",
|
||||
template_ids,
|
||||
key="prompt_template_select",
|
||||
)
|
||||
version_details = TemplateRegistry.list_version_details(template_id)
|
||||
raw_versions = TemplateRegistry.list_versions(template_id)
|
||||
active_version = None
|
||||
if version_details:
|
||||
for detail in version_details:
|
||||
if detail.get("is_active"):
|
||||
active_version = detail["version"]
|
||||
break
|
||||
if active_version is None:
|
||||
active_version = version_details[0]["version"]
|
||||
usage_snapshot = llm_metrics_snapshot()
|
||||
template_usage = usage_snapshot.get("template_usage", {}).get(template_id, {})
|
||||
|
||||
table_rows: List[Dict[str, object]] = []
|
||||
for detail in version_details:
|
||||
metadata_preview = detail.get("metadata") or {}
|
||||
table_rows.append(
|
||||
{
|
||||
"版本": detail["version"],
|
||||
"创建时间": detail.get("created_at") or "-",
|
||||
"激活": "是" if detail.get("is_active") else "否",
|
||||
"元数据": json.dumps(metadata_preview, ensure_ascii=False, default=str) if metadata_preview else "{}",
|
||||
}
|
||||
)
|
||||
if table_rows:
|
||||
st.dataframe(pd.DataFrame(table_rows), hide_index=True, width="stretch")
|
||||
|
||||
version_options = [row["版本"] for row in table_rows] if table_rows else []
|
||||
if not version_options:
|
||||
st.info("当前模板尚未创建版本,建议通过配置文件或 API 注册。")
|
||||
else:
|
||||
try:
|
||||
default_idx = version_options.index(active_version or version_options[0])
|
||||
except ValueError:
|
||||
default_idx = 0
|
||||
selected_version = st.selectbox(
|
||||
"查看版本",
|
||||
version_options,
|
||||
index=default_idx,
|
||||
key=f"{template_id}_version_select",
|
||||
)
|
||||
selected_detail = next(
|
||||
(detail for detail in version_details if detail["version"] == selected_version),
|
||||
{"metadata": {}},
|
||||
)
|
||||
usage_cols = st.columns(3)
|
||||
usage_cols[0].metric("累计调用", int(template_usage.get("total_calls", 0)))
|
||||
version_usage = (template_usage.get("versions") or {}).get(selected_version, {})
|
||||
usage_cols[1].metric("版本调用", int(version_usage.get("calls", 0)))
|
||||
usage_cols[2].metric(
|
||||
"Prompt Tokens",
|
||||
int(version_usage.get("prompt_tokens", 0)),
|
||||
)
|
||||
|
||||
template_obj = TemplateRegistry.get(template_id, version=selected_version)
|
||||
if template_obj:
|
||||
with st.expander("模板内容预览", expanded=False):
|
||||
st.write(f"名称:{template_obj.name}")
|
||||
st.write(f"描述:{template_obj.description or '-'}")
|
||||
st.write(f"变量:{', '.join(template_obj.variables) if template_obj.variables else '无'}")
|
||||
st.code(template_obj.template, language="markdown")
|
||||
|
||||
metadata_str = json.dumps(selected_detail.get("metadata") or {}, ensure_ascii=False, indent=2, default=str)
|
||||
metadata_input = st.text_area(
|
||||
"版本元数据(JSON)",
|
||||
value=metadata_str,
|
||||
height=200,
|
||||
key=f"{template_id}_{selected_version}_metadata",
|
||||
)
|
||||
meta_buttons = st.columns(3)
|
||||
enable_version_actions = bool(raw_versions)
|
||||
if meta_buttons[0].button(
|
||||
"保存元数据",
|
||||
key=f"{template_id}_{selected_version}_save_metadata",
|
||||
disabled=not enable_version_actions,
|
||||
):
|
||||
try:
|
||||
new_metadata = json.loads(metadata_input or "{}")
|
||||
except json.JSONDecodeError as exc:
|
||||
st.error(f"元数据格式错误:{exc}")
|
||||
else:
|
||||
try:
|
||||
TemplateRegistry.update_version_metadata(template_id, selected_version, new_metadata)
|
||||
st.success("元数据已更新。")
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"更新元数据失败:{exc}")
|
||||
|
||||
if meta_buttons[1].button(
|
||||
"设为激活版本",
|
||||
key=f"{template_id}_{selected_version}_activate",
|
||||
disabled=(selected_version == active_version) or not enable_version_actions,
|
||||
):
|
||||
try:
|
||||
TemplateRegistry.activate_version(template_id, selected_version)
|
||||
st.success(f"{template_id} 已切换至版本 {selected_version}。")
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"切换版本失败:{exc}")
|
||||
|
||||
export_payload = TemplateRegistry.export_versions(template_id) if enable_version_actions else None
|
||||
meta_buttons[2].download_button(
|
||||
"导出版本 JSON",
|
||||
data=export_payload or "",
|
||||
file_name=f"{template_id}_versions.json",
|
||||
mime="application/json",
|
||||
key=f"{template_id}_download_versions",
|
||||
disabled=not export_payload,
|
||||
)
|
||||
|
||||
st.divider()
|
||||
st.markdown("##### 部门遥测可视化")
|
||||
telemetry_limit = st.slider(
|
||||
"遥测查询条数",
|
||||
min_value=50,
|
||||
max_value=500,
|
||||
value=200,
|
||||
step=50,
|
||||
help="限制查询 agent_utils 表中的最新记录数量。",
|
||||
key="telemetry_limit",
|
||||
)
|
||||
telemetry_rows: List[Dict[str, object]] = []
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
raw_rows = conn.execute(
|
||||
"""
|
||||
SELECT trade_date, ts_code, agent, utils
|
||||
FROM agent_utils
|
||||
ORDER BY trade_date DESC, ts_code
|
||||
LIMIT ?
|
||||
""",
|
||||
(telemetry_limit,),
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取 agent_utils 遥测失败", extra=LOG_EXTRA)
|
||||
raw_rows = []
|
||||
|
||||
for row in raw_rows:
|
||||
trade_date = row["trade_date"]
|
||||
ts_code = row["ts_code"]
|
||||
agent = row["agent"]
|
||||
try:
|
||||
utils_payload = json.loads(row["utils"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
utils_payload = {}
|
||||
|
||||
if agent == "global":
|
||||
telemetry_map = utils_payload.get("_department_telemetry") or {}
|
||||
for dept_code, payload in telemetry_map.items():
|
||||
if not isinstance(payload, dict):
|
||||
payload = {"value": payload}
|
||||
record = {
|
||||
"trade_date": trade_date,
|
||||
"ts_code": ts_code,
|
||||
"agent": agent,
|
||||
"department": dept_code,
|
||||
"source": "global",
|
||||
"telemetry": json.dumps(payload, ensure_ascii=False, default=str),
|
||||
}
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, (int, float, bool, str)):
|
||||
record.setdefault(key, value)
|
||||
telemetry_rows.append(record)
|
||||
elif agent.startswith("dept_"):
|
||||
dept_code = agent.split("dept_", 1)[-1]
|
||||
payload = utils_payload.get("_telemetry") or {}
|
||||
if not isinstance(payload, dict):
|
||||
payload = {"value": payload}
|
||||
record = {
|
||||
"trade_date": trade_date,
|
||||
"ts_code": ts_code,
|
||||
"agent": agent,
|
||||
"department": dept_code,
|
||||
"source": "department",
|
||||
"telemetry": json.dumps(payload, ensure_ascii=False, default=str),
|
||||
}
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, (int, float, bool, str)):
|
||||
record.setdefault(key, value)
|
||||
telemetry_rows.append(record)
|
||||
|
||||
if not telemetry_rows:
|
||||
st.info("未找到遥测记录,可先运行部门评估流程。")
|
||||
else:
|
||||
telemetry_df = pd.DataFrame(telemetry_rows)
|
||||
telemetry_df["trade_date"] = telemetry_df["trade_date"].astype(str)
|
||||
trade_dates = sorted(telemetry_df["trade_date"].unique(), reverse=True)
|
||||
selected_trade_date = st.selectbox(
|
||||
"交易日",
|
||||
trade_dates,
|
||||
index=0,
|
||||
key="telemetry_trade_date",
|
||||
)
|
||||
filtered_df = telemetry_df[telemetry_df["trade_date"] == selected_trade_date]
|
||||
departments = sorted(filtered_df["department"].dropna().unique())
|
||||
selected_departments = st.multiselect(
|
||||
"部门过滤",
|
||||
departments,
|
||||
default=departments,
|
||||
key="telemetry_departments",
|
||||
)
|
||||
if selected_departments:
|
||||
filtered_df = filtered_df[filtered_df["department"].isin(selected_departments)]
|
||||
numeric_columns = [
|
||||
col
|
||||
for col in filtered_df.columns
|
||||
if col not in {"trade_date", "ts_code", "agent", "department", "source", "telemetry"}
|
||||
and pd.api.types.is_numeric_dtype(filtered_df[col])
|
||||
]
|
||||
metric_cols = st.columns(min(3, max(1, len(numeric_columns))))
|
||||
for idx, column in enumerate(numeric_columns[: len(metric_cols)]):
|
||||
column_series = filtered_df[column].dropna()
|
||||
value = column_series.mean() if not column_series.empty else 0.0
|
||||
metric_cols[idx].metric(f"{column} 均值", f"{value:.2f}")
|
||||
st.dataframe(filtered_df, hide_index=True, width="stretch")
|
||||
st.download_button(
|
||||
"下载遥测 CSV",
|
||||
data=filtered_df.to_csv(index=False),
|
||||
file_name=f"telemetry_{selected_trade_date}.csv",
|
||||
mime="text/csv",
|
||||
key="telemetry_download",
|
||||
)
|
||||
col_reset, col_save = st.columns([1, 1])
|
||||
|
||||
if col_save.button("保存部门配置"):
|
||||
|
||||
12
docs/TODO.md
12
docs/TODO.md
@ -18,9 +18,9 @@
|
||||
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| DecisionEnv 扩展 | 🔄 | 已支持多步 episode 与部分动作维度,需继续覆盖提示版本、function 策略等。 |
|
||||
| DecisionEnv 扩展 | 🔄 | Episode 指标新增 Sharpe/Calmar,奖励函数引入风险惩罚;继续覆盖提示版本、function 策略等。 |
|
||||
| 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 |
|
||||
| 奖励与评估体系 | ⏳ | 需将 `portfolio_trades`/`portfolio_snapshots` 等指标纳入奖励与评估。 |
|
||||
| 奖励与评估体系 | 🔄 | 决策环境奖励已纳入风险/Turnover/Sharpe-Calmar,待接入成交与资金曲线指标。 |
|
||||
| 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 |
|
||||
| 全局参数搜索 | ⏳ | 引入 Bandit、贝叶斯优化或 BOHB 提供权重/参数候选。 |
|
||||
|
||||
@ -29,9 +29,9 @@
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 |
|
||||
| 提示模板治理 | ⏳ | 待建立模板版本管理、成本监控与性能指标。 |
|
||||
| 部门遥测可视化 | ⏳ | `_telemetry` / `_department_telemetry` 字段需在 UI 中完整展示。 |
|
||||
| 多轮逻辑博弈框架 | ⏳ | 需实现主持/预测/风险/执行分轮对话、信念修正与冲突解决。 |
|
||||
| 提示模板治理 | 🔄 | LLM 设置新增模板版本治理与使用监控,后续补充成本/效果数据。 |
|
||||
| 部门遥测可视化 | 🔄 | LLM 设置新增遥测面板,支持分页查看/导出部门 & 全局遥测。 |
|
||||
| 多轮逻辑博弈框架 | 🔄 | 新增主持 briefing、预测对齐及冲突复核轮,持续完善信念修正策略。 |
|
||||
| LLM 稳定性提升 | ⏳ | 持续优化限速、降级、成本控制与缓存策略。 |
|
||||
|
||||
## UI 与监控
|
||||
@ -39,7 +39,7 @@
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 一键重评估入口 | ✅ | 今日计划页提供批量/全量重评估入口,待收集反馈再做优化。 |
|
||||
| 回测实验对比 | ⏳ | 提供提示/温度多版本实验管理与曲线对比。 |
|
||||
| 回测实验对比 | 🔄 | 新增会话实验保存与曲线/指标对比,后续接入更多提示参数。 |
|
||||
| 实时指标面板 | ✅ | Streamlit 监控页已具备核心实时指标。 |
|
||||
| 异常日志钻取 | ⏳ | 待补充筛选、定位与历史对比能力。 |
|
||||
| 仅监控模式 | ⏳ | 支持“监控不干预”场景的一键复评策略。 |
|
||||
|
||||
@ -31,10 +31,14 @@ class DummyEnv:
|
||||
def step(self, action):
|
||||
value = float(action[0])
|
||||
reward = 1.0 - abs(value - 0.7)
|
||||
sharpe_like = reward / 0.05 if 0.05 else 0.0
|
||||
calmar_like = reward / 0.1 if 0.1 else reward
|
||||
metrics = EpisodeMetrics(
|
||||
total_return=reward,
|
||||
max_drawdown=0.1,
|
||||
volatility=0.05,
|
||||
sharpe_like=sharpe_like,
|
||||
calmar_like=calmar_like,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=0.1,
|
||||
@ -49,6 +53,7 @@ class DummyEnv:
|
||||
"max_drawdown": 0.1,
|
||||
"volatility": 0.05,
|
||||
"sharpe_like": reward / 0.05,
|
||||
"calmar_like": reward / 0.1,
|
||||
"turnover": 0.1,
|
||||
"turnover_value": 1000.0,
|
||||
"trade_count": 0.0,
|
||||
|
||||
@ -139,10 +139,17 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||
|
||||
|
||||
def test_default_reward_penalizes_metrics():
|
||||
total_return = 0.1
|
||||
max_drawdown = 0.2
|
||||
volatility = 0.05
|
||||
sharpe_like = total_return / volatility
|
||||
calmar_like = total_return / max_drawdown
|
||||
metrics = EpisodeMetrics(
|
||||
total_return=0.1,
|
||||
max_drawdown=0.2,
|
||||
volatility=0.05,
|
||||
total_return=total_return,
|
||||
max_drawdown=max_drawdown,
|
||||
volatility=volatility,
|
||||
sharpe_like=sharpe_like,
|
||||
calmar_like=calmar_like,
|
||||
nav_series=[],
|
||||
trades=[],
|
||||
turnover=0.3,
|
||||
@ -152,7 +159,13 @@ def test_default_reward_penalizes_metrics():
|
||||
risk_breakdown={"foo": 2},
|
||||
)
|
||||
reward = DecisionEnv._default_reward(metrics)
|
||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
||||
expected = (
|
||||
total_return
|
||||
+ 0.1 * sharpe_like
|
||||
+ 0.05 * calmar_like
|
||||
- (0.5 * max_drawdown + 0.05 * metrics.risk_count + 0.1 * metrics.turnover)
|
||||
)
|
||||
assert reward == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_decision_env_department_controls(monkeypatch):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user