From ab6180646ae0f3aacafaca8d0e8397b5ed2f89e2 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 28 Sep 2025 11:28:55 +0800 Subject: [PATCH] update --- README.md | 13 +- app/agents/departments.py | 2 - app/llm/client.py | 66 ++++- app/ui/streamlit_app.py | 575 +++++++++++++++++++++----------------- app/utils/config.py | 493 +++++++++++++------------------- 5 files changed, 574 insertions(+), 575 deletions(-) diff --git a/README.md b/README.md index 479b1d0..cb82fe5 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ - **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。 - **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。 - **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。 -- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。 +- **LLM Provider 管理**:`app/utils/config.py` 集中维护供应商的 URL、API Key、可用模型及默认参数,Streamlit UI 可视化配置,全局与部门直接在 Provider 基础上设置模型、温度与 Prompt。 ## LLM + 多智能体最佳实践 @@ -60,11 +60,10 @@ export TUSHARE_TOKEN="" ### LLM 配置与测试 -- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。 -- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。 -- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。 -- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。 -- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。 +- 通过 Provider 管理供应商连接参数(Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务(DeepSeek、文心一言、OpenAI 等)。 +- 全局与部门配置直接选择 Provider,并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。 +- Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。 +- 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。 ## 快速开始 @@ -104,7 +103,7 @@ Streamlit `自检测试` 页签提供: ## 实施步骤 1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅ - - 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。 + - 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/Prompt;Streamlit 提供可视化维护表单。 2. **部门管控器** ✅ - `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。 diff --git a/app/agents/departments.py b/app/agents/departments.py index aee5284..1b65016 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -131,8 +131,6 @@ class DepartmentManager: return results def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig: - if settings.llm_route and settings.llm_route in self.config.llm_routes: - return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles) return settings.llm diff --git a/app/llm/client.py b/app/llm/client.py index be03e06..5b11f3c 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -102,24 +102,58 @@ def _request_openai( def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: - provider = (endpoint.provider or "ollama").lower() - base_url = endpoint.base_url or _default_base_url(provider) - model = endpoint.model or _default_model(provider) - temperature = max(0.0, min(endpoint.temperature, 2.0)) - timeout = max(5.0, endpoint.timeout or 30.0) + cfg = get_config() + provider_key = (endpoint.provider or "ollama").lower() + provider_cfg = cfg.llm_providers.get(provider_key) + + base_url = endpoint.base_url + api_key = endpoint.api_key + model = endpoint.model + temperature = endpoint.temperature + timeout = endpoint.timeout + prompt_template = endpoint.prompt_template + + if provider_cfg: + if not provider_cfg.enabled: + raise LLMError(f"Provider {provider_key} 已被禁用") + base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key) + api_key = api_key or provider_cfg.api_key + model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key)) + if temperature is None: + temperature = provider_cfg.default_temperature + if timeout is None: + timeout = provider_cfg.default_timeout + prompt_template = prompt_template or (provider_cfg.prompt_template or None) + mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai") + else: + base_url = base_url or _default_base_url(provider_key) + model = model or _default_model(provider_key) + if temperature is None: + temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2) + if timeout is None: + timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0) + mode = "ollama" if provider_key == "ollama" else "openai" + + temperature = max(0.0, min(float(temperature), 2.0)) + timeout = max(5.0, float(timeout)) + + if prompt_template: + try: + prompt = prompt_template.format(prompt=prompt) + except Exception: # noqa: BLE001 + LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA) LOGGER.info( "触发 LLM 请求:provider=%s model=%s base=%s", - provider, + provider_key, model, base_url, extra=LOG_EXTRA, ) - if provider in {"openai", "deepseek", "wenxin"}: - api_key = endpoint.api_key + if mode != "ollama": if not api_key: - raise LLMError(f"缺少 {provider} API Key (model={model})") + raise LLMError(f"缺少 {provider_key} API Key (model={model})") return _request_openai( model, prompt, @@ -129,7 +163,7 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> timeout=timeout, system=system, ) - if provider == "ollama": + if base_url: return _request_ollama( model, prompt, @@ -283,13 +317,17 @@ def llm_config_snapshot() -> Dict[str, object]: if record.get("api_key"): record["api_key"] = "***" ensemble.append(record) - route_name = cfg.llm_route - route_obj = cfg.llm_routes.get(route_name) return { - "route": route_name, - "route_detail": route_obj.to_dict() if route_obj else None, "strategy": settings.strategy, "majority_threshold": settings.majority_threshold, "primary": primary, "ensemble": ensemble, + "providers": { + key: { + "base_url": provider.base_url, + "default_model": provider.default_model, + "enabled": provider.enabled, + } + for key, provider in cfg.llm_providers.items() + }, } diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 0dacebf..26bf3f0 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -5,7 +5,7 @@ import sys from dataclasses import asdict from datetime import date, timedelta from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: @@ -16,6 +16,8 @@ import json import pandas as pd import plotly.express as px import plotly.graph_objects as go +import requests +from requests.exceptions import RequestException import streamlit as st from app.backtest.engine import BtConfig, run_backtest @@ -29,8 +31,8 @@ from app.utils.config import ( DEFAULT_LLM_MODEL_OPTIONS, DEFAULT_LLM_MODELS, DepartmentSettings, - LLMProfile, - LLMRoute, + LLMEndpoint, + LLMProvider, get_config, save_config, ) @@ -42,6 +44,50 @@ LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "ui"} +def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: + """Attempt to query provider API and return available model ids.""" + + base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip() + if not base_url: + return [], "请先填写 Base URL" + timeout = float(provider.default_timeout or 30.0) + mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai") + + try: + if mode == "ollama": + url = base_url.rstrip('/') + "/api/tags" + response = requests.get(url, timeout=timeout) + response.raise_for_status() + data = response.json() + models = [] + for item in data.get("models", []) or data.get("data", []): + name = item.get("name") or item.get("model") or item.get("tag") + if name: + models.append(str(name).strip()) + return sorted(set(models)), None + + api_key = (api_override or provider.api_key or "").strip() + if not api_key: + return [], "缺少 API Key" + url = base_url.rstrip('/') + "/v1/models" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + response = requests.get(url, headers=headers, timeout=timeout) + response.raise_for_status() + payload = response.json() + models = [ + str(item.get("id")).strip() + for item in payload.get("data", []) + if item.get("id") + ] + return sorted(set(models)), None + except RequestException as exc: # noqa: BLE001 + return [], f"HTTP 错误:{exc}" + except Exception as exc: # noqa: BLE001 + return [], f"解析失败:{exc}" + def _load_stock_options(limit: int = 500) -> list[str]: try: with db_session(read_only=True) as conn: @@ -350,255 +396,267 @@ def render_settings() -> None: st.divider() st.subheader("LLM 设置") - profiles = cfg.llm_profiles or {} - routes = cfg.llm_routes or {} - profile_keys = sorted(profiles.keys()) - route_keys = sorted(routes.keys()) - used_routes = { - dept.llm_route for dept in cfg.departments.values() if dept.llm_route - } - st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。") + providers = cfg.llm_providers + provider_keys = sorted(providers.keys()) + st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。") - route_select_col, route_manage_col = st.columns([3, 1]) - if route_keys: + # Provider management ------------------------------------------------- + provider_select_col, provider_manage_col = st.columns([3, 1]) + if provider_keys: try: - active_index = route_keys.index(cfg.llm_route) + default_provider = cfg.llm.primary.provider or provider_keys[0] + provider_index = provider_keys.index(default_provider) except ValueError: - active_index = 0 - selected_route = route_select_col.selectbox( - "全局路由", - route_keys, - index=active_index, - key="llm_route_select", + provider_index = 0 + selected_provider = provider_select_col.selectbox( + "选择 Provider", + provider_keys, + index=provider_index, + key="llm_provider_select", ) else: - selected_route = None - route_select_col.info("尚未配置路由,请先创建。") + selected_provider = None + provider_select_col.info("尚未配置 Provider,请先创建。") - new_route_name = route_manage_col.text_input("新增路由", key="new_route_name") - if route_manage_col.button("添加路由"): - key = (new_route_name or "").strip() + new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name") + if provider_manage_col.button("创建 Provider", key="create_provider_btn"): + key = (new_provider_name or "").strip().lower() if not key: - st.warning("请输入有效的路由名称。") - elif key in routes: - st.warning(f"路由 {key} 已存在。") + st.warning("请输入有效的 Provider 名称。") + elif key in providers: + st.warning(f"Provider {key} 已存在。") else: - routes[key] = LLMRoute(name=key) - if not selected_route: - selected_route = key - cfg.llm_route = key + providers[key] = LLMProvider(key=key) + cfg.llm_providers = providers save_config() - st.success(f"已添加路由 {key},请继续配置。") + st.success(f"已创建 Provider {key}。") st.experimental_rerun() - if selected_route: - route_obj = routes.get(selected_route) - if route_obj is None: - route_obj = LLMRoute(name=selected_route) - routes[selected_route] = route_obj - strategy_choices = sorted(ALLOWED_LLM_STRATEGIES) + if selected_provider: + provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) + title_key = f"provider_title_{selected_provider}" + base_key = f"provider_base_{selected_provider}" + api_key_key = f"provider_api_{selected_provider}" + models_key = f"provider_models_{selected_provider}" + default_model_key = f"provider_default_model_{selected_provider}" + mode_key = f"provider_mode_{selected_provider}" + temp_key = f"provider_temp_{selected_provider}" + timeout_key = f"provider_timeout_{selected_provider}" + prompt_key = f"provider_prompt_{selected_provider}" + enabled_key = f"provider_enabled_{selected_provider}" + + title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) + base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") + api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") + models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100) + default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key) + mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) + temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key) + timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key) + prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120) + enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) + + fetch_key = f"fetch_models_{selected_provider}" + if st.button("获取模型列表", key=fetch_key): + with st.spinner("正在获取模型列表..."): + models, error = _discover_provider_models(provider_cfg, base_val, api_val) + if error: + st.error(error) + else: + provider_cfg.models = models + if models and (not provider_cfg.default_model or provider_cfg.default_model not in models): + provider_cfg.default_model = models[0] + providers[selected_provider] = provider_cfg + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success(f"共获取 {len(models)} 个模型。") + st.session_state[models_key] = "\n".join(models) + st.session_state[default_model_key] = provider_cfg.default_model or "" + + if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): + provider_cfg.title = title_val.strip() + provider_cfg.base_url = base_val.strip() + provider_cfg.api_key = api_val.strip() or None + provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()] + provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model) + provider_cfg.default_temperature = float(temp_val) + provider_cfg.default_timeout = float(timeout_val) + provider_cfg.prompt_template = prompt_template_val.strip() + provider_cfg.enabled = enabled_val + provider_cfg.mode = mode_val + providers[selected_provider] = provider_cfg + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success("Provider 已保存。") + st.session_state[title_key] = provider_cfg.title or "" + st.session_state[default_model_key] = provider_cfg.default_model or "" + + provider_in_use = (cfg.llm.primary.provider == selected_provider) or any( + ep.provider == selected_provider for ep in cfg.llm.ensemble + ) + if not provider_in_use: + for dept in cfg.departments.values(): + if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble): + provider_in_use = True + break + if st.button( + "删除 Provider", + key=f"delete_provider_{selected_provider}", + disabled=provider_in_use or len(providers) <= 1, + ): + providers.pop(selected_provider, None) + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success("Provider 已删除。") + st.experimental_rerun() + + st.markdown("##### 全局推理配置") + if not provider_keys: + st.warning("请先配置至少一个 Provider。") + else: + global_cfg = cfg.llm + primary = global_cfg.primary try: - strategy_index = strategy_choices.index(route_obj.strategy) + provider_index = provider_keys.index(primary.provider or provider_keys[0]) except ValueError: - strategy_index = 0 - route_title = st.text_input( - "路由说明", - value=route_obj.title or "", - key=f"route_title_{selected_route}", + provider_index = 0 + selected_global_provider = st.selectbox( + "主模型 Provider", + provider_keys, + index=provider_index, + key="global_provider_select", ) - route_strategy = st.selectbox( - "推理策略", - strategy_choices, - index=strategy_index, - key=f"route_strategy_{selected_route}", + provider_cfg = providers.get(selected_global_provider) + available_models = provider_cfg.models if provider_cfg else [] + default_model = primary.model or (provider_cfg.default_model if provider_cfg else None) + if available_models: + options = available_models + ["自定义"] + try: + model_index = available_models.index(default_model) + model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice") + except ValueError: + model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice") + if model_choice == "自定义": + model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip() + else: + model_val = model_choice + else: + model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip() + + temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2) + temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp") + timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0) + timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout") + prompt_template_val = st.text_area( + "主模型 Prompt 模板(可选)", + value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "", + height=120, + key="global_prompt_template", ) - route_majority = st.number_input( + + strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy") + show_ensemble = strategy_val != "single" + majority_threshold_val = st.number_input( "多数投票门槛", min_value=1, max_value=10, - value=int(route_obj.majority_threshold or 1), + value=int(global_cfg.majority_threshold), step=1, - key=f"route_majority_{selected_route}", + key="global_majority", + disabled=not show_ensemble, ) - if not profile_keys: - st.warning("暂无可用 Profile,请先在下方创建。") + if not show_ensemble: + majority_threshold_val = 1 + + ensemble_rows: List[Dict[str, str]] = [] + if show_ensemble: + ensemble_rows = [ + { + "provider": ep.provider, + "model": ep.model or "", + "temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}", + "timeout": "" if ep.timeout is None else str(int(ep.timeout)), + "prompt_template": ep.prompt_template or "", + } + for ep in global_cfg.ensemble + ] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}] + + ensemble_editor = st.data_editor( + ensemble_rows, + num_rows="dynamic", + key="global_ensemble_editor", + use_container_width=True, + hide_index=True, + column_config={ + "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), + "model": st.column_config.TextColumn("模型"), + "temperature": st.column_config.TextColumn("温度"), + "timeout": st.column_config.TextColumn("超时(秒)"), + "prompt_template": st.column_config.TextColumn("Prompt 模板"), + }, + ) + if hasattr(ensemble_editor, "to_dict"): + ensemble_rows = ensemble_editor.to_dict("records") + else: + ensemble_rows = ensemble_editor else: - try: - primary_index = profile_keys.index(route_obj.primary) - except ValueError: - primary_index = 0 - primary_key = st.selectbox( - "主用 Profile", - profile_keys, - index=primary_index, - key=f"route_primary_{selected_route}", - ) - default_ensemble = [ - key for key in route_obj.ensemble if key in profile_keys and key != primary_key - ] - ensemble_keys = st.multiselect( - "协作 Profile (可多选)", - profile_keys, - default=default_ensemble, - key=f"route_ensemble_{selected_route}", - ) - if st.button("保存路由设置", key=f"save_route_{selected_route}"): - route_obj.title = route_title.strip() - route_obj.strategy = route_strategy - route_obj.majority_threshold = int(route_majority) - route_obj.primary = primary_key - route_obj.ensemble = [key for key in ensemble_keys if key != primary_key] - cfg.llm_route = selected_route - cfg.sync_runtime_llm() - save_config() - LOGGER.info( - "路由 %s 配置更新:%s", - selected_route, - route_obj.to_dict(), - extra=LOG_EXTRA, - ) - st.success("路由配置已保存。") - st.json({ - "route": selected_route, - "route_detail": route_obj.to_dict(), - "resolved": llm_config_snapshot(), - }) - route_in_use = selected_route in used_routes or selected_route == cfg.llm_route - if st.button( - "删除当前路由", - key=f"delete_route_{selected_route}", - disabled=route_in_use or len(routes) <= 1, - ): - routes.pop(selected_route, None) - if cfg.llm_route == selected_route: - cfg.llm_route = next((key for key in routes.keys()), "global") + st.info("当前策略为单模型,未启用协作模型。") + + if st.button("保存全局配置", key="save_global_llm"): + primary.provider = selected_global_provider + primary.model = model_val or None + primary.temperature = float(temp_val) + primary.timeout = float(timeout_val) + primary.prompt_template = prompt_template_val.strip() or None + primary.base_url = None + primary.api_key = None + + new_ensemble: List[LLMEndpoint] = [] + if show_ensemble: + for row in ensemble_rows: + provider_val = (row.get("provider") or "").strip().lower() + if not provider_val: + continue + model_raw = (row.get("model") or "").strip() or None + temp_raw = (row.get("temperature") or "").strip() + timeout_raw = (row.get("timeout") or "").strip() + prompt_raw = (row.get("prompt_template") or "").strip() + new_ensemble.append( + LLMEndpoint( + provider=provider_val, + model=model_raw, + temperature=float(temp_raw) if temp_raw else None, + timeout=float(timeout_raw) if timeout_raw else None, + prompt_template=prompt_raw or None, + ) + ) + cfg.llm.ensemble = new_ensemble + cfg.llm.strategy = strategy_val + cfg.llm.majority_threshold = int(majority_threshold_val) cfg.sync_runtime_llm() save_config() - st.success("路由已删除。") - st.experimental_rerun() - - st.divider() - st.subheader("LLM Profile 管理") - profile_select_col, profile_manage_col = st.columns([3, 1]) - if profile_keys: - selected_profile = profile_select_col.selectbox( - "选择 Profile", - profile_keys, - index=0, - key="profile_select", - ) - else: - selected_profile = None - profile_select_col.info("尚未配置 Profile,请先创建。") - - new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name") - if profile_manage_col.button("创建 Profile"): - key = (new_profile_name or "").strip() - if not key: - st.warning("请输入有效的 Profile 名称。") - elif key in profiles: - st.warning(f"Profile {key} 已存在。") - else: - profiles[key] = LLMProfile(key=key) - save_config() - st.success(f"已创建 Profile {key}。") - st.experimental_rerun() - - if selected_profile: - profile = profiles[selected_profile] - provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()) - try: - provider_index = provider_choices.index(profile.provider) - except ValueError: - provider_index = 0 - with st.form(f"profile_form_{selected_profile}"): - provider_val = st.selectbox( - "Provider", - provider_choices, - index=provider_index, - ) - model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "") - model_val = st.text_input( - "模型", - value=profile.model or model_default, - ) - base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "") - base_val = st.text_input( - "Base URL", - value=profile.base_url or base_default, - ) - api_val = st.text_input( - "API Key", - value=profile.api_key or "", - type="password", - ) - temp_val = st.slider( - "温度", - min_value=0.0, - max_value=2.0, - value=float(profile.temperature), - step=0.05, - ) - timeout_val = st.number_input( - "超时(秒)", - min_value=5, - max_value=180, - value=int(profile.timeout or 30), - step=5, - ) - title_val = st.text_input("备注", value=profile.title or "") - enabled_val = st.checkbox("启用", value=profile.enabled) - submitted = st.form_submit_button("保存 Profile") - if submitted: - profile.provider = provider_val - profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val) - profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val) - profile.api_key = api_val.strip() or None - profile.temperature = temp_val - profile.timeout = timeout_val - profile.title = title_val.strip() - profile.enabled = enabled_val - profiles[selected_profile] = profile - cfg.sync_runtime_llm() - save_config() - st.success("Profile 已保存。") - - profile_in_use = any( - selected_profile == route.primary or selected_profile in route.ensemble - for route in routes.values() - ) - if st.button( - "删除该 Profile", - key=f"delete_profile_{selected_profile}", - disabled=profile_in_use or len(profiles) <= 1, - ): - profiles.pop(selected_profile, None) - fallback_key = next((key for key in profiles.keys()), None) - for route in routes.values(): - if route.primary == selected_profile: - route.primary = fallback_key or route.primary - route.ensemble = [key for key in route.ensemble if key != selected_profile] - cfg.sync_runtime_llm() - save_config() - st.success("Profile 已删除。") - st.experimental_rerun() - - st.divider() - st.subheader("部门配置") + st.success("全局 LLM 配置已保存。") + st.json(llm_config_snapshot()) + # Department configuration ------------------------------------------- + st.markdown("##### 部门配置") dept_settings = cfg.departments or {} - route_options_display = [""] + route_keys dept_rows = [ { "code": code, "title": dept.title, "description": dept.description, "weight": float(dept.weight), - "llm_route": dept.llm_route or "", "strategy": dept.llm.strategy, - "primary_provider": (dept.llm.primary.provider or ""), - "primary_model": dept.llm.primary.model or "", - "ensemble_size": len(dept.llm.ensemble), + "majority_threshold": dept.llm.majority_threshold, + "provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""), + "model": dept.llm.primary.model or "", + "temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}", + "timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)), + "prompt_template": dept.llm.primary.prompt_template or "", } for code, dept in sorted(dept_settings.items()) ] @@ -618,26 +676,13 @@ def render_settings() -> None: "title": st.column_config.TextColumn("名称"), "description": st.column_config.TextColumn("说明"), "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), - "llm_route": st.column_config.SelectboxColumn( - "路由", - options=route_options_display, - help="选择预定义路由;留空表示使用自定义配置", - ), - "strategy": st.column_config.SelectboxColumn( - "自定义策略", - options=sorted(ALLOWED_LLM_STRATEGIES), - help="仅当未选择路由时生效", - ), - "primary_provider": st.column_config.SelectboxColumn( - "自定义 Provider", - options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()), - ), - "primary_model": st.column_config.TextColumn("自定义模型"), - "ensemble_size": st.column_config.NumberColumn( - "协作模型数量", - disabled=True, - help="路由模式下自动维护", - ), + "strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)), + "majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1), + "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]), + "model": st.column_config.TextColumn("模型"), + "temperature": st.column_config.TextColumn("温度"), + "timeout": st.column_config.TextColumn("超时(秒)"), + "prompt_template": st.column_config.TextColumn("Prompt 模板"), }, ) @@ -662,25 +707,34 @@ def render_settings() -> None: except (TypeError, ValueError): pass - route_name = (row.get("llm_route") or "").strip() or None - existing.llm_route = route_name - if route_name and route_name in routes: - existing.llm = routes[route_name].resolve(profiles) - else: - strategy_val = (row.get("strategy") or existing.llm.strategy).lower() - if strategy_val in ALLOWED_LLM_STRATEGIES: - existing.llm.strategy = strategy_val - provider_before = existing.llm.primary.provider or "" - provider_val = (row.get("primary_provider") or provider_before or "ollama").lower() - existing.llm.primary.provider = provider_val - model_val = (row.get("primary_model") or "").strip() - existing.llm.primary.model = ( - model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model) - ) - if provider_before != provider_val: - default_base = DEFAULT_LLM_BASE_URLS.get(provider_val) - existing.llm.primary.base_url = default_base or existing.llm.primary.base_url - existing.llm.primary.__post_init__() + strategy_val = (row.get("strategy") or existing.llm.strategy).lower() + if strategy_val in ALLOWED_LLM_STRATEGIES: + existing.llm.strategy = strategy_val + majority_raw = row.get("majority_threshold") + try: + majority_val = int(majority_raw) + if majority_val > 0: + existing.llm.majority_threshold = majority_val + except (TypeError, ValueError): + pass + + provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() + model_val = (row.get("model") or "").strip() or None + temp_raw = (row.get("temperature") or "").strip() + timeout_raw = (row.get("timeout") or "").strip() + prompt_raw = (row.get("prompt_template") or "").strip() + + endpoint = existing.llm.primary or LLMEndpoint() + endpoint.provider = provider_val + endpoint.model = model_val + endpoint.temperature = float(temp_raw) if temp_raw else None + endpoint.timeout = float(timeout_raw) if timeout_raw else None + endpoint.prompt_template = prompt_raw or None + endpoint.base_url = None + endpoint.api_key = None + existing.llm.primary = endpoint + existing.llm.ensemble = [] + updated_departments[code] = existing if updated_departments: @@ -700,8 +754,7 @@ def render_settings() -> None: st.success("已恢复默认部门配置。") st.experimental_rerun() - st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。") - st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。") + st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。") def render_tests() -> None: diff --git a/app/utils/config.py b/app/utils/config.py index d0b6222..a1f51b7 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field import json import os from pathlib import Path -from typing import Dict, List, Mapping, Optional +from typing import Dict, List, Optional def _default_root() -> Path: @@ -99,27 +99,55 @@ ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"} LLM_STRATEGY_ALIASES = {"leader-follower": "leader"} +@dataclass +class LLMProvider: + """Provider level configuration shared across profiles and routes.""" + + key: str + title: str = "" + base_url: str = "" + api_key: Optional[str] = None + models: List[str] = field(default_factory=list) + default_model: Optional[str] = None + default_temperature: float = 0.2 + default_timeout: float = 30.0 + prompt_template: str = "" + enabled: bool = True + mode: str = "openai" # openai 或 ollama + + def to_dict(self) -> Dict[str, object]: + return { + "title": self.title, + "base_url": self.base_url, + "api_key": self.api_key, + "models": list(self.models), + "default_model": self.default_model, + "default_temperature": self.default_temperature, + "default_timeout": self.default_timeout, + "prompt_template": self.prompt_template, + "enabled": self.enabled, + "mode": self.mode, + } + + @dataclass class LLMEndpoint: - """Single LLM endpoint configuration.""" + """Resolved endpoint payload used for actual LLM calls.""" provider: str = "ollama" model: Optional[str] = None base_url: Optional[str] = None api_key: Optional[str] = None - temperature: float = 0.2 - timeout: float = 30.0 + temperature: Optional[float] = None + timeout: Optional[float] = None + prompt_template: Optional[str] = None def __post_init__(self) -> None: self.provider = (self.provider or "ollama").lower() - if not self.model: - self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"]) - if not self.base_url: - self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider) - if self.temperature == 0.2 or self.temperature is None: - self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2) - if self.timeout == 30.0 or self.timeout is None: - self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0) + if self.temperature is not None: + self.temperature = float(self.temperature) + if self.timeout is not None: + self.timeout = float(self.timeout) @dataclass @@ -132,133 +160,22 @@ class LLMConfig: majority_threshold: int = 3 -@dataclass -class LLMProfile: - """Named LLM endpoint profile reusable across routes/departments.""" - - key: str - provider: str = "ollama" - model: Optional[str] = None - base_url: Optional[str] = None - api_key: Optional[str] = None - temperature: float = 0.2 - timeout: float = 30.0 - title: str = "" - enabled: bool = True - - def to_endpoint(self) -> LLMEndpoint: - return LLMEndpoint( - provider=self.provider, - model=self.model, - base_url=self.base_url, - api_key=self.api_key, - temperature=self.temperature, - timeout=self.timeout, - ) - - def to_dict(self) -> Dict[str, object]: - return { - "provider": self.provider, - "model": self.model, - "base_url": self.base_url, - "api_key": self.api_key, - "temperature": self.temperature, - "timeout": self.timeout, - "title": self.title, - "enabled": self.enabled, - } - - @classmethod - def from_endpoint( - cls, - key: str, - endpoint: LLMEndpoint, - *, - title: str = "", - enabled: bool = True, - ) -> "LLMProfile": - return cls( - key=key, - provider=endpoint.provider, - model=endpoint.model, - base_url=endpoint.base_url, - api_key=endpoint.api_key, - temperature=endpoint.temperature, - timeout=endpoint.timeout, - title=title, - enabled=enabled, - ) - - -@dataclass -class LLMRoute: - """Declarative routing for selecting profiles and strategy.""" - - name: str - title: str = "" - strategy: str = "single" - majority_threshold: int = 3 - primary: str = "ollama" - ensemble: List[str] = field(default_factory=list) - - def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig: - def _endpoint_from_key(key: str) -> LLMEndpoint: - profile = profiles.get(key) - if profile and profile.enabled: - return profile.to_endpoint() - fallback = profiles.get("ollama") - if not fallback or not fallback.enabled: - fallback = next( - (item for item in profiles.values() if item.enabled), - None, - ) - endpoint = fallback.to_endpoint() if fallback else LLMEndpoint() - endpoint.provider = key or endpoint.provider - return endpoint - - primary_endpoint = _endpoint_from_key(self.primary) - ensemble_endpoints = [ - _endpoint_from_key(key) - for key in self.ensemble - if key in profiles and profiles[key].enabled - ] - config = LLMConfig( - primary=primary_endpoint, - ensemble=ensemble_endpoints, - strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single", - majority_threshold=max(1, self.majority_threshold or 1), - ) - return config - - def to_dict(self) -> Dict[str, object]: - return { - "title": self.title, - "strategy": self.strategy, - "majority_threshold": self.majority_threshold, - "primary": self.primary, - "ensemble": list(self.ensemble), - } - - -def _default_llm_profiles() -> Dict[str, LLMProfile]: - return { - provider: LLMProfile( +def _default_llm_providers() -> Dict[str, LLMProvider]: + providers: Dict[str, LLMProvider] = {} + for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items(): + models = list(meta.get("models", [])) + mode = "ollama" if provider == "ollama" else "openai" + providers[provider] = LLMProvider( key=provider, - provider=provider, - model=DEFAULT_LLM_MODELS.get(provider), - base_url=DEFAULT_LLM_BASE_URLS.get(provider), - temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2), - timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0), title=f"默认 {provider}", + base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""), + models=models, + default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider), + default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))), + default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))), + mode=mode, ) - for provider in DEFAULT_LLM_MODEL_OPTIONS - } - - -def _default_llm_routes() -> Dict[str, LLMRoute]: - return { - "global": LLMRoute(name="global", title="全局默认路由"), - } + return providers @dataclass @@ -270,7 +187,6 @@ class DepartmentSettings: description: str = "" weight: float = 1.0 llm: LLMConfig = field(default_factory=LLMConfig) - llm_route: Optional[str] = None def _default_departments() -> Dict[str, DepartmentSettings]: @@ -282,10 +198,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]: ("macro", "宏观研究部门"), ("risk", "风险控制部门"), ] - return { - code: DepartmentSettings(code=code, title=title, llm_route="global") - for code, title in presets - } + return {code: DepartmentSettings(code=code, title=title) for code, title in presets} @dataclass @@ -298,17 +211,11 @@ class AppConfig: data_paths: DataPaths = field(default_factory=DataPaths) agent_weights: AgentWeights = field(default_factory=AgentWeights) force_refresh: bool = False + llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm: LLMConfig = field(default_factory=LLMConfig) - llm_route: str = "global" - llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles) - llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: - route_key = route or self.llm_route - route_cfg = self.llm_routes.get(route_key) - if route_cfg: - return route_cfg.resolve(self.llm_profiles) return self.llm def sync_runtime_llm(self) -> None: @@ -326,13 +233,22 @@ def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]: "api_key": endpoint.api_key, "temperature": endpoint.temperature, "timeout": endpoint.timeout, + "prompt_template": endpoint.prompt_template, } def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint: payload = { key: data.get(key) - for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout") + for key in ( + "provider", + "model", + "base_url", + "api_key", + "temperature", + "timeout", + "prompt_template", + ) if data.get(key) is not None } return LLMEndpoint(**payload) @@ -348,139 +264,148 @@ def _load_from_file(cfg: AppConfig) -> None: except (json.JSONDecodeError, OSError): return - if isinstance(payload, dict): - if "tushare_token" in payload: - cfg.tushare_token = payload.get("tushare_token") or None - if "force_refresh" in payload: - cfg.force_refresh = bool(payload.get("force_refresh")) - if "decision_method" in payload: - cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) + if not isinstance(payload, dict): + return - routes_defined = False - inline_primary_loaded = False + if "tushare_token" in payload: + cfg.tushare_token = payload.get("tushare_token") or None + if "force_refresh" in payload: + cfg.force_refresh = bool(payload.get("force_refresh")) + if "decision_method" in payload: + cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) - profiles_payload = payload.get("llm_profiles") - if isinstance(profiles_payload, dict): - profiles: Dict[str, LLMProfile] = {} - for key, data in profiles_payload.items(): - if not isinstance(data, dict): - continue - provider = str(data.get("provider") or "ollama").lower() - profile = LLMProfile( - key=key, - provider=provider, - model=data.get("model"), - base_url=data.get("base_url"), - api_key=data.get("api_key"), - temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))), - timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))), - title=str(data.get("title") or ""), - enabled=bool(data.get("enabled", True)), - ) - profiles[key] = profile - if profiles: - cfg.llm_profiles = profiles + legacy_profiles: Dict[str, Dict[str, object]] = {} + legacy_routes: Dict[str, Dict[str, object]] = {} - routes_payload = payload.get("llm_routes") - if isinstance(routes_payload, dict): - routes: Dict[str, LLMRoute] = {} - for name, data in routes_payload.items(): - if not isinstance(data, dict): - continue - strategy_raw = str(data.get("strategy") or "single").lower() - normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) - route = LLMRoute( - name=name, - title=str(data.get("title") or ""), - strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single", - majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)), - primary=str(data.get("primary") or "global"), - ensemble=[ - str(item) - for item in data.get("ensemble", []) - if isinstance(item, str) - ], - ) - routes[name] = route - if routes: - cfg.llm_routes = routes - routes_defined = True + providers_payload = payload.get("llm_providers") + if isinstance(providers_payload, dict): + providers: Dict[str, LLMProvider] = {} + for key, data in providers_payload.items(): + if not isinstance(data, dict): + continue + models_raw = data.get("models") + if isinstance(models_raw, str): + models = [item.strip() for item in models_raw.split(',') if item.strip()] + elif isinstance(models_raw, list): + models = [str(item).strip() for item in models_raw if str(item).strip()] + else: + models = [] + provider = LLMProvider( + key=str(key).lower(), + title=str(data.get("title") or ""), + base_url=str(data.get("base_url") or ""), + api_key=data.get("api_key"), + models=models, + default_model=data.get("default_model") or (models[0] if models else None), + default_temperature=float(data.get("default_temperature", 0.2)), + default_timeout=float(data.get("default_timeout", 30.0)), + prompt_template=str(data.get("prompt_template") or ""), + enabled=bool(data.get("enabled", True)), + mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")), + ) + providers[provider.key] = provider + if providers: + cfg.llm_providers = providers - route_key = payload.get("llm_route") - if isinstance(route_key, str) and route_key: - cfg.llm_route = route_key + profiles_payload = payload.get("llm_profiles") + if isinstance(profiles_payload, dict): + for key, data in profiles_payload.items(): + if isinstance(data, dict): + legacy_profiles[str(key)] = data - llm_payload = payload.get("llm") - if isinstance(llm_payload, dict): - route_value = llm_payload.get("route") - if isinstance(route_value, str) and route_value: - cfg.llm_route = route_value + routes_payload = payload.get("llm_routes") + if isinstance(routes_payload, dict): + for name, data in routes_payload.items(): + if isinstance(data, dict): + legacy_routes[str(name)] = data + + def _endpoint_from_payload(item: object) -> LLMEndpoint: + if isinstance(item, dict): + return _dict_to_endpoint(item) + if isinstance(item, str): + profile_data = legacy_profiles.get(item) + if isinstance(profile_data, dict): + return _dict_to_endpoint(profile_data) + return LLMEndpoint(provider=item) + return LLMEndpoint() + + def _resolve_route(route_name: str) -> Optional[LLMConfig]: + route_data = legacy_routes.get(route_name) + if not route_data: + return None + strategy_raw = str(route_data.get("strategy") or "single").lower() + strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) + primary_ref = route_data.get("primary") + primary_ep = _endpoint_from_payload(primary_ref) + ensemble_refs = route_data.get("ensemble", []) + ensemble_eps = [ + _endpoint_from_payload(ref) + for ref in ensemble_refs + if isinstance(ref, (dict, str)) + ] + cfg_obj = LLMConfig( + primary=primary_ep, + ensemble=ensemble_eps, + strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single", + majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)), + ) + return cfg_obj + + llm_payload = payload.get("llm") + if isinstance(llm_payload, dict): + route_value = llm_payload.get("route") + resolved_cfg = None + if isinstance(route_value, str) and route_value: + resolved_cfg = _resolve_route(route_value) + if resolved_cfg is None: + resolved_cfg = LLMConfig() primary_data = llm_payload.get("primary") if isinstance(primary_data, dict): - cfg.llm.primary = _dict_to_endpoint(primary_data) - inline_primary_loaded = True - + resolved_cfg.primary = _dict_to_endpoint(primary_data) ensemble_data = llm_payload.get("ensemble") if isinstance(ensemble_data, list): - cfg.llm.ensemble = [ + resolved_cfg.ensemble = [ _dict_to_endpoint(item) for item in ensemble_data if isinstance(item, dict) ] - strategy_raw = llm_payload.get("strategy") if isinstance(strategy_raw, str): normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) if normalized in ALLOWED_LLM_STRATEGIES: - cfg.llm.strategy = normalized - + resolved_cfg.strategy = normalized majority = llm_payload.get("majority_threshold") if isinstance(majority, int) and majority > 0: - cfg.llm.majority_threshold = majority + resolved_cfg.majority_threshold = majority + cfg.llm = resolved_cfg - if inline_primary_loaded and not routes_defined: - primary_key = "inline_global_primary" - cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint( - primary_key, - cfg.llm.primary, - title="全局主模型", - ) - ensemble_keys: List[str] = [] - for idx, endpoint in enumerate(cfg.llm.ensemble, start=1): - inline_key = f"inline_global_ensemble_{idx}" - cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint( - inline_key, - endpoint, - title=f"全局协作#{idx}", - ) - ensemble_keys.append(inline_key) - auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由") - auto_route.strategy = cfg.llm.strategy - auto_route.majority_threshold = cfg.llm.majority_threshold - auto_route.primary = primary_key - auto_route.ensemble = ensemble_keys - cfg.llm_routes["global"] = auto_route - cfg.llm_route = cfg.llm_route or "global" - - departments_payload = payload.get("departments") - if isinstance(departments_payload, dict): - new_departments: Dict[str, DepartmentSettings] = {} - for code, data in departments_payload.items(): - if not isinstance(data, dict): - continue - title = data.get("title") or code - description = data.get("description") or "" - weight = float(data.get("weight", 1.0)) + departments_payload = payload.get("departments") + if isinstance(departments_payload, dict): + new_departments: Dict[str, DepartmentSettings] = {} + for code, data in departments_payload.items(): + if not isinstance(data, dict): + continue + title = data.get("title") or code + description = data.get("description") or "" + weight = float(data.get("weight", 1.0)) + llm_cfg = LLMConfig() + route_name = data.get("llm_route") + resolved_cfg = None + if isinstance(route_name, str) and route_name: + resolved_cfg = _resolve_route(route_name) + if resolved_cfg is None: llm_data = data.get("llm") - llm_cfg = LLMConfig() if isinstance(llm_data, dict): - if isinstance(llm_data.get("primary"), dict): - llm_cfg.primary = _dict_to_endpoint(llm_data["primary"]) - llm_cfg.ensemble = [ - _dict_to_endpoint(item) - for item in llm_data.get("ensemble", []) - if isinstance(item, dict) - ] + primary_data = llm_data.get("primary") + if isinstance(primary_data, dict): + llm_cfg.primary = _dict_to_endpoint(primary_data) + ensemble_data = llm_data.get("ensemble") + if isinstance(ensemble_data, list): + llm_cfg.ensemble = [ + _dict_to_endpoint(item) + for item in ensemble_data + if isinstance(item, dict) + ] strategy_raw = llm_data.get("strategy") if isinstance(strategy_raw, str): normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) @@ -489,23 +414,18 @@ def _load_from_file(cfg: AppConfig) -> None: majority_raw = llm_data.get("majority_threshold") if isinstance(majority_raw, int) and majority_raw > 0: llm_cfg.majority_threshold = majority_raw - route = data.get("llm_route") - route_name = str(route).strip() if isinstance(route, str) and route else None - resolved = llm_cfg - if route_name and route_name in cfg.llm_routes: - resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles) - new_departments[code] = DepartmentSettings( - code=code, - title=title, - description=description, - weight=weight, - llm=resolved, - llm_route=route_name, - ) - if new_departments: - cfg.departments = new_departments + resolved_cfg = llm_cfg + new_departments[code] = DepartmentSettings( + code=code, + title=title, + description=description, + weight=weight, + llm=resolved_cfg, + ) + if new_departments: + cfg.departments = new_departments - cfg.sync_runtime_llm() + cfg.sync_runtime_llm() def save_config(cfg: AppConfig | None = None) -> None: @@ -516,28 +436,21 @@ def save_config(cfg: AppConfig | None = None) -> None: "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, "decision_method": cfg.decision_method, - "llm_route": cfg.llm_route, "llm": { - "route": cfg.llm_route, "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, "primary": _endpoint_to_dict(cfg.llm.primary), "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], }, - "llm_profiles": { - key: profile.to_dict() - for key, profile in cfg.llm_profiles.items() - }, - "llm_routes": { - name: route.to_dict() - for name, route in cfg.llm_routes.items() + "llm_providers": { + key: provider.to_dict() + for key, provider in cfg.llm_providers.items() }, "departments": { code: { "title": dept.title, "description": dept.description, "weight": dept.weight, - "llm_route": dept.llm_route, "llm": { "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": dept.llm.majority_threshold, @@ -567,11 +480,9 @@ def _load_env_defaults(cfg: AppConfig) -> None: if api_key: sanitized = api_key.strip() cfg.llm.primary.api_key = sanitized - route = cfg.llm_routes.get(cfg.llm_route) - if route: - profile = cfg.llm_profiles.get(route.primary) - if profile: - profile.api_key = sanitized + provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider) + if provider_cfg: + provider_cfg.api_key = sanitized cfg.sync_runtime_llm()