diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 342d1d4..2385e8f 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -71,6 +71,23 @@ _SIDEBAR_LISTENER_ATTACHED = False # ADD: simple in-memory cache for provider model discovery _MODEL_CACHE: Dict[str, Dict[str, object]] = {} _CACHE_TTL_SECONDS = 300 +_WARNINGS_CONTAINER = None +_WARNINGS_PLACEHOLDER = None + +# ADD: query param helpers +def _get_query_params() -> Dict[str, List[str]]: + try: + return dict(st.query_params) + except Exception: + return {} + +def _set_query_params(**kwargs: object) -> None: + try: + payload = {k: v for k, v in kwargs.items() if v is not None} + if payload: + st.query_params.update(payload) + except Exception: + pass def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: @@ -86,14 +103,20 @@ def render_global_dashboard() -> None: global _DASHBOARD_CONTAINERS global _DASHBOARD_ELEMENTS global _SIDEBAR_LISTENER_ATTACHED + global _WARNINGS_CONTAINER + global _WARNINGS_PLACEHOLDER + + # ADD: warning badge on top + warnings = alerts.get_warnings() + badge = f" ({len(warnings)})" if warnings else "" + st.sidebar.header(f"系统监控{badge}") metrics_container = st.sidebar.container() decisions_container = st.sidebar.container() + _WARNINGS_CONTAINER = st.sidebar.container() + _WARNINGS_PLACEHOLDER = st.sidebar.empty() _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) - if st.sidebar.button("清除数据告警", key="clear_data_alerts"): - alerts.clear_warnings() - _update_dashboard_sidebar() if not _SIDEBAR_LISTENER_ATTACHED: register_llm_metrics_listener(_sidebar_metrics_listener) _SIDEBAR_LISTENER_ATTACHED = True @@ -105,6 +128,8 @@ def _update_dashboard_sidebar( ) -> None: global _DASHBOARD_CONTAINERS global _DASHBOARD_ELEMENTS + global _WARNINGS_CONTAINER + global _WARNINGS_PLACEHOLDER containers = _DASHBOARD_CONTAINERS if not containers: @@ -140,23 +165,6 @@ def _update_dashboard_sidebar( else: model_placeholder.info("暂无模型分布数据。") - warnings_placeholder = elements.get("warnings") - if warnings_placeholder is not None: - warnings_placeholder.empty() - warnings = alerts.get_warnings() - if warnings: - lines = [] - for warning in warnings[-10:]: - detail = warning.get("detail") - appendix = f" {detail}" if detail else "" - lines.append( - f"- **{warning['source']}** {warning['message']}{appendix}" - f"\n{warning['timestamp']}" - ) - warnings_placeholder.markdown("\n".join(lines), unsafe_allow_html=True) - else: - warnings_placeholder.info("暂无数据告警。") - decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) if decisions: lines = [] @@ -177,6 +185,37 @@ def _update_dashboard_sidebar( decisions_placeholder = elements["decisions_list"] decisions_placeholder.empty() decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。") + # Render warnings section in-place (clear then write) + if _WARNINGS_PLACEHOLDER is not None: + _WARNINGS_PLACEHOLDER.empty() + with _WARNINGS_PLACEHOLDER.container(): + st.subheader("数据告警") + warn_list = alerts.get_warnings() + if warn_list: + lines = [] + for warning in warn_list[-10:]: + detail = warning.get("detail") + appendix = f" {detail}" if detail else "" + lines.append( + f"- **{warning['source']}** {warning['message']}{appendix}\n{warning['timestamp']}" + ) + st.markdown("\n".join(lines), unsafe_allow_html=True) + btn_cols = st.columns([1,1]) + if btn_cols[0].button("清除数据告警", key="clear_data_alerts_sibling"): + alerts.clear_warnings() + _update_dashboard_sidebar() + try: + st.download_button( + "导出告警(JSON)", + data=json.dumps(warn_list, ensure_ascii=False, indent=2), + file_name="data_warnings.json", + mime="application/json", + key="dl_warnings_json_sibling", + ) + except Exception: + pass + else: + st.info("暂无数据告警。") def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: @@ -188,8 +227,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s distribution_expander = metrics_container.expander("调用分布", expanded=False) provider_distribution = distribution_expander.empty() model_distribution = distribution_expander.empty() - warnings_expander = metrics_container.expander("数据告警", expanded=False) - warnings_placeholder = warnings_expander.empty() decisions_container.subheader("最新决策") decisions_list = decisions_container.empty() @@ -200,7 +237,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s "metrics_completion": metrics_completion, "provider_distribution": provider_distribution, "model_distribution": model_distribution, - "warnings": warnings_placeholder, "decisions_list": decisions_list, } return elements @@ -382,7 +418,14 @@ def render_today_plan() -> None: st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。") return - trade_date = st.selectbox("交易日", trade_dates, index=0) + # ADD: read default selection from URL + q = _get_query_params() + default_trade_date = q.get("date", [trade_dates[0]])[0] + try: + default_idx = trade_dates.index(default_trade_date) + except ValueError: + default_idx = 0 + trade_date = st.selectbox("交易日", trade_dates, index=default_idx) with db_session(read_only=True) as conn: code_rows = conn.execute( @@ -399,10 +442,18 @@ def render_today_plan() -> None: st.info("所选交易日暂无 agent_utils 记录。") return - ts_code = st.selectbox("标的", symbols, index=0) + default_ts = q.get("code", [symbols[0]])[0] + try: + default_ts_idx = symbols.index(default_ts) + except ValueError: + default_ts_idx = 0 + ts_code = st.selectbox("标的", symbols, index=default_ts_idx) # ADD: batch selection for re-evaluation batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) + # sync URL params + _set_query_params(date=str(trade_date), code=str(ts_code)) + with db_session(read_only=True) as conn: rows = conn.execute( """ @@ -513,15 +564,42 @@ def render_today_plan() -> None: if global_info["requires_review"]: st.warning("部门分歧较大,已标记为需人工复核。") with st.expander("基础上下文数据", expanded=False): - if global_info.get("scope_values"): - st.write("最新字段:") - st.json(global_info["scope_values"]) - if global_info.get("close_series"): + # ADD: export buttons + scope = global_info.get("scope_values") or {} + close_series = global_info.get("close_series") or [] + turnover_series = global_info.get("turnover_series") or [] + st.write("最新字段:") + if scope: + st.json(scope) + st.download_button( + "下载字段(JSON)", + data=json.dumps(scope, ensure_ascii=False, indent=2), + file_name=f"{ts_code}_{trade_date}_scope.json", + mime="application/json", + key="dl_scope_json", + ) + if close_series: st.write("收盘价时间序列 (最近窗口):") - st.json(global_info["close_series"]) - if global_info.get("turnover_series"): + st.json(close_series) + try: + import io, csv + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow(["trade_date", "close"]) + for dt, val in close_series: + writer.writerow([dt, val]) + st.download_button( + "下载收盘价(CSV)", + data=buf.getvalue(), + file_name=f"{ts_code}_{trade_date}_close_series.csv", + mime="text/csv", + key="dl_close_csv", + ) + except Exception: + pass + if turnover_series: st.write("换手率时间序列 (最近窗口):") - st.json(global_info["turnover_series"]) + st.json(turnover_series) dept_sup = global_info.get("department_supplements") or {} dept_dialogue = global_info.get("department_dialogue") or {} dept_telemetry = global_info.get("department_telemetry") or {} @@ -550,8 +628,23 @@ def render_today_plan() -> None: item for item in dept_records if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) ] + # ADD: confidence filter and sort + min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05) + sort_col = st.selectbox("排序列", ["信心", "权重"], index=0) + filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf] + filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True) dept_df = pd.DataFrame(filtered) st.dataframe(dept_df, width='stretch', hide_index=True) + try: + st.download_button( + "下载部门意见(CSV)", + data=dept_df.to_csv(index=False), + file_name=f"{trade_date}_{ts_code}_departments.csv", + mime="text/csv", + key="dl_dept_csv", + ) + except Exception: + pass for code, details in dept_details.items(): with st.expander(f"{code} 补充详情", expanded=False): supplements = details.get("supplements", []) @@ -576,8 +669,26 @@ def render_today_plan() -> None: st.subheader("代理评分") if agent_records: + # ADD: sorting and CSV export for agents + sort_agent_by = st.selectbox( + "代理排序", + ["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"], + index=1, + ) agent_df = pd.DataFrame(agent_records) + if sort_agent_by in agent_df.columns: + agent_df = agent_df.sort_values(sort_agent_by, ascending=False) st.dataframe(agent_df, width='stretch', hide_index=True) + try: + st.download_button( + "下载代理评分(CSV)", + data=agent_df.to_csv(index=False), + file_name=f"{trade_date}_{ts_code}_agents.csv", + mime="text/csv", + key="dl_agent_csv", + ) + except Exception: + pass else: st.info("暂无基础代理评分。") @@ -1398,6 +1509,8 @@ def render_backtest() -> None: cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2]) selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] + nav_df = pd.DataFrame() + rpt_df = pd.DataFrame() if selected_ids: try: with db_session(read_only=True) as conn: @@ -1419,6 +1532,16 @@ def render_backtest() -> None: if not nav_df.empty: try: nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") + # ADD: date window filter + overall_min = pd.to_datetime(nav_df["trade_date"].min()).date() + overall_max = pd.to_datetime(nav_df["trade_date"].max()).date() + col_d1, col_d2 = st.columns(2) + start_filter = col_d1.date_input("起始日期", value=overall_min) + end_filter = col_d2.date_input("结束日期", value=overall_max) + if start_filter > end_filter: + start_filter, end_filter = end_filter, start_filter + mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter) + nav_df = nav_df.loc[mask] pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") if normalize_to_one: pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) @@ -1430,6 +1553,19 @@ def render_backtest() -> None: if use_log_y: fig.update_yaxes(type="log") st.plotly_chart(fig, use_container_width=True) + # ADD: export pivot + try: + csv_buf = pivot.reset_index() + csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns] + st.download_button( + "下载曲线(CSV)", + data=csv_buf.to_csv(index=False), + file_name="bt_nav_compare.csv", + mime="text/csv", + key="dl_nav_compare", + ) + except Exception: + pass except Exception: # noqa: BLE001 LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) if not rpt_df.empty: @@ -1453,6 +1589,16 @@ def render_backtest() -> None: if metrics_rows: dfm = pd.DataFrame(metrics_rows) st.dataframe(dfm, hide_index=True, width='stretch') + try: + st.download_button( + "下载指标(CSV)", + data=dfm.to_csv(index=False), + file_name="bt_metrics_compare.csv", + mime="text/csv", + key="dl_metrics_compare", + ) + except Exception: + pass except Exception: # noqa: BLE001 LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) else: @@ -1533,23 +1679,15 @@ def render_settings() -> None: st.code("\n".join(provider_cfg.models), language="text") else: st.info("尚未获取模型列表,可点击下方按钮自动拉取。") - - model_choice_key = f"{default_model_key}_choice" - if provider_cfg.models: - options = provider_cfg.models + ["自定义"] - default_choice = provider_cfg.default_model if provider_cfg.default_model in provider_cfg.models else "自定义" - model_choice = st.selectbox("默认模型", options, index=options.index(default_choice), key=model_choice_key) - if model_choice == "自定义": - default_model_val = st.text_input("自定义默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None - else: - default_model_val = model_choice - else: - default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None - mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) - temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key) - timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key) - prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120) - enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) + # ADD: show cache last updated if available + try: + cache_key = f"{selected_provider}|{(base_val or '').strip()}" + entry = _MODEL_CACHE.get(cache_key) + if entry and isinstance(entry.get("ts"), float): + ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S") + st.caption(f"最近拉取时间:{ts}") + except Exception: + pass fetch_key = f"fetch_models_{selected_provider}" if st.button("获取模型列表", key=fetch_key):