refactor trade date parsing and prevent auto-update rerun

This commit is contained in:
sam 2025-10-15 17:32:11 +08:00
parent 7a3cfc3980
commit cebc1aeb25
2 changed files with 55 additions and 61 deletions

View File

@ -41,7 +41,9 @@ def main() -> None:
initialize_database() initialize_database()
cfg = get_config() cfg = get_config()
if cfg.auto_update_data: # 仅在首次运行时执行自动数据更新,避免 Streamlit 每次重跑都触发该逻辑
AUTO_UPDATE_FLAG = "auto_update_has_run"
if cfg.auto_update_data and not st.session_state.get(AUTO_UPDATE_FLAG):
LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA) LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA)
try: try:
with st.spinner("正在自动更新数据..."): with st.spinner("正在自动更新数据..."):
@ -66,6 +68,8 @@ def main() -> None:
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA) LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA)
st.error(f"❌ 自动数据更新失败:{exc}") st.error(f"❌ 自动数据更新失败:{exc}")
finally:
st.session_state[AUTO_UPDATE_FLAG] = True
render_global_dashboard() render_global_dashboard()

View File

@ -27,6 +27,21 @@ from app.ui.shared import (
) )
def _parse_trade_date(trade_date: str | int | date) -> date:
"""Parse trade date inputs that may come in multiple string formats."""
if isinstance(trade_date, date):
return trade_date
value = str(trade_date)
try:
return date.fromisoformat(value)
except ValueError:
pass
try:
return datetime.strptime(value, "%Y%m%d").date()
except ValueError as exc:
raise ValueError(f"无法解析交易日:{trade_date}") from exc
def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]: def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]:
unique_symbols = list(dict.fromkeys(symbols)) unique_symbols = list(dict.fromkeys(symbols))
if not unique_symbols: if not unique_symbols:
@ -176,6 +191,8 @@ def render_today_plan() -> None:
default_idx = 0 default_idx = 0
# 确保日期格式统一为 YYYYMMDD # 确保日期格式统一为 YYYYMMDD
formatted_trade_dates = [str(td) for td in trade_dates] formatted_trade_dates = [str(td) for td in trade_dates]
selector_col, actions_col = st.columns([3, 2])
with selector_col:
trade_date = st.selectbox("交易日", formatted_trade_dates, index=default_idx) trade_date = st.selectbox("交易日", formatted_trade_dates, index=default_idx)
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
@ -207,6 +224,33 @@ def render_today_plan() -> None:
else: else:
st.caption("所选日期暂无候选池数据,仍可查看代理决策记录。") st.caption("所选日期暂无候选池数据,仍可查看代理决策记录。")
with actions_col:
metrics_cols = st.columns(2)
metrics_cols[0].metric("标的数量", len(symbols))
metrics_cols[1].metric("候选池标的", len(candidate_records))
st.caption("一键触发策略重评估(包含当前交易日的所有标的)。")
if st.button("一键重评估全部", type="primary", use_container_width=True):
with st.spinner("正在对所有标的进行重评估,请稍候..."):
try:
trade_date_obj = _parse_trade_date(trade_date)
progress = st.progress(0.0)
progress.progress(0.3 if symbols else 1.0)
changes_all = _reevaluate_symbols(
trade_date_obj,
symbols,
"reeval_ui_all",
"UI All Re-eval",
)
progress.progress(1.0)
st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的")
if changes_all:
st.write("检测到以下动作变更:")
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
st.rerun()
except Exception as exc: # noqa: BLE001
LOGGER.exception("一键重评估失败", extra=LOG_EXTRA)
st.error(f"一键重评估执行过程中发生错误:{exc}")
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
with assistant_tab: with assistant_tab:
_render_today_plan_assistant_view(trade_date, candidate_records) _render_today_plan_assistant_view(trade_date, candidate_records)
@ -306,39 +350,6 @@ def _render_today_plan_symbol_view(
default_batch = [code for code in symbols if code in candidate_code_set] default_batch = [code for code in symbols if code in candidate_code_set]
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10]) batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10])
if st.button("一键重评估所有标的", type="primary", width='stretch'):
with st.spinner("正在对所有标的进行重评估,请稍候..."):
try:
trade_date_obj: Optional[date] = None
try:
trade_date_obj = date.fromisoformat(str(trade_date))
except Exception:
try:
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
pass
if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}")
progress = st.progress(0.0)
progress.progress(0.3 if symbols else 1.0)
changes_all = _reevaluate_symbols(
trade_date_obj,
symbols,
"reeval_ui_all",
"UI All Re-eval",
)
progress.progress(1.0)
st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的")
if changes_all:
st.write("检测到以下动作变更:")
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
st.rerun()
except Exception as exc: # noqa: BLE001
LOGGER.exception("一键重评估失败", extra=LOG_EXTRA)
st.error(f"一键重评估执行过程中发生错误:{exc}")
set_query_params(date=str(trade_date), code=str(ts_code)) set_query_params(date=str(trade_date), code=str(ts_code))
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
@ -594,12 +605,9 @@ def _render_today_plan_symbol_view(
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
try: try:
trade_date_obj = date.fromisoformat(str(trade_date)) trade_date_obj = _parse_trade_date(trade_date)
except Exception: except ValueError:
try: trade_date_obj = date.today()
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
trade_date_obj = date.today() - timedelta(days=7)
news_query = """ news_query = """
SELECT id, title, source, pub_time, sentiment, heat, entities SELECT id, title, source, pub_time, sentiment, heat, entities
@ -684,16 +692,7 @@ def _render_today_plan_symbol_view(
if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"):
with st.spinner("正在重评估..."): with st.spinner("正在重评估..."):
try: try:
trade_date_obj: Optional[date] = None trade_date_obj = _parse_trade_date(trade_date)
try:
trade_date_obj = date.fromisoformat(str(trade_date))
except Exception:
try:
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
pass
if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}")
changes = _reevaluate_symbols( changes = _reevaluate_symbols(
trade_date_obj, trade_date_obj,
[ts_code], [ts_code],
@ -717,16 +716,7 @@ def _render_today_plan_symbol_view(
if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols):
with st.spinner("批量重评估执行中..."): with st.spinner("批量重评估执行中..."):
try: try:
trade_date_obj: Optional[date] = None trade_date_obj = _parse_trade_date(trade_date)
try:
trade_date_obj = date.fromisoformat(str(trade_date))
except Exception:
try:
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
pass
if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}")
progress = st.progress(0.0) progress = st.progress(0.0)
progress.progress(0.3 if batch_symbols else 1.0) progress.progress(0.3 if batch_symbols else 1.0)
changes_all = _reevaluate_symbols( changes_all = _reevaluate_symbols(