update
This commit is contained in:
parent
1a99b72c60
commit
6aece20816
15
README.md
15
README.md
@ -24,6 +24,7 @@
|
|||||||
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
||||||
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
||||||
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
||||||
|
- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。
|
||||||
|
|
||||||
## LLM + 多智能体最佳实践
|
## LLM + 多智能体最佳实践
|
||||||
|
|
||||||
@ -59,13 +60,11 @@ export TUSHARE_TOKEN="<your-token>"
|
|||||||
|
|
||||||
### LLM 配置与测试
|
### LLM 配置与测试
|
||||||
|
|
||||||
- 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间,亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载。
|
- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。
|
||||||
- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。
|
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
|
||||||
- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。
|
- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
|
||||||
- 未来可对同一功能的智能体并行调用多个 LLM,采用多数投票等策略增强鲁棒性,当前代码结构已为此预留扩展空间。
|
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
|
||||||
- 若使用环境变量自动注入配置,可设置:
|
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
|
||||||
- `TUSHARE_TOKEN`
|
|
||||||
- `LLM_API_KEY`
|
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
@ -105,7 +104,7 @@ Streamlit `自检测试` 页签提供:
|
|||||||
## 实施步骤
|
## 实施步骤
|
||||||
|
|
||||||
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
||||||
- 部门支持 primary/ensemble、策略(single/majority/leader)、权重,并可在 Streamlit 中编辑主要字段。
|
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。
|
||||||
|
|
||||||
2. **部门管控器** ✅
|
2. **部门管控器** ✅
|
||||||
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
||||||
|
|||||||
@ -3,12 +3,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Mapping
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from app.agents.base import AgentAction
|
from app.agents.base import AgentAction
|
||||||
from app.llm.client import run_llm_with_config
|
from app.llm.client import run_llm_with_config
|
||||||
from app.llm.prompts import department_prompt
|
from app.llm.prompts import department_prompt
|
||||||
from app.utils.config import DepartmentSettings
|
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -53,16 +53,27 @@ class DepartmentDecision:
|
|||||||
class DepartmentAgent:
|
class DepartmentAgent:
|
||||||
"""Wraps LLM ensemble logic for a single analytical department."""
|
"""Wraps LLM ensemble logic for a single analytical department."""
|
||||||
|
|
||||||
def __init__(self, settings: DepartmentSettings) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: DepartmentSettings,
|
||||||
|
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
|
||||||
|
) -> None:
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
self._resolver = resolver
|
||||||
|
|
||||||
|
def _get_llm_config(self) -> LLMConfig:
|
||||||
|
if self._resolver:
|
||||||
|
return self._resolver(self.settings)
|
||||||
|
return self.settings.llm
|
||||||
|
|
||||||
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
|
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
|
||||||
prompt = department_prompt(self.settings, context)
|
prompt = department_prompt(self.settings, context)
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||||||
)
|
)
|
||||||
|
llm_cfg = self._get_llm_config()
|
||||||
try:
|
try:
|
||||||
response = run_llm_with_config(self.settings.llm, prompt, system=system_prompt)
|
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
|
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
|
||||||
return DepartmentDecision(
|
return DepartmentDecision(
|
||||||
@ -106,10 +117,11 @@ class DepartmentAgent:
|
|||||||
class DepartmentManager:
|
class DepartmentManager:
|
||||||
"""Orchestrates all departments defined in configuration."""
|
"""Orchestrates all departments defined in configuration."""
|
||||||
|
|
||||||
def __init__(self, departments: Mapping[str, DepartmentSettings]) -> None:
|
def __init__(self, config: AppConfig) -> None:
|
||||||
|
self.config = config
|
||||||
self.agents: Dict[str, DepartmentAgent] = {
|
self.agents: Dict[str, DepartmentAgent] = {
|
||||||
code: DepartmentAgent(settings)
|
code: DepartmentAgent(settings, self._resolve_llm)
|
||||||
for code, settings in departments.items()
|
for code, settings in config.departments.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
|
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
|
||||||
@ -118,6 +130,11 @@ class DepartmentManager:
|
|||||||
results[code] = agent.analyze(context)
|
results[code] = agent.analyze(context)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
||||||
|
if settings.llm_route and settings.llm_route in self.config.llm_routes:
|
||||||
|
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
|
||||||
|
return settings.llm
|
||||||
|
|
||||||
|
|
||||||
def _parse_department_response(text: str) -> Dict[str, Any]:
|
def _parse_department_response(text: str) -> Dict[str, Any]:
|
||||||
"""Extract a JSON object from the LLM response if possible."""
|
"""Extract a JSON object from the LLM response if possible."""
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class BacktestEngine:
|
|||||||
else:
|
else:
|
||||||
self.weights = {agent.name: 1.0 for agent in self.agents}
|
self.weights = {agent.name: 1.0 for agent in self.agents}
|
||||||
self.department_manager = (
|
self.department_manager = (
|
||||||
DepartmentManager(app_cfg.departments) if app_cfg.departments else None
|
DepartmentManager(app_cfg) if app_cfg.departments else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
||||||
|
|||||||
@ -272,7 +272,8 @@ def run_llm_with_config(
|
|||||||
def llm_config_snapshot() -> Dict[str, object]:
|
def llm_config_snapshot() -> Dict[str, object]:
|
||||||
"""Return a sanitized snapshot of current LLM configuration for debugging."""
|
"""Return a sanitized snapshot of current LLM configuration for debugging."""
|
||||||
|
|
||||||
settings = get_config().llm
|
cfg = get_config()
|
||||||
|
settings = cfg.llm
|
||||||
primary = asdict(settings.primary)
|
primary = asdict(settings.primary)
|
||||||
if primary.get("api_key"):
|
if primary.get("api_key"):
|
||||||
primary["api_key"] = "***"
|
primary["api_key"] = "***"
|
||||||
@ -282,7 +283,11 @@ def llm_config_snapshot() -> Dict[str, object]:
|
|||||||
if record.get("api_key"):
|
if record.get("api_key"):
|
||||||
record["api_key"] = "***"
|
record["api_key"] = "***"
|
||||||
ensemble.append(record)
|
ensemble.append(record)
|
||||||
|
route_name = cfg.llm_route
|
||||||
|
route_obj = cfg.llm_routes.get(route_name)
|
||||||
return {
|
return {
|
||||||
|
"route": route_name,
|
||||||
|
"route_detail": route_obj.to_dict() if route_obj else None,
|
||||||
"strategy": settings.strategy,
|
"strategy": settings.strategy,
|
||||||
"majority_threshold": settings.majority_threshold,
|
"majority_threshold": settings.majority_threshold,
|
||||||
"primary": primary,
|
"primary": primary,
|
||||||
|
|||||||
@ -29,7 +29,8 @@ from app.utils.config import (
|
|||||||
DEFAULT_LLM_MODEL_OPTIONS,
|
DEFAULT_LLM_MODEL_OPTIONS,
|
||||||
DEFAULT_LLM_MODELS,
|
DEFAULT_LLM_MODELS,
|
||||||
DepartmentSettings,
|
DepartmentSettings,
|
||||||
LLMEndpoint,
|
LLMProfile,
|
||||||
|
LLMRoute,
|
||||||
get_config,
|
get_config,
|
||||||
save_config,
|
save_config,
|
||||||
)
|
)
|
||||||
@ -349,208 +350,253 @@ def render_settings() -> None:
|
|||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.subheader("LLM 设置")
|
st.subheader("LLM 设置")
|
||||||
llm_cfg = cfg.llm
|
profiles = cfg.llm_profiles or {}
|
||||||
primary = llm_cfg.primary
|
routes = cfg.llm_routes or {}
|
||||||
providers = sorted(DEFAULT_LLM_MODELS.keys())
|
profile_keys = sorted(profiles.keys())
|
||||||
try:
|
route_keys = sorted(routes.keys())
|
||||||
provider_index = providers.index((primary.provider or "ollama").lower())
|
used_routes = {
|
||||||
except ValueError:
|
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
|
||||||
provider_index = 0
|
}
|
||||||
selected_provider = st.selectbox("LLM Provider", providers, index=provider_index)
|
st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。")
|
||||||
provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {})
|
|
||||||
model_options = provider_info.get("models", [])
|
|
||||||
custom_model_label = "自定义模型"
|
|
||||||
default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
|
|
||||||
|
|
||||||
if model_options:
|
route_select_col, route_manage_col = st.columns([3, 1])
|
||||||
options_with_custom = model_options + [custom_model_label]
|
if route_keys:
|
||||||
if primary.provider == selected_provider and primary.model in model_options:
|
try:
|
||||||
model_index = options_with_custom.index(primary.model)
|
active_index = route_keys.index(cfg.llm_route)
|
||||||
else:
|
except ValueError:
|
||||||
model_index = 0
|
active_index = 0
|
||||||
selected_model_option = st.selectbox(
|
selected_route = route_select_col.selectbox(
|
||||||
"LLM 模型",
|
"全局路由",
|
||||||
options_with_custom,
|
route_keys,
|
||||||
index=model_index,
|
index=active_index,
|
||||||
help=f"可选模型:{', '.join(model_options)}",
|
key="llm_route_select",
|
||||||
)
|
)
|
||||||
if selected_model_option == custom_model_label:
|
else:
|
||||||
custom_model_value = st.text_input(
|
selected_route = None
|
||||||
"自定义模型名称",
|
route_select_col.info("尚未配置路由,请先创建。")
|
||||||
value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
|
|
||||||
)
|
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
|
||||||
chosen_model = custom_model_value.strip() or default_model_hint
|
if route_manage_col.button("添加路由"):
|
||||||
|
key = (new_route_name or "").strip()
|
||||||
|
if not key:
|
||||||
|
st.warning("请输入有效的路由名称。")
|
||||||
|
elif key in routes:
|
||||||
|
st.warning(f"路由 {key} 已存在。")
|
||||||
else:
|
else:
|
||||||
chosen_model = selected_model_option
|
routes[key] = LLMRoute(name=key)
|
||||||
else:
|
if not selected_route:
|
||||||
chosen_model = st.text_input(
|
selected_route = key
|
||||||
"LLM 模型",
|
cfg.llm_route = key
|
||||||
value=primary.model or default_model_hint,
|
save_config()
|
||||||
help="未预设该 Provider 的模型列表,请手动填写",
|
st.success(f"已添加路由 {key},请继续配置。")
|
||||||
).strip() or default_model_hint
|
st.experimental_rerun()
|
||||||
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
|
|
||||||
provider_default_temp = float(provider_info.get("temperature", 0.2))
|
|
||||||
provider_default_timeout = int(provider_info.get("timeout", 30.0))
|
|
||||||
|
|
||||||
if primary.provider == selected_provider:
|
if selected_route:
|
||||||
base_value = primary.base_url or default_base_hint or ""
|
route_obj = routes.get(selected_route)
|
||||||
temp_value = float(primary.temperature)
|
if route_obj is None:
|
||||||
timeout_value = int(primary.timeout)
|
route_obj = LLMRoute(name=selected_route)
|
||||||
else:
|
routes[selected_route] = route_obj
|
||||||
base_value = default_base_hint or ""
|
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
|
||||||
temp_value = provider_default_temp
|
try:
|
||||||
timeout_value = provider_default_timeout
|
strategy_index = strategy_choices.index(route_obj.strategy)
|
||||||
|
except ValueError:
|
||||||
llm_base = st.text_input(
|
strategy_index = 0
|
||||||
"LLM Base URL",
|
route_title = st.text_input(
|
||||||
value=base_value,
|
"路由说明",
|
||||||
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
|
value=route_obj.title or "",
|
||||||
)
|
key=f"route_title_{selected_route}",
|
||||||
llm_api_key = st.text_input(
|
)
|
||||||
"LLM API Key",
|
route_strategy = st.selectbox(
|
||||||
value=primary.api_key or "",
|
"推理策略",
|
||||||
type="password",
|
strategy_choices,
|
||||||
help="点击右侧小图标可查看当前 Key,该值会写入 config.json(已被 gitignore 排除)",
|
index=strategy_index,
|
||||||
)
|
key=f"route_strategy_{selected_route}",
|
||||||
llm_temperature = st.slider(
|
)
|
||||||
"LLM 温度",
|
route_majority = st.number_input(
|
||||||
min_value=0.0,
|
"多数投票门槛",
|
||||||
max_value=2.0,
|
min_value=1,
|
||||||
value=temp_value,
|
max_value=10,
|
||||||
step=0.05,
|
value=int(route_obj.majority_threshold or 1),
|
||||||
)
|
step=1,
|
||||||
llm_timeout = st.number_input(
|
key=f"route_majority_{selected_route}",
|
||||||
"请求超时时间 (秒)",
|
)
|
||||||
min_value=5,
|
if not profile_keys:
|
||||||
max_value=120,
|
st.warning("暂无可用 Profile,请先在下方创建。")
|
||||||
value=timeout_value,
|
else:
|
||||||
step=5,
|
try:
|
||||||
)
|
primary_index = profile_keys.index(route_obj.primary)
|
||||||
|
except ValueError:
|
||||||
strategy_options = ["single", "majority", "leader"]
|
primary_index = 0
|
||||||
try:
|
primary_key = st.selectbox(
|
||||||
strategy_index = strategy_options.index(llm_cfg.strategy)
|
"主用 Profile",
|
||||||
except ValueError:
|
profile_keys,
|
||||||
strategy_index = 0
|
index=primary_index,
|
||||||
selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index)
|
key=f"route_primary_{selected_route}",
|
||||||
majority_threshold = st.number_input(
|
|
||||||
"多数投票门槛",
|
|
||||||
min_value=1,
|
|
||||||
max_value=10,
|
|
||||||
value=int(llm_cfg.majority_threshold),
|
|
||||||
step=1,
|
|
||||||
format="%d",
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
|
|
||||||
|
|
||||||
available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
|
||||||
ensemble_rows = [
|
|
||||||
{
|
|
||||||
"provider": ep.provider or "",
|
|
||||||
"model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]),
|
|
||||||
"base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""),
|
|
||||||
"api_key": "***" if ep.api_key else "",
|
|
||||||
"temperature": float(ep.temperature),
|
|
||||||
"timeout": float(ep.timeout),
|
|
||||||
}
|
|
||||||
for ep in llm_cfg.ensemble
|
|
||||||
] or [
|
|
||||||
{
|
|
||||||
"provider": "",
|
|
||||||
"model": "",
|
|
||||||
"base_url": "",
|
|
||||||
"api_key": "",
|
|
||||||
"temperature": provider_default_temp,
|
|
||||||
"timeout": provider_default_timeout,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
edited = st.data_editor(
|
|
||||||
ensemble_rows,
|
|
||||||
num_rows="dynamic",
|
|
||||||
key="llm_ensemble_editor",
|
|
||||||
column_config={
|
|
||||||
"provider": st.column_config.SelectboxColumn(
|
|
||||||
"Provider",
|
|
||||||
options=available_providers,
|
|
||||||
help="选择 LLM 供应商"
|
|
||||||
),
|
|
||||||
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
|
|
||||||
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
|
|
||||||
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
|
|
||||||
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
|
|
||||||
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
|
|
||||||
},
|
|
||||||
hide_index=True,
|
|
||||||
use_container_width=True,
|
|
||||||
)
|
|
||||||
if hasattr(edited, "to_dict"):
|
|
||||||
ensemble_rows = edited.to_dict("records")
|
|
||||||
else:
|
|
||||||
ensemble_rows = edited
|
|
||||||
|
|
||||||
if st.button("保存 LLM 设置"):
|
|
||||||
primary.provider = selected_provider
|
|
||||||
primary.model = chosen_model
|
|
||||||
primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
|
|
||||||
primary.temperature = llm_temperature
|
|
||||||
primary.timeout = llm_timeout
|
|
||||||
api_key_value = llm_api_key.strip()
|
|
||||||
if api_key_value:
|
|
||||||
primary.api_key = api_key_value
|
|
||||||
|
|
||||||
new_ensemble: List[LLMEndpoint] = []
|
|
||||||
for row in ensemble_rows:
|
|
||||||
provider = (row.get("provider") or "").strip().lower()
|
|
||||||
if not provider:
|
|
||||||
continue
|
|
||||||
provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
|
|
||||||
default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
|
||||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider)
|
|
||||||
temp_default = float(provider_defaults.get("temperature", 0.2))
|
|
||||||
timeout_default = float(provider_defaults.get("timeout", 30.0))
|
|
||||||
|
|
||||||
model_val = (row.get("model") or "").strip() or default_model
|
|
||||||
base_val = (row.get("base_url") or "").strip() or default_base
|
|
||||||
api_raw = (row.get("api_key") or "").strip()
|
|
||||||
if api_raw == "***":
|
|
||||||
api_value = existing_api_keys.get(provider)
|
|
||||||
else:
|
|
||||||
api_value = api_raw or None
|
|
||||||
|
|
||||||
temp_val = row.get("temperature")
|
|
||||||
timeout_val = row.get("timeout")
|
|
||||||
endpoint = LLMEndpoint(
|
|
||||||
provider=provider,
|
|
||||||
model=model_val,
|
|
||||||
base_url=base_val,
|
|
||||||
api_key=api_value,
|
|
||||||
temperature=float(temp_val) if temp_val is not None else temp_default,
|
|
||||||
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
|
|
||||||
)
|
)
|
||||||
new_ensemble.append(endpoint)
|
default_ensemble = [
|
||||||
|
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
|
||||||
|
]
|
||||||
|
ensemble_keys = st.multiselect(
|
||||||
|
"协作 Profile (可多选)",
|
||||||
|
profile_keys,
|
||||||
|
default=default_ensemble,
|
||||||
|
key=f"route_ensemble_{selected_route}",
|
||||||
|
)
|
||||||
|
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
|
||||||
|
route_obj.title = route_title.strip()
|
||||||
|
route_obj.strategy = route_strategy
|
||||||
|
route_obj.majority_threshold = int(route_majority)
|
||||||
|
route_obj.primary = primary_key
|
||||||
|
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
|
||||||
|
cfg.llm_route = selected_route
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
LOGGER.info(
|
||||||
|
"路由 %s 配置更新:%s",
|
||||||
|
selected_route,
|
||||||
|
route_obj.to_dict(),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
st.success("路由配置已保存。")
|
||||||
|
st.json({
|
||||||
|
"route": selected_route,
|
||||||
|
"route_detail": route_obj.to_dict(),
|
||||||
|
"resolved": llm_config_snapshot(),
|
||||||
|
})
|
||||||
|
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
|
||||||
|
if st.button(
|
||||||
|
"删除当前路由",
|
||||||
|
key=f"delete_route_{selected_route}",
|
||||||
|
disabled=route_in_use or len(routes) <= 1,
|
||||||
|
):
|
||||||
|
routes.pop(selected_route, None)
|
||||||
|
if cfg.llm_route == selected_route:
|
||||||
|
cfg.llm_route = next((key for key in routes.keys()), "global")
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success("路由已删除。")
|
||||||
|
st.experimental_rerun()
|
||||||
|
|
||||||
llm_cfg.ensemble = new_ensemble
|
st.divider()
|
||||||
llm_cfg.strategy = selected_strategy
|
st.subheader("LLM Profile 管理")
|
||||||
llm_cfg.majority_threshold = int(majority_threshold)
|
profile_select_col, profile_manage_col = st.columns([3, 1])
|
||||||
save_config()
|
if profile_keys:
|
||||||
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
|
selected_profile = profile_select_col.selectbox(
|
||||||
st.success("LLM 设置已保存,仅在当前会话生效。")
|
"选择 Profile",
|
||||||
st.json(llm_config_snapshot())
|
profile_keys,
|
||||||
|
index=0,
|
||||||
|
key="profile_select",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
selected_profile = None
|
||||||
|
profile_select_col.info("尚未配置 Profile,请先创建。")
|
||||||
|
|
||||||
|
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
|
||||||
|
if profile_manage_col.button("创建 Profile"):
|
||||||
|
key = (new_profile_name or "").strip()
|
||||||
|
if not key:
|
||||||
|
st.warning("请输入有效的 Profile 名称。")
|
||||||
|
elif key in profiles:
|
||||||
|
st.warning(f"Profile {key} 已存在。")
|
||||||
|
else:
|
||||||
|
profiles[key] = LLMProfile(key=key)
|
||||||
|
save_config()
|
||||||
|
st.success(f"已创建 Profile {key}。")
|
||||||
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
if selected_profile:
|
||||||
|
profile = profiles[selected_profile]
|
||||||
|
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
||||||
|
try:
|
||||||
|
provider_index = provider_choices.index(profile.provider)
|
||||||
|
except ValueError:
|
||||||
|
provider_index = 0
|
||||||
|
with st.form(f"profile_form_{selected_profile}"):
|
||||||
|
provider_val = st.selectbox(
|
||||||
|
"Provider",
|
||||||
|
provider_choices,
|
||||||
|
index=provider_index,
|
||||||
|
)
|
||||||
|
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
|
||||||
|
model_val = st.text_input(
|
||||||
|
"模型",
|
||||||
|
value=profile.model or model_default,
|
||||||
|
)
|
||||||
|
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
|
||||||
|
base_val = st.text_input(
|
||||||
|
"Base URL",
|
||||||
|
value=profile.base_url or base_default,
|
||||||
|
)
|
||||||
|
api_val = st.text_input(
|
||||||
|
"API Key",
|
||||||
|
value=profile.api_key or "",
|
||||||
|
type="password",
|
||||||
|
)
|
||||||
|
temp_val = st.slider(
|
||||||
|
"温度",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=float(profile.temperature),
|
||||||
|
step=0.05,
|
||||||
|
)
|
||||||
|
timeout_val = st.number_input(
|
||||||
|
"超时(秒)",
|
||||||
|
min_value=5,
|
||||||
|
max_value=180,
|
||||||
|
value=int(profile.timeout or 30),
|
||||||
|
step=5,
|
||||||
|
)
|
||||||
|
title_val = st.text_input("备注", value=profile.title or "")
|
||||||
|
enabled_val = st.checkbox("启用", value=profile.enabled)
|
||||||
|
submitted = st.form_submit_button("保存 Profile")
|
||||||
|
if submitted:
|
||||||
|
profile.provider = provider_val
|
||||||
|
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
|
||||||
|
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||||
|
profile.api_key = api_val.strip() or None
|
||||||
|
profile.temperature = temp_val
|
||||||
|
profile.timeout = timeout_val
|
||||||
|
profile.title = title_val.strip()
|
||||||
|
profile.enabled = enabled_val
|
||||||
|
profiles[selected_profile] = profile
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success("Profile 已保存。")
|
||||||
|
|
||||||
|
profile_in_use = any(
|
||||||
|
selected_profile == route.primary or selected_profile in route.ensemble
|
||||||
|
for route in routes.values()
|
||||||
|
)
|
||||||
|
if st.button(
|
||||||
|
"删除该 Profile",
|
||||||
|
key=f"delete_profile_{selected_profile}",
|
||||||
|
disabled=profile_in_use or len(profiles) <= 1,
|
||||||
|
):
|
||||||
|
profiles.pop(selected_profile, None)
|
||||||
|
fallback_key = next((key for key in profiles.keys()), None)
|
||||||
|
for route in routes.values():
|
||||||
|
if route.primary == selected_profile:
|
||||||
|
route.primary = fallback_key or route.primary
|
||||||
|
route.ensemble = [key for key in route.ensemble if key != selected_profile]
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
save_config()
|
||||||
|
st.success("Profile 已删除。")
|
||||||
|
st.experimental_rerun()
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.subheader("部门配置")
|
st.subheader("部门配置")
|
||||||
|
|
||||||
dept_settings = cfg.departments or {}
|
dept_settings = cfg.departments or {}
|
||||||
|
route_options_display = [""] + route_keys
|
||||||
dept_rows = [
|
dept_rows = [
|
||||||
{
|
{
|
||||||
"code": code,
|
"code": code,
|
||||||
"title": dept.title,
|
"title": dept.title,
|
||||||
"description": dept.description,
|
"description": dept.description,
|
||||||
"weight": float(dept.weight),
|
"weight": float(dept.weight),
|
||||||
|
"llm_route": dept.llm_route or "",
|
||||||
"strategy": dept.llm.strategy,
|
"strategy": dept.llm.strategy,
|
||||||
"primary_provider": (dept.llm.primary.provider or "ollama"),
|
"primary_provider": (dept.llm.primary.provider or ""),
|
||||||
"primary_model": dept.llm.primary.model or "",
|
"primary_model": dept.llm.primary.model or "",
|
||||||
"ensemble_size": len(dept.llm.ensemble),
|
"ensemble_size": len(dept.llm.ensemble),
|
||||||
}
|
}
|
||||||
@ -572,20 +618,25 @@ def render_settings() -> None:
|
|||||||
"title": st.column_config.TextColumn("名称"),
|
"title": st.column_config.TextColumn("名称"),
|
||||||
"description": st.column_config.TextColumn("说明"),
|
"description": st.column_config.TextColumn("说明"),
|
||||||
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
||||||
|
"llm_route": st.column_config.SelectboxColumn(
|
||||||
|
"路由",
|
||||||
|
options=route_options_display,
|
||||||
|
help="选择预定义路由;留空表示使用自定义配置",
|
||||||
|
),
|
||||||
"strategy": st.column_config.SelectboxColumn(
|
"strategy": st.column_config.SelectboxColumn(
|
||||||
"策略",
|
"自定义策略",
|
||||||
options=sorted(ALLOWED_LLM_STRATEGIES),
|
options=sorted(ALLOWED_LLM_STRATEGIES),
|
||||||
help="single=单模型, majority=多数投票, leader=顾问-决策者模式",
|
help="仅当未选择路由时生效",
|
||||||
),
|
),
|
||||||
"primary_provider": st.column_config.SelectboxColumn(
|
"primary_provider": st.column_config.SelectboxColumn(
|
||||||
"主模型 Provider",
|
"自定义 Provider",
|
||||||
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
||||||
),
|
),
|
||||||
"primary_model": st.column_config.TextColumn("主模型名称"),
|
"primary_model": st.column_config.TextColumn("自定义模型"),
|
||||||
"ensemble_size": st.column_config.NumberColumn(
|
"ensemble_size": st.column_config.NumberColumn(
|
||||||
"协作模型数量",
|
"协作模型数量",
|
||||||
disabled=True,
|
disabled=True,
|
||||||
help="在 config.json 中编辑 ensemble 详情",
|
help="路由模式下自动维护",
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -609,31 +660,32 @@ def render_settings() -> None:
|
|||||||
try:
|
try:
|
||||||
existing.weight = max(0.0, float(row.get("weight", existing.weight)))
|
existing.weight = max(0.0, float(row.get("weight", existing.weight)))
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
existing.weight = existing.weight
|
pass
|
||||||
|
|
||||||
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
route_name = (row.get("llm_route") or "").strip() or None
|
||||||
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
existing.llm_route = route_name
|
||||||
existing.llm.strategy = strategy_val
|
if route_name and route_name in routes:
|
||||||
|
existing.llm = routes[route_name].resolve(profiles)
|
||||||
provider_before = existing.llm.primary.provider or ""
|
|
||||||
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
|
|
||||||
existing.llm.primary.provider = provider_val
|
|
||||||
|
|
||||||
model_val = (row.get("primary_model") or "").strip()
|
|
||||||
if model_val:
|
|
||||||
existing.llm.primary.model = model_val
|
|
||||||
else:
|
else:
|
||||||
existing.llm.primary.model = DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
||||||
|
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
||||||
if provider_before != provider_val:
|
existing.llm.strategy = strategy_val
|
||||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
|
provider_before = existing.llm.primary.provider or ""
|
||||||
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
|
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
|
||||||
|
existing.llm.primary.provider = provider_val
|
||||||
existing.llm.primary.__post_init__()
|
model_val = (row.get("primary_model") or "").strip()
|
||||||
|
existing.llm.primary.model = (
|
||||||
|
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
||||||
|
)
|
||||||
|
if provider_before != provider_val:
|
||||||
|
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||||
|
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
|
||||||
|
existing.llm.primary.__post_init__()
|
||||||
updated_departments[code] = existing
|
updated_departments[code] = existing
|
||||||
|
|
||||||
if updated_departments:
|
if updated_departments:
|
||||||
cfg.departments = updated_departments
|
cfg.departments = updated_departments
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
save_config()
|
save_config()
|
||||||
st.success("部门配置已更新。")
|
st.success("部门配置已更新。")
|
||||||
else:
|
else:
|
||||||
@ -643,10 +695,12 @@ def render_settings() -> None:
|
|||||||
from app.utils.config import _default_departments
|
from app.utils.config import _default_departments
|
||||||
|
|
||||||
cfg.departments = _default_departments()
|
cfg.departments = _default_departments()
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
save_config()
|
save_config()
|
||||||
st.success("已恢复默认部门配置。")
|
st.success("已恢复默认部门配置。")
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
|
||||||
st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。")
|
st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
|
||||||
def _default_root() -> Path:
|
def _default_root() -> Path:
|
||||||
@ -132,6 +132,135 @@ class LLMConfig:
|
|||||||
majority_threshold: int = 3
|
majority_threshold: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMProfile:
|
||||||
|
"""Named LLM endpoint profile reusable across routes/departments."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
provider: str = "ollama"
|
||||||
|
model: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
temperature: float = 0.2
|
||||||
|
timeout: float = 30.0
|
||||||
|
title: str = ""
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
def to_endpoint(self) -> LLMEndpoint:
|
||||||
|
return LLMEndpoint(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model,
|
||||||
|
base_url=self.base_url,
|
||||||
|
api_key=self.api_key,
|
||||||
|
temperature=self.temperature,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, object]:
|
||||||
|
return {
|
||||||
|
"provider": self.provider,
|
||||||
|
"model": self.model,
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"title": self.title,
|
||||||
|
"enabled": self.enabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_endpoint(
|
||||||
|
cls,
|
||||||
|
key: str,
|
||||||
|
endpoint: LLMEndpoint,
|
||||||
|
*,
|
||||||
|
title: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
|
) -> "LLMProfile":
|
||||||
|
return cls(
|
||||||
|
key=key,
|
||||||
|
provider=endpoint.provider,
|
||||||
|
model=endpoint.model,
|
||||||
|
base_url=endpoint.base_url,
|
||||||
|
api_key=endpoint.api_key,
|
||||||
|
temperature=endpoint.temperature,
|
||||||
|
timeout=endpoint.timeout,
|
||||||
|
title=title,
|
||||||
|
enabled=enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMRoute:
|
||||||
|
"""Declarative routing for selecting profiles and strategy."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
title: str = ""
|
||||||
|
strategy: str = "single"
|
||||||
|
majority_threshold: int = 3
|
||||||
|
primary: str = "ollama"
|
||||||
|
ensemble: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
|
||||||
|
def _endpoint_from_key(key: str) -> LLMEndpoint:
|
||||||
|
profile = profiles.get(key)
|
||||||
|
if profile and profile.enabled:
|
||||||
|
return profile.to_endpoint()
|
||||||
|
fallback = profiles.get("ollama")
|
||||||
|
if not fallback or not fallback.enabled:
|
||||||
|
fallback = next(
|
||||||
|
(item for item in profiles.values() if item.enabled),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
|
||||||
|
endpoint.provider = key or endpoint.provider
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
primary_endpoint = _endpoint_from_key(self.primary)
|
||||||
|
ensemble_endpoints = [
|
||||||
|
_endpoint_from_key(key)
|
||||||
|
for key in self.ensemble
|
||||||
|
if key in profiles and profiles[key].enabled
|
||||||
|
]
|
||||||
|
config = LLMConfig(
|
||||||
|
primary=primary_endpoint,
|
||||||
|
ensemble=ensemble_endpoints,
|
||||||
|
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
|
majority_threshold=max(1, self.majority_threshold or 1),
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, object]:
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"strategy": self.strategy,
|
||||||
|
"majority_threshold": self.majority_threshold,
|
||||||
|
"primary": self.primary,
|
||||||
|
"ensemble": list(self.ensemble),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _default_llm_profiles() -> Dict[str, LLMProfile]:
|
||||||
|
return {
|
||||||
|
provider: LLMProfile(
|
||||||
|
key=provider,
|
||||||
|
provider=provider,
|
||||||
|
model=DEFAULT_LLM_MODELS.get(provider),
|
||||||
|
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
|
||||||
|
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
|
||||||
|
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
|
||||||
|
title=f"默认 {provider}",
|
||||||
|
)
|
||||||
|
for provider in DEFAULT_LLM_MODEL_OPTIONS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _default_llm_routes() -> Dict[str, LLMRoute]:
|
||||||
|
return {
|
||||||
|
"global": LLMRoute(name="global", title="全局默认路由"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DepartmentSettings:
|
class DepartmentSettings:
|
||||||
"""Configuration for a single decision department."""
|
"""Configuration for a single decision department."""
|
||||||
@ -141,6 +270,7 @@ class DepartmentSettings:
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
weight: float = 1.0
|
weight: float = 1.0
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
|
llm_route: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _default_departments() -> Dict[str, DepartmentSettings]:
|
def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||||
@ -153,7 +283,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
("risk", "风险控制部门"),
|
("risk", "风险控制部门"),
|
||||||
]
|
]
|
||||||
return {
|
return {
|
||||||
code: DepartmentSettings(code=code, title=title)
|
code: DepartmentSettings(code=code, title=title, llm_route="global")
|
||||||
for code, title in presets
|
for code, title in presets
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,8 +299,21 @@ class AppConfig:
|
|||||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||||
force_refresh: bool = False
|
force_refresh: bool = False
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
|
llm_route: str = "global"
|
||||||
|
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
|
||||||
|
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
|
||||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||||
|
|
||||||
|
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||||
|
route_key = route or self.llm_route
|
||||||
|
route_cfg = self.llm_routes.get(route_key)
|
||||||
|
if route_cfg:
|
||||||
|
return route_cfg.resolve(self.llm_profiles)
|
||||||
|
return self.llm
|
||||||
|
|
||||||
|
def sync_runtime_llm(self) -> None:
|
||||||
|
self.llm = self.resolve_llm()
|
||||||
|
|
||||||
|
|
||||||
CONFIG = AppConfig()
|
CONFIG = AppConfig()
|
||||||
|
|
||||||
@ -213,11 +356,69 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
if "decision_method" in payload:
|
if "decision_method" in payload:
|
||||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||||
|
|
||||||
|
routes_defined = False
|
||||||
|
inline_primary_loaded = False
|
||||||
|
|
||||||
|
profiles_payload = payload.get("llm_profiles")
|
||||||
|
if isinstance(profiles_payload, dict):
|
||||||
|
profiles: Dict[str, LLMProfile] = {}
|
||||||
|
for key, data in profiles_payload.items():
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
provider = str(data.get("provider") or "ollama").lower()
|
||||||
|
profile = LLMProfile(
|
||||||
|
key=key,
|
||||||
|
provider=provider,
|
||||||
|
model=data.get("model"),
|
||||||
|
base_url=data.get("base_url"),
|
||||||
|
api_key=data.get("api_key"),
|
||||||
|
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||||
|
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||||
|
title=str(data.get("title") or ""),
|
||||||
|
enabled=bool(data.get("enabled", True)),
|
||||||
|
)
|
||||||
|
profiles[key] = profile
|
||||||
|
if profiles:
|
||||||
|
cfg.llm_profiles = profiles
|
||||||
|
|
||||||
|
routes_payload = payload.get("llm_routes")
|
||||||
|
if isinstance(routes_payload, dict):
|
||||||
|
routes: Dict[str, LLMRoute] = {}
|
||||||
|
for name, data in routes_payload.items():
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
strategy_raw = str(data.get("strategy") or "single").lower()
|
||||||
|
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||||
|
route = LLMRoute(
|
||||||
|
name=name,
|
||||||
|
title=str(data.get("title") or ""),
|
||||||
|
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
|
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
|
||||||
|
primary=str(data.get("primary") or "global"),
|
||||||
|
ensemble=[
|
||||||
|
str(item)
|
||||||
|
for item in data.get("ensemble", [])
|
||||||
|
if isinstance(item, str)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
routes[name] = route
|
||||||
|
if routes:
|
||||||
|
cfg.llm_routes = routes
|
||||||
|
routes_defined = True
|
||||||
|
|
||||||
|
route_key = payload.get("llm_route")
|
||||||
|
if isinstance(route_key, str) and route_key:
|
||||||
|
cfg.llm_route = route_key
|
||||||
|
|
||||||
llm_payload = payload.get("llm")
|
llm_payload = payload.get("llm")
|
||||||
if isinstance(llm_payload, dict):
|
if isinstance(llm_payload, dict):
|
||||||
|
route_value = llm_payload.get("route")
|
||||||
|
if isinstance(route_value, str) and route_value:
|
||||||
|
cfg.llm_route = route_value
|
||||||
primary_data = llm_payload.get("primary")
|
primary_data = llm_payload.get("primary")
|
||||||
if isinstance(primary_data, dict):
|
if isinstance(primary_data, dict):
|
||||||
cfg.llm.primary = _dict_to_endpoint(primary_data)
|
cfg.llm.primary = _dict_to_endpoint(primary_data)
|
||||||
|
inline_primary_loaded = True
|
||||||
|
|
||||||
ensemble_data = llm_payload.get("ensemble")
|
ensemble_data = llm_payload.get("ensemble")
|
||||||
if isinstance(ensemble_data, list):
|
if isinstance(ensemble_data, list):
|
||||||
@ -237,6 +438,30 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
if isinstance(majority, int) and majority > 0:
|
if isinstance(majority, int) and majority > 0:
|
||||||
cfg.llm.majority_threshold = majority
|
cfg.llm.majority_threshold = majority
|
||||||
|
|
||||||
|
if inline_primary_loaded and not routes_defined:
|
||||||
|
primary_key = "inline_global_primary"
|
||||||
|
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
|
||||||
|
primary_key,
|
||||||
|
cfg.llm.primary,
|
||||||
|
title="全局主模型",
|
||||||
|
)
|
||||||
|
ensemble_keys: List[str] = []
|
||||||
|
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
|
||||||
|
inline_key = f"inline_global_ensemble_{idx}"
|
||||||
|
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
|
||||||
|
inline_key,
|
||||||
|
endpoint,
|
||||||
|
title=f"全局协作#{idx}",
|
||||||
|
)
|
||||||
|
ensemble_keys.append(inline_key)
|
||||||
|
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
|
||||||
|
auto_route.strategy = cfg.llm.strategy
|
||||||
|
auto_route.majority_threshold = cfg.llm.majority_threshold
|
||||||
|
auto_route.primary = primary_key
|
||||||
|
auto_route.ensemble = ensemble_keys
|
||||||
|
cfg.llm_routes["global"] = auto_route
|
||||||
|
cfg.llm_route = cfg.llm_route or "global"
|
||||||
|
|
||||||
departments_payload = payload.get("departments")
|
departments_payload = payload.get("departments")
|
||||||
if isinstance(departments_payload, dict):
|
if isinstance(departments_payload, dict):
|
||||||
new_departments: Dict[str, DepartmentSettings] = {}
|
new_departments: Dict[str, DepartmentSettings] = {}
|
||||||
@ -264,35 +489,55 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
majority_raw = llm_data.get("majority_threshold")
|
majority_raw = llm_data.get("majority_threshold")
|
||||||
if isinstance(majority_raw, int) and majority_raw > 0:
|
if isinstance(majority_raw, int) and majority_raw > 0:
|
||||||
llm_cfg.majority_threshold = majority_raw
|
llm_cfg.majority_threshold = majority_raw
|
||||||
|
route = data.get("llm_route")
|
||||||
|
route_name = str(route).strip() if isinstance(route, str) and route else None
|
||||||
|
resolved = llm_cfg
|
||||||
|
if route_name and route_name in cfg.llm_routes:
|
||||||
|
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
|
||||||
new_departments[code] = DepartmentSettings(
|
new_departments[code] = DepartmentSettings(
|
||||||
code=code,
|
code=code,
|
||||||
title=title,
|
title=title,
|
||||||
description=description,
|
description=description,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
llm=llm_cfg,
|
llm=resolved,
|
||||||
|
llm_route=route_name,
|
||||||
)
|
)
|
||||||
if new_departments:
|
if new_departments:
|
||||||
cfg.departments = new_departments
|
cfg.departments = new_departments
|
||||||
|
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
|
||||||
|
|
||||||
def save_config(cfg: AppConfig | None = None) -> None:
|
def save_config(cfg: AppConfig | None = None) -> None:
|
||||||
cfg = cfg or CONFIG
|
cfg = cfg or CONFIG
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
path = cfg.data_paths.config_file
|
path = cfg.data_paths.config_file
|
||||||
payload = {
|
payload = {
|
||||||
"tushare_token": cfg.tushare_token,
|
"tushare_token": cfg.tushare_token,
|
||||||
"force_refresh": cfg.force_refresh,
|
"force_refresh": cfg.force_refresh,
|
||||||
"decision_method": cfg.decision_method,
|
"decision_method": cfg.decision_method,
|
||||||
|
"llm_route": cfg.llm_route,
|
||||||
"llm": {
|
"llm": {
|
||||||
|
"route": cfg.llm_route,
|
||||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
"majority_threshold": cfg.llm.majority_threshold,
|
"majority_threshold": cfg.llm.majority_threshold,
|
||||||
"primary": _endpoint_to_dict(cfg.llm.primary),
|
"primary": _endpoint_to_dict(cfg.llm.primary),
|
||||||
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
||||||
},
|
},
|
||||||
|
"llm_profiles": {
|
||||||
|
key: profile.to_dict()
|
||||||
|
for key, profile in cfg.llm_profiles.items()
|
||||||
|
},
|
||||||
|
"llm_routes": {
|
||||||
|
name: route.to_dict()
|
||||||
|
for name, route in cfg.llm_routes.items()
|
||||||
|
},
|
||||||
"departments": {
|
"departments": {
|
||||||
code: {
|
code: {
|
||||||
"title": dept.title,
|
"title": dept.title,
|
||||||
"description": dept.description,
|
"description": dept.description,
|
||||||
"weight": dept.weight,
|
"weight": dept.weight,
|
||||||
|
"llm_route": dept.llm_route,
|
||||||
"llm": {
|
"llm": {
|
||||||
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
"majority_threshold": dept.llm.majority_threshold,
|
"majority_threshold": dept.llm.majority_threshold,
|
||||||
@ -320,7 +565,15 @@ def _load_env_defaults(cfg: AppConfig) -> None:
|
|||||||
|
|
||||||
api_key = os.getenv("LLM_API_KEY")
|
api_key = os.getenv("LLM_API_KEY")
|
||||||
if api_key:
|
if api_key:
|
||||||
cfg.llm.primary.api_key = api_key.strip()
|
sanitized = api_key.strip()
|
||||||
|
cfg.llm.primary.api_key = sanitized
|
||||||
|
route = cfg.llm_routes.get(cfg.llm_route)
|
||||||
|
if route:
|
||||||
|
profile = cfg.llm_profiles.get(route.primary)
|
||||||
|
if profile:
|
||||||
|
profile.api_key = sanitized
|
||||||
|
|
||||||
|
cfg.sync_runtime_llm()
|
||||||
|
|
||||||
|
|
||||||
_load_from_file(CONFIG)
|
_load_from_file(CONFIG)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user