This commit is contained in:
sam 2025-09-28 08:45:20 +08:00
parent 6a7ca88e63
commit 3f3af5404f

View File

@ -298,37 +298,36 @@ def render_settings() -> None:
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble} existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
ensemble_rows = [ ensemble_rows = [
{ {
"provider": ep.provider, "provider": ep.provider or "",
"model": ep.model, "model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]),
"base_url": ep.base_url or "", "base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""),
"api_key": ep.api_key or "", "api_key": "***" if ep.api_key else "",
"temperature": float(ep.temperature), "temperature": float(ep.temperature),
"timeout": float(ep.timeout), "timeout": float(ep.timeout),
} }
for ep in llm_cfg.ensemble for ep in llm_cfg.ensemble
] or [
{
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": provider_default_temp,
"timeout": provider_default_timeout,
}
] ]
if not ensemble_rows:
ensemble_rows = [
{
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": provider_default_temp,
"timeout": provider_default_timeout,
}
]
ensemble_rows = st.data_editor( edited = st.data_editor(
ensemble_rows, ensemble_rows,
num_rows="dynamic", num_rows="dynamic",
key="llm_ensemble_editor", key="llm_ensemble_editor",
column_config={ column_config={
"provider": st.column_config.SelectboxColumn( "provider": st.column_config.SelectboxColumn(
"Provider", "Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()), options=available_providers,
help="选择 LLM 供应商" help="选择 LLM 供应商"
), ),
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"), "model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
@ -340,8 +339,10 @@ def render_settings() -> None:
hide_index=True, hide_index=True,
use_container_width=True, use_container_width=True,
) )
if hasattr(ensemble_rows, "to_dict"): if hasattr(edited, "to_dict"):
ensemble_rows = ensemble_rows.to_dict("records") ensemble_rows = edited.to_dict("records")
else:
ensemble_rows = edited
if st.button("保存 LLM 设置"): if st.button("保存 LLM 设置"):
primary.provider = selected_provider primary.provider = selected_provider
@ -367,13 +368,10 @@ def render_settings() -> None:
model_val = (row.get("model") or "").strip() or default_model model_val = (row.get("model") or "").strip() or default_model
base_val = (row.get("base_url") or "").strip() or default_base base_val = (row.get("base_url") or "").strip() or default_base
api_raw = (row.get("api_key") or "").strip() api_raw = (row.get("api_key") or "").strip()
api_value = None if api_raw == "***":
if api_raw and api_raw != "***": api_value = existing_api_keys.get(provider)
api_value = api_raw
else: else:
existing = existing_api_keys.get(provider) api_value = api_raw or None
if existing:
api_value = existing
temp_val = row.get("temperature") temp_val = row.get("temperature")
timeout_val = row.get("timeout") timeout_val = row.get("timeout")