This commit is contained in:
sam 2025-09-27 21:53:40 +08:00
parent 5b4bd51199
commit 6a7ca88e63

View File

@ -1,7 +1,6 @@
"""Streamlit UI scaffold for the investment assistant.""" """Streamlit UI scaffold for the investment assistant."""
from __future__ import annotations from __future__ import annotations
import json
import sys import sys
from dataclasses import asdict from dataclasses import asdict
from datetime import date, timedelta from datetime import date, timedelta
@ -219,10 +218,10 @@ def render_settings() -> None:
if model_options: if model_options:
options_with_custom = model_options + [custom_model_label] options_with_custom = model_options + [custom_model_label]
if primary.model in model_options: if primary.provider == selected_provider and primary.model in model_options:
model_index = options_with_custom.index(primary.model) model_index = options_with_custom.index(primary.model)
else: else:
model_index = len(options_with_custom) - 1 model_index = 0
selected_model_option = st.selectbox( selected_model_option = st.selectbox(
"LLM 模型", "LLM 模型",
options_with_custom, options_with_custom,
@ -232,17 +231,17 @@ def render_settings() -> None:
if selected_model_option == custom_model_label: if selected_model_option == custom_model_label:
custom_model_value = st.text_input( custom_model_value = st.text_input(
"自定义模型名称", "自定义模型名称",
value=primary.model if primary.model not in model_options else "", value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
) )
chosen_model = custom_model_value.strip() or default_model_hint
else: else:
custom_model_value = selected_model_option chosen_model = selected_model_option
else: else:
custom_model_value = st.text_input( chosen_model = st.text_input(
"LLM 模型", "LLM 模型",
value=primary.model or default_model_hint, value=primary.model or default_model_hint,
help="未预设该 Provider 的模型列表,请手动填写", help="未预设该 Provider 的模型列表,请手动填写",
) ).strip() or default_model_hint
selected_model_option = custom_model_label
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "") default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
provider_default_temp = float(provider_info.get("temperature", 0.2)) provider_default_temp = float(provider_info.get("temperature", 0.2))
provider_default_timeout = int(provider_info.get("timeout", 30.0)) provider_default_timeout = int(provider_info.get("timeout", 30.0))
@ -257,11 +256,16 @@ def render_settings() -> None:
timeout_value = provider_default_timeout timeout_value = provider_default_timeout
llm_base = st.text_input( llm_base = st.text_input(
"LLM Base URL (可选)", "LLM Base URL",
value=base_value, value=base_value,
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}", help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
) )
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password") llm_api_key = st.text_input(
"LLM API Key",
value=primary.api_key or "",
type="password",
help="点击右侧小图标可查看当前 Key该值会写入 config.json已被 gitignore 排除)",
)
llm_temperature = st.slider( llm_temperature = st.slider(
"LLM 温度", "LLM 温度",
min_value=0.0, min_value=0.0,
@ -292,68 +296,104 @@ def render_settings() -> None:
format="%d", format="%d",
) )
ensemble_display = [] existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
for endpoint in llm_cfg.ensemble:
data = asdict(endpoint) ensemble_rows = [
if data.get("api_key"): {
data["api_key"] = "" "provider": ep.provider,
ensemble_display.append(data) "model": ep.model,
ensemble_text = st.text_area( "base_url": ep.base_url or "",
"LLM 集群配置 (JSON 数组)", "api_key": ep.api_key or "",
value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2), "temperature": float(ep.temperature),
height=220, "timeout": float(ep.timeout),
}
for ep in llm_cfg.ensemble
]
if not ensemble_rows:
ensemble_rows = [
{
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": provider_default_temp,
"timeout": provider_default_timeout,
}
]
ensemble_rows = st.data_editor(
ensemble_rows,
num_rows="dynamic",
key="llm_ensemble_editor",
column_config={
"provider": st.column_config.SelectboxColumn(
"Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
help="选择 LLM 供应商"
),
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
},
hide_index=True,
use_container_width=True,
) )
if hasattr(ensemble_rows, "to_dict"):
ensemble_rows = ensemble_rows.to_dict("records")
if st.button("保存 LLM 设置"): if st.button("保存 LLM 设置"):
original_provider = primary.provider
original_model = primary.model
primary.provider = selected_provider primary.provider = selected_provider
if model_options: primary.model = chosen_model
if selected_model_option == custom_model_label: primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
model_input = custom_model_value.strip()
primary.model = model_input or DEFAULT_LLM_MODELS.get(
selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
else:
primary.model = selected_model_option
else:
primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get(
selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
primary.base_url = llm_base.strip() or None
primary.temperature = llm_temperature primary.temperature = llm_temperature
primary.timeout = llm_timeout primary.timeout = llm_timeout
api_key_value = llm_api_key.strip() api_key_value = llm_api_key.strip()
primary.api_key = api_key_value or None if api_key_value:
primary.api_key = api_key_value
try: new_ensemble: List[LLMEndpoint] = []
parsed = json.loads(ensemble_text or "[]") for row in ensemble_rows:
if not isinstance(parsed, list): provider = (row.get("provider") or "").strip().lower()
raise ValueError("ensemble 配置必须是数组") if not provider:
except Exception as exc: # noqa: BLE001 continue
LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA) provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
st.error(f"LLM 集群配置解析失败:{exc}") default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
else: default_base = DEFAULT_LLM_BASE_URLS.get(provider)
new_ensemble: List[LLMEndpoint] = [] temp_default = float(provider_defaults.get("temperature", 0.2))
invalid = False timeout_default = float(provider_defaults.get("timeout", 30.0))
for item in parsed:
if not isinstance(item, dict): model_val = (row.get("model") or "").strip() or default_model
st.error("LLM 集群配置中的每个元素都必须是对象") base_val = (row.get("base_url") or "").strip() or default_base
invalid = True api_raw = (row.get("api_key") or "").strip()
break api_value = None
fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")} if api_raw and api_raw != "***":
endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")}) api_value = api_raw
if not endpoint.provider: else:
endpoint.provider = "ollama" existing = existing_api_keys.get(provider)
new_ensemble.append(endpoint) if existing:
if not invalid: api_value = existing
llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy temp_val = row.get("temperature")
llm_cfg.majority_threshold = int(majority_threshold) timeout_val = row.get("timeout")
save_config() endpoint = LLMEndpoint(
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) provider=provider,
st.success("LLM 设置已保存,仅在当前会话生效。") model=model_val,
st.json(llm_config_snapshot()) base_url=base_val,
api_key=api_value,
temperature=float(temp_val) if temp_val is not None else temp_default,
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
)
new_ensemble.append(endpoint)
llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy
llm_cfg.majority_threshold = int(majority_threshold)
save_config()
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
st.success("LLM 设置已保存,仅在当前会话生效。")
st.json(llm_config_snapshot())
def render_tests() -> None: def render_tests() -> None: