This commit is contained in:
sam 2025-09-27 21:53:40 +08:00
parent 5b4bd51199
commit 6a7ca88e63

View File

@ -1,7 +1,6 @@
"""Streamlit UI scaffold for the investment assistant."""
from __future__ import annotations
import json
import sys
from dataclasses import asdict
from datetime import date, timedelta
@ -219,10 +218,10 @@ def render_settings() -> None:
if model_options:
options_with_custom = model_options + [custom_model_label]
if primary.model in model_options:
if primary.provider == selected_provider and primary.model in model_options:
model_index = options_with_custom.index(primary.model)
else:
model_index = len(options_with_custom) - 1
model_index = 0
selected_model_option = st.selectbox(
"LLM 模型",
options_with_custom,
@ -232,17 +231,17 @@ def render_settings() -> None:
if selected_model_option == custom_model_label:
custom_model_value = st.text_input(
"自定义模型名称",
value=primary.model if primary.model not in model_options else "",
value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
)
chosen_model = custom_model_value.strip() or default_model_hint
else:
custom_model_value = selected_model_option
chosen_model = selected_model_option
else:
custom_model_value = st.text_input(
chosen_model = st.text_input(
"LLM 模型",
value=primary.model or default_model_hint,
help="未预设该 Provider 的模型列表,请手动填写",
)
selected_model_option = custom_model_label
).strip() or default_model_hint
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
provider_default_temp = float(provider_info.get("temperature", 0.2))
provider_default_timeout = int(provider_info.get("timeout", 30.0))
@ -257,11 +256,16 @@ def render_settings() -> None:
timeout_value = provider_default_timeout
llm_base = st.text_input(
"LLM Base URL (可选)",
"LLM Base URL",
value=base_value,
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
)
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password")
llm_api_key = st.text_input(
"LLM API Key",
value=primary.api_key or "",
type="password",
help="点击右侧小图标可查看当前 Key该值会写入 config.json已被 gitignore 排除)",
)
llm_temperature = st.slider(
"LLM 温度",
min_value=0.0,
@ -292,61 +296,97 @@ def render_settings() -> None:
format="%d",
)
ensemble_display = []
for endpoint in llm_cfg.ensemble:
data = asdict(endpoint)
if data.get("api_key"):
data["api_key"] = ""
ensemble_display.append(data)
ensemble_text = st.text_area(
"LLM 集群配置 (JSON 数组)",
value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2),
height=220,
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
ensemble_rows = [
{
"provider": ep.provider,
"model": ep.model,
"base_url": ep.base_url or "",
"api_key": ep.api_key or "",
"temperature": float(ep.temperature),
"timeout": float(ep.timeout),
}
for ep in llm_cfg.ensemble
]
if not ensemble_rows:
ensemble_rows = [
{
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": provider_default_temp,
"timeout": provider_default_timeout,
}
]
ensemble_rows = st.data_editor(
ensemble_rows,
num_rows="dynamic",
key="llm_ensemble_editor",
column_config={
"provider": st.column_config.SelectboxColumn(
"Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
help="选择 LLM 供应商"
),
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
},
hide_index=True,
use_container_width=True,
)
if hasattr(ensemble_rows, "to_dict"):
ensemble_rows = ensemble_rows.to_dict("records")
if st.button("保存 LLM 设置"):
original_provider = primary.provider
original_model = primary.model
primary.provider = selected_provider
if model_options:
if selected_model_option == custom_model_label:
model_input = custom_model_value.strip()
primary.model = model_input or DEFAULT_LLM_MODELS.get(
selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
else:
primary.model = selected_model_option
else:
primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get(
selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
primary.base_url = llm_base.strip() or None
primary.model = chosen_model
primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
primary.temperature = llm_temperature
primary.timeout = llm_timeout
api_key_value = llm_api_key.strip()
primary.api_key = api_key_value or None
if api_key_value:
primary.api_key = api_key_value
try:
parsed = json.loads(ensemble_text or "[]")
if not isinstance(parsed, list):
raise ValueError("ensemble 配置必须是数组")
except Exception as exc: # noqa: BLE001
LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA)
st.error(f"LLM 集群配置解析失败:{exc}")
else:
new_ensemble: List[LLMEndpoint] = []
invalid = False
for item in parsed:
if not isinstance(item, dict):
st.error("LLM 集群配置中的每个元素都必须是对象")
invalid = True
break
fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")}
endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")})
if not endpoint.provider:
endpoint.provider = "ollama"
for row in ensemble_rows:
provider = (row.get("provider") or "").strip().lower()
if not provider:
continue
provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
default_base = DEFAULT_LLM_BASE_URLS.get(provider)
temp_default = float(provider_defaults.get("temperature", 0.2))
timeout_default = float(provider_defaults.get("timeout", 30.0))
model_val = (row.get("model") or "").strip() or default_model
base_val = (row.get("base_url") or "").strip() or default_base
api_raw = (row.get("api_key") or "").strip()
api_value = None
if api_raw and api_raw != "***":
api_value = api_raw
else:
existing = existing_api_keys.get(provider)
if existing:
api_value = existing
temp_val = row.get("temperature")
timeout_val = row.get("timeout")
endpoint = LLMEndpoint(
provider=provider,
model=model_val,
base_url=base_val,
api_key=api_value,
temperature=float(temp_val) if temp_val is not None else temp_default,
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
)
new_ensemble.append(endpoint)
if not invalid:
llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy
llm_cfg.majority_threshold = int(majority_threshold)