diff --git a/README.md b/README.md index f986fec..479b1d0 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。 - **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。 - **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。 +- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。 ## LLM + 多智能体最佳实践 @@ -59,13 +60,11 @@ export TUSHARE_TOKEN="" ### LLM 配置与测试 -- 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间,亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载。 -- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。 -- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。 -- 未来可对同一功能的智能体并行调用多个 LLM,采用多数投票等策略增强鲁棒性,当前代码结构已为此预留扩展空间。 -- 若使用环境变量自动注入配置,可设置: - - `TUSHARE_TOKEN` - - `LLM_API_KEY` +- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。 +- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。 +- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。 +- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。 +- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。 ## 快速开始 @@ -105,7 +104,7 @@ Streamlit `自检测试` 页签提供: ## 实施步骤 1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅ - - 部门支持 primary/ensemble、策略(single/majority/leader)、权重,并可在 Streamlit 中编辑主要字段。 + - 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。 2. **部门管控器** ✅ - `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。 diff --git a/app/agents/departments.py b/app/agents/departments.py index 7cd7b35..aee5284 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -3,12 +3,12 @@ from __future__ import annotations import json from dataclasses import dataclass, field -from typing import Any, Dict, List, Mapping +from typing import Any, Callable, Dict, List, Mapping, Optional from app.agents.base import AgentAction from app.llm.client import run_llm_with_config from app.llm.prompts import department_prompt -from app.utils.config import DepartmentSettings +from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -53,16 +53,27 @@ class DepartmentDecision: class DepartmentAgent: """Wraps LLM ensemble logic for a single analytical department.""" - def __init__(self, settings: DepartmentSettings) -> None: + def __init__( + self, + settings: DepartmentSettings, + resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None, + ) -> None: self.settings = settings + self._resolver = resolver + + def _get_llm_config(self) -> LLMConfig: + if self._resolver: + return self._resolver(self.settings) + return self.settings.llm def analyze(self, context: DepartmentContext) -> DepartmentDecision: prompt = department_prompt(self.settings, context) system_prompt = ( "你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。" ) + llm_cfg = self._get_llm_config() try: - response = run_llm_with_config(self.settings.llm, prompt, system=system_prompt) + response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) except Exception as exc: # noqa: BLE001 LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA) return DepartmentDecision( @@ -106,10 +117,11 @@ class DepartmentAgent: class DepartmentManager: """Orchestrates all departments defined in configuration.""" - def __init__(self, departments: Mapping[str, DepartmentSettings]) -> None: + def __init__(self, config: AppConfig) -> None: + self.config = config self.agents: Dict[str, DepartmentAgent] = { - code: DepartmentAgent(settings) - for code, settings in departments.items() + code: DepartmentAgent(settings, self._resolve_llm) + for code, settings in config.departments.items() } def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]: @@ -118,6 +130,11 @@ class DepartmentManager: results[code] = agent.analyze(context) return results + def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig: + if settings.llm_route and settings.llm_route in self.config.llm_routes: + return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles) + return settings.llm + def _parse_department_response(text: str) -> Dict[str, Any]: """Extract a JSON object from the LLM response if possible.""" diff --git a/app/backtest/engine.py b/app/backtest/engine.py index eff7b72..140d51d 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -55,7 +55,7 @@ class BacktestEngine: else: self.weights = {agent.name: 1.0 for agent in self.agents} self.department_manager = ( - DepartmentManager(app_cfg.departments) if app_cfg.departments else None + DepartmentManager(app_cfg) if app_cfg.departments else None ) def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]: diff --git a/app/llm/client.py b/app/llm/client.py index 1e6a88b..be03e06 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -272,7 +272,8 @@ def run_llm_with_config( def llm_config_snapshot() -> Dict[str, object]: """Return a sanitized snapshot of current LLM configuration for debugging.""" - settings = get_config().llm + cfg = get_config() + settings = cfg.llm primary = asdict(settings.primary) if primary.get("api_key"): primary["api_key"] = "***" @@ -282,7 +283,11 @@ def llm_config_snapshot() -> Dict[str, object]: if record.get("api_key"): record["api_key"] = "***" ensemble.append(record) + route_name = cfg.llm_route + route_obj = cfg.llm_routes.get(route_name) return { + "route": route_name, + "route_detail": route_obj.to_dict() if route_obj else None, "strategy": settings.strategy, "majority_threshold": settings.majority_threshold, "primary": primary, diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 330b622..0dacebf 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -29,7 +29,8 @@ from app.utils.config import ( DEFAULT_LLM_MODEL_OPTIONS, DEFAULT_LLM_MODELS, DepartmentSettings, - LLMEndpoint, + LLMProfile, + LLMRoute, get_config, save_config, ) @@ -349,208 +350,253 @@ def render_settings() -> None: st.divider() st.subheader("LLM 设置") - llm_cfg = cfg.llm - primary = llm_cfg.primary - providers = sorted(DEFAULT_LLM_MODELS.keys()) - try: - provider_index = providers.index((primary.provider or "ollama").lower()) - except ValueError: - provider_index = 0 - selected_provider = st.selectbox("LLM Provider", providers, index=provider_index) - provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {}) - model_options = provider_info.get("models", []) - custom_model_label = "自定义模型" - default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) + profiles = cfg.llm_profiles or {} + routes = cfg.llm_routes or {} + profile_keys = sorted(profiles.keys()) + route_keys = sorted(routes.keys()) + used_routes = { + dept.llm_route for dept in cfg.departments.values() if dept.llm_route + } + st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。") - if model_options: - options_with_custom = model_options + [custom_model_label] - if primary.provider == selected_provider and primary.model in model_options: - model_index = options_with_custom.index(primary.model) - else: - model_index = 0 - selected_model_option = st.selectbox( - "LLM 模型", - options_with_custom, - index=model_index, - help=f"可选模型:{', '.join(model_options)}", + route_select_col, route_manage_col = st.columns([3, 1]) + if route_keys: + try: + active_index = route_keys.index(cfg.llm_route) + except ValueError: + active_index = 0 + selected_route = route_select_col.selectbox( + "全局路由", + route_keys, + index=active_index, + key="llm_route_select", ) - if selected_model_option == custom_model_label: - custom_model_value = st.text_input( - "自定义模型名称", - value="" if primary.provider != selected_provider or primary.model in model_options else primary.model, - ) - chosen_model = custom_model_value.strip() or default_model_hint + else: + selected_route = None + route_select_col.info("尚未配置路由,请先创建。") + + new_route_name = route_manage_col.text_input("新增路由", key="new_route_name") + if route_manage_col.button("添加路由"): + key = (new_route_name or "").strip() + if not key: + st.warning("请输入有效的路由名称。") + elif key in routes: + st.warning(f"路由 {key} 已存在。") else: - chosen_model = selected_model_option - else: - chosen_model = st.text_input( - "LLM 模型", - value=primary.model or default_model_hint, - help="未预设该 Provider 的模型列表,请手动填写", - ).strip() or default_model_hint - default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "") - provider_default_temp = float(provider_info.get("temperature", 0.2)) - provider_default_timeout = int(provider_info.get("timeout", 30.0)) + routes[key] = LLMRoute(name=key) + if not selected_route: + selected_route = key + cfg.llm_route = key + save_config() + st.success(f"已添加路由 {key},请继续配置。") + st.experimental_rerun() - if primary.provider == selected_provider: - base_value = primary.base_url or default_base_hint or "" - temp_value = float(primary.temperature) - timeout_value = int(primary.timeout) - else: - base_value = default_base_hint or "" - temp_value = provider_default_temp - timeout_value = provider_default_timeout - - llm_base = st.text_input( - "LLM Base URL", - value=base_value, - help=f"默认推荐:{default_base_hint or '按供应商要求填写'}", - ) - llm_api_key = st.text_input( - "LLM API Key", - value=primary.api_key or "", - type="password", - help="点击右侧小图标可查看当前 Key,该值会写入 config.json(已被 gitignore 排除)", - ) - llm_temperature = st.slider( - "LLM 温度", - min_value=0.0, - max_value=2.0, - value=temp_value, - step=0.05, - ) - llm_timeout = st.number_input( - "请求超时时间 (秒)", - min_value=5, - max_value=120, - value=timeout_value, - step=5, - ) - - strategy_options = ["single", "majority", "leader"] - try: - strategy_index = strategy_options.index(llm_cfg.strategy) - except ValueError: - strategy_index = 0 - selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index) - majority_threshold = st.number_input( - "多数投票门槛", - min_value=1, - max_value=10, - value=int(llm_cfg.majority_threshold), - step=1, - format="%d", - ) - - existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble} - - available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()) - ensemble_rows = [ - { - "provider": ep.provider or "", - "model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]), - "base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""), - "api_key": "***" if ep.api_key else "", - "temperature": float(ep.temperature), - "timeout": float(ep.timeout), - } - for ep in llm_cfg.ensemble - ] or [ - { - "provider": "", - "model": "", - "base_url": "", - "api_key": "", - "temperature": provider_default_temp, - "timeout": provider_default_timeout, - } - ] - - edited = st.data_editor( - ensemble_rows, - num_rows="dynamic", - key="llm_ensemble_editor", - column_config={ - "provider": st.column_config.SelectboxColumn( - "Provider", - options=available_providers, - help="选择 LLM 供应商" - ), - "model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"), - "base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"), - "api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"), - "temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05), - "timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0), - }, - hide_index=True, - use_container_width=True, - ) - if hasattr(edited, "to_dict"): - ensemble_rows = edited.to_dict("records") - else: - ensemble_rows = edited - - if st.button("保存 LLM 设置"): - primary.provider = selected_provider - primary.model = chosen_model - primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider) - primary.temperature = llm_temperature - primary.timeout = llm_timeout - api_key_value = llm_api_key.strip() - if api_key_value: - primary.api_key = api_key_value - - new_ensemble: List[LLMEndpoint] = [] - for row in ensemble_rows: - provider = (row.get("provider") or "").strip().lower() - if not provider: - continue - provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {}) - default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"]) - default_base = DEFAULT_LLM_BASE_URLS.get(provider) - temp_default = float(provider_defaults.get("temperature", 0.2)) - timeout_default = float(provider_defaults.get("timeout", 30.0)) - - model_val = (row.get("model") or "").strip() or default_model - base_val = (row.get("base_url") or "").strip() or default_base - api_raw = (row.get("api_key") or "").strip() - if api_raw == "***": - api_value = existing_api_keys.get(provider) - else: - api_value = api_raw or None - - temp_val = row.get("temperature") - timeout_val = row.get("timeout") - endpoint = LLMEndpoint( - provider=provider, - model=model_val, - base_url=base_val, - api_key=api_value, - temperature=float(temp_val) if temp_val is not None else temp_default, - timeout=float(timeout_val) if timeout_val is not None else timeout_default, + if selected_route: + route_obj = routes.get(selected_route) + if route_obj is None: + route_obj = LLMRoute(name=selected_route) + routes[selected_route] = route_obj + strategy_choices = sorted(ALLOWED_LLM_STRATEGIES) + try: + strategy_index = strategy_choices.index(route_obj.strategy) + except ValueError: + strategy_index = 0 + route_title = st.text_input( + "路由说明", + value=route_obj.title or "", + key=f"route_title_{selected_route}", + ) + route_strategy = st.selectbox( + "推理策略", + strategy_choices, + index=strategy_index, + key=f"route_strategy_{selected_route}", + ) + route_majority = st.number_input( + "多数投票门槛", + min_value=1, + max_value=10, + value=int(route_obj.majority_threshold or 1), + step=1, + key=f"route_majority_{selected_route}", + ) + if not profile_keys: + st.warning("暂无可用 Profile,请先在下方创建。") + else: + try: + primary_index = profile_keys.index(route_obj.primary) + except ValueError: + primary_index = 0 + primary_key = st.selectbox( + "主用 Profile", + profile_keys, + index=primary_index, + key=f"route_primary_{selected_route}", ) - new_ensemble.append(endpoint) + default_ensemble = [ + key for key in route_obj.ensemble if key in profile_keys and key != primary_key + ] + ensemble_keys = st.multiselect( + "协作 Profile (可多选)", + profile_keys, + default=default_ensemble, + key=f"route_ensemble_{selected_route}", + ) + if st.button("保存路由设置", key=f"save_route_{selected_route}"): + route_obj.title = route_title.strip() + route_obj.strategy = route_strategy + route_obj.majority_threshold = int(route_majority) + route_obj.primary = primary_key + route_obj.ensemble = [key for key in ensemble_keys if key != primary_key] + cfg.llm_route = selected_route + cfg.sync_runtime_llm() + save_config() + LOGGER.info( + "路由 %s 配置更新:%s", + selected_route, + route_obj.to_dict(), + extra=LOG_EXTRA, + ) + st.success("路由配置已保存。") + st.json({ + "route": selected_route, + "route_detail": route_obj.to_dict(), + "resolved": llm_config_snapshot(), + }) + route_in_use = selected_route in used_routes or selected_route == cfg.llm_route + if st.button( + "删除当前路由", + key=f"delete_route_{selected_route}", + disabled=route_in_use or len(routes) <= 1, + ): + routes.pop(selected_route, None) + if cfg.llm_route == selected_route: + cfg.llm_route = next((key for key in routes.keys()), "global") + cfg.sync_runtime_llm() + save_config() + st.success("路由已删除。") + st.experimental_rerun() - llm_cfg.ensemble = new_ensemble - llm_cfg.strategy = selected_strategy - llm_cfg.majority_threshold = int(majority_threshold) - save_config() - LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) - st.success("LLM 设置已保存,仅在当前会话生效。") - st.json(llm_config_snapshot()) + st.divider() + st.subheader("LLM Profile 管理") + profile_select_col, profile_manage_col = st.columns([3, 1]) + if profile_keys: + selected_profile = profile_select_col.selectbox( + "选择 Profile", + profile_keys, + index=0, + key="profile_select", + ) + else: + selected_profile = None + profile_select_col.info("尚未配置 Profile,请先创建。") + + new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name") + if profile_manage_col.button("创建 Profile"): + key = (new_profile_name or "").strip() + if not key: + st.warning("请输入有效的 Profile 名称。") + elif key in profiles: + st.warning(f"Profile {key} 已存在。") + else: + profiles[key] = LLMProfile(key=key) + save_config() + st.success(f"已创建 Profile {key}。") + st.experimental_rerun() + + if selected_profile: + profile = profiles[selected_profile] + provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()) + try: + provider_index = provider_choices.index(profile.provider) + except ValueError: + provider_index = 0 + with st.form(f"profile_form_{selected_profile}"): + provider_val = st.selectbox( + "Provider", + provider_choices, + index=provider_index, + ) + model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "") + model_val = st.text_input( + "模型", + value=profile.model or model_default, + ) + base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "") + base_val = st.text_input( + "Base URL", + value=profile.base_url or base_default, + ) + api_val = st.text_input( + "API Key", + value=profile.api_key or "", + type="password", + ) + temp_val = st.slider( + "温度", + min_value=0.0, + max_value=2.0, + value=float(profile.temperature), + step=0.05, + ) + timeout_val = st.number_input( + "超时(秒)", + min_value=5, + max_value=180, + value=int(profile.timeout or 30), + step=5, + ) + title_val = st.text_input("备注", value=profile.title or "") + enabled_val = st.checkbox("启用", value=profile.enabled) + submitted = st.form_submit_button("保存 Profile") + if submitted: + profile.provider = provider_val + profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val) + profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val) + profile.api_key = api_val.strip() or None + profile.temperature = temp_val + profile.timeout = timeout_val + profile.title = title_val.strip() + profile.enabled = enabled_val + profiles[selected_profile] = profile + cfg.sync_runtime_llm() + save_config() + st.success("Profile 已保存。") + + profile_in_use = any( + selected_profile == route.primary or selected_profile in route.ensemble + for route in routes.values() + ) + if st.button( + "删除该 Profile", + key=f"delete_profile_{selected_profile}", + disabled=profile_in_use or len(profiles) <= 1, + ): + profiles.pop(selected_profile, None) + fallback_key = next((key for key in profiles.keys()), None) + for route in routes.values(): + if route.primary == selected_profile: + route.primary = fallback_key or route.primary + route.ensemble = [key for key in route.ensemble if key != selected_profile] + cfg.sync_runtime_llm() + save_config() + st.success("Profile 已删除。") + st.experimental_rerun() st.divider() st.subheader("部门配置") dept_settings = cfg.departments or {} + route_options_display = [""] + route_keys dept_rows = [ { "code": code, "title": dept.title, "description": dept.description, "weight": float(dept.weight), + "llm_route": dept.llm_route or "", "strategy": dept.llm.strategy, - "primary_provider": (dept.llm.primary.provider or "ollama"), + "primary_provider": (dept.llm.primary.provider or ""), "primary_model": dept.llm.primary.model or "", "ensemble_size": len(dept.llm.ensemble), } @@ -572,20 +618,25 @@ def render_settings() -> None: "title": st.column_config.TextColumn("名称"), "description": st.column_config.TextColumn("说明"), "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), + "llm_route": st.column_config.SelectboxColumn( + "路由", + options=route_options_display, + help="选择预定义路由;留空表示使用自定义配置", + ), "strategy": st.column_config.SelectboxColumn( - "策略", + "自定义策略", options=sorted(ALLOWED_LLM_STRATEGIES), - help="single=单模型, majority=多数投票, leader=顾问-决策者模式", + help="仅当未选择路由时生效", ), "primary_provider": st.column_config.SelectboxColumn( - "主模型 Provider", + "自定义 Provider", options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()), ), - "primary_model": st.column_config.TextColumn("主模型名称"), + "primary_model": st.column_config.TextColumn("自定义模型"), "ensemble_size": st.column_config.NumberColumn( "协作模型数量", disabled=True, - help="在 config.json 中编辑 ensemble 详情", + help="路由模式下自动维护", ), }, ) @@ -609,31 +660,32 @@ def render_settings() -> None: try: existing.weight = max(0.0, float(row.get("weight", existing.weight))) except (TypeError, ValueError): - existing.weight = existing.weight + pass - strategy_val = (row.get("strategy") or existing.llm.strategy).lower() - if strategy_val in ALLOWED_LLM_STRATEGIES: - existing.llm.strategy = strategy_val - - provider_before = existing.llm.primary.provider or "" - provider_val = (row.get("primary_provider") or provider_before or "ollama").lower() - existing.llm.primary.provider = provider_val - - model_val = (row.get("primary_model") or "").strip() - if model_val: - existing.llm.primary.model = model_val + route_name = (row.get("llm_route") or "").strip() or None + existing.llm_route = route_name + if route_name and route_name in routes: + existing.llm = routes[route_name].resolve(profiles) else: - existing.llm.primary.model = DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model) - - if provider_before != provider_val: - default_base = DEFAULT_LLM_BASE_URLS.get(provider_val) - existing.llm.primary.base_url = default_base or existing.llm.primary.base_url - - existing.llm.primary.__post_init__() + strategy_val = (row.get("strategy") or existing.llm.strategy).lower() + if strategy_val in ALLOWED_LLM_STRATEGIES: + existing.llm.strategy = strategy_val + provider_before = existing.llm.primary.provider or "" + provider_val = (row.get("primary_provider") or provider_before or "ollama").lower() + existing.llm.primary.provider = provider_val + model_val = (row.get("primary_model") or "").strip() + existing.llm.primary.model = ( + model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model) + ) + if provider_before != provider_val: + default_base = DEFAULT_LLM_BASE_URLS.get(provider_val) + existing.llm.primary.base_url = default_base or existing.llm.primary.base_url + existing.llm.primary.__post_init__() updated_departments[code] = existing if updated_departments: cfg.departments = updated_departments + cfg.sync_runtime_llm() save_config() st.success("部门配置已更新。") else: @@ -643,10 +695,12 @@ def render_settings() -> None: from app.utils.config import _default_departments cfg.departments = _default_departments() + cfg.sync_runtime_llm() save_config() st.success("已恢复默认部门配置。") st.experimental_rerun() + st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。") st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。") diff --git a/app/utils/config.py b/app/utils/config.py index 6aa835e..d0b6222 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field import json import os from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Mapping, Optional def _default_root() -> Path: @@ -132,6 +132,135 @@ class LLMConfig: majority_threshold: int = 3 +@dataclass +class LLMProfile: + """Named LLM endpoint profile reusable across routes/departments.""" + + key: str + provider: str = "ollama" + model: Optional[str] = None + base_url: Optional[str] = None + api_key: Optional[str] = None + temperature: float = 0.2 + timeout: float = 30.0 + title: str = "" + enabled: bool = True + + def to_endpoint(self) -> LLMEndpoint: + return LLMEndpoint( + provider=self.provider, + model=self.model, + base_url=self.base_url, + api_key=self.api_key, + temperature=self.temperature, + timeout=self.timeout, + ) + + def to_dict(self) -> Dict[str, object]: + return { + "provider": self.provider, + "model": self.model, + "base_url": self.base_url, + "api_key": self.api_key, + "temperature": self.temperature, + "timeout": self.timeout, + "title": self.title, + "enabled": self.enabled, + } + + @classmethod + def from_endpoint( + cls, + key: str, + endpoint: LLMEndpoint, + *, + title: str = "", + enabled: bool = True, + ) -> "LLMProfile": + return cls( + key=key, + provider=endpoint.provider, + model=endpoint.model, + base_url=endpoint.base_url, + api_key=endpoint.api_key, + temperature=endpoint.temperature, + timeout=endpoint.timeout, + title=title, + enabled=enabled, + ) + + +@dataclass +class LLMRoute: + """Declarative routing for selecting profiles and strategy.""" + + name: str + title: str = "" + strategy: str = "single" + majority_threshold: int = 3 + primary: str = "ollama" + ensemble: List[str] = field(default_factory=list) + + def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig: + def _endpoint_from_key(key: str) -> LLMEndpoint: + profile = profiles.get(key) + if profile and profile.enabled: + return profile.to_endpoint() + fallback = profiles.get("ollama") + if not fallback or not fallback.enabled: + fallback = next( + (item for item in profiles.values() if item.enabled), + None, + ) + endpoint = fallback.to_endpoint() if fallback else LLMEndpoint() + endpoint.provider = key or endpoint.provider + return endpoint + + primary_endpoint = _endpoint_from_key(self.primary) + ensemble_endpoints = [ + _endpoint_from_key(key) + for key in self.ensemble + if key in profiles and profiles[key].enabled + ] + config = LLMConfig( + primary=primary_endpoint, + ensemble=ensemble_endpoints, + strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single", + majority_threshold=max(1, self.majority_threshold or 1), + ) + return config + + def to_dict(self) -> Dict[str, object]: + return { + "title": self.title, + "strategy": self.strategy, + "majority_threshold": self.majority_threshold, + "primary": self.primary, + "ensemble": list(self.ensemble), + } + + +def _default_llm_profiles() -> Dict[str, LLMProfile]: + return { + provider: LLMProfile( + key=provider, + provider=provider, + model=DEFAULT_LLM_MODELS.get(provider), + base_url=DEFAULT_LLM_BASE_URLS.get(provider), + temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2), + timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0), + title=f"默认 {provider}", + ) + for provider in DEFAULT_LLM_MODEL_OPTIONS + } + + +def _default_llm_routes() -> Dict[str, LLMRoute]: + return { + "global": LLMRoute(name="global", title="全局默认路由"), + } + + @dataclass class DepartmentSettings: """Configuration for a single decision department.""" @@ -141,6 +270,7 @@ class DepartmentSettings: description: str = "" weight: float = 1.0 llm: LLMConfig = field(default_factory=LLMConfig) + llm_route: Optional[str] = None def _default_departments() -> Dict[str, DepartmentSettings]: @@ -153,7 +283,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]: ("risk", "风险控制部门"), ] return { - code: DepartmentSettings(code=code, title=title) + code: DepartmentSettings(code=code, title=title, llm_route="global") for code, title in presets } @@ -169,8 +299,21 @@ class AppConfig: agent_weights: AgentWeights = field(default_factory=AgentWeights) force_refresh: bool = False llm: LLMConfig = field(default_factory=LLMConfig) + llm_route: str = "global" + llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles) + llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) + def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: + route_key = route or self.llm_route + route_cfg = self.llm_routes.get(route_key) + if route_cfg: + return route_cfg.resolve(self.llm_profiles) + return self.llm + + def sync_runtime_llm(self) -> None: + self.llm = self.resolve_llm() + CONFIG = AppConfig() @@ -213,11 +356,69 @@ def _load_from_file(cfg: AppConfig) -> None: if "decision_method" in payload: cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) + routes_defined = False + inline_primary_loaded = False + + profiles_payload = payload.get("llm_profiles") + if isinstance(profiles_payload, dict): + profiles: Dict[str, LLMProfile] = {} + for key, data in profiles_payload.items(): + if not isinstance(data, dict): + continue + provider = str(data.get("provider") or "ollama").lower() + profile = LLMProfile( + key=key, + provider=provider, + model=data.get("model"), + base_url=data.get("base_url"), + api_key=data.get("api_key"), + temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))), + timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))), + title=str(data.get("title") or ""), + enabled=bool(data.get("enabled", True)), + ) + profiles[key] = profile + if profiles: + cfg.llm_profiles = profiles + + routes_payload = payload.get("llm_routes") + if isinstance(routes_payload, dict): + routes: Dict[str, LLMRoute] = {} + for name, data in routes_payload.items(): + if not isinstance(data, dict): + continue + strategy_raw = str(data.get("strategy") or "single").lower() + normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) + route = LLMRoute( + name=name, + title=str(data.get("title") or ""), + strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single", + majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)), + primary=str(data.get("primary") or "global"), + ensemble=[ + str(item) + for item in data.get("ensemble", []) + if isinstance(item, str) + ], + ) + routes[name] = route + if routes: + cfg.llm_routes = routes + routes_defined = True + + route_key = payload.get("llm_route") + if isinstance(route_key, str) and route_key: + cfg.llm_route = route_key + llm_payload = payload.get("llm") if isinstance(llm_payload, dict): + route_value = llm_payload.get("route") + if isinstance(route_value, str) and route_value: + cfg.llm_route = route_value primary_data = llm_payload.get("primary") if isinstance(primary_data, dict): cfg.llm.primary = _dict_to_endpoint(primary_data) + inline_primary_loaded = True ensemble_data = llm_payload.get("ensemble") if isinstance(ensemble_data, list): @@ -237,6 +438,30 @@ def _load_from_file(cfg: AppConfig) -> None: if isinstance(majority, int) and majority > 0: cfg.llm.majority_threshold = majority + if inline_primary_loaded and not routes_defined: + primary_key = "inline_global_primary" + cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint( + primary_key, + cfg.llm.primary, + title="全局主模型", + ) + ensemble_keys: List[str] = [] + for idx, endpoint in enumerate(cfg.llm.ensemble, start=1): + inline_key = f"inline_global_ensemble_{idx}" + cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint( + inline_key, + endpoint, + title=f"全局协作#{idx}", + ) + ensemble_keys.append(inline_key) + auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由") + auto_route.strategy = cfg.llm.strategy + auto_route.majority_threshold = cfg.llm.majority_threshold + auto_route.primary = primary_key + auto_route.ensemble = ensemble_keys + cfg.llm_routes["global"] = auto_route + cfg.llm_route = cfg.llm_route or "global" + departments_payload = payload.get("departments") if isinstance(departments_payload, dict): new_departments: Dict[str, DepartmentSettings] = {} @@ -264,35 +489,55 @@ def _load_from_file(cfg: AppConfig) -> None: majority_raw = llm_data.get("majority_threshold") if isinstance(majority_raw, int) and majority_raw > 0: llm_cfg.majority_threshold = majority_raw + route = data.get("llm_route") + route_name = str(route).strip() if isinstance(route, str) and route else None + resolved = llm_cfg + if route_name and route_name in cfg.llm_routes: + resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles) new_departments[code] = DepartmentSettings( code=code, title=title, description=description, weight=weight, - llm=llm_cfg, + llm=resolved, + llm_route=route_name, ) if new_departments: cfg.departments = new_departments + cfg.sync_runtime_llm() + def save_config(cfg: AppConfig | None = None) -> None: cfg = cfg or CONFIG + cfg.sync_runtime_llm() path = cfg.data_paths.config_file payload = { "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, "decision_method": cfg.decision_method, + "llm_route": cfg.llm_route, "llm": { + "route": cfg.llm_route, "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, "primary": _endpoint_to_dict(cfg.llm.primary), "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], }, + "llm_profiles": { + key: profile.to_dict() + for key, profile in cfg.llm_profiles.items() + }, + "llm_routes": { + name: route.to_dict() + for name, route in cfg.llm_routes.items() + }, "departments": { code: { "title": dept.title, "description": dept.description, "weight": dept.weight, + "llm_route": dept.llm_route, "llm": { "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": dept.llm.majority_threshold, @@ -320,7 +565,15 @@ def _load_env_defaults(cfg: AppConfig) -> None: api_key = os.getenv("LLM_API_KEY") if api_key: - cfg.llm.primary.api_key = api_key.strip() + sanitized = api_key.strip() + cfg.llm.primary.api_key = sanitized + route = cfg.llm_routes.get(cfg.llm_route) + if route: + profile = cfg.llm_profiles.get(route.primary) + if profile: + profile.api_key = sanitized + + cfg.sync_runtime_llm() _load_from_file(CONFIG)