diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index d756358..3784b39 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -1,7 +1,6 @@ """Streamlit UI scaffold for the investment assistant.""" from __future__ import annotations -import json import sys from dataclasses import asdict from datetime import date, timedelta @@ -219,10 +218,10 @@ def render_settings() -> None: if model_options: options_with_custom = model_options + [custom_model_label] - if primary.model in model_options: + if primary.provider == selected_provider and primary.model in model_options: model_index = options_with_custom.index(primary.model) else: - model_index = len(options_with_custom) - 1 + model_index = 0 selected_model_option = st.selectbox( "LLM 模型", options_with_custom, @@ -232,17 +231,17 @@ def render_settings() -> None: if selected_model_option == custom_model_label: custom_model_value = st.text_input( "自定义模型名称", - value=primary.model if primary.model not in model_options else "", + value="" if primary.provider != selected_provider or primary.model in model_options else primary.model, ) + chosen_model = custom_model_value.strip() or default_model_hint else: - custom_model_value = selected_model_option + chosen_model = selected_model_option else: - custom_model_value = st.text_input( + chosen_model = st.text_input( "LLM 模型", value=primary.model or default_model_hint, help="未预设该 Provider 的模型列表,请手动填写", - ) - selected_model_option = custom_model_label + ).strip() or default_model_hint default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "") provider_default_temp = float(provider_info.get("temperature", 0.2)) provider_default_timeout = int(provider_info.get("timeout", 30.0)) @@ -257,11 +256,16 @@ def render_settings() -> None: timeout_value = provider_default_timeout llm_base = st.text_input( - "LLM Base URL (可选)", + "LLM Base URL", value=base_value, help=f"默认推荐:{default_base_hint or '按供应商要求填写'}", ) - llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password") + llm_api_key = st.text_input( + "LLM API Key", + value=primary.api_key or "", + type="password", + help="点击右侧小图标可查看当前 Key,该值会写入 config.json(已被 gitignore 排除)", + ) llm_temperature = st.slider( "LLM 温度", min_value=0.0, @@ -292,68 +296,104 @@ def render_settings() -> None: format="%d", ) - ensemble_display = [] - for endpoint in llm_cfg.ensemble: - data = asdict(endpoint) - if data.get("api_key"): - data["api_key"] = "" - ensemble_display.append(data) - ensemble_text = st.text_area( - "LLM 集群配置 (JSON 数组)", - value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2), - height=220, + existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble} + + ensemble_rows = [ + { + "provider": ep.provider, + "model": ep.model, + "base_url": ep.base_url or "", + "api_key": ep.api_key or "", + "temperature": float(ep.temperature), + "timeout": float(ep.timeout), + } + for ep in llm_cfg.ensemble + ] + if not ensemble_rows: + ensemble_rows = [ + { + "provider": "", + "model": "", + "base_url": "", + "api_key": "", + "temperature": provider_default_temp, + "timeout": provider_default_timeout, + } + ] + + ensemble_rows = st.data_editor( + ensemble_rows, + num_rows="dynamic", + key="llm_ensemble_editor", + column_config={ + "provider": st.column_config.SelectboxColumn( + "Provider", + options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()), + help="选择 LLM 供应商" + ), + "model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"), + "base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"), + "api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"), + "temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05), + "timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0), + }, + hide_index=True, + use_container_width=True, ) + if hasattr(ensemble_rows, "to_dict"): + ensemble_rows = ensemble_rows.to_dict("records") if st.button("保存 LLM 设置"): - original_provider = primary.provider - original_model = primary.model primary.provider = selected_provider - if model_options: - if selected_model_option == custom_model_label: - model_input = custom_model_value.strip() - primary.model = model_input or DEFAULT_LLM_MODELS.get( - selected_provider, DEFAULT_LLM_MODELS["ollama"] - ) - else: - primary.model = selected_model_option - else: - primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get( - selected_provider, DEFAULT_LLM_MODELS["ollama"] - ) - primary.base_url = llm_base.strip() or None + primary.model = chosen_model + primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider) primary.temperature = llm_temperature primary.timeout = llm_timeout api_key_value = llm_api_key.strip() - primary.api_key = api_key_value or None + if api_key_value: + primary.api_key = api_key_value - try: - parsed = json.loads(ensemble_text or "[]") - if not isinstance(parsed, list): - raise ValueError("ensemble 配置必须是数组") - except Exception as exc: # noqa: BLE001 - LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA) - st.error(f"LLM 集群配置解析失败:{exc}") - else: - new_ensemble: List[LLMEndpoint] = [] - invalid = False - for item in parsed: - if not isinstance(item, dict): - st.error("LLM 集群配置中的每个元素都必须是对象") - invalid = True - break - fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")} - endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")}) - if not endpoint.provider: - endpoint.provider = "ollama" - new_ensemble.append(endpoint) - if not invalid: - llm_cfg.ensemble = new_ensemble - llm_cfg.strategy = selected_strategy - llm_cfg.majority_threshold = int(majority_threshold) - save_config() - LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) - st.success("LLM 设置已保存,仅在当前会话生效。") - st.json(llm_config_snapshot()) + new_ensemble: List[LLMEndpoint] = [] + for row in ensemble_rows: + provider = (row.get("provider") or "").strip().lower() + if not provider: + continue + provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {}) + default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"]) + default_base = DEFAULT_LLM_BASE_URLS.get(provider) + temp_default = float(provider_defaults.get("temperature", 0.2)) + timeout_default = float(provider_defaults.get("timeout", 30.0)) + + model_val = (row.get("model") or "").strip() or default_model + base_val = (row.get("base_url") or "").strip() or default_base + api_raw = (row.get("api_key") or "").strip() + api_value = None + if api_raw and api_raw != "***": + api_value = api_raw + else: + existing = existing_api_keys.get(provider) + if existing: + api_value = existing + + temp_val = row.get("temperature") + timeout_val = row.get("timeout") + endpoint = LLMEndpoint( + provider=provider, + model=model_val, + base_url=base_val, + api_key=api_value, + temperature=float(temp_val) if temp_val is not None else temp_default, + timeout=float(timeout_val) if timeout_val is not None else timeout_default, + ) + new_ensemble.append(endpoint) + + llm_cfg.ensemble = new_ensemble + llm_cfg.strategy = selected_strategy + llm_cfg.majority_threshold = int(majority_threshold) + save_config() + LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) + st.success("LLM 设置已保存,仅在当前会话生效。") + st.json(llm_config_snapshot()) def render_tests() -> None: