update
This commit is contained in:
parent
5b4bd51199
commit
6a7ca88e63
@ -1,7 +1,6 @@
|
|||||||
"""Streamlit UI scaffold for the investment assistant."""
|
"""Streamlit UI scaffold for the investment assistant."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
@ -219,10 +218,10 @@ def render_settings() -> None:
|
|||||||
|
|
||||||
if model_options:
|
if model_options:
|
||||||
options_with_custom = model_options + [custom_model_label]
|
options_with_custom = model_options + [custom_model_label]
|
||||||
if primary.model in model_options:
|
if primary.provider == selected_provider and primary.model in model_options:
|
||||||
model_index = options_with_custom.index(primary.model)
|
model_index = options_with_custom.index(primary.model)
|
||||||
else:
|
else:
|
||||||
model_index = len(options_with_custom) - 1
|
model_index = 0
|
||||||
selected_model_option = st.selectbox(
|
selected_model_option = st.selectbox(
|
||||||
"LLM 模型",
|
"LLM 模型",
|
||||||
options_with_custom,
|
options_with_custom,
|
||||||
@ -232,17 +231,17 @@ def render_settings() -> None:
|
|||||||
if selected_model_option == custom_model_label:
|
if selected_model_option == custom_model_label:
|
||||||
custom_model_value = st.text_input(
|
custom_model_value = st.text_input(
|
||||||
"自定义模型名称",
|
"自定义模型名称",
|
||||||
value=primary.model if primary.model not in model_options else "",
|
value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
|
||||||
)
|
)
|
||||||
|
chosen_model = custom_model_value.strip() or default_model_hint
|
||||||
else:
|
else:
|
||||||
custom_model_value = selected_model_option
|
chosen_model = selected_model_option
|
||||||
else:
|
else:
|
||||||
custom_model_value = st.text_input(
|
chosen_model = st.text_input(
|
||||||
"LLM 模型",
|
"LLM 模型",
|
||||||
value=primary.model or default_model_hint,
|
value=primary.model or default_model_hint,
|
||||||
help="未预设该 Provider 的模型列表,请手动填写",
|
help="未预设该 Provider 的模型列表,请手动填写",
|
||||||
)
|
).strip() or default_model_hint
|
||||||
selected_model_option = custom_model_label
|
|
||||||
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
|
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
|
||||||
provider_default_temp = float(provider_info.get("temperature", 0.2))
|
provider_default_temp = float(provider_info.get("temperature", 0.2))
|
||||||
provider_default_timeout = int(provider_info.get("timeout", 30.0))
|
provider_default_timeout = int(provider_info.get("timeout", 30.0))
|
||||||
@ -257,11 +256,16 @@ def render_settings() -> None:
|
|||||||
timeout_value = provider_default_timeout
|
timeout_value = provider_default_timeout
|
||||||
|
|
||||||
llm_base = st.text_input(
|
llm_base = st.text_input(
|
||||||
"LLM Base URL (可选)",
|
"LLM Base URL",
|
||||||
value=base_value,
|
value=base_value,
|
||||||
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
|
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
|
||||||
)
|
)
|
||||||
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password")
|
llm_api_key = st.text_input(
|
||||||
|
"LLM API Key",
|
||||||
|
value=primary.api_key or "",
|
||||||
|
type="password",
|
||||||
|
help="点击右侧小图标可查看当前 Key,该值会写入 config.json(已被 gitignore 排除)",
|
||||||
|
)
|
||||||
llm_temperature = st.slider(
|
llm_temperature = st.slider(
|
||||||
"LLM 温度",
|
"LLM 温度",
|
||||||
min_value=0.0,
|
min_value=0.0,
|
||||||
@ -292,68 +296,104 @@ def render_settings() -> None:
|
|||||||
format="%d",
|
format="%d",
|
||||||
)
|
)
|
||||||
|
|
||||||
ensemble_display = []
|
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
|
||||||
for endpoint in llm_cfg.ensemble:
|
|
||||||
data = asdict(endpoint)
|
ensemble_rows = [
|
||||||
if data.get("api_key"):
|
{
|
||||||
data["api_key"] = ""
|
"provider": ep.provider,
|
||||||
ensemble_display.append(data)
|
"model": ep.model,
|
||||||
ensemble_text = st.text_area(
|
"base_url": ep.base_url or "",
|
||||||
"LLM 集群配置 (JSON 数组)",
|
"api_key": ep.api_key or "",
|
||||||
value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2),
|
"temperature": float(ep.temperature),
|
||||||
height=220,
|
"timeout": float(ep.timeout),
|
||||||
|
}
|
||||||
|
for ep in llm_cfg.ensemble
|
||||||
|
]
|
||||||
|
if not ensemble_rows:
|
||||||
|
ensemble_rows = [
|
||||||
|
{
|
||||||
|
"provider": "",
|
||||||
|
"model": "",
|
||||||
|
"base_url": "",
|
||||||
|
"api_key": "",
|
||||||
|
"temperature": provider_default_temp,
|
||||||
|
"timeout": provider_default_timeout,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
ensemble_rows = st.data_editor(
|
||||||
|
ensemble_rows,
|
||||||
|
num_rows="dynamic",
|
||||||
|
key="llm_ensemble_editor",
|
||||||
|
column_config={
|
||||||
|
"provider": st.column_config.SelectboxColumn(
|
||||||
|
"Provider",
|
||||||
|
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
||||||
|
help="选择 LLM 供应商"
|
||||||
|
),
|
||||||
|
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
|
||||||
|
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
|
||||||
|
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
|
||||||
|
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
|
||||||
|
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
|
||||||
|
},
|
||||||
|
hide_index=True,
|
||||||
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
if hasattr(ensemble_rows, "to_dict"):
|
||||||
|
ensemble_rows = ensemble_rows.to_dict("records")
|
||||||
|
|
||||||
if st.button("保存 LLM 设置"):
|
if st.button("保存 LLM 设置"):
|
||||||
original_provider = primary.provider
|
|
||||||
original_model = primary.model
|
|
||||||
primary.provider = selected_provider
|
primary.provider = selected_provider
|
||||||
if model_options:
|
primary.model = chosen_model
|
||||||
if selected_model_option == custom_model_label:
|
primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
|
||||||
model_input = custom_model_value.strip()
|
|
||||||
primary.model = model_input or DEFAULT_LLM_MODELS.get(
|
|
||||||
selected_provider, DEFAULT_LLM_MODELS["ollama"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
primary.model = selected_model_option
|
|
||||||
else:
|
|
||||||
primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get(
|
|
||||||
selected_provider, DEFAULT_LLM_MODELS["ollama"]
|
|
||||||
)
|
|
||||||
primary.base_url = llm_base.strip() or None
|
|
||||||
primary.temperature = llm_temperature
|
primary.temperature = llm_temperature
|
||||||
primary.timeout = llm_timeout
|
primary.timeout = llm_timeout
|
||||||
api_key_value = llm_api_key.strip()
|
api_key_value = llm_api_key.strip()
|
||||||
primary.api_key = api_key_value or None
|
if api_key_value:
|
||||||
|
primary.api_key = api_key_value
|
||||||
|
|
||||||
try:
|
new_ensemble: List[LLMEndpoint] = []
|
||||||
parsed = json.loads(ensemble_text or "[]")
|
for row in ensemble_rows:
|
||||||
if not isinstance(parsed, list):
|
provider = (row.get("provider") or "").strip().lower()
|
||||||
raise ValueError("ensemble 配置必须是数组")
|
if not provider:
|
||||||
except Exception as exc: # noqa: BLE001
|
continue
|
||||||
LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA)
|
provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
|
||||||
st.error(f"LLM 集群配置解析失败:{exc}")
|
default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
||||||
else:
|
default_base = DEFAULT_LLM_BASE_URLS.get(provider)
|
||||||
new_ensemble: List[LLMEndpoint] = []
|
temp_default = float(provider_defaults.get("temperature", 0.2))
|
||||||
invalid = False
|
timeout_default = float(provider_defaults.get("timeout", 30.0))
|
||||||
for item in parsed:
|
|
||||||
if not isinstance(item, dict):
|
model_val = (row.get("model") or "").strip() or default_model
|
||||||
st.error("LLM 集群配置中的每个元素都必须是对象")
|
base_val = (row.get("base_url") or "").strip() or default_base
|
||||||
invalid = True
|
api_raw = (row.get("api_key") or "").strip()
|
||||||
break
|
api_value = None
|
||||||
fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")}
|
if api_raw and api_raw != "***":
|
||||||
endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")})
|
api_value = api_raw
|
||||||
if not endpoint.provider:
|
else:
|
||||||
endpoint.provider = "ollama"
|
existing = existing_api_keys.get(provider)
|
||||||
new_ensemble.append(endpoint)
|
if existing:
|
||||||
if not invalid:
|
api_value = existing
|
||||||
llm_cfg.ensemble = new_ensemble
|
|
||||||
llm_cfg.strategy = selected_strategy
|
temp_val = row.get("temperature")
|
||||||
llm_cfg.majority_threshold = int(majority_threshold)
|
timeout_val = row.get("timeout")
|
||||||
save_config()
|
endpoint = LLMEndpoint(
|
||||||
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
|
provider=provider,
|
||||||
st.success("LLM 设置已保存,仅在当前会话生效。")
|
model=model_val,
|
||||||
st.json(llm_config_snapshot())
|
base_url=base_val,
|
||||||
|
api_key=api_value,
|
||||||
|
temperature=float(temp_val) if temp_val is not None else temp_default,
|
||||||
|
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
|
||||||
|
)
|
||||||
|
new_ensemble.append(endpoint)
|
||||||
|
|
||||||
|
llm_cfg.ensemble = new_ensemble
|
||||||
|
llm_cfg.strategy = selected_strategy
|
||||||
|
llm_cfg.majority_threshold = int(majority_threshold)
|
||||||
|
save_config()
|
||||||
|
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
|
||||||
|
st.success("LLM 设置已保存,仅在当前会话生效。")
|
||||||
|
st.json(llm_config_snapshot())
|
||||||
|
|
||||||
|
|
||||||
def render_tests() -> None:
|
def render_tests() -> None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user