This commit is contained in:
sam 2025-09-28 11:28:55 +08:00
parent 6aece20816
commit ab6180646a
5 changed files with 574 additions and 575 deletions

View File

@ -24,7 +24,7 @@
- **统一日志与持久化**SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。 - **统一日志与持久化**SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
- **跨市场数据扩展**`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。 - **跨市场数据扩展**`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
- **部门化多模型协作**`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。 - **部门化多模型协作**`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
- **LLM Profile/Route 管理**`app/utils/config.py` 引入可复用的 Profile终端定义与 Route推理策略组合Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性 - **LLM Provider 管理**`app/utils/config.py` 集中维护供应商的 URL、API Key、可用模型及默认参数Streamlit UI 可视化配置,全局与部门直接在 Provider 基础上设置模型、温度与 Prompt
## LLM + 多智能体最佳实践 ## LLM + 多智能体最佳实践
@ -60,11 +60,10 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试 ### LLM 配置与测试
- 新增 Profile/Route 双层配置Profile 定义单个端点(含 Provider/模型/域名/API KeyRoute 组合 Profile 并指定推理策略single/majority/leader。全局路由可一键切换部门可复用命名路由或保留自定义设置。 - 通过 Provider 管理供应商连接参数Base URL、API Key、模型列表、默认温度/超时/Prompt 模板),可随时扩展本地 Ollama 或各类云端服务DeepSeek、文心一言、OpenAI 等)。
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由保存即写入 `app/data/config.json`Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。 - 全局与部门配置直接选择 Provider并根据需要覆盖模型、温度、Prompt 模板、投票策略;保存后写入 `app/data/config.json`,下次启动自动加载。
- 支持本地 Ollama 与多家 OpenAI 兼容供应商DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。 - Streamlit “数据与设置” 页提供 Provider/全局/部门三栏编辑界面,保存后即时生效,并通过 `llm_config_snapshot()` 输出脱敏检查信息。
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。 - 支持使用环境变量注入敏感信息:`TUSHARE_TOKEN`、`LLM_API_KEY`。
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
## 快速开始 ## 快速开始
@ -104,7 +103,7 @@ Streamlit `自检测试` 页签提供:
## 实施步骤 ## 实施步骤
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅ 1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置Streamlit 提供可视化维护表单。 - 引入 `llm_providers` 集中管理供应商参数,全局与部门直接绑定 Provider 并自定义模型/温度/PromptStreamlit 提供可视化维护表单。
2. **部门管控器** 2. **部门管控器**
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。 - `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。

View File

@ -131,8 +131,6 @@ class DepartmentManager:
return results return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig: def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
if settings.llm_route and settings.llm_route in self.config.llm_routes:
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
return settings.llm return settings.llm

View File

@ -102,24 +102,58 @@ def _request_openai(
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
provider = (endpoint.provider or "ollama").lower() cfg = get_config()
base_url = endpoint.base_url or _default_base_url(provider) provider_key = (endpoint.provider or "ollama").lower()
model = endpoint.model or _default_model(provider) provider_cfg = cfg.llm_providers.get(provider_key)
temperature = max(0.0, min(endpoint.temperature, 2.0))
timeout = max(5.0, endpoint.timeout or 30.0) base_url = endpoint.base_url
api_key = endpoint.api_key
model = endpoint.model
temperature = endpoint.temperature
timeout = endpoint.timeout
prompt_template = endpoint.prompt_template
if provider_cfg:
if not provider_cfg.enabled:
raise LLMError(f"Provider {provider_key} 已被禁用")
base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key)
api_key = api_key or provider_cfg.api_key
model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key))
if temperature is None:
temperature = provider_cfg.default_temperature
if timeout is None:
timeout = provider_cfg.default_timeout
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
else:
base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key)
if temperature is None:
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai"
temperature = max(0.0, min(float(temperature), 2.0))
timeout = max(5.0, float(timeout))
if prompt_template:
try:
prompt = prompt_template.format(prompt=prompt)
except Exception: # noqa: BLE001
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
LOGGER.info( LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s", "触发 LLM 请求provider=%s model=%s base=%s",
provider, provider_key,
model, model,
base_url, base_url,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
if provider in {"openai", "deepseek", "wenxin"}: if mode != "ollama":
api_key = endpoint.api_key
if not api_key: if not api_key:
raise LLMError(f"缺少 {provider} API Key (model={model})") raise LLMError(f"缺少 {provider_key} API Key (model={model})")
return _request_openai( return _request_openai(
model, model,
prompt, prompt,
@ -129,7 +163,7 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
timeout=timeout, timeout=timeout,
system=system, system=system,
) )
if provider == "ollama": if base_url:
return _request_ollama( return _request_ollama(
model, model,
prompt, prompt,
@ -283,13 +317,17 @@ def llm_config_snapshot() -> Dict[str, object]:
if record.get("api_key"): if record.get("api_key"):
record["api_key"] = "***" record["api_key"] = "***"
ensemble.append(record) ensemble.append(record)
route_name = cfg.llm_route
route_obj = cfg.llm_routes.get(route_name)
return { return {
"route": route_name,
"route_detail": route_obj.to_dict() if route_obj else None,
"strategy": settings.strategy, "strategy": settings.strategy,
"majority_threshold": settings.majority_threshold, "majority_threshold": settings.majority_threshold,
"primary": primary, "primary": primary,
"ensemble": ensemble, "ensemble": ensemble,
"providers": {
key: {
"base_url": provider.base_url,
"default_model": provider.default_model,
"enabled": provider.enabled,
}
for key, provider in cfg.llm_providers.items()
},
} }

View File

@ -5,7 +5,7 @@ import sys
from dataclasses import asdict from dataclasses import asdict
from datetime import date, timedelta from datetime import date, timedelta
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Optional
ROOT = Path(__file__).resolve().parents[2] ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path: if str(ROOT) not in sys.path:
@ -16,6 +16,8 @@ import json
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import plotly.graph_objects as go import plotly.graph_objects as go
import requests
from requests.exceptions import RequestException
import streamlit as st import streamlit as st
from app.backtest.engine import BtConfig, run_backtest from app.backtest.engine import BtConfig, run_backtest
@ -29,8 +31,8 @@ from app.utils.config import (
DEFAULT_LLM_MODEL_OPTIONS, DEFAULT_LLM_MODEL_OPTIONS,
DEFAULT_LLM_MODELS, DEFAULT_LLM_MODELS,
DepartmentSettings, DepartmentSettings,
LLMProfile, LLMEndpoint,
LLMRoute, LLMProvider,
get_config, get_config,
save_config, save_config,
) )
@ -42,6 +44,50 @@ LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"} LOG_EXTRA = {"stage": "ui"}
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
"""Attempt to query provider API and return available model ids."""
base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip()
if not base_url:
return [], "请先填写 Base URL"
timeout = float(provider.default_timeout or 30.0)
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
try:
if mode == "ollama":
url = base_url.rstrip('/') + "/api/tags"
response = requests.get(url, timeout=timeout)
response.raise_for_status()
data = response.json()
models = []
for item in data.get("models", []) or data.get("data", []):
name = item.get("name") or item.get("model") or item.get("tag")
if name:
models.append(str(name).strip())
return sorted(set(models)), None
api_key = (api_override or provider.api_key or "").strip()
if not api_key:
return [], "缺少 API Key"
url = base_url.rstrip('/') + "/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
payload = response.json()
models = [
str(item.get("id")).strip()
for item in payload.get("data", [])
if item.get("id")
]
return sorted(set(models)), None
except RequestException as exc: # noqa: BLE001
return [], f"HTTP 错误:{exc}"
except Exception as exc: # noqa: BLE001
return [], f"解析失败:{exc}"
def _load_stock_options(limit: int = 500) -> list[str]: def _load_stock_options(limit: int = 500) -> list[str]:
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
@ -350,255 +396,267 @@ def render_settings() -> None:
st.divider() st.divider()
st.subheader("LLM 设置") st.subheader("LLM 设置")
profiles = cfg.llm_profiles or {} providers = cfg.llm_providers
routes = cfg.llm_routes or {} provider_keys = sorted(providers.keys())
profile_keys = sorted(profiles.keys()) st.caption("先在 Provider 中维护基础连接URL、Key、模型再为全局与各部门设置个性化参数。")
route_keys = sorted(routes.keys())
used_routes = {
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
}
st.caption("Profile 定义单个模型终端Route 负责组合 Profile 与推理策略。")
route_select_col, route_manage_col = st.columns([3, 1]) # Provider management -------------------------------------------------
if route_keys: provider_select_col, provider_manage_col = st.columns([3, 1])
if provider_keys:
try: try:
active_index = route_keys.index(cfg.llm_route) default_provider = cfg.llm.primary.provider or provider_keys[0]
provider_index = provider_keys.index(default_provider)
except ValueError: except ValueError:
active_index = 0 provider_index = 0
selected_route = route_select_col.selectbox( selected_provider = provider_select_col.selectbox(
"全局路由", "选择 Provider",
route_keys, provider_keys,
index=active_index, index=provider_index,
key="llm_route_select", key="llm_provider_select",
) )
else: else:
selected_route = None selected_provider = None
route_select_col.info("尚未配置路由,请先创建。") provider_select_col.info("尚未配置 Provider,请先创建。")
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name") new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name")
if route_manage_col.button("添加路由"): if provider_manage_col.button("创建 Provider", key="create_provider_btn"):
key = (new_route_name or "").strip() key = (new_provider_name or "").strip().lower()
if not key: if not key:
st.warning("请输入有效的路由名称。") st.warning("请输入有效的 Provider 名称。")
elif key in routes: elif key in providers:
st.warning(f"路由 {key} 已存在。") st.warning(f"Provider {key} 已存在。")
else: else:
routes[key] = LLMRoute(name=key) providers[key] = LLMProvider(key=key)
if not selected_route: cfg.llm_providers = providers
selected_route = key
cfg.llm_route = key
save_config() save_config()
st.success(f"添加路由 {key},请继续配置") st.success(f"已创建 Provider {key}")
st.experimental_rerun() st.experimental_rerun()
if selected_route: if selected_provider:
route_obj = routes.get(selected_route) provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
if route_obj is None: title_key = f"provider_title_{selected_provider}"
route_obj = LLMRoute(name=selected_route) base_key = f"provider_base_{selected_provider}"
routes[selected_route] = route_obj api_key_key = f"provider_api_{selected_provider}"
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES) models_key = f"provider_models_{selected_provider}"
default_model_key = f"provider_default_model_{selected_provider}"
mode_key = f"provider_mode_{selected_provider}"
temp_key = f"provider_temp_{selected_provider}"
timeout_key = f"provider_timeout_{selected_provider}"
prompt_key = f"provider_prompt_{selected_provider}"
enabled_key = f"provider_enabled_{selected_provider}"
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址例如https://api.openai.com")
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
models_val = st.text_area("可用模型(每行一个)", value="\n".join(provider_cfg.models), key=models_key, height=100)
default_model_val = st.text_input("默认模型", value=provider_cfg.default_model or "", key=default_model_key)
mode_val = st.selectbox("调用模式", ["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
temp_val = st.slider("默认温度", min_value=0.0, max_value=2.0, value=float(provider_cfg.default_temperature), step=0.05, key=temp_key)
timeout_val = st.number_input("默认超时(秒)", min_value=5, max_value=300, value=int(provider_cfg.default_timeout or 30), step=5, key=timeout_key)
prompt_template_val = st.text_area("默认 Prompt 模板(可选,使用 {prompt} 占位)", value=provider_cfg.prompt_template or "", key=prompt_key, height=120)
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
fetch_key = f"fetch_models_{selected_provider}"
if st.button("获取模型列表", key=fetch_key):
with st.spinner("正在获取模型列表..."):
models, error = _discover_provider_models(provider_cfg, base_val, api_val)
if error:
st.error(error)
else:
provider_cfg.models = models
if models and (not provider_cfg.default_model or provider_cfg.default_model not in models):
provider_cfg.default_model = models[0]
providers[selected_provider] = provider_cfg
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success(f"共获取 {len(models)} 个模型。")
st.session_state[models_key] = "\n".join(models)
st.session_state[default_model_key] = provider_cfg.default_model or ""
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
provider_cfg.title = title_val.strip()
provider_cfg.base_url = base_val.strip()
provider_cfg.api_key = api_val.strip() or None
provider_cfg.models = [line.strip() for line in models_val.splitlines() if line.strip()]
provider_cfg.default_model = default_model_val.strip() or (provider_cfg.models[0] if provider_cfg.models else provider_cfg.default_model)
provider_cfg.default_temperature = float(temp_val)
provider_cfg.default_timeout = float(timeout_val)
provider_cfg.prompt_template = prompt_template_val.strip()
provider_cfg.enabled = enabled_val
provider_cfg.mode = mode_val
providers[selected_provider] = provider_cfg
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success("Provider 已保存。")
st.session_state[title_key] = provider_cfg.title or ""
st.session_state[default_model_key] = provider_cfg.default_model or ""
provider_in_use = (cfg.llm.primary.provider == selected_provider) or any(
ep.provider == selected_provider for ep in cfg.llm.ensemble
)
if not provider_in_use:
for dept in cfg.departments.values():
if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble):
provider_in_use = True
break
if st.button(
"删除 Provider",
key=f"delete_provider_{selected_provider}",
disabled=provider_in_use or len(providers) <= 1,
):
providers.pop(selected_provider, None)
cfg.llm_providers = providers
cfg.sync_runtime_llm()
save_config()
st.success("Provider 已删除。")
st.experimental_rerun()
st.markdown("##### 全局推理配置")
if not provider_keys:
st.warning("请先配置至少一个 Provider。")
else:
global_cfg = cfg.llm
primary = global_cfg.primary
try: try:
strategy_index = strategy_choices.index(route_obj.strategy) provider_index = provider_keys.index(primary.provider or provider_keys[0])
except ValueError: except ValueError:
strategy_index = 0 provider_index = 0
route_title = st.text_input( selected_global_provider = st.selectbox(
"路由说明", "主模型 Provider",
value=route_obj.title or "", provider_keys,
key=f"route_title_{selected_route}", index=provider_index,
key="global_provider_select",
) )
route_strategy = st.selectbox( provider_cfg = providers.get(selected_global_provider)
"推理策略", available_models = provider_cfg.models if provider_cfg else []
strategy_choices, default_model = primary.model or (provider_cfg.default_model if provider_cfg else None)
index=strategy_index, if available_models:
key=f"route_strategy_{selected_route}", options = available_models + ["自定义"]
try:
model_index = available_models.index(default_model)
model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice")
except ValueError:
model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice")
if model_choice == "自定义":
model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip()
else:
model_val = model_choice
else:
model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip()
temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2)
temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp")
timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0)
timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout")
prompt_template_val = st.text_area(
"主模型 Prompt 模板(可选)",
value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "",
height=120,
key="global_prompt_template",
) )
route_majority = st.number_input(
strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy")
show_ensemble = strategy_val != "single"
majority_threshold_val = st.number_input(
"多数投票门槛", "多数投票门槛",
min_value=1, min_value=1,
max_value=10, max_value=10,
value=int(route_obj.majority_threshold or 1), value=int(global_cfg.majority_threshold),
step=1, step=1,
key=f"route_majority_{selected_route}", key="global_majority",
disabled=not show_ensemble,
) )
if not profile_keys: if not show_ensemble:
st.warning("暂无可用 Profile请先在下方创建。") majority_threshold_val = 1
ensemble_rows: List[Dict[str, str]] = []
if show_ensemble:
ensemble_rows = [
{
"provider": ep.provider,
"model": ep.model or "",
"temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}",
"timeout": "" if ep.timeout is None else str(int(ep.timeout)),
"prompt_template": ep.prompt_template or "",
}
for ep in global_cfg.ensemble
] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}]
ensemble_editor = st.data_editor(
ensemble_rows,
num_rows="dynamic",
key="global_ensemble_editor",
use_container_width=True,
hide_index=True,
column_config={
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
"model": st.column_config.TextColumn("模型"),
"temperature": st.column_config.TextColumn("温度"),
"timeout": st.column_config.TextColumn("超时(秒)"),
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
},
)
if hasattr(ensemble_editor, "to_dict"):
ensemble_rows = ensemble_editor.to_dict("records")
else:
ensemble_rows = ensemble_editor
else: else:
try: st.info("当前策略为单模型,未启用协作模型。")
primary_index = profile_keys.index(route_obj.primary)
except ValueError: if st.button("保存全局配置", key="save_global_llm"):
primary_index = 0 primary.provider = selected_global_provider
primary_key = st.selectbox( primary.model = model_val or None
"主用 Profile", primary.temperature = float(temp_val)
profile_keys, primary.timeout = float(timeout_val)
index=primary_index, primary.prompt_template = prompt_template_val.strip() or None
key=f"route_primary_{selected_route}", primary.base_url = None
) primary.api_key = None
default_ensemble = [
key for key in route_obj.ensemble if key in profile_keys and key != primary_key new_ensemble: List[LLMEndpoint] = []
] if show_ensemble:
ensemble_keys = st.multiselect( for row in ensemble_rows:
"协作 Profile (可多选)", provider_val = (row.get("provider") or "").strip().lower()
profile_keys, if not provider_val:
default=default_ensemble, continue
key=f"route_ensemble_{selected_route}", model_raw = (row.get("model") or "").strip() or None
) temp_raw = (row.get("temperature") or "").strip()
if st.button("保存路由设置", key=f"save_route_{selected_route}"): timeout_raw = (row.get("timeout") or "").strip()
route_obj.title = route_title.strip() prompt_raw = (row.get("prompt_template") or "").strip()
route_obj.strategy = route_strategy new_ensemble.append(
route_obj.majority_threshold = int(route_majority) LLMEndpoint(
route_obj.primary = primary_key provider=provider_val,
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key] model=model_raw,
cfg.llm_route = selected_route temperature=float(temp_raw) if temp_raw else None,
cfg.sync_runtime_llm() timeout=float(timeout_raw) if timeout_raw else None,
save_config() prompt_template=prompt_raw or None,
LOGGER.info( )
"路由 %s 配置更新:%s", )
selected_route, cfg.llm.ensemble = new_ensemble
route_obj.to_dict(), cfg.llm.strategy = strategy_val
extra=LOG_EXTRA, cfg.llm.majority_threshold = int(majority_threshold_val)
)
st.success("路由配置已保存。")
st.json({
"route": selected_route,
"route_detail": route_obj.to_dict(),
"resolved": llm_config_snapshot(),
})
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
if st.button(
"删除当前路由",
key=f"delete_route_{selected_route}",
disabled=route_in_use or len(routes) <= 1,
):
routes.pop(selected_route, None)
if cfg.llm_route == selected_route:
cfg.llm_route = next((key for key in routes.keys()), "global")
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
save_config() save_config()
st.success("路由已删除。") st.success("全局 LLM 配置已保存。")
st.experimental_rerun() st.json(llm_config_snapshot())
st.divider()
st.subheader("LLM Profile 管理")
profile_select_col, profile_manage_col = st.columns([3, 1])
if profile_keys:
selected_profile = profile_select_col.selectbox(
"选择 Profile",
profile_keys,
index=0,
key="profile_select",
)
else:
selected_profile = None
profile_select_col.info("尚未配置 Profile请先创建。")
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
if profile_manage_col.button("创建 Profile"):
key = (new_profile_name or "").strip()
if not key:
st.warning("请输入有效的 Profile 名称。")
elif key in profiles:
st.warning(f"Profile {key} 已存在。")
else:
profiles[key] = LLMProfile(key=key)
save_config()
st.success(f"已创建 Profile {key}")
st.experimental_rerun()
if selected_profile:
profile = profiles[selected_profile]
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
try:
provider_index = provider_choices.index(profile.provider)
except ValueError:
provider_index = 0
with st.form(f"profile_form_{selected_profile}"):
provider_val = st.selectbox(
"Provider",
provider_choices,
index=provider_index,
)
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
model_val = st.text_input(
"模型",
value=profile.model or model_default,
)
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
base_val = st.text_input(
"Base URL",
value=profile.base_url or base_default,
)
api_val = st.text_input(
"API Key",
value=profile.api_key or "",
type="password",
)
temp_val = st.slider(
"温度",
min_value=0.0,
max_value=2.0,
value=float(profile.temperature),
step=0.05,
)
timeout_val = st.number_input(
"超时(秒)",
min_value=5,
max_value=180,
value=int(profile.timeout or 30),
step=5,
)
title_val = st.text_input("备注", value=profile.title or "")
enabled_val = st.checkbox("启用", value=profile.enabled)
submitted = st.form_submit_button("保存 Profile")
if submitted:
profile.provider = provider_val
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
profile.api_key = api_val.strip() or None
profile.temperature = temp_val
profile.timeout = timeout_val
profile.title = title_val.strip()
profile.enabled = enabled_val
profiles[selected_profile] = profile
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已保存。")
profile_in_use = any(
selected_profile == route.primary or selected_profile in route.ensemble
for route in routes.values()
)
if st.button(
"删除该 Profile",
key=f"delete_profile_{selected_profile}",
disabled=profile_in_use or len(profiles) <= 1,
):
profiles.pop(selected_profile, None)
fallback_key = next((key for key in profiles.keys()), None)
for route in routes.values():
if route.primary == selected_profile:
route.primary = fallback_key or route.primary
route.ensemble = [key for key in route.ensemble if key != selected_profile]
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已删除。")
st.experimental_rerun()
st.divider()
st.subheader("部门配置")
# Department configuration -------------------------------------------
st.markdown("##### 部门配置")
dept_settings = cfg.departments or {} dept_settings = cfg.departments or {}
route_options_display = [""] + route_keys
dept_rows = [ dept_rows = [
{ {
"code": code, "code": code,
"title": dept.title, "title": dept.title,
"description": dept.description, "description": dept.description,
"weight": float(dept.weight), "weight": float(dept.weight),
"llm_route": dept.llm_route or "",
"strategy": dept.llm.strategy, "strategy": dept.llm.strategy,
"primary_provider": (dept.llm.primary.provider or ""), "majority_threshold": dept.llm.majority_threshold,
"primary_model": dept.llm.primary.model or "", "provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""),
"ensemble_size": len(dept.llm.ensemble), "model": dept.llm.primary.model or "",
"temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}",
"timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)),
"prompt_template": dept.llm.primary.prompt_template or "",
} }
for code, dept in sorted(dept_settings.items()) for code, dept in sorted(dept_settings.items())
] ]
@ -618,26 +676,13 @@ def render_settings() -> None:
"title": st.column_config.TextColumn("名称"), "title": st.column_config.TextColumn("名称"),
"description": st.column_config.TextColumn("说明"), "description": st.column_config.TextColumn("说明"),
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
"llm_route": st.column_config.SelectboxColumn( "strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)),
"路由", "majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1),
options=route_options_display, "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]),
help="选择预定义路由;留空表示使用自定义配置", "model": st.column_config.TextColumn("模型"),
), "temperature": st.column_config.TextColumn("温度"),
"strategy": st.column_config.SelectboxColumn( "timeout": st.column_config.TextColumn("超时(秒)"),
"自定义策略", "prompt_template": st.column_config.TextColumn("Prompt 模板"),
options=sorted(ALLOWED_LLM_STRATEGIES),
help="仅当未选择路由时生效",
),
"primary_provider": st.column_config.SelectboxColumn(
"自定义 Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
),
"primary_model": st.column_config.TextColumn("自定义模型"),
"ensemble_size": st.column_config.NumberColumn(
"协作模型数量",
disabled=True,
help="路由模式下自动维护",
),
}, },
) )
@ -662,25 +707,34 @@ def render_settings() -> None:
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
route_name = (row.get("llm_route") or "").strip() or None strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
existing.llm_route = route_name if strategy_val in ALLOWED_LLM_STRATEGIES:
if route_name and route_name in routes: existing.llm.strategy = strategy_val
existing.llm = routes[route_name].resolve(profiles) majority_raw = row.get("majority_threshold")
else: try:
strategy_val = (row.get("strategy") or existing.llm.strategy).lower() majority_val = int(majority_raw)
if strategy_val in ALLOWED_LLM_STRATEGIES: if majority_val > 0:
existing.llm.strategy = strategy_val existing.llm.majority_threshold = majority_val
provider_before = existing.llm.primary.provider or "" except (TypeError, ValueError):
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower() pass
existing.llm.primary.provider = provider_val
model_val = (row.get("primary_model") or "").strip() provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
existing.llm.primary.model = ( model_val = (row.get("model") or "").strip() or None
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model) temp_raw = (row.get("temperature") or "").strip()
) timeout_raw = (row.get("timeout") or "").strip()
if provider_before != provider_val: prompt_raw = (row.get("prompt_template") or "").strip()
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url endpoint = existing.llm.primary or LLMEndpoint()
existing.llm.primary.__post_init__() endpoint.provider = provider_val
endpoint.model = model_val
endpoint.temperature = float(temp_raw) if temp_raw else None
endpoint.timeout = float(timeout_raw) if timeout_raw else None
endpoint.prompt_template = prompt_raw or None
endpoint.base_url = None
endpoint.api_key = None
existing.llm.primary = endpoint
existing.llm.ensemble = []
updated_departments[code] = existing updated_departments[code] = existing
if updated_departments: if updated_departments:
@ -700,8 +754,7 @@ def render_settings() -> None:
st.success("已恢复默认部门配置。") st.success("已恢复默认部门配置。")
st.experimental_rerun() st.experimental_rerun()
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。") st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")
st.caption("部门协作模型ensemble请在 config.json 中手动编辑UI 将在后续版本补充。")
def render_tests() -> None: def render_tests() -> None:

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, List, Mapping, Optional from typing import Dict, List, Optional
def _default_root() -> Path: def _default_root() -> Path:
@ -99,27 +99,55 @@ ALLOWED_LLM_STRATEGIES = {"single", "majority", "leader"}
LLM_STRATEGY_ALIASES = {"leader-follower": "leader"} LLM_STRATEGY_ALIASES = {"leader-follower": "leader"}
@dataclass
class LLMProvider:
"""Provider level configuration shared across profiles and routes."""
key: str
title: str = ""
base_url: str = ""
api_key: Optional[str] = None
models: List[str] = field(default_factory=list)
default_model: Optional[str] = None
default_temperature: float = 0.2
default_timeout: float = 30.0
prompt_template: str = ""
enabled: bool = True
mode: str = "openai" # openai 或 ollama
def to_dict(self) -> Dict[str, object]:
return {
"title": self.title,
"base_url": self.base_url,
"api_key": self.api_key,
"models": list(self.models),
"default_model": self.default_model,
"default_temperature": self.default_temperature,
"default_timeout": self.default_timeout,
"prompt_template": self.prompt_template,
"enabled": self.enabled,
"mode": self.mode,
}
@dataclass @dataclass
class LLMEndpoint: class LLMEndpoint:
"""Single LLM endpoint configuration.""" """Resolved endpoint payload used for actual LLM calls."""
provider: str = "ollama" provider: str = "ollama"
model: Optional[str] = None model: Optional[str] = None
base_url: Optional[str] = None base_url: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
temperature: float = 0.2 temperature: Optional[float] = None
timeout: float = 30.0 timeout: Optional[float] = None
prompt_template: Optional[str] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.provider = (self.provider or "ollama").lower() self.provider = (self.provider or "ollama").lower()
if not self.model: if self.temperature is not None:
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"]) self.temperature = float(self.temperature)
if not self.base_url: if self.timeout is not None:
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider) self.timeout = float(self.timeout)
if self.temperature == 0.2 or self.temperature is None:
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
if self.timeout == 30.0 or self.timeout is None:
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
@dataclass @dataclass
@ -132,133 +160,22 @@ class LLMConfig:
majority_threshold: int = 3 majority_threshold: int = 3
@dataclass def _default_llm_providers() -> Dict[str, LLMProvider]:
class LLMProfile: providers: Dict[str, LLMProvider] = {}
"""Named LLM endpoint profile reusable across routes/departments.""" for provider, meta in DEFAULT_LLM_MODEL_OPTIONS.items():
models = list(meta.get("models", []))
key: str mode = "ollama" if provider == "ollama" else "openai"
provider: str = "ollama" providers[provider] = LLMProvider(
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.2
timeout: float = 30.0
title: str = ""
enabled: bool = True
def to_endpoint(self) -> LLMEndpoint:
return LLMEndpoint(
provider=self.provider,
model=self.model,
base_url=self.base_url,
api_key=self.api_key,
temperature=self.temperature,
timeout=self.timeout,
)
def to_dict(self) -> Dict[str, object]:
return {
"provider": self.provider,
"model": self.model,
"base_url": self.base_url,
"api_key": self.api_key,
"temperature": self.temperature,
"timeout": self.timeout,
"title": self.title,
"enabled": self.enabled,
}
@classmethod
def from_endpoint(
cls,
key: str,
endpoint: LLMEndpoint,
*,
title: str = "",
enabled: bool = True,
) -> "LLMProfile":
return cls(
key=key,
provider=endpoint.provider,
model=endpoint.model,
base_url=endpoint.base_url,
api_key=endpoint.api_key,
temperature=endpoint.temperature,
timeout=endpoint.timeout,
title=title,
enabled=enabled,
)
@dataclass
class LLMRoute:
"""Declarative routing for selecting profiles and strategy."""
name: str
title: str = ""
strategy: str = "single"
majority_threshold: int = 3
primary: str = "ollama"
ensemble: List[str] = field(default_factory=list)
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
def _endpoint_from_key(key: str) -> LLMEndpoint:
profile = profiles.get(key)
if profile and profile.enabled:
return profile.to_endpoint()
fallback = profiles.get("ollama")
if not fallback or not fallback.enabled:
fallback = next(
(item for item in profiles.values() if item.enabled),
None,
)
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
endpoint.provider = key or endpoint.provider
return endpoint
primary_endpoint = _endpoint_from_key(self.primary)
ensemble_endpoints = [
_endpoint_from_key(key)
for key in self.ensemble
if key in profiles and profiles[key].enabled
]
config = LLMConfig(
primary=primary_endpoint,
ensemble=ensemble_endpoints,
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, self.majority_threshold or 1),
)
return config
def to_dict(self) -> Dict[str, object]:
return {
"title": self.title,
"strategy": self.strategy,
"majority_threshold": self.majority_threshold,
"primary": self.primary,
"ensemble": list(self.ensemble),
}
def _default_llm_profiles() -> Dict[str, LLMProfile]:
return {
provider: LLMProfile(
key=provider, key=provider,
provider=provider,
model=DEFAULT_LLM_MODELS.get(provider),
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
title=f"默认 {provider}", title=f"默认 {provider}",
base_url=str(meta.get("base_url", DEFAULT_LLM_BASE_URLS.get(provider, "")) or ""),
models=models,
default_model=models[0] if models else DEFAULT_LLM_MODELS.get(provider),
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
mode=mode,
) )
for provider in DEFAULT_LLM_MODEL_OPTIONS return providers
}
def _default_llm_routes() -> Dict[str, LLMRoute]:
return {
"global": LLMRoute(name="global", title="全局默认路由"),
}
@dataclass @dataclass
@ -270,7 +187,6 @@ class DepartmentSettings:
description: str = "" description: str = ""
weight: float = 1.0 weight: float = 1.0
llm: LLMConfig = field(default_factory=LLMConfig) llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: Optional[str] = None
def _default_departments() -> Dict[str, DepartmentSettings]: def _default_departments() -> Dict[str, DepartmentSettings]:
@ -282,10 +198,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
("macro", "宏观研究部门"), ("macro", "宏观研究部门"),
("risk", "风险控制部门"), ("risk", "风险控制部门"),
] ]
return { return {code: DepartmentSettings(code=code, title=title) for code, title in presets}
code: DepartmentSettings(code=code, title=title, llm_route="global")
for code, title in presets
}
@dataclass @dataclass
@ -298,17 +211,11 @@ class AppConfig:
data_paths: DataPaths = field(default_factory=DataPaths) data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights) agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False force_refresh: bool = False
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
llm: LLMConfig = field(default_factory=LLMConfig) llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: str = "global"
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig: def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
route_key = route or self.llm_route
route_cfg = self.llm_routes.get(route_key)
if route_cfg:
return route_cfg.resolve(self.llm_profiles)
return self.llm return self.llm
def sync_runtime_llm(self) -> None: def sync_runtime_llm(self) -> None:
@ -326,13 +233,22 @@ def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
"api_key": endpoint.api_key, "api_key": endpoint.api_key,
"temperature": endpoint.temperature, "temperature": endpoint.temperature,
"timeout": endpoint.timeout, "timeout": endpoint.timeout,
"prompt_template": endpoint.prompt_template,
} }
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint: def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
payload = { payload = {
key: data.get(key) key: data.get(key)
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout") for key in (
"provider",
"model",
"base_url",
"api_key",
"temperature",
"timeout",
"prompt_template",
)
if data.get(key) is not None if data.get(key) is not None
} }
return LLMEndpoint(**payload) return LLMEndpoint(**payload)
@ -348,139 +264,148 @@ def _load_from_file(cfg: AppConfig) -> None:
except (json.JSONDecodeError, OSError): except (json.JSONDecodeError, OSError):
return return
if isinstance(payload, dict): if not isinstance(payload, dict):
if "tushare_token" in payload: return
cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
cfg.force_refresh = bool(payload.get("force_refresh"))
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
routes_defined = False if "tushare_token" in payload:
inline_primary_loaded = False cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
cfg.force_refresh = bool(payload.get("force_refresh"))
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
profiles_payload = payload.get("llm_profiles") legacy_profiles: Dict[str, Dict[str, object]] = {}
if isinstance(profiles_payload, dict): legacy_routes: Dict[str, Dict[str, object]] = {}
profiles: Dict[str, LLMProfile] = {}
for key, data in profiles_payload.items():
if not isinstance(data, dict):
continue
provider = str(data.get("provider") or "ollama").lower()
profile = LLMProfile(
key=key,
provider=provider,
model=data.get("model"),
base_url=data.get("base_url"),
api_key=data.get("api_key"),
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
title=str(data.get("title") or ""),
enabled=bool(data.get("enabled", True)),
)
profiles[key] = profile
if profiles:
cfg.llm_profiles = profiles
routes_payload = payload.get("llm_routes") providers_payload = payload.get("llm_providers")
if isinstance(routes_payload, dict): if isinstance(providers_payload, dict):
routes: Dict[str, LLMRoute] = {} providers: Dict[str, LLMProvider] = {}
for name, data in routes_payload.items(): for key, data in providers_payload.items():
if not isinstance(data, dict): if not isinstance(data, dict):
continue continue
strategy_raw = str(data.get("strategy") or "single").lower() models_raw = data.get("models")
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) if isinstance(models_raw, str):
route = LLMRoute( models = [item.strip() for item in models_raw.split(',') if item.strip()]
name=name, elif isinstance(models_raw, list):
title=str(data.get("title") or ""), models = [str(item).strip() for item in models_raw if str(item).strip()]
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single", else:
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)), models = []
primary=str(data.get("primary") or "global"), provider = LLMProvider(
ensemble=[ key=str(key).lower(),
str(item) title=str(data.get("title") or ""),
for item in data.get("ensemble", []) base_url=str(data.get("base_url") or ""),
if isinstance(item, str) api_key=data.get("api_key"),
], models=models,
) default_model=data.get("default_model") or (models[0] if models else None),
routes[name] = route default_temperature=float(data.get("default_temperature", 0.2)),
if routes: default_timeout=float(data.get("default_timeout", 30.0)),
cfg.llm_routes = routes prompt_template=str(data.get("prompt_template") or ""),
routes_defined = True enabled=bool(data.get("enabled", True)),
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")),
)
providers[provider.key] = provider
if providers:
cfg.llm_providers = providers
route_key = payload.get("llm_route") profiles_payload = payload.get("llm_profiles")
if isinstance(route_key, str) and route_key: if isinstance(profiles_payload, dict):
cfg.llm_route = route_key for key, data in profiles_payload.items():
if isinstance(data, dict):
legacy_profiles[str(key)] = data
llm_payload = payload.get("llm") routes_payload = payload.get("llm_routes")
if isinstance(llm_payload, dict): if isinstance(routes_payload, dict):
route_value = llm_payload.get("route") for name, data in routes_payload.items():
if isinstance(route_value, str) and route_value: if isinstance(data, dict):
cfg.llm_route = route_value legacy_routes[str(name)] = data
def _endpoint_from_payload(item: object) -> LLMEndpoint:
if isinstance(item, dict):
return _dict_to_endpoint(item)
if isinstance(item, str):
profile_data = legacy_profiles.get(item)
if isinstance(profile_data, dict):
return _dict_to_endpoint(profile_data)
return LLMEndpoint(provider=item)
return LLMEndpoint()
def _resolve_route(route_name: str) -> Optional[LLMConfig]:
route_data = legacy_routes.get(route_name)
if not route_data:
return None
strategy_raw = str(route_data.get("strategy") or "single").lower()
strategy = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
primary_ref = route_data.get("primary")
primary_ep = _endpoint_from_payload(primary_ref)
ensemble_refs = route_data.get("ensemble", [])
ensemble_eps = [
_endpoint_from_payload(ref)
for ref in ensemble_refs
if isinstance(ref, (dict, str))
]
cfg_obj = LLMConfig(
primary=primary_ep,
ensemble=ensemble_eps,
strategy=strategy if strategy in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, int(route_data.get("majority_threshold", 3) or 3)),
)
return cfg_obj
llm_payload = payload.get("llm")
if isinstance(llm_payload, dict):
route_value = llm_payload.get("route")
resolved_cfg = None
if isinstance(route_value, str) and route_value:
resolved_cfg = _resolve_route(route_value)
if resolved_cfg is None:
resolved_cfg = LLMConfig()
primary_data = llm_payload.get("primary") primary_data = llm_payload.get("primary")
if isinstance(primary_data, dict): if isinstance(primary_data, dict):
cfg.llm.primary = _dict_to_endpoint(primary_data) resolved_cfg.primary = _dict_to_endpoint(primary_data)
inline_primary_loaded = True
ensemble_data = llm_payload.get("ensemble") ensemble_data = llm_payload.get("ensemble")
if isinstance(ensemble_data, list): if isinstance(ensemble_data, list):
cfg.llm.ensemble = [ resolved_cfg.ensemble = [
_dict_to_endpoint(item) _dict_to_endpoint(item)
for item in ensemble_data for item in ensemble_data
if isinstance(item, dict) if isinstance(item, dict)
] ]
strategy_raw = llm_payload.get("strategy") strategy_raw = llm_payload.get("strategy")
if isinstance(strategy_raw, str): if isinstance(strategy_raw, str):
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
if normalized in ALLOWED_LLM_STRATEGIES: if normalized in ALLOWED_LLM_STRATEGIES:
cfg.llm.strategy = normalized resolved_cfg.strategy = normalized
majority = llm_payload.get("majority_threshold") majority = llm_payload.get("majority_threshold")
if isinstance(majority, int) and majority > 0: if isinstance(majority, int) and majority > 0:
cfg.llm.majority_threshold = majority resolved_cfg.majority_threshold = majority
cfg.llm = resolved_cfg
if inline_primary_loaded and not routes_defined: departments_payload = payload.get("departments")
primary_key = "inline_global_primary" if isinstance(departments_payload, dict):
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint( new_departments: Dict[str, DepartmentSettings] = {}
primary_key, for code, data in departments_payload.items():
cfg.llm.primary, if not isinstance(data, dict):
title="全局主模型", continue
) title = data.get("title") or code
ensemble_keys: List[str] = [] description = data.get("description") or ""
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1): weight = float(data.get("weight", 1.0))
inline_key = f"inline_global_ensemble_{idx}" llm_cfg = LLMConfig()
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint( route_name = data.get("llm_route")
inline_key, resolved_cfg = None
endpoint, if isinstance(route_name, str) and route_name:
title=f"全局协作#{idx}", resolved_cfg = _resolve_route(route_name)
) if resolved_cfg is None:
ensemble_keys.append(inline_key)
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
auto_route.strategy = cfg.llm.strategy
auto_route.majority_threshold = cfg.llm.majority_threshold
auto_route.primary = primary_key
auto_route.ensemble = ensemble_keys
cfg.llm_routes["global"] = auto_route
cfg.llm_route = cfg.llm_route or "global"
departments_payload = payload.get("departments")
if isinstance(departments_payload, dict):
new_departments: Dict[str, DepartmentSettings] = {}
for code, data in departments_payload.items():
if not isinstance(data, dict):
continue
title = data.get("title") or code
description = data.get("description") or ""
weight = float(data.get("weight", 1.0))
llm_data = data.get("llm") llm_data = data.get("llm")
llm_cfg = LLMConfig()
if isinstance(llm_data, dict): if isinstance(llm_data, dict):
if isinstance(llm_data.get("primary"), dict): primary_data = llm_data.get("primary")
llm_cfg.primary = _dict_to_endpoint(llm_data["primary"]) if isinstance(primary_data, dict):
llm_cfg.ensemble = [ llm_cfg.primary = _dict_to_endpoint(primary_data)
_dict_to_endpoint(item) ensemble_data = llm_data.get("ensemble")
for item in llm_data.get("ensemble", []) if isinstance(ensemble_data, list):
if isinstance(item, dict) llm_cfg.ensemble = [
] _dict_to_endpoint(item)
for item in ensemble_data
if isinstance(item, dict)
]
strategy_raw = llm_data.get("strategy") strategy_raw = llm_data.get("strategy")
if isinstance(strategy_raw, str): if isinstance(strategy_raw, str):
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw) normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
@ -489,23 +414,18 @@ def _load_from_file(cfg: AppConfig) -> None:
majority_raw = llm_data.get("majority_threshold") majority_raw = llm_data.get("majority_threshold")
if isinstance(majority_raw, int) and majority_raw > 0: if isinstance(majority_raw, int) and majority_raw > 0:
llm_cfg.majority_threshold = majority_raw llm_cfg.majority_threshold = majority_raw
route = data.get("llm_route") resolved_cfg = llm_cfg
route_name = str(route).strip() if isinstance(route, str) and route else None new_departments[code] = DepartmentSettings(
resolved = llm_cfg code=code,
if route_name and route_name in cfg.llm_routes: title=title,
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles) description=description,
new_departments[code] = DepartmentSettings( weight=weight,
code=code, llm=resolved_cfg,
title=title, )
description=description, if new_departments:
weight=weight, cfg.departments = new_departments
llm=resolved,
llm_route=route_name,
)
if new_departments:
cfg.departments = new_departments
cfg.sync_runtime_llm() cfg.sync_runtime_llm()
def save_config(cfg: AppConfig | None = None) -> None: def save_config(cfg: AppConfig | None = None) -> None:
@ -516,28 +436,21 @@ def save_config(cfg: AppConfig | None = None) -> None:
"tushare_token": cfg.tushare_token, "tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh, "force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method, "decision_method": cfg.decision_method,
"llm_route": cfg.llm_route,
"llm": { "llm": {
"route": cfg.llm_route,
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": cfg.llm.majority_threshold, "majority_threshold": cfg.llm.majority_threshold,
"primary": _endpoint_to_dict(cfg.llm.primary), "primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
}, },
"llm_profiles": { "llm_providers": {
key: profile.to_dict() key: provider.to_dict()
for key, profile in cfg.llm_profiles.items() for key, provider in cfg.llm_providers.items()
},
"llm_routes": {
name: route.to_dict()
for name, route in cfg.llm_routes.items()
}, },
"departments": { "departments": {
code: { code: {
"title": dept.title, "title": dept.title,
"description": dept.description, "description": dept.description,
"weight": dept.weight, "weight": dept.weight,
"llm_route": dept.llm_route,
"llm": { "llm": {
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": dept.llm.majority_threshold, "majority_threshold": dept.llm.majority_threshold,
@ -567,11 +480,9 @@ def _load_env_defaults(cfg: AppConfig) -> None:
if api_key: if api_key:
sanitized = api_key.strip() sanitized = api_key.strip()
cfg.llm.primary.api_key = sanitized cfg.llm.primary.api_key = sanitized
route = cfg.llm_routes.get(cfg.llm_route) provider_cfg = cfg.llm_providers.get(cfg.llm.primary.provider)
if route: if provider_cfg:
profile = cfg.llm_profiles.get(route.primary) provider_cfg.api_key = sanitized
if profile:
profile.api_key = sanitized
cfg.sync_runtime_llm() cfg.sync_runtime_llm()