This commit is contained in:
sam 2025-10-02 22:18:17 +08:00
parent 4eb2b2d81e
commit 91e8eb5cb3

View File

@ -71,6 +71,23 @@ _SIDEBAR_LISTENER_ATTACHED = False
# ADD: simple in-memory cache for provider model discovery # ADD: simple in-memory cache for provider model discovery
_MODEL_CACHE: Dict[str, Dict[str, object]] = {} _MODEL_CACHE: Dict[str, Dict[str, object]] = {}
_CACHE_TTL_SECONDS = 300 _CACHE_TTL_SECONDS = 300
_WARNINGS_CONTAINER = None
_WARNINGS_PLACEHOLDER = None
# ADD: query param helpers
def _get_query_params() -> Dict[str, List[str]]:
try:
return dict(st.query_params)
except Exception:
return {}
def _set_query_params(**kwargs: object) -> None:
try:
payload = {k: v for k, v in kwargs.items() if v is not None}
if payload:
st.query_params.update(payload)
except Exception:
pass
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
@ -86,14 +103,20 @@ def render_global_dashboard() -> None:
global _DASHBOARD_CONTAINERS global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS global _DASHBOARD_ELEMENTS
global _SIDEBAR_LISTENER_ATTACHED global _SIDEBAR_LISTENER_ATTACHED
global _WARNINGS_CONTAINER
global _WARNINGS_PLACEHOLDER
# ADD: warning badge on top
warnings = alerts.get_warnings()
badge = f" ({len(warnings)})" if warnings else ""
st.sidebar.header(f"系统监控{badge}")
metrics_container = st.sidebar.container() metrics_container = st.sidebar.container()
decisions_container = st.sidebar.container() decisions_container = st.sidebar.container()
_WARNINGS_CONTAINER = st.sidebar.container()
_WARNINGS_PLACEHOLDER = st.sidebar.empty()
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container) _DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
if st.sidebar.button("清除数据告警", key="clear_data_alerts"):
alerts.clear_warnings()
_update_dashboard_sidebar()
if not _SIDEBAR_LISTENER_ATTACHED: if not _SIDEBAR_LISTENER_ATTACHED:
register_llm_metrics_listener(_sidebar_metrics_listener) register_llm_metrics_listener(_sidebar_metrics_listener)
_SIDEBAR_LISTENER_ATTACHED = True _SIDEBAR_LISTENER_ATTACHED = True
@ -105,6 +128,8 @@ def _update_dashboard_sidebar(
) -> None: ) -> None:
global _DASHBOARD_CONTAINERS global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS global _DASHBOARD_ELEMENTS
global _WARNINGS_CONTAINER
global _WARNINGS_PLACEHOLDER
containers = _DASHBOARD_CONTAINERS containers = _DASHBOARD_CONTAINERS
if not containers: if not containers:
@ -140,23 +165,6 @@ def _update_dashboard_sidebar(
else: else:
model_placeholder.info("暂无模型分布数据。") model_placeholder.info("暂无模型分布数据。")
warnings_placeholder = elements.get("warnings")
if warnings_placeholder is not None:
warnings_placeholder.empty()
warnings = alerts.get_warnings()
if warnings:
lines = []
for warning in warnings[-10:]:
detail = warning.get("detail")
appendix = f" {detail}" if detail else ""
lines.append(
f"- **{warning['source']}** {warning['message']}{appendix}"
f"\n<small>{warning['timestamp']}</small>"
)
warnings_placeholder.markdown("\n".join(lines), unsafe_allow_html=True)
else:
warnings_placeholder.info("暂无数据告警。")
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
if decisions: if decisions:
lines = [] lines = []
@ -177,6 +185,37 @@ def _update_dashboard_sidebar(
decisions_placeholder = elements["decisions_list"] decisions_placeholder = elements["decisions_list"]
decisions_placeholder.empty() decisions_placeholder.empty()
decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。") decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。")
# Render warnings section in-place (clear then write)
if _WARNINGS_PLACEHOLDER is not None:
_WARNINGS_PLACEHOLDER.empty()
with _WARNINGS_PLACEHOLDER.container():
st.subheader("数据告警")
warn_list = alerts.get_warnings()
if warn_list:
lines = []
for warning in warn_list[-10:]:
detail = warning.get("detail")
appendix = f" {detail}" if detail else ""
lines.append(
f"- **{warning['source']}** {warning['message']}{appendix}\n<small>{warning['timestamp']}</small>"
)
st.markdown("\n".join(lines), unsafe_allow_html=True)
btn_cols = st.columns([1,1])
if btn_cols[0].button("清除数据告警", key="clear_data_alerts_sibling"):
alerts.clear_warnings()
_update_dashboard_sidebar()
try:
st.download_button(
"导出告警(JSON)",
data=json.dumps(warn_list, ensure_ascii=False, indent=2),
file_name="data_warnings.json",
mime="application/json",
key="dl_warnings_json_sibling",
)
except Exception:
pass
else:
st.info("暂无数据告警。")
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
@ -188,8 +227,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
distribution_expander = metrics_container.expander("调用分布", expanded=False) distribution_expander = metrics_container.expander("调用分布", expanded=False)
provider_distribution = distribution_expander.empty() provider_distribution = distribution_expander.empty()
model_distribution = distribution_expander.empty() model_distribution = distribution_expander.empty()
warnings_expander = metrics_container.expander("数据告警", expanded=False)
warnings_placeholder = warnings_expander.empty()
decisions_container.subheader("最新决策") decisions_container.subheader("最新决策")
decisions_list = decisions_container.empty() decisions_list = decisions_container.empty()
@ -200,7 +237,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
"metrics_completion": metrics_completion, "metrics_completion": metrics_completion,
"provider_distribution": provider_distribution, "provider_distribution": provider_distribution,
"model_distribution": model_distribution, "model_distribution": model_distribution,
"warnings": warnings_placeholder,
"decisions_list": decisions_list, "decisions_list": decisions_list,
} }
return elements return elements
@ -382,7 +418,14 @@ def render_today_plan() -> None:
st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。") st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。")
return return
trade_date = st.selectbox("交易日", trade_dates, index=0) # ADD: read default selection from URL
q = _get_query_params()
default_trade_date = q.get("date", [trade_dates[0]])[0]
try:
default_idx = trade_dates.index(default_trade_date)
except ValueError:
default_idx = 0
trade_date = st.selectbox("交易日", trade_dates, index=default_idx)
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
code_rows = conn.execute( code_rows = conn.execute(
@ -399,10 +442,18 @@ def render_today_plan() -> None:
st.info("所选交易日暂无 agent_utils 记录。") st.info("所选交易日暂无 agent_utils 记录。")
return return
ts_code = st.selectbox("标的", symbols, index=0) default_ts = q.get("code", [symbols[0]])[0]
try:
default_ts_idx = symbols.index(default_ts)
except ValueError:
default_ts_idx = 0
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
# ADD: batch selection for re-evaluation # ADD: batch selection for re-evaluation
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
# sync URL params
_set_query_params(date=str(trade_date), code=str(ts_code))
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
rows = conn.execute( rows = conn.execute(
""" """
@ -513,15 +564,42 @@ def render_today_plan() -> None:
if global_info["requires_review"]: if global_info["requires_review"]:
st.warning("部门分歧较大,已标记为需人工复核。") st.warning("部门分歧较大,已标记为需人工复核。")
with st.expander("基础上下文数据", expanded=False): with st.expander("基础上下文数据", expanded=False):
if global_info.get("scope_values"): # ADD: export buttons
st.write("最新字段:") scope = global_info.get("scope_values") or {}
st.json(global_info["scope_values"]) close_series = global_info.get("close_series") or []
if global_info.get("close_series"): turnover_series = global_info.get("turnover_series") or []
st.write("最新字段:")
if scope:
st.json(scope)
st.download_button(
"下载字段(JSON)",
data=json.dumps(scope, ensure_ascii=False, indent=2),
file_name=f"{ts_code}_{trade_date}_scope.json",
mime="application/json",
key="dl_scope_json",
)
if close_series:
st.write("收盘价时间序列 (最近窗口)") st.write("收盘价时间序列 (最近窗口)")
st.json(global_info["close_series"]) st.json(close_series)
if global_info.get("turnover_series"): try:
import io, csv
buf = io.StringIO()
writer = csv.writer(buf)
writer.writerow(["trade_date", "close"])
for dt, val in close_series:
writer.writerow([dt, val])
st.download_button(
"下载收盘价(CSV)",
data=buf.getvalue(),
file_name=f"{ts_code}_{trade_date}_close_series.csv",
mime="text/csv",
key="dl_close_csv",
)
except Exception:
pass
if turnover_series:
st.write("换手率时间序列 (最近窗口)") st.write("换手率时间序列 (最近窗口)")
st.json(global_info["turnover_series"]) st.json(turnover_series)
dept_sup = global_info.get("department_supplements") or {} dept_sup = global_info.get("department_supplements") or {}
dept_dialogue = global_info.get("department_dialogue") or {} dept_dialogue = global_info.get("department_dialogue") or {}
dept_telemetry = global_info.get("department_telemetry") or {} dept_telemetry = global_info.get("department_telemetry") or {}
@ -550,8 +628,23 @@ def render_today_plan() -> None:
item for item in dept_records item for item in dept_records
if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", ""))
] ]
# ADD: confidence filter and sort
min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05)
sort_col = st.selectbox("排序列", ["信心", "权重"], index=0)
filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf]
filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True)
dept_df = pd.DataFrame(filtered) dept_df = pd.DataFrame(filtered)
st.dataframe(dept_df, width='stretch', hide_index=True) st.dataframe(dept_df, width='stretch', hide_index=True)
try:
st.download_button(
"下载部门意见(CSV)",
data=dept_df.to_csv(index=False),
file_name=f"{trade_date}_{ts_code}_departments.csv",
mime="text/csv",
key="dl_dept_csv",
)
except Exception:
pass
for code, details in dept_details.items(): for code, details in dept_details.items():
with st.expander(f"{code} 补充详情", expanded=False): with st.expander(f"{code} 补充详情", expanded=False):
supplements = details.get("supplements", []) supplements = details.get("supplements", [])
@ -576,8 +669,26 @@ def render_today_plan() -> None:
st.subheader("代理评分") st.subheader("代理评分")
if agent_records: if agent_records:
# ADD: sorting and CSV export for agents
sort_agent_by = st.selectbox(
"代理排序",
["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"],
index=1,
)
agent_df = pd.DataFrame(agent_records) agent_df = pd.DataFrame(agent_records)
if sort_agent_by in agent_df.columns:
agent_df = agent_df.sort_values(sort_agent_by, ascending=False)
st.dataframe(agent_df, width='stretch', hide_index=True) st.dataframe(agent_df, width='stretch', hide_index=True)
try:
st.download_button(
"下载代理评分(CSV)",
data=agent_df.to_csv(index=False),
file_name=f"{trade_date}_{ts_code}_agents.csv",
mime="text/csv",
key="dl_agent_csv",
)
except Exception:
pass
else: else:
st.info("暂无基础代理评分。") st.info("暂无基础代理评分。")
@ -1398,6 +1509,8 @@ def render_backtest() -> None:
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2]) selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2])
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
if selected_ids: if selected_ids:
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
@ -1419,6 +1532,16 @@ def render_backtest() -> None:
if not nav_df.empty: if not nav_df.empty:
try: try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
# ADD: date window filter
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
col_d1, col_d2 = st.columns(2)
start_filter = col_d1.date_input("起始日期", value=overall_min)
end_filter = col_d2.date_input("结束日期", value=overall_max)
if start_filter > end_filter:
start_filter, end_filter = end_filter, start_filter
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
nav_df = nav_df.loc[mask]
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
if normalize_to_one: if normalize_to_one:
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
@ -1430,6 +1553,19 @@ def render_backtest() -> None:
if use_log_y: if use_log_y:
fig.update_yaxes(type="log") fig.update_yaxes(type="log")
st.plotly_chart(fig, use_container_width=True) st.plotly_chart(fig, use_container_width=True)
# ADD: export pivot
try:
csv_buf = pivot.reset_index()
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
st.download_button(
"下载曲线(CSV)",
data=csv_buf.to_csv(index=False),
file_name="bt_nav_compare.csv",
mime="text/csv",
key="dl_nav_compare",
)
except Exception:
pass
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
if not rpt_df.empty: if not rpt_df.empty:
@ -1453,6 +1589,16 @@ def render_backtest() -> None:
if metrics_rows: if metrics_rows:
dfm = pd.DataFrame(metrics_rows) dfm = pd.DataFrame(metrics_rows)
st.dataframe(dfm, hide_index=True, width='stretch') st.dataframe(dfm, hide_index=True, width='stretch')
try:
st.download_button(
"下载指标(CSV)",
data=dfm.to_csv(index=False),
file_name="bt_metrics_compare.csv",
mime="text/csv",
key="dl_metrics_compare",
)
except Exception:
pass
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
else: else:
@ -1533,23 +1679,15 @@ def render_settings() -> None:
st.code("\n".join(provider_cfg.models), language="text") st.code("\n".join(provider_cfg.models), language="text")
else: else:
st.info("尚未获取模型列表,可点击下方按钮自动拉取。") st.info("尚未获取模型列表,可点击下方按钮自动拉取。")
# ADD: show cache last updated if available
model_choice_key = f"{default_model_key}_choice" try:
if provider_cfg.models: cache_key = f"{selected_provider}|{(base_val or '').strip()}"
options = provider_cfg.models + ["自定义"] entry = _MODEL_CACHE.get(cache_key)
default_choice = provider_cfg.default_model if provider_cfg.default_model in provider_cfg.models else "自定义" if entry and isinstance(entry.get("ts"), float):
model_choice = st.selectbox("默认模型", options, index=options.index(default_choice), key=model_choice_key) ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S")
if model_choice == "自定义": st.caption(f"最近拉取时间:{ts}")
default_model_val = st.text_input("自定义默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None except Exception:
else: pass
default_model_val = model_choice
else:
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120)
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
fetch_key = f"fetch_models_{selected_provider}" fetch_key = f"fetch_models_{selected_provider}"
if st.button("获取模型列表", key=fetch_key): if st.button("获取模型列表", key=fetch_key):