This commit is contained in:
sam 2025-09-28 11:28:55 +08:00
parent 6aece20816
commit ab6180646a
5 changed files with 574 additions and 575 deletions

View File

@ -24,7 +24,7 @@
- **统一日志与持久化**SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
- **跨市场数据扩展**`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
- **部门化多模型协作**`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
- **LLM Profile/Route 管理**`app/utils/config.py` 引入可复用的 Profile终端定义与 Route推理策略组合Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性
- **LLM Provider 管理**`app/utils/config.py` 集中维护供应商的 URL、API Key、可用模型及默认参数Streamlit UI 可视化配置,全局与部门直接在 Provider 基础上设置模型、温度与 Prompt
## LLM + 多智能体最佳实践
@ -60,11 +60,10 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试
- 新增 Profile/Route 双层配置Profile 定义单个端点(含 Provider/模型/域名/API KeyRoute 组合 Profile 并指定推理策略single/majority/leader。全局路由可一键切换部门可复用命名路由或保留自定义设置。
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由保存即写入 `app/data/config.json`Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
- 支持本地 Ollama 与多家 OpenAI 兼容供应商DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
- 通过 Provider 管理供应商连接参数Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务DeepSeek、文心一言、OpenAI 等)。
- 全局与部门配置直接选择 Provider并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。
- Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。
- 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。
## 快速开始
@ -104,7 +103,7 @@ Streamlit `自检测试` 页签提供:
## 实施步骤
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置Streamlit 提供可视化维护表单。
- 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/PromptStreamlit 提供可视化维护表单。
2. **部门管控器**
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。

View File

@ -131,8 +131,6 @@ class DepartmentManager:
return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
if settings.llm_route and settings.llm_route in self.config.llm_routes:
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
return settings.llm

View File

