From cebc1aeb254342d4655bfdc5b4887abbcb541539 Mon Sep 17 00:00:00 2001 From: sam Date: Wed, 15 Oct 2025 17:32:11 +0800 Subject: [PATCH] refactor trade date parsing and prevent auto-update rerun --- app/ui/streamlit_app.py | 6 ++- app/ui/views/today.py | 110 ++++++++++++++++++---------------------- 2 files changed, 55 insertions(+), 61 deletions(-) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index fa4c320..b96935c 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -41,7 +41,9 @@ def main() -> None: initialize_database() cfg = get_config() - if cfg.auto_update_data: + # 仅在首次运行时执行自动数据更新,避免 Streamlit 每次重跑都触发该逻辑 + AUTO_UPDATE_FLAG = "auto_update_has_run" + if cfg.auto_update_data and not st.session_state.get(AUTO_UPDATE_FLAG): LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA) try: with st.spinner("正在自动更新数据..."): @@ -66,6 +68,8 @@ def main() -> None: except Exception as exc: # noqa: BLE001 LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA) st.error(f"❌ 自动数据更新失败:{exc}") + finally: + st.session_state[AUTO_UPDATE_FLAG] = True render_global_dashboard() diff --git a/app/ui/views/today.py b/app/ui/views/today.py index 0454acb..3602b82 100644 --- a/app/ui/views/today.py +++ b/app/ui/views/today.py @@ -27,6 +27,21 @@ from app.ui.shared import ( ) +def _parse_trade_date(trade_date: str | int | date) -> date: + """Parse trade date inputs that may come in multiple string formats.""" + if isinstance(trade_date, date): + return trade_date + value = str(trade_date) + try: + return date.fromisoformat(value) + except ValueError: + pass + try: + return datetime.strptime(value, "%Y%m%d").date() + except ValueError as exc: + raise ValueError(f"无法解析交易日:{trade_date}") from exc + + def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]: unique_symbols = list(dict.fromkeys(symbols)) if not unique_symbols: @@ -176,7 +191,9 @@ def render_today_plan() -> None: default_idx = 0 # 确保日期格式统一为 YYYYMMDD formatted_trade_dates = [str(td) for td in trade_dates] - trade_date = st.selectbox("交易日", formatted_trade_dates, index=default_idx) + selector_col, actions_col = st.columns([3, 2]) + with selector_col: + trade_date = st.selectbox("交易日", formatted_trade_dates, index=default_idx) with db_session(read_only=True) as conn: code_rows = conn.execute( @@ -207,6 +224,33 @@ def render_today_plan() -> None: else: st.caption("所选日期暂无候选池数据,仍可查看代理决策记录。") + with actions_col: + metrics_cols = st.columns(2) + metrics_cols[0].metric("标的数量", len(symbols)) + metrics_cols[1].metric("候选池标的", len(candidate_records)) + st.caption("一键触发策略重评估(包含当前交易日的所有标的)。") + if st.button("一键重评估全部", type="primary", use_container_width=True): + with st.spinner("正在对所有标的进行重评估,请稍候..."): + try: + trade_date_obj = _parse_trade_date(trade_date) + progress = st.progress(0.0) + progress.progress(0.3 if symbols else 1.0) + changes_all = _reevaluate_symbols( + trade_date_obj, + symbols, + "reeval_ui_all", + "UI All Re-eval", + ) + progress.progress(1.0) + st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的") + if changes_all: + st.write("检测到以下动作变更:") + st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("一键重评估失败", extra=LOG_EXTRA) + st.error(f"一键重评估执行过程中发生错误:{exc}") + detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) with assistant_tab: _render_today_plan_assistant_view(trade_date, candidate_records) @@ -306,39 +350,6 @@ def _render_today_plan_symbol_view( default_batch = [code for code in symbols if code in candidate_code_set] batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10]) - if st.button("一键重评估所有标的", type="primary", width='stretch'): - with st.spinner("正在对所有标的进行重评估,请稍候..."): - try: - trade_date_obj: Optional[date] = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") - - progress = st.progress(0.0) - progress.progress(0.3 if symbols else 1.0) - changes_all = _reevaluate_symbols( - trade_date_obj, - symbols, - "reeval_ui_all", - "UI All Re-eval", - ) - progress.progress(1.0) - st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的") - if changes_all: - st.write("检测到以下动作变更:") - st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') - - st.rerun() - except Exception as exc: # noqa: BLE001 - LOGGER.exception("一键重评估失败", extra=LOG_EXTRA) - st.error(f"一键重评估执行过程中发生错误:{exc}") - set_query_params(date=str(trade_date), code=str(ts_code)) with db_session(read_only=True) as conn: @@ -594,12 +605,9 @@ def _render_today_plan_symbol_view( try: with db_session(read_only=True) as conn: try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - trade_date_obj = date.today() - timedelta(days=7) + trade_date_obj = _parse_trade_date(trade_date) + except ValueError: + trade_date_obj = date.today() news_query = """ SELECT id, title, source, pub_time, sentiment, heat, entities @@ -684,16 +692,7 @@ def _render_today_plan_symbol_view( if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): with st.spinner("正在重评估..."): try: - trade_date_obj: Optional[date] = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") + trade_date_obj = _parse_trade_date(trade_date) changes = _reevaluate_symbols( trade_date_obj, [ts_code], @@ -717,16 +716,7 @@ def _render_today_plan_symbol_view( if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): with st.spinner("批量重评估执行中..."): try: - trade_date_obj: Optional[date] = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") + trade_date_obj = _parse_trade_date(trade_date) progress = st.progress(0.0) progress.progress(0.3 if batch_symbols else 1.0) changes_all = _reevaluate_symbols(