From a6b28ff019a5e42243207c397097921bb022b353 Mon Sep 17 00:00:00 2001 From: sam Date: Fri, 3 Oct 2025 20:41:07 +0800 Subject: [PATCH] update --- app/ui/streamlit_app.py | 62 ++++++++++++++++++++++++++++++++++++++++- app/utils/config.py | 4 +++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 2385e8f..a519186 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -1610,12 +1610,24 @@ def render_settings() -> None: st.header("数据与设置") cfg = get_config() LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA) - token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") + + # 基础配置 + col1, col2 = st.columns([2, 1]) + with col1: + token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") + with col2: + auto_update = st.checkbox( + "自动更新数据", + value=cfg.auto_update_data, + help="勾选后,每次启动程序将自动执行Tushare和RSS数据拉取" + ) if st.button("保存设置"): LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA) cfg.tushare_token = token.strip() or None + cfg.auto_update_data = auto_update LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA) + LOGGER.info("自动更新数据设置=%s", cfg.auto_update_data, extra=LOG_EXTRA) save_config() st.success("设置已保存,仅在当前会话生效。") @@ -1674,6 +1686,19 @@ def render_settings() -> None: title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") + + # 添加缺失的表单字段 + default_model_val = st.selectbox( + "默认模型", + options=provider_cfg.models or [""], + index=0 if not provider_cfg.models else (provider_cfg.models.index(provider_cfg.default_model) if provider_cfg.default_model in provider_cfg.models else 0), + key=f"provider_default_model_{selected_provider}" + ) + temp_val = st.number_input("默认温度", value=provider_cfg.default_temperature, min_value=0.0, max_value=2.0, step=0.1, key=temp_key) + timeout_val = st.number_input("默认超时(秒)", value=provider_cfg.default_timeout, min_value=1, max_value=300, step=1, key=timeout_key) + prompt_template_val = st.text_area("Prompt 模板", value=provider_cfg.prompt_template or "", key=prompt_key) + enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) + mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) st.markdown("可用模型:") if provider_cfg.models: st.code("\n".join(provider_cfg.models), language="text") @@ -2299,6 +2324,41 @@ def render_tests() -> None: def main() -> None: LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) st.set_page_config(page_title="多智能体个人投资助理", layout="wide") + + # 检查是否需要自动更新数据 + cfg = get_config() + if cfg.auto_update_data: + LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA) + try: + # 初始化数据库 + from app.data.schema import initialize_database + initialize_database() + + # 执行开机检查(包含数据拉取) + from app.ingest.checker import run_boot_check + with st.spinner("正在自动更新数据..."): + def progress_hook(message: str, progress: float) -> None: + st.write(f"📊 {message} ({progress:.1%})") + + report = run_boot_check( + days=30, # 最近30天 + auto_fetch=True, + progress_hook=progress_hook, + force_refresh=False + ) + + # 执行RSS新闻拉取 + from app.ingest.rss import ingest_configured_rss + rss_count = ingest_configured_rss(hours_back=24, max_items_per_feed=50) + + LOGGER.info("自动数据更新完成:日线数据覆盖%s-%s,RSS新闻%s条", + report.start, report.end, rss_count, extra=LOG_EXTRA) + st.success(f"✅ 自动数据更新完成:获取RSS新闻 {rss_count} 条") + + except Exception as exc: + LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA) + st.error(f"❌ 自动数据更新失败:{exc}") + render_global_dashboard() tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA) diff --git a/app/utils/config.py b/app/utils/config.py index 9d6963d..c4c4749 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -339,6 +339,7 @@ class AppConfig: data_paths: DataPaths = field(default_factory=DataPaths) agent_weights: AgentWeights = field(default_factory=AgentWeights) force_refresh: bool = False + auto_update_data: bool = False llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm: LLMConfig = field(default_factory=LLMConfig) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) @@ -399,6 +400,8 @@ def _load_from_file(cfg: AppConfig) -> None: cfg.tushare_token = payload.get("tushare_token") or None if "force_refresh" in payload: cfg.force_refresh = bool(payload.get("force_refresh")) + if "auto_update_data" in payload: + cfg.auto_update_data = bool(payload.get("auto_update_data")) if "decision_method" in payload: cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) @@ -579,6 +582,7 @@ def save_config(cfg: AppConfig | None = None) -> None: payload = { "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, + "auto_update_data": cfg.auto_update_data, "decision_method": cfg.decision_method, "rss_sources": cfg.rss_sources, "agent_weights": cfg.agent_weights.as_dict(),