update
This commit is contained in:
parent
6aece20816
commit
ab6180646a
13
README.md
13
README.md
@ -24,7 +24,7 @@
|
|||||||
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
||||||
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
||||||
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
||||||
- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。
|
- **LLM Provider 管理**:`app/utils/config.py` 集中维护供应商的 URL、API Key、可用模型及默认参数,Streamlit UI 可视化配置,全局与部门直接在 Provider 基础上设置模型、温度与 Prompt。
|
||||||
|
|
||||||
## LLM + 多智能体最佳实践
|
## LLM + 多智能体最佳实践
|
||||||
|
|
||||||
@ -60,11 +60,10 @@ export TUSHARE_TOKEN="<your-token>"
|
|||||||
|
|
||||||
### LLM 配置与测试
|
### LLM 配置与测试
|
||||||
|
|
||||||
- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。
|
- 通过 Provider 管理供应商连接参数(Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务(DeepSeek、文心一言、OpenAI 等)。
|
||||||
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
|
- 全局与部门配置直接选择 Provider,并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。
|
||||||
- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
|
- Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。
|
||||||
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
|
- 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。
|
||||||
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
|
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ Streamlit `自检测试` 页签提供:
|
|||||||
## 实施步骤
|
## 实施步骤
|
||||||
|
|
||||||
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
||||||
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。
|
- 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/Prompt;Streamlit 提供可视化维护表单。
|
||||||
|
|
||||||
2. **部门管控器** ✅
|
2. **部门管控器** ✅
|
||||||
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
||||||
|
|||||||
@ -131,8 +131,6 @@ class DepartmentManager:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
||||||
if settings.llm_route and settings.llm_route in self.config.llm_routes:
|
|
||||||
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
|
|
||||||
return settings.llm
|
return settings.llm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -102,24 +102,58 @@ def _request_openai(
|
|||||||
|
|
||||||
|
|
||||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||||
provider = (endpoint.provider or "ollama").lower()
|
cfg = get_config()
|
||||||
base_url = endpoint.base_url or _default_base_url(provider)
|
provider_key = (endpoint.provider or "ollama").lower()
|
||||||
model = endpoint.model or _default_model(provider)
|
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||||
temperature = max(0.0, min(endpoint.temperature, 2.0))
|
|
||||||
timeout = max(5.0, endpoint.timeout or 30.0)
|
base_url = endpoint.base_url
|
||||||
|
api_key = endpoint.api_key
|
||||||
|
model = endpoint.model
|
||||||
|
temperature = endpoint.temperature
|
||||||
|
timeout = endpoint.timeout
|
||||||
|
prompt_template = endpoint.prompt_template
|
||||||
|
|
||||||
|
if provider_cfg:
|
||||||
|
if not provider_cfg.enabled:
|
||||||
|
raise LLMError(f"Provider {provider_key} 已被禁用")
|
||||||
|
base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key)
|
||||||
|
api_key = api_key or provider_cfg.api_key
|
||||||
|
model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key))
|
||||||
|
if temperature is None:
|
||||||
|
temperature = provider_cfg.default_temperature
|
||||||
|
if timeout is None:
|
||||||
|
timeout = provider_cfg.default_timeout
|
||||||
|
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
|
||||||
|
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
|
||||||
|
else:
|
||||||
|
base_url = base_url or _default_base_url(provider_key)
|
||||||
|
model = model or _default_model(provider_key)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
|
||||||
|
if timeout is None:
|
||||||
|
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
||||||
|
mode = "ollama" if provider_key == "ollama" else "openai"
|
||||||
|
|
||||||
|
temperature = max(0.0, min(float(temperature), 2.0))
|
||||||
|
timeout = max(5.0, float(timeout))
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
try:
|
||||||
|
prompt = prompt_template.format(prompt=prompt)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||||||
provider,
|
provider_key,
|
||||||
model,
|
model,
|
||||||
base_url,
|
base_url,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider in {"openai", "deepseek", "wenxin"}:
|
if mode != "ollama":
|
||||||
api_key = endpoint.api_key
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise LLMError(f"缺少 {provider} API Key (model={model})")
|
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||||
return _request_openai(
|
return _request_openai(
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
@ -129,7 +163,7 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
system=system,
|
system=system,
|
||||||
)
|
)
|
||||||
if provider == "ollama":
|
if base_url:
|
||||||
return _request_ollama(
|
return _request_ollama(
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
@ -283,13 +317,17 @@ def llm_config_snapshot() -> Dict[str, object]:
|
|||||||
if record.get("api_key"):
|
if record.get("api_key"):
|
||||||
record["api_key"] = "***"
|
record["api_key"] = "***"
|
||||||
ensemble.append(record)
|
ensemble.append(record)
|
||||||
route_name = cfg.llm_route
|
|
||||||
route_obj = cfg.llm_routes.get(route_name)
|
|
||||||
return {
|
return {
|
||||||
"route": route_name,
|
|
||||||
"route_detail": route_obj.to_dict() if route_obj else None,
|
|
||||||
"strategy": settings.strategy,
|
"strategy": settings.strategy,
|
||||||
"majority_threshold": settings.majority_threshold,
|
"majority_threshold": settings.majority_threshold,
|
||||||
"primary": primary,
|
"primary": primary,
|
||||||
"ensemble": ensemble,
|
"ensemble": ensemble,
|
||||||
|
"providers": {
|
||||||
|
key: {
|
||||||
|
"base_url": provider.base_url,
|
||||||
|
"default_model": provider.default_model,
|
||||||
|
"enabled": provider.enabled,
|
||||||
|
}
|
||||||
|
for key, provider in cfg.llm_providers.items()
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import sys
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[2]
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
if str(ROOT) not in sys.path:
|
if str(ROOT) not in sys.path:
|
||||||
@ -16,6 +16,8 @@ import json
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
|
import requests
|
||||||
|
from requests.exceptions import RequestException
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from app.backtest.engine import BtConfig, run_backtest
|
from app.backtest.engine import BtConfig, run_backtest
|
||||||
@ -29,8 +31,8 @@ from app.utils.config import (
|
|||||||
DEFAULT_LLM_MODEL_OPTIONS,
|
DEFAULT_LLM_MODEL_OPTIONS,
|
||||||
DEFAULT_LLM_MODELS,
|
DEFAULT_LLM_MODELS,
|
||||||
DepartmentSettings,
|
DepartmentSettings,
|
||||||
LLMProfile,
|
LLMEndpoint,
|
||||||
LLMRoute,
|
LLMProvider,
|
||||||
get_config,
|
get_config,
|
||||||
save_config,
|
save_config,
|
||||||
)
|
)
|
||||||
@ -42,6 +44,50 @@ LOGGER = get_logger(__name__)
|
|||||||
LOG_EXTRA = {"stage": "ui"}
|
LOG_EXTRA = {"stage": "ui"}
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||||
|
"""Attempt to query provider API and return available model ids."""
|
||||||
|
|
||||||
|
base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip()
|
||||||
|
if not base_url:
|
||||||
|
return [], "请先填写 Base URL"
|
||||||
|
timeout = float(provider.default_timeout or 30.0)
|
||||||
|
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mode == "ollama":
|
||||||
|
url = base_url.rstrip('/') + "/api/tags"
|
||||||
|
response = requests.get(url, timeout=timeout)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
for item in data.get("models", []) or data.get("data", []):
|
||||||
|
name = item.get("name") or item.get("model") or item.get("tag")
|
||||||
|
if name:
|
||||||
|
models.append(str(name).strip())
|
||||||
|
return sorted(set(models)), None
|
||||||
|
|
||||||
|
api_key = (api_override or provider.api_key or "").strip()
|
||||||
|
if not api_key:
|
||||||
|
return [], "缺少 API Key"
|
||||||
|
url = base_url.rstrip('/') + "/v1/models"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(url, headers=headers, timeout=timeout)
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = response.json()
|
||||||
|
models = [
|
||||||
|
str(item.get("id")).strip()
|
||||||
|
for item in payload.get("data", [])
|
||||||
|
if item.get("id")
|
||||||
|
]
|
||||||
|
return sorted(set(models)), None
|
||||||
|
except RequestException as exc: # noqa: BLE001
|
||||||
|
return [], f"HTTP 错误:{exc}"
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return [], f"解析失败:{exc}"
|
||||||
|
|
||||||
def _load_stock_options(limit: int = 500) -> list[str]:
|
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
@ -350,255 +396,267 @@ def render_settings() -> None:
|
|||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.subheader("LLM 设置")
|
st.subheader("LLM 设置")
|
||||||
profiles = cfg.llm_profiles or {}
|
providers = cfg.llm_providers
|
||||||
routes = cfg.llm_routes or {}
|
provider_keys = sorted(providers.keys())
|
||||||
profile_keys = sorted(profiles.keys())
|
st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。")
|
||||||
route_keys = sorted(routes.keys())
|
|
||||||
used_routes = {
|
|
||||||
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
|
|
||||||
}
|
|
||||||
st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。")
|
|
||||||
|
|
||||||
route_select_col, route_manage_col = st.columns([3, 1])
|
# Provider management -------------------------------------------------
|
||||||
if route_keys:
|
provider_select_col, provider_manage_col = st.columns([3, 1])
|
||||||
|
if provider_keys:
|
||||||
try:
|
try:
|
||||||
active_index = route_keys.index(cfg.llm_route)
|
default_provider = cfg.llm.primary.provider or provider_keys[0]
|
||||||
|
provider_index = provider_keys.index(default_provider)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
active_index = 0
|
provider_index = 0
|
||||||
selected_route = route_select_col.selectbox(
|
selected_provider = provider_select_col.selectbox(
|
||||||
"全局路由",
|
"选择 Provider",
|
||||||
route_keys,
|
provider_keys,
|
||||||
index=active_index,
|
index=provider_index,
|
||||||
key="llm_route_select",
|
key="llm_provider_select",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
selected_route = None
|
selected_provider = None
|
||||||
route_select_col.info("尚未配置路由,请先创建。")
|
provider_select_col.info("尚未配置 Provider,请先创建。")
|
||||||
|
|
||||||
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
|
new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name")
|
||||||
if route_manage_col.button("添加路由"):
|
if provider_manage_col.button("创建 Provider", key="create_provider_btn"):
|
||||||
key = (new_route_name or "").strip()
|
key = (new_provider_name or "").strip().lower()
|
||||||
if not key:
|
if not key:
|
||||||
st.warning("请输入有效的路由名称。")
|
st.warning("请输入有效的 Provider 名称。")
|
||||||
elif key in routes:
|
elif key in providers:
|
||||||
st.warning(f"路由 {key} 已存在。")
|
st.warning(f"Provider {key} 已存在。")
|
||||||
else:
|
else:
|
||||||
routes[key] = LLMRoute(name=key)
|
providers[key] = LLMProvider(key=key)
|
||||||
if not selected_route:
|
cfg.llm_providers = providers
|
||||||
selected_route = key
|
|
||||||
cfg.llm_route = key
|
|
||||||
save_config()
|
save_config()
|
||||||
st.success(f"已添加路由 {key},请继续配置。")
|
st.success(f"已创建 Provider {key}。")
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
if selected_route:
|
if selected_provider:
|
||||||
route_obj = routes.get(selected_route)
|
provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
|
||||||
if route_obj is None:
|
title_key = f"provider_title_{selected_provider}"
|
||||||
route_obj = LLMRoute(name=selected_route)
|
base_key = f"provider_base_{selected_provider}"
|
||||||
routes[selected_route] = route_obj
|
api_key_key = f"provider_api_{selected_provider}"
|
||||||
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
|
models_key = f"provider_models_{selected_provider}"
|
||||||
|
default_model_key = f"provider_default_model_{selected_provider}"
|
||||||
|
mode_key = f"provider_mode_{selected_provider}"
|
||||||
|
temp_key = f"provider_temp_{selected_provider}"
|
||||||
|
timeout_key = f"provider_timeout_{selected_provider}"
|
||||||
|
prompt_key = f"provider_prompt_{selected_provider}"
|
||||||
|
enabled_key = f"provider_enabled_{selected_provider}"
|
||||||
|
|
||||||
|
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
|
||||||
|
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com")
|
||||||
|
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
|
||||||
|
models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100)
|
||||||
|
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key)
|
||||||
|
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
|
||||||
|
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
|
||||||
|
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
|
||||||
|
prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120)
|
||||||
|
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
|
||||||
|
|
||||||
|
fetch_key = f"fetch_models_{selected_provider}"
|
||||||
|
if st.button("获取模型列表", key=fetch_key):
|
||||||
|
with st.spinner("正在获取模型列表..."):
|
||||||
|
models, error = _discover_provider_models(provider_cfg, base_val, api_val)
|
||||||
|
if error:
|
||||||
|
st.error(error)
|
||||||
|
else:
|
||||||
|
provider_cfg.models = models
|
||||||
|
if models and (not provider_cfg.default_model or provider_cfg.default_model not in models):
|
||||||
|
provider_cfg.default_model = models[0]
|
||||||
|
providers[selected_provider] = provider_cfg
|
||||||
|
cfg.llm_providers = providers
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success(f"共获取 {len(models)} 个模型。")
|
||||||
|
st.session_state[models_key] = "\n".join(models)
|
||||||
|
st.session_state[default_model_key] = provider_cfg.default_model or ""
|
||||||
|
|
||||||
|
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
|
||||||
|
provider_cfg.title = title_val.strip()
|
||||||
|
provider_cfg.base_url = base_val.strip()
|
||||||
|
provider_cfg.api_key = api_val.strip() or None
|
||||||
|
provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()]
|
||||||
|
provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model)
|
||||||
|
provider_cfg.default_temperature = float(temp_val)
|
||||||
|
provider_cfg.default_timeout = float(timeout_val)
|
||||||
|
provider_cfg.prompt_template = prompt_template_val.strip()
|
||||||
|
provider_cfg.enabled = enabled_val
|
||||||
|
provider_cfg.mode = mode_val
|
||||||
|
providers[selected_provider] = provider_cfg
|
||||||
|
cfg.llm_providers = providers
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success("Provider 已保存。")
|
||||||
|
st.session_state[title_key] = provider_cfg.title or ""
|
||||||
|
st.session_state[default_model_key] = provider_cfg.default_model or ""
|
||||||
|
|
||||||
|
provider_in_use = (cfg.llm.primary.provider == selected_provider) or any(
|
||||||
|
ep.provider == selected_provider for ep in cfg.llm.ensemble
|
||||||
|
)
|
||||||
|
if not provider_in_use:
|
||||||
|
for dept in cfg.departments.values():
|
||||||
|
if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble):
|
||||||
|
provider_in_use = True
|
||||||
|
break
|
||||||
|
if st.button(
|
||||||
|
"删除 Provider",
|
||||||
|
key=f"delete_provider_{selected_provider}",
|
||||||
|
disabled=provider_in_use or len(providers) <= 1,
|
||||||
|
):
|
||||||
|
providers.pop(selected_provider, None)
|
||||||
|
cfg.llm_providers = providers
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success("Provider 已删除。")
|
||||||
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
st.markdown("##### 全局推理配置")
|
||||||
|
if not provider_keys:
|
||||||
|
st.warning("请先配置至少一个 Provider。")
|
||||||
|
else:
|
||||||
|
global_cfg = cfg.llm
|
||||||
|
primary = global_cfg.primary
|
||||||
try:
|
try:
|
||||||
strategy_index = strategy_choices.index(route_obj.strategy)
|
provider_index = provider_keys.index(primary.provider or provider_keys[0])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
strategy_index = 0
|
provider_index = 0
|
||||||
route_title = st.text_input(
|
selected_global_provider = st.selectbox(
|
||||||
"路由说明",
|
"主模型 Provider",
|
||||||
value=route_obj.title or "",
|
provider_keys,
|
||||||
key=f"route_title_{selected_route}",
|
index=provider_index,
|
||||||
|
key="global_provider_select",
|
||||||
)
|
)
|
||||||
route_strategy = st.selectbox(
|
provider_cfg = providers.get(selected_global_provider)
|
||||||
"推理策略",
|
available_models = provider_cfg.models if provider_cfg else []
|
||||||
strategy_choices,
|
default_model = primary.model or (provider_cfg.default_model if provider_cfg else None)
|
||||||
index=strategy_index,
|
if available_models:
|
||||||
key=f"route_strategy_{selected_route}",
|
options = available_models + ["自定义"]
|
||||||
|
try:
|
||||||
|
model_index = available_models.index(default_model)
|
||||||
|
model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice")
|
||||||
|
except ValueError:
|
||||||
|
model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice")
|
||||||
|
if model_choice == "自定义":
|
||||||
|
model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip()
|
||||||
|
else:
|
||||||
|
model_val = model_choice
|
||||||
|
else:
|
||||||
|
model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip()
|
||||||
|
|
||||||
|
temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2)
|
||||||
|
temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp")
|
||||||
|
timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0)
|
||||||
|
timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout")
|
||||||
|
prompt_template_val = st.text_area(
|
||||||
|
"主模型 Prompt 模板(可选)",
|
||||||
|
value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "",
|
||||||
|
height=120,
|
||||||
|
key="global_prompt_template",
|
||||||
)
|
)
|
||||||
route_majority = st.number_input(
|
|
||||||
|
strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy")
|
||||||
|
show_ensemble = strategy_val != "single"
|
||||||
|
majority_threshold_val = st.number_input(
|
||||||
"多数投票门槛",
|
"多数投票门槛",
|
||||||
min_value=1,
|
min_value=1,
|
||||||
max_value=10,
|
max_value=10,
|
||||||
value=int(route_obj.majority_threshold or 1),
|
value=int(global_cfg.majority_threshold),
|
||||||
step=1,
|
step=1,
|
||||||
key=f"route_majority_{selected_route}",
|
key="global_majority",
|
||||||
|
disabled=not show_ensemble,
|
||||||
)
|
)
|
||||||
if not profile_keys:
|
if not show_ensemble:
|
||||||
st.warning("暂无可用 Profile,请先在下方创建。")
|
majority_threshold_val = 1
|
||||||
|
|
||||||
|
ensemble_rows: List[Dict[str, str]] = []
|
||||||
|
if show_ensemble:
|
||||||
|
ensemble_rows = [
|
||||||
|
{
|
||||||
|
"provider": ep.provider,
|
||||||
|
"model": ep.model or "",
|
||||||
|
"temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}",
|
||||||
|
"timeout": "" if ep.timeout is None else str(int(ep.timeout)),
|
||||||
|
"prompt_template": ep.prompt_template or "",
|
||||||
|
}
|
||||||
|
for ep in global_cfg.ensemble
|
||||||
|
] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}]
|
||||||
|
|
||||||
|
ensemble_editor = st.data_editor(
|
||||||
|
ensemble_rows,
|
||||||
|
num_rows="dynamic",
|
||||||
|
key="global_ensemble_editor",
|
||||||
|
use_container_width=True,
|
||||||
|
hide_index=True,
|
||||||
|
column_config={
|
||||||
|
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
||||||
|
"model": st.column_config.TextColumn("模型"),
|
||||||
|
"temperature": st.column_config.TextColumn("温度"),
|
||||||
|
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||||
|
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if hasattr(ensemble_editor, "to_dict"):
|
||||||
|
ensemble_rows = ensemble_editor.to_dict("records")
|
||||||
|
else:
|
||||||
|
ensemble_rows = ensemble_editor
|
||||||
else:
|
else:
|
||||||
try:
|
st.info("当前策略为单模型,未启用协作模型。")
|
||||||
primary_index = profile_keys.index(route_obj.primary)
|
|
||||||
except ValueError:
|
if st.button("保存全局配置", key="save_global_llm"):
|
||||||
primary_index = 0
|
primary.provider = selected_global_provider
|
||||||
primary_key = st.selectbox(
|
primary.model = model_val or None
|
||||||
"主用 Profile",
|
primary.temperature = float(temp_val)
|
||||||
profile_keys,
|
primary.timeout = float(timeout_val)
|
||||||
index=primary_index,
|
primary.prompt_template = prompt_template_val.strip() or None
|
||||||
key=f"route_primary_{selected_route}",
|
primary.base_url = None
|
||||||
)
|
primary.api_key = None
|
||||||
default_ensemble = [
|
|
||||||
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
|
new_ensemble: List[LLMEndpoint] = []
|
||||||
]
|
if show_ensemble:
|
||||||
ensemble_keys = st.multiselect(
|
for row in ensemble_rows:
|
||||||
"协作 Profile (可多选)",
|
provider_val = (row.get("provider") or "").strip().lower()
|
||||||
profile_keys,
|
if not provider_val:
|
||||||
default=default_ensemble,
|
continue
|
||||||
key=f"route_ensemble_{selected_route}",
|
model_raw = (row.get("model") or "").strip() or None
|
||||||
)
|
temp_raw = (row.get("temperature") or "").strip()
|
||||||
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
|
timeout_raw = (row.get("timeout") or "").strip()
|
||||||
route_obj.title = route_title.strip()
|
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||||
route_obj.strategy = route_strategy
|
new_ensemble.append(
|
||||||
route_obj.majority_threshold = int(route_majority)
|
LLMEndpoint(
|
||||||
route_obj.primary = primary_key
|
provider=provider_val,
|
||||||
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
|
model=model_raw,
|
||||||
cfg.llm_route = selected_route
|
temperature=float(temp_raw) if temp_raw else None,
|
||||||
cfg.sync_runtime_llm()
|
timeout=float(timeout_raw) if timeout_raw else None,
|
||||||
save_config()
|
prompt_template=prompt_raw or None,
|
||||||
LOGGER.info(
|
)
|
||||||
"路由 %s 配置更新:%s",
|
)
|
||||||
selected_route,
|
cfg.llm.ensemble = new_ensemble
|
||||||
route_obj.to_dict(),
|
cfg.llm.strategy = strategy_val
|
||||||
extra=LOG_EXTRA,
|
cfg.llm.majority_threshold = int(majority_threshold_val)
|
||||||
)
|
|
||||||
st.success("路由配置已保存。")
|
|
||||||
st.json({
|
|
||||||
"route": selected_route,
|
|
||||||
"route_detail": route_obj.to_dict(),
|
|
||||||
"resolved": llm_config_snapshot(),
|
|
||||||
})
|
|
||||||
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
|
|
||||||
if st.button(
|
|
||||||
"删除当前路由",
|
|
||||||
key=f"delete_route_{selected_route}",
|
|
||||||
disabled=route_in_use or len(routes) <= 1,
|
|
||||||
):
|
|
||||||
routes.pop(selected_route, None)
|
|
||||||
if cfg.llm_route == selected_route:
|
|
||||||
cfg.llm_route = next((key for key in routes.keys()), "global")
|
|
||||||
cfg.sync_runtime_llm()
|
cfg.sync_runtime_llm()
|
||||||
save_config()
|
save_config()
|
||||||
st.success("路由已删除。")
|
st.success("全局 LLM 配置已保存。")
|
||||||
st.experimental_rerun()
|
st.json(llm_config_snapshot())
|
||||||
|
|
||||||
st.divider()
|
|
||||||
st.subheader("LLM Profile 管理")
|
|
||||||
profile_select_col, profile_manage_col = st.columns([3, 1])
|
|
||||||
if profile_keys:
|
|
||||||
selected_profile = profile_select_col.selectbox(
|
|
||||||
"选择 Profile",
|
|
||||||
profile_keys,
|
|
||||||
index=0,
|
|
||||||
key="profile_select",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
selected_profile = None
|
|
||||||
profile_select_col.info("尚未配置 Profile,请先创建。")
|
|
||||||
|
|
||||||
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
|
|
||||||
if profile_manage_col.button("创建 Profile"):
|
|
||||||
key = (new_profile_name or "").strip()
|
|
||||||
if not key:
|
|
||||||
st.warning("请输入有效的 Profile 名称。")
|
|
||||||
elif key in profiles:
|
|
||||||
st.warning(f"Profile {key} 已存在。")
|
|
||||||
else:
|
|
||||||
profiles[key] = LLMProfile(key=key)
|
|
||||||
save_config()
|
|
||||||
st.success(f"已创建 Profile {key}。")
|
|
||||||
st.experimental_rerun()
|
|
||||||
|
|
||||||
if selected_profile:
|
|
||||||
profile = profiles[selected_profile]
|
|
||||||
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
|
||||||
try:
|
|
||||||
provider_index = provider_choices.index(profile.provider)
|
|
||||||
except ValueError:
|
|
||||||
provider_index = 0
|
|
||||||
with st.form(f"profile_form_{selected_profile}"):
|
|
||||||
provider_val = st.selectbox(
|
|
||||||
"Provider",
|
|
||||||
provider_choices,
|
|
||||||
index=provider_index,
|
|
||||||
)
|
|
||||||
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
|
|
||||||
model_val = st.text_input(
|
|
||||||
"模型",
|
|
||||||
value=profile.model or model_default,
|
|
||||||
)
|
|
||||||
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
|
|
||||||
base_val = st.text_input(
|
|
||||||
"Base URL",
|
|
||||||
value=profile.base_url or base_default,
|
|
||||||
)
|
|
||||||
api_val = st.text_input(
|
|
||||||
"API Key",
|
|
||||||
value=profile.api_key or "",
|
|
||||||
type="password",
|
|
||||||
)
|
|
||||||
temp_val = st.slider(
|
|
||||||
"温度",
|
|
||||||
min_value=0.0,
|
|
||||||
max_value=2.0,
|
|
||||||
value=float(profile.temperature),
|
|
||||||
step=0.05,
|
|
||||||
)
|
|
||||||
timeout_val = st.number_input(
|
|
||||||
"超时(秒)",
|
|
||||||
min_value=5,
|
|
||||||
max_value=180,
|
|
||||||
value=int(profile.timeout or 30),
|
|
||||||
step=5,
|
|
||||||
)
|
|
||||||
title_val = st.text_input("备注", value=profile.title or "")
|
|
||||||
enabled_val = st.checkbox("启用", value=profile.enabled)
|
|
||||||
submitted = st.form_submit_button("保存 Profile")
|
|
||||||
if submitted:
|
|
||||||
profile.provider = provider_val
|
|
||||||
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
|
|
||||||
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
|
|
||||||
profile.api_key = api_val.strip() or None
|
|
||||||
profile.temperature = temp_val
|
|
||||||
profile.timeout = timeout_val
|
|
||||||
profile.title = title_val.strip()
|
|
||||||
profile.enabled = enabled_val
|
|
||||||
profiles[selected_profile] = profile
|
|
||||||
cfg.sync_runtime_llm()
|
|
||||||
save_config()
|
|
||||||
st.success("Profile 已保存。")
|
|
||||||
|
|
||||||
profile_in_use = any(
|
|
||||||
selected_profile == route.primary or selected_profile in route.ensemble
|
|
||||||
for route in routes.values()
|
|
||||||
)
|
|
||||||
if st.button(
|
|
||||||
"删除该 Profile",
|
|
||||||
key=f"delete_profile_{selected_profile}",
|
|
||||||
disabled=profile_in_use or len(profiles) <= 1,
|
|
||||||
):
|
|
||||||
profiles.pop(selected_profile, None)
|
|
||||||
fallback_key = next((key for key in profiles.keys()), None)
|
|
||||||
for route in routes.values():
|
|
||||||
if route.primary == selected_profile:
|
|
||||||
route.primary = fallback_key or route.primary
|
|
||||||
route.ensemble = [key for key in route.ensemble if key != selected_profile]
|
|
||||||
cfg.sync_runtime_llm()
|
|
||||||
save_config()
|
|
||||||
st.success("Profile 已删除。")
|
|
||||||
st.experimental_rerun()
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
st.subheader("部门配置")
|
|
||||||
|
|
||||||
|
# Department configuration -------------------------------------------
|
||||||
|
st.markdown("##### 部门配置")
|
||||||
dept_settings = cfg.departments or {}
|
dept_settings = cfg.departments or {}
|
||||||
route_options_display = [""] + route_keys
|
|
||||||
dept_rows = [
|
dept_rows = [
|
||||||
{
|
{
|
||||||
"code": code,
|
"code": code,
|
||||||
"title": dept.title,
|
"title": dept.title,
|
||||||
"description": dept.description,
|
"description": dept.description,
|
||||||
"weight": float(dept.weight),
|
"weight": float(dept.weight),
|
||||||
"llm_route": dept.llm_route or "",
|
|
||||||
"strategy": dept.llm.strategy,
|
"strategy": dept.llm.strategy,
|
||||||
"primary_provider": (dept.llm.primary.provider or ""),
|
"majority_threshold": dept.llm.majority_threshold,
|
||||||
"primary_model": dept.llm.primary.model or "",
|
"provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""),
|
||||||
"ensemble_size": len(dept.llm.ensemble),
|
"model": dept.llm.primary.model or "",
|
||||||
|
"temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}",
|
||||||
|
"timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)),
|
||||||
|
"prompt_template": dept.llm.primary.prompt_template or "",
|
||||||
}
|
}
|
||||||
for code, dept in sorted(dept_settings.items())
|
for code, dept in sorted(dept_settings.items())
|
||||||
]
|
]
|
||||||
@ -618,26 +676,13 @@ def render_settings() -> None:
|
|||||||
"title": st.column_config.TextColumn("名称"),
|
"title": st.column_config.TextColumn("名称"),
|
||||||
"description": st.column_config.TextColumn("说明"),
|
"description": st.column_config.TextColumn("说明"),
|
||||||
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
||||||
"llm_route": st.column_config.SelectboxColumn(
|
"strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)),
|
||||||
"路由",
|
"majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1),
|
||||||
options=route_options_display,
|
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]),
|
||||||
help="选择预定义路由;留空表示使用自定义配置",
|
"model": st.column_config.TextColumn("模型"),
|
||||||
),
|
"temperature": st.column_config.TextColumn("温度"),
|
||||||
"strategy": st.column_config.SelectboxColumn(
|
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||||
"自定义策略",
|
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||||
options=sorted(ALLOWED_LLM_STRATEGIES),
|
|
||||||
help="仅当未选择路由时生效",
|
|
||||||
),
|
|
||||||
"primary_provider": st.column_config.SelectboxColumn(
|
|
||||||
"自定义 Provider",
|
|
||||||
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
|
||||||
),
|
|
||||||
"primary_model": st.column_config.TextColumn("自定义模型"),
|
|
||||||
"ensemble_size": st.column_config.NumberColumn(
|
|
||||||
"协作模型数量",
|
|
||||||
disabled=True,
|
|
||||||
help="路由模式下自动维护",
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -662,25 +707,34 @@ def render_settings() -> None:
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
route_name = (row.get("llm_route") or "").strip() or None
|
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
||||||
existing.llm_route = route_name
|
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
||||||
if route_name and route_name in routes:
|
existing.llm.strategy = strategy_val
|
||||||
existing.llm = routes[route_name].resolve(profiles)
|
majority_raw = row.get("majority_threshold")
|
||||||
else:
|
try:
|
||||||
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
majority_val = int(majority_raw)
|
||||||
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
if majority_val > 0:
|
||||||
existing.llm.strategy = strategy_val
|
existing.llm.majority_threshold = majority_val
|
||||||
provider_before = existing.llm.primary.provider or ""
|
except (TypeError, ValueError):
|
||||||
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
|
pass
|
||||||
existing.llm.primary.provider = provider_val
|
|
||||||
model_val = (row.get("primary_model") or "").strip()
|
provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
|
||||||
existing.llm.primary.model = (
|
model_val = (row.get("model") or "").strip() or None
|
||||||
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
temp_raw = (row.get("temperature") or "").strip()
|
||||||
)
|
timeout_raw = (row.get("timeout") or "").strip()
|
||||||
if provider_before != provider_val:
|
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
|
|
||||||
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
|
endpoint = existing.llm.primary or LLMEndpoint()
|
||||||
existing.llm.primary.__post_init__()
|
endpoint.provider = provider_val
|
||||||
|
endpoint.model = model_val
|
||||||
|
endpoint.temperature = float(temp_raw) if temp_raw else None
|
||||||
|
endpoint.timeout = float(timeout_raw) if timeout_raw else None
|
||||||
|
endpoint.prompt_template = prompt_raw or None
|
||||||
|
endpoint.base_url = None
|
||||||
|
endpoint.api_key = None
|
||||||
|
existing.llm.primary = endpoint
|
||||||
|
existing.llm.ensemble = []
|
||||||
|
|
||||||
updated_departments[code] = existing
|
updated_departments[code] = existing
|
||||||
|
|
||||||
if updated_departments:
|
if updated_departments:
|
||||||
@ -700,8 +754,7 @@ def render_settings() -> None:
|
|||||||
st.success("已恢复默认部门配置。")
|
st.success("已恢复默认部门配置。")
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
|
st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")
|
||||||
st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。")
|
|
||||||
|
|
||||||
|
|
||||||
def render_tests() -> None:
|
def render_tests() -> None:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Mapping, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
def _default_root() -> Path:
|
def _default_root() -> Path:
|
||||||
@ -99,27 +99,55 @@ ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"}
|
|||||||
LLM_STRATEGY_ALIASES = {"leader-follower": "leader"}
|
LLM_STRATEGY_ALIASES = {"leader-follower": "leader"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMProvider:
|
||||||
|
"""Provider level configuration shared across profiles and routes."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
title: str = ""
|
||||||
|
base_url: str = ""
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
models: List[str] = field(default_factory=list)
|
||||||
|
default_model: Optional[str] = None
|
||||||
|
default_temperature: float = 0.2
|
||||||
|
default_timeout: float = 30.0
|
||||||
|
prompt_template: str = ""
|
||||||
|
enabled: bool = True
|
||||||
|
mode: str = "openai" # openai 或 ollama
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, object]:
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"models": list(self.models),
|
||||||
|
"default_model": self.default_model,
|
||||||
|
"default_temperature": self.default_temperature,
|
||||||
|
"default_timeout": self.default_timeout,
|
||||||
|
"prompt_template": self.prompt_template,
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"mode": self.mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMEndpoint:
|
class LLMEndpoint:
|
||||||
"""Single LLM endpoint configuration."""
|
"""Resolved endpoint payload used for actual LLM calls."""
|
||||||
|
|
||||||
provider: str = "ollama"
|
provider: str = "ollama"
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
temperature: float = 0.2
|
temperature: Optional[float] = None
|
||||||
timeout: float = 30.0
|
timeout: Optional[float] = None
|
||||||
|
prompt_template: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.provider = (self.provider or "ollama").lower()
|
self.provider = (self.provider or "ollama").lower()
|
||||||
if not self.model:
|
if self.temperature is not None:
|
||||||
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
|
self.temperature = float(self.temperature)
|
||||||
if not self.base_url:
|
if self.timeout is not None:
|
||||||
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider)
|
self.timeout = float(self.timeout)
|
||||||
if self.temperature == 0.2 or self.temperature is None:
|
|
||||||
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
|
|
||||||
if self.timeout == 30.0 or self.timeout is None:
|
|
||||||
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -132,133 +160,22 @@ class LLMConfig:
|
|||||||
majority_threshold: int = 3
|
majority_threshold: int = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def _default_llm_providers() -> Dict[str, LLMProvider]:
|
||||||
class LLMProfile:
|
providers: Dict[str, LLMProvider] = {}
|
||||||
"""Named LLM endpoint profile reusable across routes/departments."""
|
for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items():
|
||||||
|
models = list(meta.get("models", []))
|
||||||
key: str
|
mode = "ollama" if provider == "ollama" else "openai"
|
||||||
provider: str = "ollama"
|
providers[provider] = LLMProvider(
|
||||||
model: Optional[str] = None
|
|
||||||
base_url: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
temperature: float = 0.2
|
|
||||||
timeout: float = 30.0
|
|
||||||
title: str = ""
|
|
||||||
enabled: bool = True
|
|
||||||
|
|
||||||
def to_endpoint(self) -> LLMEndpoint:
|
|
||||||
return LLMEndpoint(
|
|
||||||
provider=self.provider,
|
|
||||||
model=self.model,
|
|
||||||
base_url=self.base_url,
|
|
||||||
api_key=self.api_key,
|
|
||||||
temperature=self.temperature,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, object]:
|
|
||||||
return {
|
|
||||||
"provider": self.provider,
|
|
||||||
"model": self.model,
|
|
||||||
"base_url": self.base_url,
|
|
||||||
"api_key": self.api_key,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"timeout": self.timeout,
|
|
||||||
"title": self.title,
|
|
||||||
"enabled": self.enabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_endpoint(
|
|
||||||
cls,
|
|
||||||
key: str,
|
|
||||||
endpoint: LLMEndpoint,
|
|
||||||
*,
|
|
||||||
title: str = "",
|
|
||||||
enabled: bool = True,
|
|
||||||
) -> "LLMProfile":
|
|
||||||
return cls(
|
|
||||||
key=key,
|
|
||||||
provider=endpoint.provider,
|
|
||||||
model=endpoint.model,
|
|
||||||
base_url=endpoint.base_url,
|
|
||||||
api_key=endpoint.api_key,
|
|
||||||
temperature=endpoint.temperature,
|
|
||||||
timeout=endpoint.timeout,
|
|
||||||
title=title,
|
|
||||||
enabled=enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMRoute:
|
|
||||||
"""Declarative routing for selecting profiles and strategy."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
title: str = ""
|
|
||||||
strategy: str = "single"
|
|
||||||
majority_threshold: int = 3
|
|
||||||
primary: str = "ollama"
|
|
||||||
ensemble: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
|
|
||||||
def _endpoint_from_key(key: str) -> LLMEndpoint:
|
|
||||||
profile = profiles.get(key)
|
|
||||||
if profile and profile.enabled:
|
|
||||||
return profile.to_endpoint()
|
|
||||||
fallback = profiles.get("ollama")
|
|
||||||
if not fallback or not fallback.enabled:
|
|
||||||
fallback = next(
|
|
||||||
(item for item in profiles.values() if item.enabled),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
|
|
||||||
endpoint.provider = key or endpoint.provider
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
primary_endpoint = _endpoint_from_key(self.primary)
|
|
||||||
ensemble_endpoints = [
|
|
||||||
_endpoint_from_key(key)
|
|
||||||
for key in self.ensemble
|
|
||||||
if key in profiles and profiles[key].enabled
|
|
||||||
]
|
|
||||||
config = LLMConfig(
|
|
||||||
primary=primary_endpoint,
|
|
||||||
ensemble=ensemble_endpoints,
|
|
||||||
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
|
||||||
majority_threshold=max(1, self.majority_threshold or 1),
|
|
||||||
)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, object]:
|
|
||||||
return {
|
|
||||||
"title": self.title,
|
|
||||||
"strategy": self.strategy,
|
|
||||||
"majority_threshold": self.majority_threshold,
|
|
||||||
"primary": self.primary,
|
|
||||||
"ensemble": list(self.ensemble),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _default_llm_profiles() -> Dict[str, LLMProfile]:
|
|
||||||
return {
|
|
||||||
provider: LLMProfile(
|
|
||||||
key=provider,
|
key=provider,
|
||||||
provider=provider,
|
|
||||||
model=DEFAULT_LLM_MODELS.get(provider),
|
|
||||||
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
|
|
||||||
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
|
|
||||||
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
|
|
||||||
title=f"默认 {provider}",
|
title=f"默认 {provider}",
|
||||||
|
base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""),
|
||||||
|
models=models,
|
||||||
|
default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider),
|
||||||
|
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||||
|
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||||
|
mode=mode,
|
||||||
)
|
)
|
||||||
for provider in DEFAULT_LLM_MODEL_OPTIONS
|
return providers
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _default_llm_routes() -> Dict[str, LLMRoute]:
|
|
||||||
return {
|
|
||||||
"global": LLMRoute(name="global", title="全局默认路由"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -270,7 +187,6 @@ class DepartmentSettings:
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
weight: float = 1.0
|
weight: float = 1.0
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
llm_route: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _default_departments() -> Dict[str, DepartmentSettings]:
|
def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||||
@ -282,10 +198,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
("macro", "宏观研究部门"),
|
("macro", "宏观研究部门"),
|
||||||
("risk", "风险控制部门"),
|
("risk", "风险控制部门"),
|
||||||
]
|
]
|
||||||
return {
|
return {code: DepartmentSettings(code=code, title=title) for code, title in presets}
|
||||||
code: DepartmentSettings(code=code, title=title, llm_route="global")
|
|
||||||
for code, title in presets
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -298,17 +211,11 @@ class AppConfig:
|
|||||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||||
force_refresh: bool = False
|
force_refresh: bool = False
|
||||||
|
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
llm_route: str = "global"
|
|
||||||
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
|
|
||||||
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
|
|
||||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||||
|
|
||||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||||
route_key = route or self.llm_route
|
|
||||||
route_cfg = self.llm_routes.get(route_key)
|
|
||||||
if route_cfg:
|
|
||||||
return route_cfg.resolve(self.llm_profiles)
|
|
||||||
return self.llm
|
return self.llm
|
||||||
|
|
||||||
def sync_runtime_llm(self) -> None:
|
def sync_runtime_llm(self) -> None:
|
||||||
@ -326,13 +233,22 @@ def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
|
|||||||
"api_key": endpoint.api_key,
|
"api_key": endpoint.api_key,
|
||||||
"temperature": endpoint.temperature,
|
"temperature": endpoint.temperature,
|
||||||
"timeout": endpoint.timeout,
|
"timeout": endpoint.timeout,
|
||||||
|
"prompt_template": endpoint.prompt_template,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
|
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
|
||||||
payload = {
|
payload = {
|
||||||
key: data.get(key)
|
key: data.get(key)
|
||||||
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")
|
for key in (
|
||||||
|
"provider",
|
||||||
|
"model",
|
||||||
|
"base_url",
|
||||||
|
"api_key",
|
||||||
|
"temperature",
|
||||||
|
"timeout",
|
||||||
|
"prompt_template",
|
||||||
|
)
|
||||||
if data.get(key) is not None
|
if data.get(key) is not None
|
||||||
}
|
}
|
||||||
return LLMEndpoint(**payload)
|
return LLMEndpoint(**payload)
|
||||||
@ -348,139 +264,148 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
except (json.JSONDecodeError, OSError):
|
except (json.JSONDecodeError, OSError):
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
if "tushare_token" in payload:
|
return
|
||||||
cfg.tushare_token = payload.get("tushare_token") or None
|
|
||||||
if "force_refresh" in payload:
|
|
||||||
cfg.force_refresh = bool(payload.get("force_refresh"))
|
|
||||||
if "decision_method" in payload:
|
|
||||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
|
||||||
|
|
||||||
routes_defined = False
|
if "tushare_token" in payload:
|
||||||
inline_primary_loaded = False
|
cfg.tushare_token = payload.get("tushare_token") or None
|
||||||
|
if "force_refresh" in payload:
|
||||||
|
cfg.force_refresh = bool(payload.get("force_refresh"))
|
||||||
|
if "decision_method" in payload:
|
||||||
|
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||||
|
|
||||||
profiles_payload = payload.get("llm_profiles")
|
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
||||||
if isinstance(profiles_payload, dict):
|
legacy_routes: Dict[str, Dict[str, object]] = {}
|
||||||
profiles: Dict[str, LLMProfile] = {}
|
|
||||||
for key, data in profiles_payload.items():
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
provider = str(data.get("provider") or "ollama").lower()
|
|
||||||
profile = LLMProfile(
|
|
||||||
key=key,
|
|
||||||
provider=provider,
|
|
||||||
model=data.get("model"),
|
|
||||||
base_url=data.get("base_url"),
|
|
||||||
api_key=data.get("api_key"),
|
|
||||||
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
|
||||||
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
|
||||||
title=str(data.get("title") or ""),
|
|
||||||
enabled=bool(data.get("enabled", True)),
|
|
||||||
)
|
|
||||||
profiles[key] = profile
|
|
||||||
if profiles:
|
|
||||||
cfg.llm_profiles = profiles
|
|
||||||
|
|
||||||
routes_payload = payload.get("llm_routes")
|
providers_payload = payload.get("llm_providers")
|
||||||
if isinstance(routes_payload, dict):
|
if isinstance(providers_payload, dict):
|
||||||
routes: Dict[str, LLMRoute] = {}
|
providers: Dict[str, LLMProvider] = {}
|
||||||
for name, data in routes_payload.items():
|
for key, data in providers_payload.items():
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
continue
|
continue
|
||||||
strategy_raw = str(data.get("strategy") or "single").lower()
|
models_raw = data.get("models")
|
||||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
if isinstance(models_raw, str):
|
||||||
route = LLMRoute(
|
models = [item.strip() for item in models_raw.split(',') if item.strip()]
|
||||||
name=name,
|
elif isinstance(models_raw, list):
|
||||||
title=str(data.get("title") or ""),
|
models = [str(item).strip() for item in models_raw if str(item).strip()]
|
||||||
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
|
else:
|
||||||
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
|
models = []
|
||||||
primary=str(data.get("primary") or "global"),
|
provider = LLMProvider(
|
||||||
ensemble=[
|
key=str(key).lower(),
|
||||||
str(item)
|
title=str(data.get("title") or ""),
|
||||||
for item in data.get("ensemble", [])
|
base_url=str(data.get("base_url") or ""),
|
||||||
if isinstance(item, str)
|
api_key=data.get("api_key"),
|
||||||
],
|
models=models,
|
||||||
)
|
default_model=data.get("default_model") or (models[0] if models else None),
|
||||||
routes[name] = route
|
default_temperature=float(data.get("default_temperature", 0.2)),
|
||||||
if routes:
|
default_timeout=float(data.get("default_timeout", 30.0)),
|
||||||
cfg.llm_routes = routes
|
prompt_template=str(data.get("prompt_template") or ""),
|
||||||
routes_defined = True
|
enabled=bool(data.get("enabled", True)),
|
||||||
|
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")),
|
||||||
|
)
|
||||||
|
providers[provider.key] = provider
|
||||||
|
if providers:
|
||||||
|
cfg.llm_providers = providers
|
||||||
|
|
||||||
route_key = payload.get("llm_route")
|
profiles_payload = payload.get("llm_profiles")
|
||||||
if isinstance(route_key, str) and route_key:
|
if isinstance(profiles_payload, dict):
|
||||||
cfg.llm_route = route_key
|
for key, data in profiles_payload.items():
|
||||||
|
if isinstance(data, dict):
|
||||||
|
legacy_profiles[str(key)] = data
|
||||||
|
|
||||||
llm_payload = payload.get("llm")
|
routes_payload = payload.get("llm_routes")
|
||||||
if isinstance(llm_payload, dict):
|
if isinstance(routes_payload, dict):
|
||||||
route_value = llm_payload.get("route")
|
for name, data in routes_payload.items():
|
||||||
if isinstance(route_value, str) and route_value:
|
if isinstance(data, dict):
|
||||||
cfg.llm_route = route_value
|
legacy_routes[str(name)] = data
|
||||||
|
|
||||||
|
def _endpoint_from_payload(item: object) -> LLMEndpoint:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return _dict_to_endpoint(item)
|
||||||
|
if isinstance(item, str):
|
||||||
|
profile_data = legacy_profiles.get(item)
|
||||||
|
if isinstance(profile_data, dict):
|
||||||
|
return _dict_to_endpoint(profile_data)
|
||||||
|
return LLMEndpoint(provider=item)
|
||||||
|
return LLMEndpoint()
|
||||||
|
|
||||||
|
def _resolve_route(route_name: str) -> Optional[LLMConfig]:
|
||||||
|
route_data = legacy_routes.get(route_name)
|
||||||
|
if not route_data:
|
||||||
|
return None
|
||||||
|
strategy_raw = str(route_data.get("strategy") or "single").lower()
|
||||||
|
strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||||
|
primary_ref = route_data.get("primary")
|
||||||
|
primary_ep = _endpoint_from_payload(primary_ref)
|
||||||
|
ensemble_refs = route_data.get("ensemble", [])
|
||||||
|
ensemble_eps = [
|
||||||
|
_endpoint_from_payload(ref)
|
||||||
|
for ref in ensemble_refs
|
||||||
|
if isinstance(ref, (dict, str))
|
||||||
|
]
|
||||||
|
cfg_obj = LLMConfig(
|
||||||
|
primary=primary_ep,
|
||||||
|
ensemble=ensemble_eps,
|
||||||
|
strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
|
majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)),
|
||||||
|
)
|
||||||
|
return cfg_obj
|
||||||
|
|
||||||
|
llm_payload = payload.get("llm")
|
||||||
|
if isinstance(llm_payload, dict):
|
||||||
|
route_value = llm_payload.get("route")
|
||||||
|
resolved_cfg = None
|
||||||
|
if isinstance(route_value, str) and route_value:
|
||||||
|
resolved_cfg = _resolve_route(route_value)
|
||||||
|
if resolved_cfg is None:
|
||||||
|
resolved_cfg = LLMConfig()
|
||||||
primary_data = llm_payload.get("primary")
|
primary_data = llm_payload.get("primary")
|
||||||
if isinstance(primary_data, dict):
|
if isinstance(primary_data, dict):
|
||||||
cfg.llm.primary = _dict_to_endpoint(primary_data)
|
resolved_cfg.primary = _dict_to_endpoint(primary_data)
|
||||||
inline_primary_loaded = True
|
|
||||||
|
|
||||||
ensemble_data = llm_payload.get("ensemble")
|
ensemble_data = llm_payload.get("ensemble")
|
||||||
if isinstance(ensemble_data, list):
|
if isinstance(ensemble_data, list):
|
||||||
cfg.llm.ensemble = [
|
resolved_cfg.ensemble = [
|
||||||
_dict_to_endpoint(item)
|
_dict_to_endpoint(item)
|
||||||
for item in ensemble_data
|
for item in ensemble_data
|
||||||
if isinstance(item, dict)
|
if isinstance(item, dict)
|
||||||
]
|
]
|
||||||
|
|
||||||
strategy_raw = llm_payload.get("strategy")
|
strategy_raw = llm_payload.get("strategy")
|
||||||
if isinstance(strategy_raw, str):
|
if isinstance(strategy_raw, str):
|
||||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||||
if normalized in ALLOWED_LLM_STRATEGIES:
|
if normalized in ALLOWED_LLM_STRATEGIES:
|
||||||
cfg.llm.strategy = normalized
|
resolved_cfg.strategy = normalized
|
||||||
|
|
||||||
majority = llm_payload.get("majority_threshold")
|
majority = llm_payload.get("majority_threshold")
|
||||||
if isinstance(majority, int) and majority > 0:
|
if isinstance(majority, int) and majority > 0:
|
||||||
cfg.llm.majority_threshold = majority
|
resolved_cfg.majority_threshold = majority
|
||||||
|
cfg.llm = resolved_cfg
|
||||||
|
|
||||||
if inline_primary_loaded and not routes_defined:
|
departments_payload = payload.get("departments")
|
||||||
primary_key = "inline_global_primary"
|
if isinstance(departments_payload, dict):
|
||||||
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
|
new_departments: Dict[str, DepartmentSettings] = {}
|
||||||
primary_key,
|
for code, data in departments_payload.items():
|
||||||
cfg.llm.primary,
|
if not isinstance(data, dict):
|
||||||
title="全局主模型",
|
continue
|
||||||
)
|
title = data.get("title") or code
|
||||||
ensemble_keys: List[str] = []
|
description = data.get("description") or ""
|
||||||
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
|
weight = float(data.get("weight", 1.0))
|
||||||
inline_key = f"inline_global_ensemble_{idx}"
|
llm_cfg = LLMConfig()
|
||||||
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
|
route_name = data.get("llm_route")
|
||||||
inline_key,
|
resolved_cfg = None
|
||||||
endpoint,
|
if isinstance(route_name, str) and route_name:
|
||||||
title=f"全局协作#{idx}",
|
resolved_cfg = _resolve_route(route_name)
|
||||||
)
|
if resolved_cfg is None:
|
||||||
ensemble_keys.append(inline_key)
|
|
||||||
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
|
|
||||||
auto_route.strategy = cfg.llm.strategy
|
|
||||||
auto_route.majority_threshold = cfg.llm.majority_threshold
|
|
||||||
auto_route.primary = primary_key
|
|
||||||
auto_route.ensemble = ensemble_keys
|
|
||||||
cfg.llm_routes["global"] = auto_route
|
|
||||||
cfg.llm_route = cfg.llm_route or "global"
|
|
||||||
|
|
||||||
departments_payload = payload.get("departments")
|
|
||||||
if isinstance(departments_payload, dict):
|
|
||||||
new_departments: Dict[str, DepartmentSettings] = {}
|
|
||||||
for code, data in departments_payload.items():
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
title = data.get("title") or code
|
|
||||||
description = data.get("description") or ""
|
|
||||||
weight = float(data.get("weight", 1.0))
|
|
||||||
llm_data = data.get("llm")
|
llm_data = data.get("llm")
|
||||||
llm_cfg = LLMConfig()
|
|
||||||
if isinstance(llm_data, dict):
|
if isinstance(llm_data, dict):
|
||||||
if isinstance(llm_data.get("primary"), dict):
|
primary_data = llm_data.get("primary")
|
||||||
llm_cfg.primary = _dict_to_endpoint(llm_data["primary"])
|
if isinstance(primary_data, dict):
|
||||||
llm_cfg.ensemble = [
|
llm_cfg.primary = _dict_to_endpoint(primary_data)
|
||||||
_dict_to_endpoint(item)
|
ensemble_data = llm_data.get("ensemble")
|
||||||
for item in llm_data.get("ensemble", [])
|
if isinstance(ensemble_data, list):
|
||||||
if isinstance(item, dict)
|
llm_cfg.ensemble = [
|
||||||
]
|
_dict_to_endpoint(item)
|
||||||
|
for item in ensemble_data
|
||||||
|
if isinstance(item, dict)
|
||||||
|
]
|
||||||
strategy_raw = llm_data.get("strategy")
|
strategy_raw = llm_data.get("strategy")
|
||||||
if isinstance(strategy_raw, str):
|
if isinstance(strategy_raw, str):
|
||||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||||
@ -489,23 +414,18 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
majority_raw = llm_data.get("majority_threshold")
|
majority_raw = llm_data.get("majority_threshold")
|
||||||
if isinstance(majority_raw, int) and majority_raw > 0:
|
if isinstance(majority_raw, int) and majority_raw > 0:
|
||||||
llm_cfg.majority_threshold = majority_raw
|
llm_cfg.majority_threshold = majority_raw
|
||||||
route = data.get("llm_route")
|
resolved_cfg = llm_cfg
|
||||||
route_name = str(route).strip() if isinstance(route, str) and route else None
|
new_departments[code] = DepartmentSettings(
|
||||||
resolved = llm_cfg
|
code=code,
|
||||||
if route_name and route_name in cfg.llm_routes:
|
title=title,
|
||||||
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
|
description=description,
|
||||||
new_departments[code] = DepartmentSettings(
|
weight=weight,
|
||||||
code=code,
|
llm=resolved_cfg,
|
||||||
title=title,
|
)
|
||||||
description=description,
|
if new_departments:
|
||||||
weight=weight,
|
cfg.departments = new_departments
|
||||||
llm=resolved,
|
|
||||||
llm_route=route_name,
|
|
||||||
)
|
|
||||||
if new_departments:
|
|
||||||
cfg.departments = new_departments
|
|
||||||
|
|
||||||
cfg.sync_runtime_llm()
|
cfg.sync_runtime_llm()
|
||||||
|
|
||||||
|
|
||||||
def save_config(cfg: AppConfig | None = None) -> None:
|
def save_config(cfg: AppConfig | None = None) -> None:
|
||||||
@ -516,28 +436,21 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
|||||||
"tushare_token": cfg.tushare_token,
|
"tushare_token": cfg.tushare_token,
|
||||||
"force_refresh": cfg.force_refresh,
|
"force_refresh": cfg.force_refresh,
|
||||||
"decision_method": cfg.decision_method,
|
"decision_method": cfg.decision_method,
|
||||||
"llm_route": cfg.llm_route,
|
|
||||||
"llm": {
|
"llm": {
|
||||||
"route": cfg.llm_route,
|
|
||||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
"majority_threshold": cfg.llm.majority_threshold,
|
"majority_threshold": cfg.llm.majority_threshold,
|
||||||
"primary": _endpoint_to_dict(cfg.llm.primary),
|
"primary": _endpoint_to_dict(cfg.llm.primary),
|
||||||
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
||||||
},
|
},
|
||||||
"llm_profiles": {
|
"llm_providers": {
|
||||||
key: profile.to_dict()
|
key: provider.to_dict()
|
||||||
for key, profile in cfg.llm_profiles.items()
|
for key, provider in cfg.llm_providers.items()
|
||||||
},
|
|
||||||
"llm_routes": {
|
|
||||||
name: route.to_dict()
|
|
||||||
for name, route in cfg.llm_routes.items()
|
|
||||||
},
|
},
|
||||||
"departments": {
|
"departments": {
|
||||||
code: {
|
code: {
|
||||||
"title": dept.title,
|
"title": dept.title,
|
||||||
"description": dept.description,
|
"description": dept.description,
|
||||||
"weight": dept.weight,
|
"weight": dept.weight,
|
||||||
"llm_route": dept.llm_route,
|
|
||||||
"llm": {
|
"llm": {
|
||||||
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
"majority_threshold": dept.llm.majority_threshold,
|
"majority_threshold": dept.llm.majority_threshold,
|
||||||
@ -567,11 +480,9 @@ def _load_env_defaults(cfg: AppConfig) -> None:
|
|||||||
if api_key:
|
if api_key:
|
||||||
sanitized = api_key.strip()
|
sanitized = api_key.strip()
|
||||||
cfg.llm.primary.api_key = sanitized
|
cfg.llm.primary.api_key = sanitized
|
||||||
route = cfg.llm_routes.get(cfg.llm_route)
|
provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider)
|
||||||
if route:
|
if provider_cfg:
|
||||||
profile = cfg.llm_profiles.get(route.primary)
|
provider_cfg.api_key = sanitized
|
||||||
if profile:
|
|
||||||
profile.api_key = sanitized
|
|
||||||
|
|
||||||
cfg.sync_runtime_llm()
|
cfg.sync_runtime_llm()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user