This commit is contained in:
sam 2025-09-28 14:49:58 +08:00
parent ab6180646a
commit d81dfefc31
2 changed files with 40 additions and 20 deletions

View File

@ -60,7 +60,7 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试 ### LLM 配置与测试
- 通过 Provider 管理供应商连接参数Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务DeepSeek、文心一言、OpenAI 等)。 - 通过 Provider 管理供应商连接参数Base URL、API Key、默认温度/超时/Prompt 模板),并支持在界面内一键调用 `client.models.list()` 拉取可用模型列表,便于扩展本地 Ollama 或各类云端服务DeepSeek、文心一言、OpenAI 等)。
- 全局与部门配置直接选择 Provider并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。 - 全局与部门配置直接选择 Provider并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。
- Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。 - Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。
- 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。 - 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。
@ -103,7 +103,7 @@ Streamlit `自检测试` 页签提供:
## 实施步骤 ## 实施步骤
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅ 1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
- 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/PromptStreamlit 提供可视化维护表单。 - 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/PromptProvider 页面提供模型列表自动获取Streamlit 提供可视化维护表单。
2. **部门管控器** 2. **部门管控器**
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。 - `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。

View File

@ -430,14 +430,13 @@ def render_settings() -> None:
cfg.llm_providers = providers cfg.llm_providers = providers
save_config() save_config()
st.success(f"已创建 Provider {key}") st.success(f"已创建 Provider {key}")
st.experimental_rerun() st.rerun()
if selected_provider: if selected_provider:
provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
title_key = f"provider_title_{selected_provider}" title_key = f"provider_title_{selected_provider}"
base_key = f"provider_base_{selected_provider}" base_key = f"provider_base_{selected_provider}"
api_key_key = f"provider_api_{selected_provider}" api_key_key = f"provider_api_{selected_provider}"
models_key = f"provider_models_{selected_provider}"
default_model_key = f"provider_default_model_{selected_provider}" default_model_key = f"provider_default_model_{selected_provider}"
mode_key = f"provider_mode_{selected_provider}" mode_key = f"provider_mode_{selected_provider}"
temp_key = f"provider_temp_{selected_provider}" temp_key = f"provider_temp_{selected_provider}"
@ -448,8 +447,23 @@ def render_settings() -> None:
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址例如https://api.openai.com") base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址例如https://api.openai.com")
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100) st.markdown("可用模型:")
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key) if provider_cfg.models:
st.code("\n".join(provider_cfg.models), language="text")
else:
st.info("尚未获取模型列表,可点击下方按钮自动拉取。")
model_choice_key = f"{default_model_key}_choice"
if provider_cfg.models:
options = provider_cfg.models + ["自定义"]
default_choice = provider_cfg.default_model if provider_cfg.default_model in provider_cfg.models else "自定义"
model_choice = st.selectbox("默认模型", options, index=options.index(default_choice), key=model_choice_key)
if model_choice == "自定义":
default_model_val = st.text_input("自定义默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None
else:
default_model_val = model_choice
else:
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key).strip() or None
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key) temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key) timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
@ -471,15 +485,16 @@ def render_settings() -> None:
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
save_config() save_config()
st.success(f"共获取 {len(models)} 个模型。") st.success(f"共获取 {len(models)} 个模型。")
st.session_state[models_key] = "\n".join(models) st.rerun()
st.session_state[default_model_key] = provider_cfg.default_model or ""
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
provider_cfg.title = title_val.strip() provider_cfg.title = title_val.strip()
provider_cfg.base_url = base_val.strip() provider_cfg.base_url = base_val.strip()
provider_cfg.api_key = api_val.strip() or None provider_cfg.api_key = api_val.strip() or None
provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()] if provider_cfg.models and default_model_val in provider_cfg.models:
provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model) provider_cfg.default_model = default_model_val
else:
provider_cfg.default_model = default_model_val
provider_cfg.default_temperature = float(temp_val) provider_cfg.default_temperature = float(temp_val)
provider_cfg.default_timeout = float(timeout_val) provider_cfg.default_timeout = float(timeout_val)
provider_cfg.prompt_template = prompt_template_val.strip() provider_cfg.prompt_template = prompt_template_val.strip()
@ -511,7 +526,7 @@ def render_settings() -> None:
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
save_config() save_config()
st.success("Provider 已删除。") st.success("Provider 已删除。")
st.experimental_rerun() st.rerun()
st.markdown("##### 全局推理配置") st.markdown("##### 全局推理配置")
if not provider_keys: if not provider_keys:
@ -710,13 +725,17 @@ def render_settings() -> None:
strategy_val = (row.get("strategy") or existing.llm.strategy).lower() strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
if strategy_val in ALLOWED_LLM_STRATEGIES: if strategy_val in ALLOWED_LLM_STRATEGIES:
existing.llm.strategy = strategy_val existing.llm.strategy = strategy_val
majority_raw = row.get("majority_threshold") if existing.llm.strategy == "single":
try: existing.llm.majority_threshold = 1
majority_val = int(majority_raw) existing.llm.ensemble = []
if majority_val > 0: else:
existing.llm.majority_threshold = majority_val majority_raw = row.get("majority_threshold")
except (TypeError, ValueError): try:
pass majority_val = int(majority_raw)
if majority_val > 0:
existing.llm.majority_threshold = majority_val
except (TypeError, ValueError):
pass
provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
model_val = (row.get("model") or "").strip() or None model_val = (row.get("model") or "").strip() or None
@ -733,7 +752,8 @@ def render_settings() -> None:
endpoint.base_url = None endpoint.base_url = None
endpoint.api_key = None endpoint.api_key = None
existing.llm.primary = endpoint existing.llm.primary = endpoint
existing.llm.ensemble = [] if existing.llm.strategy != "single":
existing.llm.ensemble = []
updated_departments[code] = existing updated_departments[code] = existing
@ -752,7 +772,7 @@ def render_settings() -> None:
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
save_config() save_config()
st.success("已恢复默认部门配置。") st.success("已恢复默认部门配置。")
st.experimental_rerun() st.rerun()
st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。") st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")