From d81dfefc31ff36ce869ce1069b971cc0e0b97f4f Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 28 Sep 2025 14:49:58 +0800 Subject: [PATCH] update --- README.md | 4 +-- app/ui/streamlit_app.py | 56 ++++++++++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index cb82fe5..ebf8cb2 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ export TUSHARE_TOKEN="" ### LLM 配置与测试 -- 通过 Provider 管理供应商连接参数(Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务(DeepSeek、文心一言、OpenAI 等)。 +- 通过 Provider 管理供应商连接参数(Base URL、API Key、默认温度/超时/Prompt 模板),并支持在界面内一键调用 `client.models.list()` 拉取可用模型列表,便于扩展本地 Ollama 或各类云端服务(DeepSeek、文心一言、OpenAI 等)。 - 全局与部门配置直接选择 Provider,并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。 - Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。 - 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。 @@ -103,7 +103,7 @@ Streamlit `自检测试` 页签提供: ## 实施步骤 1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅ - - 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/Prompt;Streamlit 提供可视化维护表单。 + - 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/Prompt,Provider 页面提供模型列表自动获取;Streamlit 提供可视化维护表单。 2. **部门管控器** ✅ - `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。 diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 26bf3f0..4649ce5 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -430,14 +430,13 @@ def render_settings() -> None: cfg.llm_providers = providers save_config() st.success(f"已创建 Provider {key}。") - st.experimental_rerun() + st.rerun() if selected_provider: provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) title_key = f"provider_title_{selected_provider}" base_key = f"provider_base_{selected_provider}" api_key_key = f"provider_api_{selected_provider}" - models_key = f"provider_models_{selected_provider}" default_model_key = f"provider_default_model_{selected_provider}" mode_key = f"provider_mode_{selected_provider}" temp_key = f"provider_temp_{selected_provider}" @@ -448,8 +447,23 @@ def render_settings() -> None: title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") - models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100) - default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key) + st.markdown("可用模型:") + if provider_cfg.models: + st.code("\n".join(provider_cfg.models), language="text") + else: + st.info("尚未获取模型列表,可点击下方按钮自动拉取。") + + model_choice_key = f"{default_model_key}_choice" + if provider_cfg.models: + options = provider_cfg.models + ["自定义"] + default_choice = provider_cfg.default_model if provider_cfg.default_model in provider_cfg.models else "自定义" + model_choice = st.selectbox("默认模型", options, index=options.index(default_choice), key=model_choice_key) + if model_choice == "自定义": + default_model_val = st.text_input("自定义默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None + else: + default_model_val = model_choice + else: + default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key) timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key) @@ -471,15 +485,16 @@ def render_settings() -> None: cfg.sync_runtime_llm() save_config() st.success(f"共获取 {len(models)} 个模型。") - st.session_state[models_key] = "\n".join(models) - st.session_state[default_model_key] = provider_cfg.default_model or "" + st.rerun() if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): provider_cfg.title = title_val.strip() provider_cfg.base_url = base_val.strip() provider_cfg.api_key = api_val.strip() or None - provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()] - provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model) + if provider_cfg.models and default_model_val in provider_cfg.models: + provider_cfg.default_model = default_model_val + else: + provider_cfg.default_model = default_model_val provider_cfg.default_temperature = float(temp_val) provider_cfg.default_timeout = float(timeout_val) provider_cfg.prompt_template = prompt_template_val.strip() @@ -511,7 +526,7 @@ def render_settings() -> None: cfg.sync_runtime_llm() save_config() st.success("Provider 已删除。") - st.experimental_rerun() + st.rerun() st.markdown("##### 全局推理配置") if not provider_keys: @@ -710,13 +725,17 @@ def render_settings() -> None: strategy_val = (row.get("strategy") or existing.llm.strategy).lower() if strategy_val in ALLOWED_LLM_STRATEGIES: existing.llm.strategy = strategy_val - majority_raw = row.get("majority_threshold") - try: - majority_val = int(majority_raw) - if majority_val > 0: - existing.llm.majority_threshold = majority_val - except (TypeError, ValueError): - pass + if existing.llm.strategy == "single": + existing.llm.majority_threshold = 1 + existing.llm.ensemble = [] + else: + majority_raw = row.get("majority_threshold") + try: + majority_val = int(majority_raw) + if majority_val > 0: + existing.llm.majority_threshold = majority_val + except (TypeError, ValueError): + pass provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() model_val = (row.get("model") or "").strip() or None @@ -733,7 +752,8 @@ def render_settings() -> None: endpoint.base_url = None endpoint.api_key = None existing.llm.primary = endpoint - existing.llm.ensemble = [] + if existing.llm.strategy != "single": + existing.llm.ensemble = [] updated_departments[code] = existing @@ -752,7 +772,7 @@ def render_settings() -> None: cfg.sync_runtime_llm() save_config() st.success("已恢复默认部门配置。") - st.experimental_rerun() + st.rerun() st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")