@ -102,24 +102,58 @@ def _request_openai(
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
provider = (endpoint.provider or "ollama").lower()
base_url = endpoint.base_url or _default_base_url(provider)
model = endpoint.model or _default_model(provider)
temperature = max(0.0, min(endpoint.temperature, 2.0))
timeout = max(5.0, endpoint.timeout or 30.0)
cfg = get_config()
provider_key = (endpoint.provider or "ollama").lower()
provider_cfg = cfg.llm_providers.get(provider_key)
base_url = endpoint.base_url
api_key = endpoint.api_key
model = endpoint.model
temperature = endpoint.temperature
timeout = endpoint.timeout
prompt_template = endpoint.prompt_template
if provider_cfg:
if not provider_cfg.enabled:
raise LLMError(f"Provider {provider_key} 已被禁用")
base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key)
api_key = api_key or provider_cfg.api_key
model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key))
if temperature is None:
temperature = provider_cfg.default_temperature
if timeout is None:
timeout = provider_cfg.default_timeout
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
else:
base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key)
if temperature is None:
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai"
temperature = max(0.0, min(float(temperature), 2.0))
timeout = max(5.0, float(timeout))
if prompt_template:
try:
prompt = prompt_template.format(prompt=prompt)
except Exception: # noqa: BLE001
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s",
provider,
provider_key,
model,
base_url,
extra=LOG_EXTRA,
)
if provider in {"openai", "deepseek", "wenxin"}:
api_key = endpoint.api_key
if mode != "ollama":
if not api_key:
raise LLMError(f"缺少 {provider} API Key (model={model})")
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
return _request_openai(
model,
prompt,
@ -129,7 +163,7 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
timeout=timeout,
system=system,
)
if provider == "ollama":
if base_url:
return _request_ollama(
model,
prompt,
@ -283,13 +317,17 @@ def llm_config_snapshot() -> Dict[str, object]:
if record.get("api_key"):
record["api_key"] = "***"
ensemble.append(record)
route_name = cfg.llm_route
route_obj = cfg.llm_routes.get(route_name)
return {
"route": route_name,
"route_detail": route_obj.to_dict() if route_obj else None,
"strategy": settings.strategy,
"majority_threshold": settings.majority_threshold,
"primary": primary,
"ensemble": ensemble,
"providers": {
key: {
"base_url": provider.base_url,
"default_model": provider.default_model,
"enabled": provider.enabled,
}
for key, provider in cfg.llm_providers.items()
},
}

View File

@ -5,7 +5,7 @@ import sys
from dataclasses import asdict
from datetime import date, timedelta
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
@ -16,6 +16,8 @@ import json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
from requests.exceptions import RequestException
import streamlit as st
from app.backtest.engine import BtConfig, run_backtest
@ -29,8 +31,8 @@ from app.utils.config import (
DEFAULT_LLM_MODEL_OPTIONS,
DEFAULT_LLM_MODELS,
DepartmentSettings,
LLMProfile,
LLMRoute,
LLMEndpoint,
LLMProvider,
get_config,
save_config,
)
@ -42,6 +44,50 @@ LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"}
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
"""Attempt to query provider API and return available model ids."""
base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip()
if not base_url:
return [], "请先填写 Base URL"
timeout = float(provider.default_timeout or 30.0)
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
try:
if mode == "ollama":
url = base_url.rstrip('/') + "/api/tags"
response = requests.get(url, timeout=timeout)
response.raise_for_status()
data = response.json()
models = []
for item in data.get("models", []) or data.get("data", []):
name = item.get("name") or item.get("model") or item.get("tag")
if name:
models.append(str(name).strip())
return sorted(set(models)), None
api_key = (api_override or provider.api_key or "").strip()
if not api_key:
return [], "缺少 API Key"
url = base_url.rstrip('/') + "/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
payload = response.json()
models = [
str(item.get("id")).strip()
for item in payload.get("data", [])
if item.get("id")
]
return sorted(set(models)), None
except RequestException as exc: # noqa: BLE001
return [], f"HTTP 错误:{exc}"
except Exception as exc: # noqa: BLE001
return [], f"解析失败:{exc}"
def _load_stock_options(limit: int = 500) -> list[str]:
try:
with db_session(read_only=True) as conn:
@ -350,255 +396,267 @@ def render_settings() -> None:
st.divider()
st.subheader("LLM 设置")
profiles = cfg.llm_profiles or {}
routes = cfg.llm_routes or {}
profile_keys = sorted(profiles.keys())
route_keys = sorted(routes.keys())
used_routes = {
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
}
st.caption("Profile 定义单个模型终端Route 负责组合 Profile 与推理策略。")
providers = cfg.llm_providers
provider_keys = sorted(providers.keys())
st.caption("先在 Provider 中维护基础连接URL、Key、模型再为全局与各部门设置个性化参数。")
route_select_col, route_manage_col = st.columns([3, 1])
if route_keys:
# Provider management -------------------------------------------------
provider_select_col, provider_manage_col = st.columns([3, 1])
if provider_keys:
try:
active_index = route_keys.index(cfg.llm_route)
default_provider = cfg.llm.primary.provider or provider_keys[0]
provider_index = provider_keys.index(default_provider)
except ValueError:
active_index = 0
selected_route = route_select_col.selectbox(
"全局路由",
route_keys,
index=active_index,
key="llm_route_select",
provider_index = 0
selected_provider = provider_select_col.selectbox(
"选择 Provider",
provider_keys,
index=provider_index,
key="llm_provider_select",
)
else:
selected_route = None
route_select_col.info("尚未配置路由,请先创建。")
selected_provider = None
provider_select_col.info("尚未配置 Provider,请先创建。")
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
if route_manage_col.button("添加路由"):
key = (new_route_name or "").strip()
new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name")
if provider_manage_col.button("创建 Provider", key="create_provider_btn"):
key = (new_provider_name or "").strip().lower()
if not key:
st.warning("请输入有效的路由名称。")
elif key in routes:
st.warning(f"路由 {key} 已存在。")
st.warning("请输入有效的 Provider 名称。")
elif key in providers:
st.warning(f"Provider {key} 已存在。")
else:
routes[key] = LLMRoute(name=key)
if not selected_route:
selected_route = key
cfg.llm_route = key
providers[key] = LLMProvider(key=key)
cfg.llm_providers = providers
save_config()
st.success(f"添加路由 {key},请继续配置")
st.success(f"已创建 Provider {key}")
st.experimental_rerun()
if selected_route:
route_obj = routes.get(selected_route)
if route_obj is None:
route_obj = LLMRoute(name=selected_route)
routes[selected_route] = route_obj
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
if selected_provider:
provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
title_key = f"provider_title_{selected_provider}"
base_key = f"provider_base_{selected_provider}"
api_key_key = f"provider_api_{selected_provider}"
models_key = f"provider_models_{selected_provider}"
default_model_key = f"provider_default_model_{selected_provider}"
mode_key = f"provider_mode_{selected_provider}"
temp_key = f"provider_temp_{selected_provider}"
timeout_key = f"provider_timeout_{selected_provider}"
prompt_key = f"provider_prompt_{selected_provider}"
enabled_key = f"provider_enabled_{selected_provider}"
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址例如https://api.openai.com")
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100)
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key)
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120)
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
fetch_key = f"fetch_models_{selected_provider}"
if st.button("获取模型列表", key=fetch_key):
with st.spinner("正在获取模型列表..."):
models, error = _discover_provider_models(provider_cfg, base_val, api_val)
if error:
st.error(error)
else:
provider_cfg.models = models
if models and (not provider_cfg.default_model or provider_cfg.default_model not in models):
provider_cfg.default_model = models[0]
providers[selected_provider] = provider_cfg
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success(f"共获取 {len(models)} 个模型。")
st.session_state[models_key] = "\n".join(models)
st.session_state[default_model_key] = provider_cfg.default_model or ""
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
provider_cfg.title = title_val.strip()
provider_cfg.base_url = base_val.strip()
provider_cfg.api_key = api_val.strip() or None
provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()]
provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model)
provider_cfg.default_temperature = float(temp_val)
provider_cfg.default_timeout = float(timeout_val)
provider_cfg.prompt_template = prompt_template_val.strip()
provider_cfg.enabled = enabled_val
provider_cfg.mode = mode_val
providers[selected_provider] = provider_cfg
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success("Provider 已保存。")
st.session_state[title_key] = provider_cfg.title or ""
st.session_state[default_model_key] = provider_cfg.default_model or ""
provider_in_use = (cfg.llm.primary.provider == selected_provider) or any(
ep.provider == selected_provider for ep in cfg.llm.ensemble
)
if not provider_in_use:
for dept in cfg.departments.values():
if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble):
provider_in_use = True
break
if st.button(
"删除 Provider",
key=f"delete_provider_{selected_provider}",
disabled=provider_in_use or len(providers) <= 1,
):
providers.pop(selected_provider, None)
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success("Provider 已删除。")
st.experimental_rerun()
st.markdown("##### 全局推理配置")
if not provider_keys:
st.warning("请先配置至少一个 Provider。")
else:
global_cfg = cfg.llm
primary = global_cfg.primary
try:
strategy_index = strategy_choices.index(route_obj.strategy)
provider_index = provider_keys.index(primary.provider or provider_keys[0])
except ValueError:
strategy_index = 0
route_title = st.text_input(
"路由说明",
value=route_obj.title or "",
key=f"route_title_{selected_route}",
provider_index = 0
selected_global_provider = st.selectbox(
"主模型 Provider",
provider_keys,
index=provider_index,
key="global_provider_select",
)
route_strategy = st.selectbox(
"推理策略",
strategy_choices,
index=strategy_index,
key=f"route_strategy_{selected_route}",
provider_cfg = providers.get(selected_global_provider)
available_models = provider_cfg.models if provider_cfg else []
default_model = primary.model or (provider_cfg.default_model if provider_cfg else None)
if available_models:
options = available_models + ["自定义"]
try:
model_index = available_models.index(default_model)
model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice")
except ValueError:
model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice")
if model_choice == "自定义":
model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip()
else:
model_val = model_choice
else:
model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip()
temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2)
temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp")
timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0)
timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout")
prompt_template_val = st.text_area(
"主模型 Prompt 模板(可选)",
value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "",
height=120,
key="global_prompt_template",
)
route_majority = st.number_input(
strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy")
show_ensemble = strategy_val != "single"
majority_threshold_val = st.number_input(
"多数投票门槛",
min_value=1,
max_value=10,
value=int(route_obj.majority_threshold or 1),
value=int(global_cfg.majority_threshold),
step=1,
key=f"route_majority_{selected_route}",
key="global_majority",
disabled=not show_ensemble,
)
if not profile_keys:
st.warning("暂无可用 Profile请先在下方创建。")
if not show_ensemble:
majority_threshold_val = 1
ensemble_rows: List[Dict[str, str]] = []
if show_ensemble:
ensemble_rows = [
{
"provider": ep.provider,
"model": ep.model or "",
"temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}",
"timeout": "" if ep.timeout is None else str(int(ep.timeout)),
"prompt_template": ep.prompt_template or "",
}
for ep in global_cfg.ensemble
] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}]
ensemble_editor = st.data_editor(
ensemble_rows,
num_rows="dynamic",
key="global_ensemble_editor",
use_container_width=True,
hide_index=True,
column_config={
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
"model": st.column_config.TextColumn("模型"),
"temperature": st.column_config.TextColumn("温度"),
"timeout": st.column_config.TextColumn("超时(秒)"),
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
},
)
if hasattr(ensemble_editor, "to_dict"):
ensemble_rows = ensemble_editor.to_dict("records")
else:
try:
primary_index = profile_keys.index(route_obj.primary)
except ValueError:
primary_index = 0
primary_key = st.selectbox(
"主用 Profile",
profile_keys,
index=primary_index,
key=f"route_primary_{selected_route}",
)
default_ensemble = [
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
]
ensemble_keys = st.multiselect(
"协作 Profile (可多选)",
profile_keys,
default=default_ensemble,
key=f"route_ensemble_{selected_route}",
)
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
route_obj.title = route_title.strip()
route_obj.strategy = route_strategy
route_obj.majority_threshold = int(route_majority)
route_obj.primary = primary_key
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
cfg.llm_route = selected_route
cfg.sync_runtime_llm()
save_config()
LOGGER.info(
"路由 %s 配置更新:%s",
selected_route,
route_obj.to_dict(),
extra=LOG_EXTRA,
)
st.success("路由配置已保存。")
st.json({
"route": selected_route,
"route_detail": route_obj.to_dict(),
"resolved": llm_config_snapshot(),
})
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
if st.button(
"删除当前路由",
key=f"delete_route_{selected_route}",
disabled=route_in_use or len(routes) <= 1,
):
routes.pop(selected_route, None)
if cfg.llm_route == selected_route:
cfg.llm_route = next((key for key in routes.keys()), "global")
cfg.sync_runtime_llm()
save_config()
st.success("路由已删除。")
st.experimental_rerun()
st.divider()
st.subheader("LLM Profile 管理")
profile_select_col, profile_manage_col = st.columns([3, 1])
if profile_keys:
selected_profile = profile_select_col.selectbox(
"选择 Profile",
profile_keys,
index=0,
key="profile_select",
)
ensemble_rows = ensemble_editor
else:
selected_profile = None
profile_select_col.info("尚未配置 Profile请先创建。")
st.info("当前策略为单模型,未启用协作模型。")
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
if profile_manage_col.button("创建 Profile"):
key = (new_profile_name or "").strip()
if not key:
st.warning("请输入有效的 Profile 名称。")
elif key in profiles:
st.warning(f"Profile {key} 已存在。")
else:
profiles[key] = LLMProfile(key=key)
save_config()
st.success(f"已创建 Profile {key}")
st.experimental_rerun()
if st.button("保存全局配置", key="save_global_llm"):
primary.provider = selected_global_provider
primary.model = model_val or None
primary.temperature = float(temp_val)
primary.timeout = float(timeout_val)
primary.prompt_template = prompt_template_val.strip() or None
primary.base_url = None
primary.api_key = None
if selected_profile:
profile = profiles[selected_profile]
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
try:
provider_index = provider_choices.index(profile.provider)
except ValueError:
provider_index = 0
with st.form(f"profile_form_{selected_profile}"):
provider_val = st.selectbox(
"Provider",
provider_choices,
index=provider_index,
new_ensemble: List[LLMEndpoint] = []
if show_ensemble:
for row in ensemble_rows:
provider_val = (row.get("provider") or "").strip().lower()
if not provider_val:
continue
model_raw = (row.get("model") or "").strip() or None
temp_raw = (row.get("temperature") or "").strip()
timeout_raw = (row.get("timeout") or "").strip()
prompt_raw = (row.get("prompt_template") or "").strip()
new_ensemble.append(
LLMEndpoint(
provider=provider_val,
model=model_raw,
temperature=float(temp_raw) if temp_raw else None,
timeout=float(timeout_raw) if timeout_raw else None,
prompt_template=prompt_raw or None,
)
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
model_val = st.text_input(
"模型",
value=profile.model or model_default,
)
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
base_val = st.text_input(
"Base URL",
value=profile.base_url or base_default,
)
api_val = st.text_input(
"API Key",
value=profile.api_key or "",
type="password",
)
temp_val = st.slider(
"温度",
min_value=0.0,
max_value=2.0,
value=float(profile.temperature),
step=0.05,
)
timeout_val = st.number_input(
"超时(秒)",
min_value=5,
max_value=180,
value=int(profile.timeout or 30),
step=5,
)
title_val = st.text_input("备注", value=profile.title or "")
enabled_val = st.checkbox("启用", value=profile.enabled)
submitted = st.form_submit_button("保存 Profile")
if submitted:
profile.provider = provider_val
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
profile.api_key = api_val.strip() or None
profile.temperature = temp_val
profile.timeout = timeout_val
profile.title = title_val.strip()
profile.enabled = enabled_val
profiles[selected_profile] = profile
cfg.llm.ensemble = new_ensemble
cfg.llm.strategy = strategy_val
cfg.llm.majority_threshold = int(majority_threshold_val)
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已保存。")
profile_in_use = any(
selected_profile == route.primary or selected_profile in route.ensemble
for route in routes.values()
)
if st.button(
"删除该 Profile",
key=f"delete_profile_{selected_profile}",
disabled=profile_in_use or len(profiles) <= 1,
):
profiles.pop(selected_profile, None)
fallback_key = next((key for key in profiles.keys()), None)
for route in routes.values():
if route.primary == selected_profile:
route.primary = fallback_key or route.primary
route.ensemble = [key for key in route.ensemble if key != selected_profile]
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已删除。")
st.experimental_rerun()
st.divider()
st.subheader("部门配置")
st.success("全局 LLM 配置已保存。")
st.json(llm_config_snapshot())
# Department configuration -------------------------------------------
st.markdown("##### 部门配置")
dept_settings = cfg.departments or {}
route_options_display = [""] + route_keys
dept_rows = [
{
"code": code,
"title": dept.title,
"description": dept.description,
"weight": float(dept.weight),
"llm_route": dept.llm_route or "",
"strategy": dept.llm.strategy,
"primary_provider": (dept.llm.primary.provider or ""),
"primary_model": dept.llm.primary.model or "",
"ensemble_size": len(dept.llm.ensemble),
"majority_threshold": dept.llm.majority_threshold,
"provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""),
"model": dept.llm.primary.model or "",
"temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}",
"timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)),
"prompt_template": dept.llm.primary.prompt_template or "",
}
for code, dept in sorted(dept_settings.items())
]
@ -618,26 +676,13 @@ def render_settings() -> None:
"title": st.column_config.TextColumn("名称"),
"description": st.column_config.TextColumn("说明"),
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
"llm_route": st.column_config.SelectboxColumn(
"路由",
options=route_options_display,
help="选择预定义路由;留空表示使用自定义配置",
),
"strategy": st.column_config.SelectboxColumn(
"自定义策略",
options=sorted(ALLOWED_LLM_STRATEGIES),
help="仅当未选择路由时生效",
),
"primary_provider": st.column_config.SelectboxColumn(
"自定义 Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
),
"primary_model": st.column_config.TextColumn("自定义模型"),
"ensemble_size": st.column_config.NumberColumn(
"协作模型数量",
disabled=True,
help="路由模式下自动维护",
),
"strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)),
"majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1),
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]),
"model": st.column_config.TextColumn("模型"),
"temperature": st.column_config.TextColumn("温度"),
"timeout": st.column_config.TextColumn("超时(秒)"),
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
},
)
@ -662,25 +707,34 @@ def render_settings() -> None:
except (TypeError, ValueError):
pass
route_name = (row.get("llm_route") or "").strip() or None
existing.llm_route = route_name
if route_name and route_name in routes:
existing.llm = routes[route_name].resolve(profiles)
else:
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
if strategy_val in ALLOWED_LLM_STRATEGIES:
existing.llm.strategy = strategy_val
provider_before = existing.llm.primary.provider or ""
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
existing.llm.primary.provider = provider_val
model_val = (row.get("primary_model") or "").strip()
existing.llm.primary.model = (
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
)
if provider_before != provider_val:
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
existing.llm.primary.__post_init__()
majority_raw = row.get("majority_threshold")
try:
majority_val = int(majority_raw)
if majority_val > 0:
existing.llm.majority_threshold = majority_val
except (TypeError, ValueError):
pass
provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
model_val = (row.get("model") or "").strip() or None
temp_raw = (row.get("temperature") or "").strip()
timeout_raw = (row.get("timeout") or "").strip()
prompt_raw = (row.get("prompt_template") or "").strip()
endpoint = existing.llm.primary or LLMEndpoint()
endpoint.provider = provider_val
endpoint.model = model_val
endpoint.temperature = float(temp_raw) if temp_raw else None
endpoint.timeout = float(timeout_raw) if timeout_raw else None
endpoint.prompt_template = prompt_raw or None
endpoint.base_url = None
endpoint.api_key = None
existing.llm.primary = endpoint
existing.llm.ensemble = []
updated_departments[code] = existing
if updated_departments:
@ -700,8 +754,7 @@ def render_settings() -> None:
st.success("已恢复默认部门配置。")
st.experimental_rerun()
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
st.caption("部门协作模型ensemble请在 config.json 中手动编辑UI 将在后续版本补充。")
st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")
def render_tests() -> None:

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
import json
import os
from pathlib import Path
from typing import Dict, List, Mapping, Optional
from typing import Dict, List, Optional
def _default_root() -> Path:
@ -99,27 +99,55 @@ ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"}
LLM_STRATEGY_ALIASES = {"leader-follower": "leader"}
@dataclass
class LLMProvider:
"""Provider level configuration shared across profiles and routes."""
key: str
title: str = ""
base_url: str = ""
api_key: Optional[str] = None
models: List[str] = field(default_factory=list)
default_model: Optional[str] = None
default_temperature: float = 0.2
default_timeout: float = 30.0
prompt_template: str = ""
enabled: bool = True
mode: str = "openai" # openai 或 ollama
def to_dict(self) -> Dict[str, object]:
return {
"title": self.title,
"base_url": self.base_url,
"api_key": self.api_key,
"models": list(self.models),
"default_model": self.default_model,
"default_temperature": self.default_temperature,
"default_timeout": self.default_timeout,
"prompt_template": self.prompt_template,
"enabled": self.enabled,
"mode": self.mode,
}
@dataclass
class LLMEndpoint:
"""Single LLM endpoint configuration."""
"""Resolved endpoint payload used for actual LLM calls."""
provider: str = "ollama"
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.2
timeout: float = 30.0
temperature: Optional[float] = None
timeout: Optional[float] = None
prompt_template: Optional[str] = None
def __post_init__(self) -> None:
self.provider = (self.provider or "ollama").lower()
if not self.model:
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
if not self.base_url:
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider)
if self.temperature == 0.2 or self.temperature is None:
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
if self.timeout == 30.0 or self.timeout is None:
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
if self.temperature is not None:
self.temperature = float(self.temperature)
if self.timeout is not None:
self.timeout = float(self.timeout)
@dataclass
@ -132,133 +160,22 @@ class LLMConfig:
majority_threshold: int = 3
@dataclass
class LLMProfile:
"""Named LLM endpoint profile reusable across routes/departments."""
key: str
provider: str = "ollama"
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.2
timeout: float = 30.0
title: str = ""
enabled: bool = True
def to_endpoint(self) -> LLMEndpoint:
return LLMEndpoint(
provider=self.provider,
model=self.model,
base_url=self.base_url,
api_key=self.api_key,
temperature=self.temperature,
timeout=self.timeout,
)
def to_dict(self) -> Dict[str, object]:
return {
"provider": self.provider,
"model": self.model,
"base_url": self.base_url,
"api_key": self.api_key,
"temperature": self.temperature,
"timeout": self.timeout,
"title": self.title,
"enabled": self.enabled,
}
@classmethod
def from_endpoint(
cls,
key: str,
endpoint: LLMEndpoint,
*,
title: str = "",
enabled: bool = True,
) -> "LLMProfile":
return cls(
key=key,
provider=endpoint.provider,
model=endpoint.model,
base_url=endpoint.base_url,
api_key=endpoint.api_key,
temperature=endpoint.temperature,
timeout=endpoint.timeout,
title=title,
enabled=enabled,
)
@dataclass
class LLMRoute:
"""Declarative routing for selecting profiles and strategy."""
name: str
title: str = ""
strategy: str = "single"
majority_threshold: int = 3
primary: str = "ollama"
ensemble: List[str] = field(default_factory=list)
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
def _endpoint_from_key(key: str) -> LLMEndpoint:
profile = profiles.get(key)
if profile and profile.enabled:
return profile.to_endpoint()
fallback = profiles.get("ollama")
if not fallback or not fallback.enabled:
fallback = next(
(item for item in profiles.values() if item.enabled),
None,
)
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
endpoint.provider = key or endpoint.provider
return endpoint
primary_endpoint = _endpoint_from_key(self.primary)
ensemble_endpoints = [
_endpoint_from_key(key)
for key in self.ensemble
if key in profiles and profiles[key].enabled
]
config = LLMConfig(
primary=primary_endpoint,
ensemble=ensemble_endpoints,
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, self.majority_threshold or 1),
)
return config
def to_dict(self) -> Dict[str, object]:
return {
"title": self.title,
"strategy": self.strategy,
"majority_threshold": self.majority_threshold,
"primary": self.primary,
"ensemble": list(self.ensemble),
}
def _default_llm_profiles() -> Dict[str, LLMProfile]:
return {
provider: LLMProfile(
def _default_llm_providers() -> Dict[str, LLMProvider]:
providers: Dict[str, LLMProvider] = {}
for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items():
models = list(meta.get("models", []))
mode = "ollama" if provider == "ollama" else "openai"
providers[provider] = LLMProvider(
key=provider,
provider=provider,
model=DEFAULT_LLM_MODELS.get(provider),
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
title=f"默认 {provider}",
base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""),
models=models,
default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider),
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
mode=mode,
)
for provider in DEFAULT_LLM_MODEL_OPTIONS
}
def _default_llm_routes() -> Dict[str, LLMRoute]:
return {
"global": LLMRoute(name="global", title="全局默认路由"),
}
return providers
@dataclass
@ -270,7 +187,6 @@ class DepartmentSettings:
description: str = ""
weight: float = 1.0
llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: Optional[str] = None
def _default_departments() -> Dict[str, DepartmentSettings]:
@ -282,10 +198,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
("macro", "宏观研究部门"),
("risk", "风险控制部门"),
]
return {
code: DepartmentSettings(code=code, title=title, llm_route="global")
for code, title in presets
}
return {code: DepartmentSettings(code=code, title=title) for code, title in presets}
@dataclass
@ -298,17 +211,11 @@ class AppConfig:
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: str = "global"
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
route_key = route or self.llm_route
route_cfg = self.llm_routes.get(route_key)
if route_cfg:
return route_cfg.resolve(self.llm_profiles)
return self.llm
def sync_runtime_llm(self) -> None:
@ -326,13 +233,22 @@ def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
"api_key": endpoint.api_key,
"temperature": endpoint.temperature,
"timeout": endpoint.timeout,
"prompt_template": endpoint.prompt_template,
}
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
payload = {
key: data.get(key)
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")
for key in (
"provider",
"model",
"base_url",
"api_key",
"temperature",
"timeout",
"prompt_template",
)
if data.get(key) is not None
}
return LLMEndpoint(**payload)
@ -348,7 +264,9 @@ def _load_from_file(cfg: AppConfig) -> None:
except (json.JSONDecodeError, OSError):
return
if isinstance(payload, dict):
if not isinstance(payload, dict):
return
if "tushare_token" in payload:
cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
@ -356,111 +274,110 @@ def _load_from_file(cfg: AppConfig) -> None:
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
routes_defined = False
inline_primary_loaded = False
legacy_profiles: Dict[str, Dict[str, object]] = {}
legacy_routes: Dict[str, Dict[str, object]] = {}
providers_payload = payload.get("llm_providers")
if isinstance(providers_payload, dict):
providers: Dict[str, LLMProvider] = {}
for key, data in providers_payload.items():
if not isinstance(data, dict):
continue
models_raw = data.get("models")
if isinstance(models_raw, str):
models = [item.strip() for item in models_raw.split(',') if item.strip()]
elif isinstance(models_raw, list):
models = [str(item).strip() for item in models_raw if str(item).strip()]
else:
models = []
provider = LLMProvider(
key=str(key).lower(),
title=str(data.get("title") or ""),
base_url=str(data.get("base_url") or ""),
api_key=data.get("api_key"),
models=models,
default_model=data.get("default_model") or (models[0] if models else None),
default_temperature=float(data.get("default_temperature", 0.2)),
default_timeout=float(data.get("default_timeout", 30.0)),
prompt_template=str(data.get("prompt_template") or ""),
enabled=bool(data.get("enabled", True)),
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")),
)
providers[provider.key] = provider
if providers:
cfg.llm_providers = providers
profiles_payload = payload.get("llm_profiles")
if isinstance(profiles_payload, dict):
profiles: Dict[str, LLMProfile] = {}
for key, data in profiles_payload.items():
if not isinstance(data, dict):
continue
provider = str(data.get("provider") or "ollama").lower()
profile = LLMProfile(
key=key,
provider=provider,
model=data.get("model"),
base_url=data.get("base_url"),
api_key=data.get("api_key"),
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
title=str(data.get("title") or ""),
enabled=bool(data.get("enabled", True)),
)
profiles[key] = profile
if profiles:
cfg.llm_profiles = profiles
if isinstance(data, dict):
legacy_profiles[str(key)] = data
routes_payload = payload.get("llm_routes")
if isinstance(routes_payload, dict):
routes: Dict[str, LLMRoute] = {}
for name, data in routes_payload.items():
if not isinstance(data, dict):
continue
strategy_raw = str(data.get("strategy") or "single").lower()
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
route = LLMRoute(
name=name,
title=str(data.get("title") or ""),
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
primary=str(data.get("primary") or "global"),
ensemble=[
str(item)
for item in data.get("ensemble", [])
if isinstance(item, str)
],
)
routes[name] = route
if routes:
cfg.llm_routes = routes
routes_defined = True
if isinstance(data, dict):
legacy_routes[str(name)] = data
route_key = payload.get("llm_route")
if isinstance(route_key, str) and route_key:
cfg.llm_route = route_key
def _endpoint_from_payload(item: object) -> LLMEndpoint:
if isinstance(item, dict):
return _dict_to_endpoint(item)
if isinstance(item, str):
profile_data = legacy_profiles.get(item)
if isinstance(profile_data, dict):
return _dict_to_endpoint(profile_data)
return LLMEndpoint(provider=item)
return LLMEndpoint()
def _resolve_route(route_name: str) -> Optional[LLMConfig]:
route_data = legacy_routes.get(route_name)
if not route_data:
return None
strategy_raw = str(route_data.get("strategy") or "single").lower()
strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
primary_ref = route_data.get("primary")
primary_ep = _endpoint_from_payload(primary_ref)
ensemble_refs = route_data.get("ensemble", [])
ensemble_eps = [
_endpoint_from_payload(ref)
for ref in ensemble_refs
if isinstance(ref, (dict, str))
]
cfg_obj = LLMConfig(
primary=primary_ep,
ensemble=ensemble_eps,
strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)),
)
return cfg_obj
llm_payload = payload.get("llm")
if isinstance(llm_payload, dict):
route_value = llm_payload.get("route")
resolved_cfg = None
if isinstance(route_value, str) and route_value:
cfg.llm_route = route_value
resolved_cfg = _resolve_route(route_value)
if resolved_cfg is None:
resolved_cfg = LLMConfig()
primary_data = llm_payload.get("primary")
if isinstance(primary_data, dict):
cfg.llm.primary = _dict_to_endpoint(primary_data)
inline_primary_loaded = True
resolved_cfg.primary = _dict_to_endpoint(primary_data)
ensemble_data = llm_payload.get("ensemble")
if isinstance(ensemble_data, list):
cfg.llm.ensemble = [
resolved_cfg.ensemble = [
_dict_to_endpoint(item)
for item in ensemble_data
if isinstance(item, dict)
]
strategy_raw = llm_payload.get("strategy")
if isinstance(strategy_raw, str):
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
if normalized in ALLOWED_LLM_STRATEGIES:
cfg.llm.strategy = normalized
resolved_cfg.strategy = normalized
majority = llm_payload.get("majority_threshold")
if isinstance(majority, int) and majority > 0:
cfg.llm.majority_threshold = majority
if inline_primary_loaded and not routes_defined:
primary_key = "inline_global_primary"
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
primary_key,
cfg.llm.primary,
title="全局主模型",
)
ensemble_keys: List[str] = []
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
inline_key = f"inline_global_ensemble_{idx}"
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
inline_key,
endpoint,
title=f"全局协作#{idx}",
)
ensemble_keys.append(inline_key)
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
auto_route.strategy = cfg.llm.strategy
auto_route.majority_threshold = cfg.llm.majority_threshold
auto_route.primary = primary_key
auto_route.ensemble = ensemble_keys
cfg.llm_routes["global"] = auto_route
cfg.llm_route = cfg.llm_route or "global"
resolved_cfg.majority_threshold = majority
cfg.llm = resolved_cfg
departments_payload = payload.get("departments")
if isinstance(departments_payload, dict):
@ -471,14 +388,22 @@ def _load_from_file(cfg: AppConfig) -> None:
title = data.get("title") or code
description = data.get("description") or ""
weight = float(data.get("weight", 1.0))
llm_data = data.get("llm")
llm_cfg = LLMConfig()
route_name = data.get("llm_route")
resolved_cfg = None
if isinstance(route_name, str) and route_name:
resolved_cfg = _resolve_route(route_name)
if resolved_cfg is None:
llm_data = data.get("llm")
if isinstance(llm_data, dict):
if isinstance(llm_data.get("primary"), dict):
llm_cfg.primary = _dict_to_endpoint(llm_data["primary"])
primary_data = llm_data.get("primary")
if isinstance(primary_data, dict):
llm_cfg.primary = _dict_to_endpoint(primary_data)
ensemble_data = llm_data.get("ensemble")
if isinstance(ensemble_data, list):
llm_cfg.ensemble = [
_dict_to_endpoint(item)
for item in llm_data.get("ensemble", [])
for item in ensemble_data
if isinstance(item, dict)
]
strategy_raw = llm_data.get("strategy")
@ -489,18 +414,13 @@ def _load_from_file(cfg: AppConfig) -> None:
majority_raw = llm_data.get("majority_threshold")
if isinstance(majority_raw, int) and majority_raw > 0:
llm_cfg.majority_threshold = majority_raw
route = data.get("llm_route")
route_name = str(route).strip() if isinstance(route, str) and route else None
resolved = llm_cfg
if route_name and route_name in cfg.llm_routes:
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
resolved_cfg = llm_cfg
new_departments[code] = DepartmentSettings(
code=code,
title=title,
description=description,
weight=weight,
llm=resolved,
llm_route=route_name,
llm=resolved_cfg,
)
if new_departments:
cfg.departments = new_departments
@ -516,28 +436,21 @@ def save_config(cfg: AppConfig | None = None) -> None:
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method,
"llm_route": cfg.llm_route,
"llm": {
"route": cfg.llm_route,
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": cfg.llm.majority_threshold,
"primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
},
"llm_profiles": {
key: profile.to_dict()
for key, profile in cfg.llm_profiles.items()
},
"llm_routes": {
name: route.to_dict()
for name, route in cfg.llm_routes.items()
"llm_providers": {
key: provider.to_dict()
for key, provider in cfg.llm_providers.items()
},
"departments": {
code: {
"title": dept.title,
"description": dept.description,
"weight": dept.weight,
"llm_route": dept.llm_route,
"llm": {
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": dept.llm.majority_threshold,
@ -567,11 +480,9 @@ def _load_env_defaults(cfg: AppConfig) -> None:
if api_key:
sanitized = api_key.strip()
cfg.llm.primary.api_key = sanitized
route = cfg.llm_routes.get(cfg.llm_route)
if route:
profile = cfg.llm_profiles.get(route.primary)
if profile:
profile.api_key = sanitized
provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider)
if provider_cfg:
provider_cfg.api_key = sanitized
cfg.sync_runtime_llm()