diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 3784b39..e232bac 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -298,37 +298,36 @@ def render_settings() -> None: existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble} + available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()) ensemble_rows = [ { - "provider": ep.provider, - "model": ep.model, - "base_url": ep.base_url or "", - "api_key": ep.api_key or "", + "provider": ep.provider or "", + "model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]), + "base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""), + "api_key": "***" if ep.api_key else "", "temperature": float(ep.temperature), "timeout": float(ep.timeout), } for ep in llm_cfg.ensemble + ] or [ + { + "provider": "", + "model": "", + "base_url": "", + "api_key": "", + "temperature": provider_default_temp, + "timeout": provider_default_timeout, + } ] - if not ensemble_rows: - ensemble_rows = [ - { - "provider": "", - "model": "", - "base_url": "", - "api_key": "", - "temperature": provider_default_temp, - "timeout": provider_default_timeout, - } - ] - ensemble_rows = st.data_editor( + edited = st.data_editor( ensemble_rows, num_rows="dynamic", key="llm_ensemble_editor", column_config={ "provider": st.column_config.SelectboxColumn( "Provider", - options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()), + options=available_providers, help="选择 LLM 供应商" ), "model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"), @@ -340,8 +339,10 @@ def render_settings() -> None: hide_index=True, use_container_width=True, ) - if hasattr(ensemble_rows, "to_dict"): - ensemble_rows = ensemble_rows.to_dict("records") + if hasattr(edited, "to_dict"): + ensemble_rows = edited.to_dict("records") + else: + ensemble_rows = edited if st.button("保存 LLM 设置"): primary.provider = selected_provider @@ -367,13 +368,10 @@ def render_settings() -> None: model_val = (row.get("model") or "").strip() or default_model base_val = (row.get("base_url") or "").strip() or default_base api_raw = (row.get("api_key") or "").strip() - api_value = None - if api_raw and api_raw != "***": - api_value = api_raw + if api_raw == "***": + api_value = existing_api_keys.get(provider) else: - existing = existing_api_keys.get(provider) - if existing: - api_value = existing + api_value = api_raw or None temp_val = row.get("temperature") timeout_val = row.get("timeout")