update
This commit is contained in:
parent
6aece20816
commit
ab6180646a
13
README.md
13
README.md
@ -24,7 +24,7 @@
|
||||
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
||||
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
||||
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
||||
- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。
|
||||
- **LLM Provider 管理**:`app/utils/config.py` 集中维护供应商的 URL、API Key、可用模型及默认参数,Streamlit UI 可视化配置,全局与部门直接在 Provider 基础上设置模型、温度与 Prompt。
|
||||
|
||||
## LLM + 多智能体最佳实践
|
||||
|
||||
@ -60,11 +60,10 @@ export TUSHARE_TOKEN="<your-token>"
|
||||
|
||||
### LLM 配置与测试
|
||||
|
||||
- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。
|
||||
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
|
||||
- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
|
||||
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
|
||||
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
|
||||
- 通过 Provider 管理供应商连接参数(Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务(DeepSeek、文心一言、OpenAI 等)。
|
||||
- 全局与部门配置直接选择 Provider,并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。
|
||||
- Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。
|
||||
- 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。
|
||||
|
||||
## 快速开始
|
||||
|
||||
@ -104,7 +103,7 @@ Streamlit `自检测试` 页签提供:
|
||||
## 实施步骤
|
||||
|
||||
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
||||
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。
|
||||
- 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/Prompt;Streamlit 提供可视化维护表单。
|
||||
|
||||
2. **部门管控器** ✅
|
||||
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
||||
|
||||
@ -131,8 +131,6 @@ class DepartmentManager:
|
||||
return results
|
||||
|
||||
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
||||
if settings.llm_route and settings.llm_route in self.config.llm_routes:
|
||||
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
|
||||
return settings.llm
|
||||
|
||||
|
||||
|
||||
@ -102,24 +102,58 @@ def _request_openai(
|
||||
|
||||
|
||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||
provider = (endpoint.provider or "ollama").lower()
|
||||
base_url = endpoint.base_url or _default_base_url(provider)
|
||||
model = endpoint.model or _default_model(provider)
|
||||
temperature = max(0.0, min(endpoint.temperature, 2.0))
|
||||
timeout = max(5.0, endpoint.timeout or 30.0)
|
||||
cfg = get_config()
|
||||
provider_key = (endpoint.provider or "ollama").lower()
|
||||
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||
|
||||
base_url = endpoint.base_url
|
||||
api_key = endpoint.api_key
|
||||
model = endpoint.model
|
||||
temperature = endpoint.temperature
|
||||
timeout = endpoint.timeout
|
||||
prompt_template = endpoint.prompt_template
|
||||
|
||||
if provider_cfg:
|
||||
if not provider_cfg.enabled:
|
||||
raise LLMError(f"Provider {provider_key} 已被禁用")
|
||||
base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key)
|
||||
api_key = api_key or provider_cfg.api_key
|
||||
model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key))
|
||||
if temperature is None:
|
||||
temperature = provider_cfg.default_temperature
|
||||
if timeout is None:
|
||||
timeout = provider_cfg.default_timeout
|
||||
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
|
||||
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
|
||||
else:
|
||||
base_url = base_url or _default_base_url(provider_key)
|
||||
model = model or _default_model(provider_key)
|
||||
if temperature is None:
|
||||
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
||||
mode = "ollama" if provider_key == "ollama" else "openai"
|
||||
|
||||
temperature = max(0.0, min(float(temperature), 2.0))
|
||||
timeout = max(5.0, float(timeout))
|
||||
|
||||
if prompt_template:
|
||||
try:
|
||||
prompt = prompt_template.format(prompt=prompt)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
|
||||
|
||||
LOGGER.info(
|
||||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||||
provider,
|
||||
provider_key,
|
||||
model,
|
||||
base_url,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if provider in {"openai", "deepseek", "wenxin"}:
|
||||
api_key = endpoint.api_key
|
||||
if mode != "ollama":
|
||||
if not api_key:
|
||||
raise LLMError(f"缺少 {provider} API Key (model={model})")
|
||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||
return _request_openai(
|
||||
model,
|
||||
prompt,
|
||||
@ -129,7 +163,7 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
||||
timeout=timeout,
|
||||
system=system,
|
||||
)
|
||||
if provider == "ollama":
|
||||
if base_url:
|
||||
return _request_ollama(
|
||||
model,
|
||||
prompt,
|
||||
@ -283,13 +317,17 @@ def llm_config_snapshot() -> Dict[str, object]:
|
||||
if record.get("api_key"):
|
||||
record["api_key"] = "***"
|
||||
ensemble.append(record)
|
||||
route_name = cfg.llm_route
|
||||
route_obj = cfg.llm_routes.get(route_name)
|
||||
return {
|
||||
"route": route_name,
|
||||
"route_detail": route_obj.to_dict() if route_obj else None,
|
||||
"strategy": settings.strategy,
|
||||
"majority_threshold": settings.majority_threshold,
|
||||
"primary": primary,
|
||||
"ensemble": ensemble,
|
||||
"providers": {
|
||||
key: {
|
||||
"base_url": provider.base_url,
|
||||
"default_model": provider.default_model,
|
||||
"enabled": provider.enabled,
|
||||
}
|
||||
for key, provider in cfg.llm_providers.items()
|
||||
},
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ import sys
|
||||
from dataclasses import asdict
|
||||
from datetime import date, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
@ -16,6 +16,8 @@ import json
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
import streamlit as st
|
||||
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
@ -29,8 +31,8 @@ from app.utils.config import (
|
||||
DEFAULT_LLM_MODEL_OPTIONS,
|
||||
DEFAULT_LLM_MODELS,
|
||||
DepartmentSettings,
|
||||
LLMProfile,
|
||||
LLMRoute,
|
||||
LLMEndpoint,
|
||||
LLMProvider,
|
||||
get_config,
|
||||
save_config,
|
||||
)
|
||||
@ -42,6 +44,50 @@ LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "ui"}
|
||||
|
||||
|
||||
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||
"""Attempt to query provider API and return available model ids."""
|
||||
|
||||
base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip()
|
||||
if not base_url:
|
||||
return [], "请先填写 Base URL"
|
||||
timeout = float(provider.default_timeout or 30.0)
|
||||
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
||||
|
||||
try:
|
||||
if mode == "ollama":
|
||||
url = base_url.rstrip('/') + "/api/tags"
|
||||
response = requests.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = []
|
||||
for item in data.get("models", []) or data.get("data", []):
|
||||
name = item.get("name") or item.get("model") or item.get("tag")
|
||||
if name:
|
||||
models.append(str(name).strip())
|
||||
return sorted(set(models)), None
|
||||
|
||||
api_key = (api_override or provider.api_key or "").strip()
|
||||
if not api_key:
|
||||
return [], "缺少 API Key"
|
||||
url = base_url.rstrip('/') + "/v1/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
models = [
|
||||
str(item.get("id")).strip()
|
||||
for item in payload.get("data", [])
|
||||
if item.get("id")
|
||||
]
|
||||
return sorted(set(models)), None
|
||||
except RequestException as exc: # noqa: BLE001
|
||||
return [], f"HTTP 错误:{exc}"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return [], f"解析失败:{exc}"
|
||||
|
||||
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
@ -350,255 +396,267 @@ def render_settings() -> None:
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM 设置")
|
||||
profiles = cfg.llm_profiles or {}
|
||||
routes = cfg.llm_routes or {}
|
||||
profile_keys = sorted(profiles.keys())
|
||||
route_keys = sorted(routes.keys())
|
||||
used_routes = {
|
||||
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
|
||||
}
|
||||
st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。")
|
||||
providers = cfg.llm_providers
|
||||
provider_keys = sorted(providers.keys())
|
||||
st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。")
|
||||
|
||||
route_select_col, route_manage_col = st.columns([3, 1])
|
||||
if route_keys:
|
||||
# Provider management -------------------------------------------------
|
||||
provider_select_col, provider_manage_col = st.columns([3, 1])
|
||||
if provider_keys:
|
||||
try:
|
||||
active_index = route_keys.index(cfg.llm_route)
|
||||
default_provider = cfg.llm.primary.provider or provider_keys[0]
|
||||
provider_index = provider_keys.index(default_provider)
|
||||
except ValueError:
|
||||
active_index = 0
|
||||
selected_route = route_select_col.selectbox(
|
||||
"全局路由",
|
||||
route_keys,
|
||||
index=active_index,
|
||||
key="llm_route_select",
|
||||
provider_index = 0
|
||||
selected_provider = provider_select_col.selectbox(
|
||||
"选择 Provider",
|
||||
provider_keys,
|
||||
index=provider_index,
|
||||
key="llm_provider_select",
|
||||
)
|
||||
else:
|
||||
selected_route = None
|
||||
route_select_col.info("尚未配置路由,请先创建。")
|
||||
selected_provider = None
|
||||
provider_select_col.info("尚未配置 Provider,请先创建。")
|
||||
|
||||
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
|
||||
if route_manage_col.button("添加路由"):
|
||||
key = (new_route_name or "").strip()
|
||||
new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name")
|
||||
if provider_manage_col.button("创建 Provider", key="create_provider_btn"):
|
||||
key = (new_provider_name or "").strip().lower()
|
||||
if not key:
|
||||
st.warning("请输入有效的路由名称。")
|
||||
elif key in routes:
|
||||
st.warning(f"路由 {key} 已存在。")
|
||||
st.warning("请输入有效的 Provider 名称。")
|
||||
elif key in providers:
|
||||
st.warning(f"Provider {key} 已存在。")
|
||||
else:
|
||||
routes[key] = LLMRoute(name=key)
|
||||
if not selected_route:
|
||||
selected_route = key
|
||||
cfg.llm_route = key
|
||||
providers[key] = LLMProvider(key=key)
|
||||
cfg.llm_providers = providers
|
||||
save_config()
|
||||
st.success(f"已添加路由 {key},请继续配置。")
|
||||
st.success(f"已创建 Provider {key}。")
|
||||
st.experimental_rerun()
|
||||
|
||||
if selected_route:
|
||||
route_obj = routes.get(selected_route)
|
||||
if route_obj is None:
|
||||
route_obj = LLMRoute(name=selected_route)
|
||||
routes[selected_route] = route_obj
|
||||
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
|
||||
if selected_provider:
|
||||
provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
|
||||
title_key = f"provider_title_{selected_provider}"
|
||||
base_key = f"provider_base_{selected_provider}"
|
||||
api_key_key = f"provider_api_{selected_provider}"
|
||||
models_key = f"provider_models_{selected_provider}"
|
||||
default_model_key = f"provider_default_model_{selected_provider}"
|
||||
mode_key = f"provider_mode_{selected_provider}"
|
||||
temp_key = f"provider_temp_{selected_provider}"
|
||||
timeout_key = f"provider_timeout_{selected_provider}"
|
||||
prompt_key = f"provider_prompt_{selected_provider}"
|
||||
enabled_key = f"provider_enabled_{selected_provider}"
|
||||
|
||||
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
|
||||
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com")
|
||||
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
|
||||
models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100)
|
||||
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key)
|
||||
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
|
||||
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
|
||||
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
|
||||
prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120)
|
||||
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
|
||||
|
||||
fetch_key = f"fetch_models_{selected_provider}"
|
||||
if st.button("获取模型列表", key=fetch_key):
|
||||
with st.spinner("正在获取模型列表..."):
|
||||
models, error = _discover_provider_models(provider_cfg, base_val, api_val)
|
||||
if error:
|
||||
st.error(error)
|
||||
else:
|
||||
provider_cfg.models = models
|
||||
if models and (not provider_cfg.default_model or provider_cfg.default_model not in models):
|
||||
provider_cfg.default_model = models[0]
|
||||
providers[selected_provider] = provider_cfg
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success(f"共获取 {len(models)} 个模型。")
|
||||
st.session_state[models_key] = "\n".join(models)
|
||||
st.session_state[default_model_key] = provider_cfg.default_model or ""
|
||||
|
||||
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
|
||||
provider_cfg.title = title_val.strip()
|
||||
provider_cfg.base_url = base_val.strip()
|
||||
provider_cfg.api_key = api_val.strip() or None
|
||||
provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()]
|
||||
provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model)
|
||||
provider_cfg.default_temperature = float(temp_val)
|
||||
provider_cfg.default_timeout = float(timeout_val)
|
||||
provider_cfg.prompt_template = prompt_template_val.strip()
|
||||
provider_cfg.enabled = enabled_val
|
||||
provider_cfg.mode = mode_val
|
||||
providers[selected_provider] = provider_cfg
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Provider 已保存。")
|
||||
st.session_state[title_key] = provider_cfg.title or ""
|
||||
st.session_state[default_model_key] = provider_cfg.default_model or ""
|
||||
|
||||
provider_in_use = (cfg.llm.primary.provider == selected_provider) or any(
|
||||
ep.provider == selected_provider for ep in cfg.llm.ensemble
|
||||
)
|
||||
if not provider_in_use:
|
||||
for dept in cfg.departments.values():
|
||||
if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble):
|
||||
provider_in_use = True
|
||||
break
|
||||
if st.button(
|
||||
"删除 Provider",
|
||||
key=f"delete_provider_{selected_provider}",
|
||||
disabled=provider_in_use or len(providers) <= 1,
|
||||
):
|
||||
providers.pop(selected_provider, None)
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Provider 已删除。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.markdown("##### 全局推理配置")
|
||||
if not provider_keys:
|
||||
st.warning("请先配置至少一个 Provider。")
|
||||
else:
|
||||
global_cfg = cfg.llm
|
||||
primary = global_cfg.primary
|
||||
try:
|
||||
strategy_index = strategy_choices.index(route_obj.strategy)
|
||||
provider_index = provider_keys.index(primary.provider or provider_keys[0])
|
||||
except ValueError:
|
||||
strategy_index = 0
|
||||
route_title = st.text_input(
|
||||
"路由说明",
|
||||
value=route_obj.title or "",
|
||||
key=f"route_title_{selected_route}",
|
||||
provider_index = 0
|
||||
selected_global_provider = st.selectbox(
|
||||
"主模型 Provider",
|
||||
provider_keys,
|
||||
index=provider_index,
|
||||
key="global_provider_select",
|
||||
)
|
||||
route_strategy = st.selectbox(
|
||||
"推理策略",
|
||||
strategy_choices,
|
||||
index=strategy_index,
|
||||
key=f"route_strategy_{selected_route}",
|
||||
provider_cfg = providers.get(selected_global_provider)
|
||||
available_models = provider_cfg.models if provider_cfg else []
|
||||
default_model = primary.model or (provider_cfg.default_model if provider_cfg else None)
|
||||
if available_models:
|
||||
options = available_models + ["自定义"]
|
||||
try:
|
||||
model_index = available_models.index(default_model)
|
||||
model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice")
|
||||
except ValueError:
|
||||
model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice")
|
||||
if model_choice == "自定义":
|
||||
model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip()
|
||||
else:
|
||||
model_val = model_choice
|
||||
else:
|
||||
model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip()
|
||||
|
||||
temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2)
|
||||
temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp")
|
||||
timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0)
|
||||
timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout")
|
||||
prompt_template_val = st.text_area(
|
||||
"主模型 Prompt 模板(可选)",
|
||||
value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "",
|
||||
height=120,
|
||||
key="global_prompt_template",
|
||||
)
|
||||
route_majority = st.number_input(
|
||||
|
||||
strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy")
|
||||
show_ensemble = strategy_val != "single"
|
||||
majority_threshold_val = st.number_input(
|
||||
"多数投票门槛",
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
value=int(route_obj.majority_threshold or 1),
|
||||
value=int(global_cfg.majority_threshold),
|
||||
step=1,
|
||||
key=f"route_majority_{selected_route}",
|
||||
key="global_majority",
|
||||
disabled=not show_ensemble,
|
||||
)
|
||||
if not profile_keys:
|
||||
st.warning("暂无可用 Profile,请先在下方创建。")
|
||||
if not show_ensemble:
|
||||
majority_threshold_val = 1
|
||||
|
||||
ensemble_rows: List[Dict[str, str]] = []
|
||||
if show_ensemble:
|
||||
ensemble_rows = [
|
||||
{
|
||||
"provider": ep.provider,
|
||||
"model": ep.model or "",
|
||||
"temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}",
|
||||
"timeout": "" if ep.timeout is None else str(int(ep.timeout)),
|
||||
"prompt_template": ep.prompt_template or "",
|
||||
}
|
||||
for ep in global_cfg.ensemble
|
||||
] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}]
|
||||
|
||||
ensemble_editor = st.data_editor(
|
||||
ensemble_rows,
|
||||
num_rows="dynamic",
|
||||
key="global_ensemble_editor",
|
||||
use_container_width=True,
|
||||
hide_index=True,
|
||||
column_config={
|
||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
||||
"model": st.column_config.TextColumn("模型"),
|
||||
"temperature": st.column_config.TextColumn("温度"),
|
||||
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||
},
|
||||
)
|
||||
if hasattr(ensemble_editor, "to_dict"):
|
||||
ensemble_rows = ensemble_editor.to_dict("records")
|
||||
else:
|
||||
try:
|
||||
primary_index = profile_keys.index(route_obj.primary)
|
||||
except ValueError:
|
||||
primary_index = 0
|
||||
primary_key = st.selectbox(
|
||||
"主用 Profile",
|
||||
profile_keys,
|
||||
index=primary_index,
|
||||
key=f"route_primary_{selected_route}",
|
||||
)
|
||||
default_ensemble = [
|
||||
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
|
||||
]
|
||||
ensemble_keys = st.multiselect(
|
||||
"协作 Profile (可多选)",
|
||||
profile_keys,
|
||||
default=default_ensemble,
|
||||
key=f"route_ensemble_{selected_route}",
|
||||
)
|
||||
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
|
||||
route_obj.title = route_title.strip()
|
||||
route_obj.strategy = route_strategy
|
||||
route_obj.majority_threshold = int(route_majority)
|
||||
route_obj.primary = primary_key
|
||||
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
|
||||
cfg.llm_route = selected_route
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
LOGGER.info(
|
||||
"路由 %s 配置更新:%s",
|
||||
selected_route,
|
||||
route_obj.to_dict(),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
st.success("路由配置已保存。")
|
||||
st.json({
|
||||
"route": selected_route,
|
||||
"route_detail": route_obj.to_dict(),
|
||||
"resolved": llm_config_snapshot(),
|
||||
})
|
||||
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
|
||||
if st.button(
|
||||
"删除当前路由",
|
||||
key=f"delete_route_{selected_route}",
|
||||
disabled=route_in_use or len(routes) <= 1,
|
||||
):
|
||||
routes.pop(selected_route, None)
|
||||
if cfg.llm_route == selected_route:
|
||||
cfg.llm_route = next((key for key in routes.keys()), "global")
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("路由已删除。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM Profile 管理")
|
||||
profile_select_col, profile_manage_col = st.columns([3, 1])
|
||||
if profile_keys:
|
||||
selected_profile = profile_select_col.selectbox(
|
||||
"选择 Profile",
|
||||
profile_keys,
|
||||
index=0,
|
||||
key="profile_select",
|
||||
)
|
||||
ensemble_rows = ensemble_editor
|
||||
else:
|
||||
selected_profile = None
|
||||
profile_select_col.info("尚未配置 Profile,请先创建。")
|
||||
st.info("当前策略为单模型,未启用协作模型。")
|
||||
|
||||
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
|
||||
if profile_manage_col.button("创建 Profile"):
|
||||
key = (new_profile_name or "").strip()
|
||||
if not key:
|
||||
st.warning("请输入有效的 Profile 名称。")
|
||||
elif key in profiles:
|
||||
st.warning(f"Profile {key} 已存在。")
|
||||
else:
|
||||
profiles[key] = LLMProfile(key=key)
|
||||
save_config()
|
||||
st.success(f"已创建 Profile {key}。")
|
||||
st.experimental_rerun()
|
||||
if st.button("保存全局配置", key="save_global_llm"):
|
||||
primary.provider = selected_global_provider
|
||||
primary.model = model_val or None
|
||||
primary.temperature = float(temp_val)
|
||||
primary.timeout = float(timeout_val)
|
||||
primary.prompt_template = prompt_template_val.strip() or None
|
||||
primary.base_url = None
|
||||
primary.api_key = None
|
||||
|
||||
if selected_profile:
|
||||
profile = profiles[selected_profile]
|
||||
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
||||
try:
|
||||
provider_index = provider_choices.index(profile.provider)
|
||||
except ValueError:
|
||||
provider_index = 0
|
||||
with st.form(f"profile_form_{selected_profile}"):
|
||||
provider_val = st.selectbox(
|
||||
"Provider",
|
||||
provider_choices,
|
||||
index=provider_index,
|
||||
new_ensemble: List[LLMEndpoint] = []
|
||||
if show_ensemble:
|
||||
for row in ensemble_rows:
|
||||
provider_val = (row.get("provider") or "").strip().lower()
|
||||
if not provider_val:
|
||||
continue
|
||||
model_raw = (row.get("model") or "").strip() or None
|
||||
temp_raw = (row.get("temperature") or "").strip()
|
||||
timeout_raw = (row.get("timeout") or "").strip()
|
||||
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||
new_ensemble.append(
|
||||
LLMEndpoint(
|
||||
provider=provider_val,
|
||||
model=model_raw,
|
||||
temperature=float(temp_raw) if temp_raw else None,
|
||||
timeout=float(timeout_raw) if timeout_raw else None,
|
||||
prompt_template=prompt_raw or None,
|
||||
)
|
||||
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
|
||||
model_val = st.text_input(
|
||||
"模型",
|
||||
value=profile.model or model_default,
|
||||
)
|
||||
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
|
||||
base_val = st.text_input(
|
||||
"Base URL",
|
||||
value=profile.base_url or base_default,
|
||||
)
|
||||
api_val = st.text_input(
|
||||
"API Key",
|
||||
value=profile.api_key or "",
|
||||
type="password",
|
||||
)
|
||||
temp_val = st.slider(
|
||||
"温度",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=float(profile.temperature),
|
||||
step=0.05,
|
||||
)
|
||||
timeout_val = st.number_input(
|
||||
"超时(秒)",
|
||||
min_value=5,
|
||||
max_value=180,
|
||||
value=int(profile.timeout or 30),
|
||||
step=5,
|
||||
)
|
||||
title_val = st.text_input("备注", value=profile.title or "")
|
||||
enabled_val = st.checkbox("启用", value=profile.enabled)
|
||||
submitted = st.form_submit_button("保存 Profile")
|
||||
if submitted:
|
||||
profile.provider = provider_val
|
||||
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
|
||||
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||
profile.api_key = api_val.strip() or None
|
||||
profile.temperature = temp_val
|
||||
profile.timeout = timeout_val
|
||||
profile.title = title_val.strip()
|
||||
profile.enabled = enabled_val
|
||||
profiles[selected_profile] = profile
|
||||
cfg.llm.ensemble = new_ensemble
|
||||
cfg.llm.strategy = strategy_val
|
||||
cfg.llm.majority_threshold = int(majority_threshold_val)
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Profile 已保存。")
|
||||
|
||||
profile_in_use = any(
|
||||
selected_profile == route.primary or selected_profile in route.ensemble
|
||||
for route in routes.values()
|
||||
)
|
||||
if st.button(
|
||||
"删除该 Profile",
|
||||
key=f"delete_profile_{selected_profile}",
|
||||
disabled=profile_in_use or len(profiles) <= 1,
|
||||
):
|
||||
profiles.pop(selected_profile, None)
|
||||
fallback_key = next((key for key in profiles.keys()), None)
|
||||
for route in routes.values():
|
||||
if route.primary == selected_profile:
|
||||
route.primary = fallback_key or route.primary
|
||||
route.ensemble = [key for key in route.ensemble if key != selected_profile]
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Profile 已删除。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
st.subheader("部门配置")
|
||||
st.success("全局 LLM 配置已保存。")
|
||||
st.json(llm_config_snapshot())
|
||||
|
||||
# Department configuration -------------------------------------------
|
||||
st.markdown("##### 部门配置")
|
||||
dept_settings = cfg.departments or {}
|
||||
route_options_display = [""] + route_keys
|
||||
dept_rows = [
|
||||
{
|
||||
"code": code,
|
||||
"title": dept.title,
|
||||
"description": dept.description,
|
||||
"weight": float(dept.weight),
|
||||
"llm_route": dept.llm_route or "",
|
||||
"strategy": dept.llm.strategy,
|
||||
"primary_provider": (dept.llm.primary.provider or ""),
|
||||
"primary_model": dept.llm.primary.model or "",
|
||||
"ensemble_size": len(dept.llm.ensemble),
|
||||
"majority_threshold": dept.llm.majority_threshold,
|
||||
"provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""),
|
||||
"model": dept.llm.primary.model or "",
|
||||
"temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}",
|
||||
"timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)),
|
||||
"prompt_template": dept.llm.primary.prompt_template or "",
|
||||
}
|
||||
for code, dept in sorted(dept_settings.items())
|
||||
]
|
||||
@ -618,26 +676,13 @@ def render_settings() -> None:
|
||||
"title": st.column_config.TextColumn("名称"),
|
||||
"description": st.column_config.TextColumn("说明"),
|
||||
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
||||
"llm_route": st.column_config.SelectboxColumn(
|
||||
"路由",
|
||||
options=route_options_display,
|
||||
help="选择预定义路由;留空表示使用自定义配置",
|
||||
),
|
||||
"strategy": st.column_config.SelectboxColumn(
|
||||
"自定义策略",
|
||||
options=sorted(ALLOWED_LLM_STRATEGIES),
|
||||
help="仅当未选择路由时生效",
|
||||
),
|
||||
"primary_provider": st.column_config.SelectboxColumn(
|
||||
"自定义 Provider",
|
||||
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
||||
),
|
||||
"primary_model": st.column_config.TextColumn("自定义模型"),
|
||||
"ensemble_size": st.column_config.NumberColumn(
|
||||
"协作模型数量",
|
||||
disabled=True,
|
||||
help="路由模式下自动维护",
|
||||
),
|
||||
"strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)),
|
||||
"majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1),
|
||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]),
|
||||
"model": st.column_config.TextColumn("模型"),
|
||||
"temperature": st.column_config.TextColumn("温度"),
|
||||
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||
},
|
||||
)
|
||||
|
||||
@ -662,25 +707,34 @@ def render_settings() -> None:
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
route_name = (row.get("llm_route") or "").strip() or None
|
||||
existing.llm_route = route_name
|
||||
if route_name and route_name in routes:
|
||||
existing.llm = routes[route_name].resolve(profiles)
|
||||
else:
|
||||
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
||||
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
||||
existing.llm.strategy = strategy_val
|
||||
provider_before = existing.llm.primary.provider or ""
|
||||
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
|
||||
existing.llm.primary.provider = provider_val
|
||||
model_val = (row.get("primary_model") or "").strip()
|
||||
existing.llm.primary.model = (
|
||||
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
||||
)
|
||||
if provider_before != provider_val:
|
||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
|
||||
existing.llm.primary.__post_init__()
|
||||
majority_raw = row.get("majority_threshold")
|
||||
try:
|
||||
majority_val = int(majority_raw)
|
||||
if majority_val > 0:
|
||||
existing.llm.majority_threshold = majority_val
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
|
||||
model_val = (row.get("model") or "").strip() or None
|
||||
temp_raw = (row.get("temperature") or "").strip()
|
||||
timeout_raw = (row.get("timeout") or "").strip()
|
||||
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||
|
||||
endpoint = existing.llm.primary or LLMEndpoint()
|
||||
endpoint.provider = provider_val
|
||||
endpoint.model = model_val
|
||||
endpoint.temperature = float(temp_raw) if temp_raw else None
|
||||
endpoint.timeout = float(timeout_raw) if timeout_raw else None
|
||||
endpoint.prompt_template = prompt_raw or None
|
||||
endpoint.base_url = None
|
||||
endpoint.api_key = None
|
||||
existing.llm.primary = endpoint
|
||||
existing.llm.ensemble = []
|
||||
|
||||
updated_departments[code] = existing
|
||||
|
||||
if updated_departments:
|
||||
@ -700,8 +754,7 @@ def render_settings() -> None:
|
||||
st.success("已恢复默认部门配置。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
|
||||
st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。")
|
||||
st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")
|
||||
|
||||
|
||||
def render_tests() -> None:
|
||||
|
||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def _default_root() -> Path:
|
||||
@ -99,27 +99,55 @@ ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"}
|
||||
LLM_STRATEGY_ALIASES = {"leader-follower": "leader"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProvider:
|
||||
"""Provider level configuration shared across profiles and routes."""
|
||||
|
||||
key: str
|
||||
title: str = ""
|
||||
base_url: str = ""
|
||||
api_key: Optional[str] = None
|
||||
models: List[str] = field(default_factory=list)
|
||||
default_model: Optional[str] = None
|
||||
default_temperature: float = 0.2
|
||||
default_timeout: float = 30.0
|
||||
prompt_template: str = ""
|
||||
enabled: bool = True
|
||||
mode: str = "openai" # openai 或 ollama
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"base_url": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"models": list(self.models),
|
||||
"default_model": self.default_model,
|
||||
"default_temperature": self.default_temperature,
|
||||
"default_timeout": self.default_timeout,
|
||||
"prompt_template": self.prompt_template,
|
||||
"enabled": self.enabled,
|
||||
"mode": self.mode,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMEndpoint:
|
||||
"""Single LLM endpoint configuration."""
|
||||
"""Resolved endpoint payload used for actual LLM calls."""
|
||||
|
||||
provider: str = "ollama"
|
||||
model: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
temperature: float = 0.2
|
||||
timeout: float = 30.0
|
||||
temperature: Optional[float] = None
|
||||
timeout: Optional[float] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.provider = (self.provider or "ollama").lower()
|
||||
if not self.model:
|
||||
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
|
||||
if not self.base_url:
|
||||
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider)
|
||||
if self.temperature == 0.2 or self.temperature is None:
|
||||
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
|
||||
if self.timeout == 30.0 or self.timeout is None:
|
||||
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
|
||||
if self.temperature is not None:
|
||||
self.temperature = float(self.temperature)
|
||||
if self.timeout is not None:
|
||||
self.timeout = float(self.timeout)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -132,133 +160,22 @@ class LLMConfig:
|
||||
majority_threshold: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProfile:
|
||||
"""Named LLM endpoint profile reusable across routes/departments."""
|
||||
|
||||
key: str
|
||||
provider: str = "ollama"
|
||||
model: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
temperature: float = 0.2
|
||||
timeout: float = 30.0
|
||||
title: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
def to_endpoint(self) -> LLMEndpoint:
|
||||
return LLMEndpoint(
|
||||
provider=self.provider,
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"temperature": self.temperature,
|
||||
"timeout": self.timeout,
|
||||
"title": self.title,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_endpoint(
|
||||
cls,
|
||||
key: str,
|
||||
endpoint: LLMEndpoint,
|
||||
*,
|
||||
title: str = "",
|
||||
enabled: bool = True,
|
||||
) -> "LLMProfile":
|
||||
return cls(
|
||||
key=key,
|
||||
provider=endpoint.provider,
|
||||
model=endpoint.model,
|
||||
base_url=endpoint.base_url,
|
||||
api_key=endpoint.api_key,
|
||||
temperature=endpoint.temperature,
|
||||
timeout=endpoint.timeout,
|
||||
title=title,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRoute:
|
||||
"""Declarative routing for selecting profiles and strategy."""
|
||||
|
||||
name: str
|
||||
title: str = ""
|
||||
strategy: str = "single"
|
||||
majority_threshold: int = 3
|
||||
primary: str = "ollama"
|
||||
ensemble: List[str] = field(default_factory=list)
|
||||
|
||||
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
|
||||
def _endpoint_from_key(key: str) -> LLMEndpoint:
|
||||
profile = profiles.get(key)
|
||||
if profile and profile.enabled:
|
||||
return profile.to_endpoint()
|
||||
fallback = profiles.get("ollama")
|
||||
if not fallback or not fallback.enabled:
|
||||
fallback = next(
|
||||
(item for item in profiles.values() if item.enabled),
|
||||
None,
|
||||
)
|
||||
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
|
||||
endpoint.provider = key or endpoint.provider
|
||||
return endpoint
|
||||
|
||||
primary_endpoint = _endpoint_from_key(self.primary)
|
||||
ensemble_endpoints = [
|
||||
_endpoint_from_key(key)
|
||||
for key in self.ensemble
|
||||
if key in profiles and profiles[key].enabled
|
||||
]
|
||||
config = LLMConfig(
|
||||
primary=primary_endpoint,
|
||||
ensemble=ensemble_endpoints,
|
||||
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
majority_threshold=max(1, self.majority_threshold or 1),
|
||||
)
|
||||
return config
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"strategy": self.strategy,
|
||||
"majority_threshold": self.majority_threshold,
|
||||
"primary": self.primary,
|
||||
"ensemble": list(self.ensemble),
|
||||
}
|
||||
|
||||
|
||||
def _default_llm_profiles() -> Dict[str, LLMProfile]:
|
||||
return {
|
||||
provider: LLMProfile(
|
||||
def _default_llm_providers() -> Dict[str, LLMProvider]:
|
||||
providers: Dict[str, LLMProvider] = {}
|
||||
for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items():
|
||||
models = list(meta.get("models", []))
|
||||
mode = "ollama" if provider == "ollama" else "openai"
|
||||
providers[provider] = LLMProvider(
|
||||
key=provider,
|
||||
provider=provider,
|
||||
model=DEFAULT_LLM_MODELS.get(provider),
|
||||
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
|
||||
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
|
||||
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
|
||||
title=f"默认 {provider}",
|
||||
base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""),
|
||||
models=models,
|
||||
default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider),
|
||||
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||
mode=mode,
|
||||
)
|
||||
for provider in DEFAULT_LLM_MODEL_OPTIONS
|
||||
}
|
||||
|
||||
|
||||
def _default_llm_routes() -> Dict[str, LLMRoute]:
|
||||
return {
|
||||
"global": LLMRoute(name="global", title="全局默认路由"),
|
||||
}
|
||||
return providers
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -270,7 +187,6 @@ class DepartmentSettings:
|
||||
description: str = ""
|
||||
weight: float = 1.0
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
llm_route: Optional[str] = None
|
||||
|
||||
|
||||
def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||
@ -282,10 +198,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||
("macro", "宏观研究部门"),
|
||||
("risk", "风险控制部门"),
|
||||
]
|
||||
return {
|
||||
code: DepartmentSettings(code=code, title=title, llm_route="global")
|
||||
for code, title in presets
|
||||
}
|
||||
return {code: DepartmentSettings(code=code, title=title) for code, title in presets}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -298,17 +211,11 @@ class AppConfig:
|
||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
force_refresh: bool = False
|
||||
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
llm_route: str = "global"
|
||||
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
|
||||
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
|
||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||
|
||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||
route_key = route or self.llm_route
|
||||
route_cfg = self.llm_routes.get(route_key)
|
||||
if route_cfg:
|
||||
return route_cfg.resolve(self.llm_profiles)
|
||||
return self.llm
|
||||
|
||||
def sync_runtime_llm(self) -> None:
|
||||
@ -326,13 +233,22 @@ def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
"api_key": endpoint.api_key,
|
||||
"temperature": endpoint.temperature,
|
||||
"timeout": endpoint.timeout,
|
||||
"prompt_template": endpoint.prompt_template,
|
||||
}
|
||||
|
||||
|
||||
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
|
||||
payload = {
|
||||
key: data.get(key)
|
||||
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")
|
||||
for key in (
|
||||
"provider",
|
||||
"model",
|
||||
"base_url",
|
||||
"api_key",
|
||||
"temperature",
|
||||
"timeout",
|
||||
"prompt_template",
|
||||
)
|
||||
if data.get(key) is not None
|
||||
}
|
||||
return LLMEndpoint(**payload)
|
||||
@ -348,7 +264,9 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return
|
||||
|
||||
if isinstance(payload, dict):
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
if "tushare_token" in payload:
|
||||
cfg.tushare_token = payload.get("tushare_token") or None
|
||||
if "force_refresh" in payload:
|
||||
@ -356,111 +274,110 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if "decision_method" in payload:
|
||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||
|
||||
routes_defined = False
|
||||
inline_primary_loaded = False
|
||||
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
||||
legacy_routes: Dict[str, Dict[str, object]] = {}
|
||||
|
||||
providers_payload = payload.get("llm_providers")
|
||||
if isinstance(providers_payload, dict):
|
||||
providers: Dict[str, LLMProvider] = {}
|
||||
for key, data in providers_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
models_raw = data.get("models")
|
||||
if isinstance(models_raw, str):
|
||||
models = [item.strip() for item in models_raw.split(',') if item.strip()]
|
||||
elif isinstance(models_raw, list):
|
||||
models = [str(item).strip() for item in models_raw if str(item).strip()]
|
||||
else:
|
||||
models = []
|
||||
provider = LLMProvider(
|
||||
key=str(key).lower(),
|
||||
title=str(data.get("title") or ""),
|
||||
base_url=str(data.get("base_url") or ""),
|
||||
api_key=data.get("api_key"),
|
||||
models=models,
|
||||
default_model=data.get("default_model") or (models[0] if models else None),
|
||||
default_temperature=float(data.get("default_temperature", 0.2)),
|
||||
default_timeout=float(data.get("default_timeout", 30.0)),
|
||||
prompt_template=str(data.get("prompt_template") or ""),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")),
|
||||
)
|
||||
providers[provider.key] = provider
|
||||
if providers:
|
||||
cfg.llm_providers = providers
|
||||
|
||||
profiles_payload = payload.get("llm_profiles")
|
||||
if isinstance(profiles_payload, dict):
|
||||
profiles: Dict[str, LLMProfile] = {}
|
||||
for key, data in profiles_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
provider = str(data.get("provider") or "ollama").lower()
|
||||
profile = LLMProfile(
|
||||
key=key,
|
||||
provider=provider,
|
||||
model=data.get("model"),
|
||||
base_url=data.get("base_url"),
|
||||
api_key=data.get("api_key"),
|
||||
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||
title=str(data.get("title") or ""),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
)
|
||||
profiles[key] = profile
|
||||
if profiles:
|
||||
cfg.llm_profiles = profiles
|
||||
if isinstance(data, dict):
|
||||
legacy_profiles[str(key)] = data
|
||||
|
||||
routes_payload = payload.get("llm_routes")
|
||||
if isinstance(routes_payload, dict):
|
||||
routes: Dict[str, LLMRoute] = {}
|
||||
for name, data in routes_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
strategy_raw = str(data.get("strategy") or "single").lower()
|
||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||
route = LLMRoute(
|
||||
name=name,
|
||||
title=str(data.get("title") or ""),
|
||||
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
|
||||
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
|
||||
primary=str(data.get("primary") or "global"),
|
||||
ensemble=[
|
||||
str(item)
|
||||
for item in data.get("ensemble", [])
|
||||
if isinstance(item, str)
|
||||
],
|
||||
)
|
||||
routes[name] = route
|
||||
if routes:
|
||||
cfg.llm_routes = routes
|
||||
routes_defined = True
|
||||
if isinstance(data, dict):
|
||||
legacy_routes[str(name)] = data
|
||||
|
||||
route_key = payload.get("llm_route")
|
||||
if isinstance(route_key, str) and route_key:
|
||||
cfg.llm_route = route_key
|
||||
def _endpoint_from_payload(item: object) -> LLMEndpoint:
|
||||
if isinstance(item, dict):
|
||||
return _dict_to_endpoint(item)
|
||||
if isinstance(item, str):
|
||||
profile_data = legacy_profiles.get(item)
|
||||
if isinstance(profile_data, dict):
|
||||
return _dict_to_endpoint(profile_data)
|
||||
return LLMEndpoint(provider=item)
|
||||
return LLMEndpoint()
|
||||
|
||||
def _resolve_route(route_name: str) -> Optional[LLMConfig]:
|
||||
route_data = legacy_routes.get(route_name)
|
||||
if not route_data:
|
||||
return None
|
||||
strategy_raw = str(route_data.get("strategy") or "single").lower()
|
||||
strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||
primary_ref = route_data.get("primary")
|
||||
primary_ep = _endpoint_from_payload(primary_ref)
|
||||
ensemble_refs = route_data.get("ensemble", [])
|
||||
ensemble_eps = [
|
||||
_endpoint_from_payload(ref)
|
||||
for ref in ensemble_refs
|
||||
if isinstance(ref, (dict, str))
|
||||
]
|
||||
cfg_obj = LLMConfig(
|
||||
primary=primary_ep,
|
||||
ensemble=ensemble_eps,
|
||||
strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)),
|
||||
)
|
||||
return cfg_obj
|
||||
|
||||
llm_payload = payload.get("llm")
|
||||
if isinstance(llm_payload, dict):
|
||||
route_value = llm_payload.get("route")
|
||||
resolved_cfg = None
|
||||
if isinstance(route_value, str) and route_value:
|
||||
cfg.llm_route = route_value
|
||||
resolved_cfg = _resolve_route(route_value)
|
||||
if resolved_cfg is None:
|
||||
resolved_cfg = LLMConfig()
|
||||
primary_data = llm_payload.get("primary")
|
||||
if isinstance(primary_data, dict):
|
||||
cfg.llm.primary = _dict_to_endpoint(primary_data)
|
||||
inline_primary_loaded = True
|
||||
|
||||
resolved_cfg.primary = _dict_to_endpoint(primary_data)
|
||||
ensemble_data = llm_payload.get("ensemble")
|
||||
if isinstance(ensemble_data, list):
|
||||
cfg.llm.ensemble = [
|
||||
resolved_cfg.ensemble = [
|
||||
_dict_to_endpoint(item)
|
||||
for item in ensemble_data
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
|
||||
strategy_raw = llm_payload.get("strategy")
|
||||
if isinstance(strategy_raw, str):
|
||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||
if normalized in ALLOWED_LLM_STRATEGIES:
|
||||
cfg.llm.strategy = normalized
|
||||
|
||||
resolved_cfg.strategy = normalized
|
||||
majority = llm_payload.get("majority_threshold")
|
||||
if isinstance(majority, int) and majority > 0:
|
||||
cfg.llm.majority_threshold = majority
|
||||
|
||||
if inline_primary_loaded and not routes_defined:
|
||||
primary_key = "inline_global_primary"
|
||||
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
|
||||
primary_key,
|
||||
cfg.llm.primary,
|
||||
title="全局主模型",
|
||||
)
|
||||
ensemble_keys: List[str] = []
|
||||
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
|
||||
inline_key = f"inline_global_ensemble_{idx}"
|
||||
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
|
||||
inline_key,
|
||||
endpoint,
|
||||
title=f"全局协作#{idx}",
|
||||
)
|
||||
ensemble_keys.append(inline_key)
|
||||
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
|
||||
auto_route.strategy = cfg.llm.strategy
|
||||
auto_route.majority_threshold = cfg.llm.majority_threshold
|
||||
auto_route.primary = primary_key
|
||||
auto_route.ensemble = ensemble_keys
|
||||
cfg.llm_routes["global"] = auto_route
|
||||
cfg.llm_route = cfg.llm_route or "global"
|
||||
resolved_cfg.majority_threshold = majority
|
||||
cfg.llm = resolved_cfg
|
||||
|
||||
departments_payload = payload.get("departments")
|
||||
if isinstance(departments_payload, dict):
|
||||
@ -471,14 +388,22 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
title = data.get("title") or code
|
||||
description = data.get("description") or ""
|
||||
weight = float(data.get("weight", 1.0))
|
||||
llm_data = data.get("llm")
|
||||
llm_cfg = LLMConfig()
|
||||
route_name = data.get("llm_route")
|
||||
resolved_cfg = None
|
||||
if isinstance(route_name, str) and route_name:
|
||||
resolved_cfg = _resolve_route(route_name)
|
||||
if resolved_cfg is None:
|
||||
llm_data = data.get("llm")
|
||||
if isinstance(llm_data, dict):
|
||||
if isinstance(llm_data.get("primary"), dict):
|
||||
llm_cfg.primary = _dict_to_endpoint(llm_data["primary"])
|
||||
primary_data = llm_data.get("primary")
|
||||
if isinstance(primary_data, dict):
|
||||
llm_cfg.primary = _dict_to_endpoint(primary_data)
|
||||
ensemble_data = llm_data.get("ensemble")
|
||||
if isinstance(ensemble_data, list):
|
||||
llm_cfg.ensemble = [
|
||||
_dict_to_endpoint(item)
|
||||
for item in llm_data.get("ensemble", [])
|
||||
for item in ensemble_data
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
strategy_raw = llm_data.get("strategy")
|
||||
@ -489,18 +414,13 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
majority_raw = llm_data.get("majority_threshold")
|
||||
if isinstance(majority_raw, int) and majority_raw > 0:
|
||||
llm_cfg.majority_threshold = majority_raw
|
||||
route = data.get("llm_route")
|
||||
route_name = str(route).strip() if isinstance(route, str) and route else None
|
||||
resolved = llm_cfg
|
||||
if route_name and route_name in cfg.llm_routes:
|
||||
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
|
||||
resolved_cfg = llm_cfg
|
||||
new_departments[code] = DepartmentSettings(
|
||||
code=code,
|
||||
title=title,
|
||||
description=description,
|
||||
weight=weight,
|
||||
llm=resolved,
|
||||
llm_route=route_name,
|
||||
llm=resolved_cfg,
|
||||
)
|
||||
if new_departments:
|
||||
cfg.departments = new_departments
|
||||
@ -516,28 +436,21 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
"tushare_token": cfg.tushare_token,
|
||||
"force_refresh": cfg.force_refresh,
|
||||
"decision_method": cfg.decision_method,
|
||||
"llm_route": cfg.llm_route,
|
||||
"llm": {
|
||||
"route": cfg.llm_route,
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": cfg.llm.majority_threshold,
|
||||
"primary": _endpoint_to_dict(cfg.llm.primary),
|
||||
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
||||
},
|
||||
"llm_profiles": {
|
||||
key: profile.to_dict()
|
||||
for key, profile in cfg.llm_profiles.items()
|
||||
},
|
||||
"llm_routes": {
|
||||
name: route.to_dict()
|
||||
for name, route in cfg.llm_routes.items()
|
||||
"llm_providers": {
|
||||
key: provider.to_dict()
|
||||
for key, provider in cfg.llm_providers.items()
|
||||
},
|
||||
"departments": {
|
||||
code: {
|
||||
"title": dept.title,
|
||||
"description": dept.description,
|
||||
"weight": dept.weight,
|
||||
"llm_route": dept.llm_route,
|
||||
"llm": {
|
||||
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": dept.llm.majority_threshold,
|
||||
@ -567,11 +480,9 @@ def _load_env_defaults(cfg: AppConfig) -> None:
|
||||
if api_key:
|
||||
sanitized = api_key.strip()
|
||||
cfg.llm.primary.api_key = sanitized
|
||||
route = cfg.llm_routes.get(cfg.llm_route)
|
||||
if route:
|
||||
profile = cfg.llm_profiles.get(route.primary)
|
||||
if profile:
|
||||
profile.api_key = sanitized
|
||||
provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider)
|
||||
if provider_cfg:
|
||||
provider_cfg.api_key = sanitized
|
||||
|
||||
cfg.sync_runtime_llm()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user