update
This commit is contained in:
parent
1a99b72c60
commit
6aece20816
15
README.md
15
README.md
@ -24,6 +24,7 @@
|
||||
- **统一日志与持久化**:SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
|
||||
- **跨市场数据扩展**:`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
|
||||
- **部门化多模型协作**:`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
|
||||
- **LLM Profile/Route 管理**:`app/utils/config.py` 引入可复用的 Profile(终端定义)与 Route(推理策略组合),Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。
|
||||
|
||||
## LLM + 多智能体最佳实践
|
||||
|
||||
@ -59,13 +60,11 @@ export TUSHARE_TOKEN="<your-token>"
|
||||
|
||||
### LLM 配置与测试
|
||||
|
||||
- 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间,亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载。
|
||||
- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。
|
||||
- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。
|
||||
- 未来可对同一功能的智能体并行调用多个 LLM,采用多数投票等策略增强鲁棒性,当前代码结构已为此预留扩展空间。
|
||||
- 若使用环境变量自动注入配置,可设置:
|
||||
- `TUSHARE_TOKEN`
|
||||
- `LLM_API_KEY`
|
||||
- 新增 Profile/Route 双层配置:Profile 定义单个端点(含 Provider/模型/域名/API Key),Route 组合 Profile 并指定推理策略(single/majority/leader)。全局路由可一键切换,部门可复用命名路由或保留自定义设置。
|
||||
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由,保存即写入 `app/data/config.json`;Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
|
||||
- 支持本地 Ollama 与多家 OpenAI 兼容供应商(DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
|
||||
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
|
||||
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
|
||||
|
||||
## 快速开始
|
||||
|
||||
@ -105,7 +104,7 @@ Streamlit `自检测试` 页签提供:
|
||||
## 实施步骤
|
||||
|
||||
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
|
||||
- 部门支持 primary/ensemble、策略(single/majority/leader)、权重,并可在 Streamlit 中编辑主要字段。
|
||||
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略,部门可复用路由或使用自定义配置;Streamlit 提供可视化维护表单。
|
||||
|
||||
2. **部门管控器** ✅
|
||||
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。
|
||||
|
||||
@ -3,12 +3,12 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Mapping
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
from app.agents.base import AgentAction
|
||||
from app.llm.client import run_llm_with_config
|
||||
from app.llm.prompts import department_prompt
|
||||
from app.utils.config import DepartmentSettings
|
||||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -53,16 +53,27 @@ class DepartmentDecision:
|
||||
class DepartmentAgent:
|
||||
"""Wraps LLM ensemble logic for a single analytical department."""
|
||||
|
||||
def __init__(self, settings: DepartmentSettings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: DepartmentSettings,
|
||||
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self._resolver = resolver
|
||||
|
||||
def _get_llm_config(self) -> LLMConfig:
|
||||
if self._resolver:
|
||||
return self._resolver(self.settings)
|
||||
return self.settings.llm
|
||||
|
||||
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
|
||||
prompt = department_prompt(self.settings, context)
|
||||
system_prompt = (
|
||||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||||
)
|
||||
llm_cfg = self._get_llm_config()
|
||||
try:
|
||||
response = run_llm_with_config(self.settings.llm, prompt, system=system_prompt)
|
||||
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
|
||||
return DepartmentDecision(
|
||||
@ -106,10 +117,11 @@ class DepartmentAgent:
|
||||
class DepartmentManager:
|
||||
"""Orchestrates all departments defined in configuration."""
|
||||
|
||||
def __init__(self, departments: Mapping[str, DepartmentSettings]) -> None:
|
||||
def __init__(self, config: AppConfig) -> None:
|
||||
self.config = config
|
||||
self.agents: Dict[str, DepartmentAgent] = {
|
||||
code: DepartmentAgent(settings)
|
||||
for code, settings in departments.items()
|
||||
code: DepartmentAgent(settings, self._resolve_llm)
|
||||
for code, settings in config.departments.items()
|
||||
}
|
||||
|
||||
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
|
||||
@ -118,6 +130,11 @@ class DepartmentManager:
|
||||
results[code] = agent.analyze(context)
|
||||
return results
|
||||
|
||||
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
||||
if settings.llm_route and settings.llm_route in self.config.llm_routes:
|
||||
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
|
||||
return settings.llm
|
||||
|
||||
|
||||
def _parse_department_response(text: str) -> Dict[str, Any]:
|
||||
"""Extract a JSON object from the LLM response if possible."""
|
||||
|
||||
@ -55,7 +55,7 @@ class BacktestEngine:
|
||||
else:
|
||||
self.weights = {agent.name: 1.0 for agent in self.agents}
|
||||
self.department_manager = (
|
||||
DepartmentManager(app_cfg.departments) if app_cfg.departments else None
|
||||
DepartmentManager(app_cfg) if app_cfg.departments else None
|
||||
)
|
||||
|
||||
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:
|
||||
|
||||
@ -272,7 +272,8 @@ def run_llm_with_config(
|
||||
def llm_config_snapshot() -> Dict[str, object]:
|
||||
"""Return a sanitized snapshot of current LLM configuration for debugging."""
|
||||
|
||||
settings = get_config().llm
|
||||
cfg = get_config()
|
||||
settings = cfg.llm
|
||||
primary = asdict(settings.primary)
|
||||
if primary.get("api_key"):
|
||||
primary["api_key"] = "***"
|
||||
@ -282,7 +283,11 @@ def llm_config_snapshot() -> Dict[str, object]:
|
||||
if record.get("api_key"):
|
||||
record["api_key"] = "***"
|
||||
ensemble.append(record)
|
||||
route_name = cfg.llm_route
|
||||
route_obj = cfg.llm_routes.get(route_name)
|
||||
return {
|
||||
"route": route_name,
|
||||
"route_detail": route_obj.to_dict() if route_obj else None,
|
||||
"strategy": settings.strategy,
|
||||
"majority_threshold": settings.majority_threshold,
|
||||
"primary": primary,
|
||||
|
||||
@ -29,7 +29,8 @@ from app.utils.config import (
|
||||
DEFAULT_LLM_MODEL_OPTIONS,
|
||||
DEFAULT_LLM_MODELS,
|
||||
DepartmentSettings,
|
||||
LLMEndpoint,
|
||||
LLMProfile,
|
||||
LLMRoute,
|
||||
get_config,
|
||||
save_config,
|
||||
)
|
||||
@ -349,208 +350,253 @@ def render_settings() -> None:
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM 设置")
|
||||
llm_cfg = cfg.llm
|
||||
primary = llm_cfg.primary
|
||||
providers = sorted(DEFAULT_LLM_MODELS.keys())
|
||||
profiles = cfg.llm_profiles or {}
|
||||
routes = cfg.llm_routes or {}
|
||||
profile_keys = sorted(profiles.keys())
|
||||
route_keys = sorted(routes.keys())
|
||||
used_routes = {
|
||||
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
|
||||
}
|
||||
st.caption("Profile 定义单个模型终端,Route 负责组合 Profile 与推理策略。")
|
||||
|
||||
route_select_col, route_manage_col = st.columns([3, 1])
|
||||
if route_keys:
|
||||
try:
|
||||
provider_index = providers.index((primary.provider or "ollama").lower())
|
||||
active_index = route_keys.index(cfg.llm_route)
|
||||
except ValueError:
|
||||
provider_index = 0
|
||||
selected_provider = st.selectbox("LLM Provider", providers, index=provider_index)
|
||||
provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {})
|
||||
model_options = provider_info.get("models", [])
|
||||
custom_model_label = "自定义模型"
|
||||
default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
|
||||
|
||||
if model_options:
|
||||
options_with_custom = model_options + [custom_model_label]
|
||||
if primary.provider == selected_provider and primary.model in model_options:
|
||||
model_index = options_with_custom.index(primary.model)
|
||||
active_index = 0
|
||||
selected_route = route_select_col.selectbox(
|
||||
"全局路由",
|
||||
route_keys,
|
||||
index=active_index,
|
||||
key="llm_route_select",
|
||||
)
|
||||
else:
|
||||
model_index = 0
|
||||
selected_model_option = st.selectbox(
|
||||
"LLM 模型",
|
||||
options_with_custom,
|
||||
index=model_index,
|
||||
help=f"可选模型:{', '.join(model_options)}",
|
||||
)
|
||||
if selected_model_option == custom_model_label:
|
||||
custom_model_value = st.text_input(
|
||||
"自定义模型名称",
|
||||
value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
|
||||
)
|
||||
chosen_model = custom_model_value.strip() or default_model_hint
|
||||
else:
|
||||
chosen_model = selected_model_option
|
||||
else:
|
||||
chosen_model = st.text_input(
|
||||
"LLM 模型",
|
||||
value=primary.model or default_model_hint,
|
||||
help="未预设该 Provider 的模型列表,请手动填写",
|
||||
).strip() or default_model_hint
|
||||
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
|
||||
provider_default_temp = float(provider_info.get("temperature", 0.2))
|
||||
provider_default_timeout = int(provider_info.get("timeout", 30.0))
|
||||
selected_route = None
|
||||
route_select_col.info("尚未配置路由,请先创建。")
|
||||
|
||||
if primary.provider == selected_provider:
|
||||
base_value = primary.base_url or default_base_hint or ""
|
||||
temp_value = float(primary.temperature)
|
||||
timeout_value = int(primary.timeout)
|
||||
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
|
||||
if route_manage_col.button("添加路由"):
|
||||
key = (new_route_name or "").strip()
|
||||
if not key:
|
||||
st.warning("请输入有效的路由名称。")
|
||||
elif key in routes:
|
||||
st.warning(f"路由 {key} 已存在。")
|
||||
else:
|
||||
base_value = default_base_hint or ""
|
||||
temp_value = provider_default_temp
|
||||
timeout_value = provider_default_timeout
|
||||
routes[key] = LLMRoute(name=key)
|
||||
if not selected_route:
|
||||
selected_route = key
|
||||
cfg.llm_route = key
|
||||
save_config()
|
||||
st.success(f"已添加路由 {key},请继续配置。")
|
||||
st.experimental_rerun()
|
||||
|
||||
llm_base = st.text_input(
|
||||
"LLM Base URL",
|
||||
value=base_value,
|
||||
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
|
||||
)
|
||||
llm_api_key = st.text_input(
|
||||
"LLM API Key",
|
||||
value=primary.api_key or "",
|
||||
type="password",
|
||||
help="点击右侧小图标可查看当前 Key,该值会写入 config.json(已被 gitignore 排除)",
|
||||
)
|
||||
llm_temperature = st.slider(
|
||||
"LLM 温度",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=temp_value,
|
||||
step=0.05,
|
||||
)
|
||||
llm_timeout = st.number_input(
|
||||
"请求超时时间 (秒)",
|
||||
min_value=5,
|
||||
max_value=120,
|
||||
value=timeout_value,
|
||||
step=5,
|
||||
)
|
||||
|
||||
strategy_options = ["single", "majority", "leader"]
|
||||
if selected_route:
|
||||
route_obj = routes.get(selected_route)
|
||||
if route_obj is None:
|
||||
route_obj = LLMRoute(name=selected_route)
|
||||
routes[selected_route] = route_obj
|
||||
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
|
||||
try:
|
||||
strategy_index = strategy_options.index(llm_cfg.strategy)
|
||||
strategy_index = strategy_choices.index(route_obj.strategy)
|
||||
except ValueError:
|
||||
strategy_index = 0
|
||||
selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index)
|
||||
majority_threshold = st.number_input(
|
||||
route_title = st.text_input(
|
||||
"路由说明",
|
||||
value=route_obj.title or "",
|
||||
key=f"route_title_{selected_route}",
|
||||
)
|
||||
route_strategy = st.selectbox(
|
||||
"推理策略",
|
||||
strategy_choices,
|
||||
index=strategy_index,
|
||||
key=f"route_strategy_{selected_route}",
|
||||
)
|
||||
route_majority = st.number_input(
|
||||
"多数投票门槛",
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
value=int(llm_cfg.majority_threshold),
|
||||
value=int(route_obj.majority_threshold or 1),
|
||||
step=1,
|
||||
format="%d",
|
||||
key=f"route_majority_{selected_route}",
|
||||
)
|
||||
|
||||
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
|
||||
|
||||
available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
||||
ensemble_rows = [
|
||||
{
|
||||
"provider": ep.provider or "",
|
||||
"model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]),
|
||||
"base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""),
|
||||
"api_key": "***" if ep.api_key else "",
|
||||
"temperature": float(ep.temperature),
|
||||
"timeout": float(ep.timeout),
|
||||
}
|
||||
for ep in llm_cfg.ensemble
|
||||
] or [
|
||||
{
|
||||
"provider": "",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"temperature": provider_default_temp,
|
||||
"timeout": provider_default_timeout,
|
||||
}
|
||||
if not profile_keys:
|
||||
st.warning("暂无可用 Profile,请先在下方创建。")
|
||||
else:
|
||||
try:
|
||||
primary_index = profile_keys.index(route_obj.primary)
|
||||
except ValueError:
|
||||
primary_index = 0
|
||||
primary_key = st.selectbox(
|
||||
"主用 Profile",
|
||||
profile_keys,
|
||||
index=primary_index,
|
||||
key=f"route_primary_{selected_route}",
|
||||
)
|
||||
default_ensemble = [
|
||||
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
|
||||
]
|
||||
|
||||
edited = st.data_editor(
|
||||
ensemble_rows,
|
||||
num_rows="dynamic",
|
||||
key="llm_ensemble_editor",
|
||||
column_config={
|
||||
"provider": st.column_config.SelectboxColumn(
|
||||
"Provider",
|
||||
options=available_providers,
|
||||
help="选择 LLM 供应商"
|
||||
),
|
||||
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
|
||||
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
|
||||
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
|
||||
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
|
||||
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
|
||||
},
|
||||
hide_index=True,
|
||||
use_container_width=True,
|
||||
ensemble_keys = st.multiselect(
|
||||
"协作 Profile (可多选)",
|
||||
profile_keys,
|
||||
default=default_ensemble,
|
||||
key=f"route_ensemble_{selected_route}",
|
||||
)
|
||||
if hasattr(edited, "to_dict"):
|
||||
ensemble_rows = edited.to_dict("records")
|
||||
else:
|
||||
ensemble_rows = edited
|
||||
|
||||
if st.button("保存 LLM 设置"):
|
||||
primary.provider = selected_provider
|
||||
primary.model = chosen_model
|
||||
primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
|
||||
primary.temperature = llm_temperature
|
||||
primary.timeout = llm_timeout
|
||||
api_key_value = llm_api_key.strip()
|
||||
if api_key_value:
|
||||
primary.api_key = api_key_value
|
||||
|
||||
new_ensemble: List[LLMEndpoint] = []
|
||||
for row in ensemble_rows:
|
||||
provider = (row.get("provider") or "").strip().lower()
|
||||
if not provider:
|
||||
continue
|
||||
provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
|
||||
default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider)
|
||||
temp_default = float(provider_defaults.get("temperature", 0.2))
|
||||
timeout_default = float(provider_defaults.get("timeout", 30.0))
|
||||
|
||||
model_val = (row.get("model") or "").strip() or default_model
|
||||
base_val = (row.get("base_url") or "").strip() or default_base
|
||||
api_raw = (row.get("api_key") or "").strip()
|
||||
if api_raw == "***":
|
||||
api_value = existing_api_keys.get(provider)
|
||||
else:
|
||||
api_value = api_raw or None
|
||||
|
||||
temp_val = row.get("temperature")
|
||||
timeout_val = row.get("timeout")
|
||||
endpoint = LLMEndpoint(
|
||||
provider=provider,
|
||||
model=model_val,
|
||||
base_url=base_val,
|
||||
api_key=api_value,
|
||||
temperature=float(temp_val) if temp_val is not None else temp_default,
|
||||
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
|
||||
)
|
||||
new_ensemble.append(endpoint)
|
||||
|
||||
llm_cfg.ensemble = new_ensemble
|
||||
llm_cfg.strategy = selected_strategy
|
||||
llm_cfg.majority_threshold = int(majority_threshold)
|
||||
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
|
||||
route_obj.title = route_title.strip()
|
||||
route_obj.strategy = route_strategy
|
||||
route_obj.majority_threshold = int(route_majority)
|
||||
route_obj.primary = primary_key
|
||||
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
|
||||
cfg.llm_route = selected_route
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
|
||||
st.success("LLM 设置已保存,仅在当前会话生效。")
|
||||
st.json(llm_config_snapshot())
|
||||
LOGGER.info(
|
||||
"路由 %s 配置更新:%s",
|
||||
selected_route,
|
||||
route_obj.to_dict(),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
st.success("路由配置已保存。")
|
||||
st.json({
|
||||
"route": selected_route,
|
||||
"route_detail": route_obj.to_dict(),
|
||||
"resolved": llm_config_snapshot(),
|
||||
})
|
||||
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
|
||||
if st.button(
|
||||
"删除当前路由",
|
||||
key=f"delete_route_{selected_route}",
|
||||
disabled=route_in_use or len(routes) <= 1,
|
||||
):
|
||||
routes.pop(selected_route, None)
|
||||
if cfg.llm_route == selected_route:
|
||||
cfg.llm_route = next((key for key in routes.keys()), "global")
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("路由已删除。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM Profile 管理")
|
||||
profile_select_col, profile_manage_col = st.columns([3, 1])
|
||||
if profile_keys:
|
||||
selected_profile = profile_select_col.selectbox(
|
||||
"选择 Profile",
|
||||
profile_keys,
|
||||
index=0,
|
||||
key="profile_select",
|
||||
)
|
||||
else:
|
||||
selected_profile = None
|
||||
profile_select_col.info("尚未配置 Profile,请先创建。")
|
||||
|
||||
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
|
||||
if profile_manage_col.button("创建 Profile"):
|
||||
key = (new_profile_name or "").strip()
|
||||
if not key:
|
||||
st.warning("请输入有效的 Profile 名称。")
|
||||
elif key in profiles:
|
||||
st.warning(f"Profile {key} 已存在。")
|
||||
else:
|
||||
profiles[key] = LLMProfile(key=key)
|
||||
save_config()
|
||||
st.success(f"已创建 Profile {key}。")
|
||||
st.experimental_rerun()
|
||||
|
||||
if selected_profile:
|
||||
profile = profiles[selected_profile]
|
||||
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
|
||||
try:
|
||||
provider_index = provider_choices.index(profile.provider)
|
||||
except ValueError:
|
||||
provider_index = 0
|
||||
with st.form(f"profile_form_{selected_profile}"):
|
||||
provider_val = st.selectbox(
|
||||
"Provider",
|
||||
provider_choices,
|
||||
index=provider_index,
|
||||
)
|
||||
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
|
||||
model_val = st.text_input(
|
||||
"模型",
|
||||
value=profile.model or model_default,
|
||||
)
|
||||
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
|
||||
base_val = st.text_input(
|
||||
"Base URL",
|
||||
value=profile.base_url or base_default,
|
||||
)
|
||||
api_val = st.text_input(
|
||||
"API Key",
|
||||
value=profile.api_key or "",
|
||||
type="password",
|
||||
)
|
||||
temp_val = st.slider(
|
||||
"温度",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=float(profile.temperature),
|
||||
step=0.05,
|
||||
)
|
||||
timeout_val = st.number_input(
|
||||
"超时(秒)",
|
||||
min_value=5,
|
||||
max_value=180,
|
||||
value=int(profile.timeout or 30),
|
||||
step=5,
|
||||
)
|
||||
title_val = st.text_input("备注", value=profile.title or "")
|
||||
enabled_val = st.checkbox("启用", value=profile.enabled)
|
||||
submitted = st.form_submit_button("保存 Profile")
|
||||
if submitted:
|
||||
profile.provider = provider_val
|
||||
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
|
||||
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||
profile.api_key = api_val.strip() or None
|
||||
profile.temperature = temp_val
|
||||
profile.timeout = timeout_val
|
||||
profile.title = title_val.strip()
|
||||
profile.enabled = enabled_val
|
||||
profiles[selected_profile] = profile
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Profile 已保存。")
|
||||
|
||||
profile_in_use = any(
|
||||
selected_profile == route.primary or selected_profile in route.ensemble
|
||||
for route in routes.values()
|
||||
)
|
||||
if st.button(
|
||||
"删除该 Profile",
|
||||
key=f"delete_profile_{selected_profile}",
|
||||
disabled=profile_in_use or len(profiles) <= 1,
|
||||
):
|
||||
profiles.pop(selected_profile, None)
|
||||
fallback_key = next((key for key in profiles.keys()), None)
|
||||
for route in routes.values():
|
||||
if route.primary == selected_profile:
|
||||
route.primary = fallback_key or route.primary
|
||||
route.ensemble = [key for key in route.ensemble if key != selected_profile]
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Profile 已删除。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
st.subheader("部门配置")
|
||||
|
||||
dept_settings = cfg.departments or {}
|
||||
route_options_display = [""] + route_keys
|
||||
dept_rows = [
|
||||
{
|
||||
"code": code,
|
||||
"title": dept.title,
|
||||
"description": dept.description,
|
||||
"weight": float(dept.weight),
|
||||
"llm_route": dept.llm_route or "",
|
||||
"strategy": dept.llm.strategy,
|
||||
"primary_provider": (dept.llm.primary.provider or "ollama"),
|
||||
"primary_provider": (dept.llm.primary.provider or ""),
|
||||
"primary_model": dept.llm.primary.model or "",
|
||||
"ensemble_size": len(dept.llm.ensemble),
|
||||
}
|
||||
@ -572,20 +618,25 @@ def render_settings() -> None:
|
||||
"title": st.column_config.TextColumn("名称"),
|
||||
"description": st.column_config.TextColumn("说明"),
|
||||
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
||||
"llm_route": st.column_config.SelectboxColumn(
|
||||
"路由",
|
||||
options=route_options_display,
|
||||
help="选择预定义路由;留空表示使用自定义配置",
|
||||
),
|
||||
"strategy": st.column_config.SelectboxColumn(
|
||||
"策略",
|
||||
"自定义策略",
|
||||
options=sorted(ALLOWED_LLM_STRATEGIES),
|
||||
help="single=单模型, majority=多数投票, leader=顾问-决策者模式",
|
||||
help="仅当未选择路由时生效",
|
||||
),
|
||||
"primary_provider": st.column_config.SelectboxColumn(
|
||||
"主模型 Provider",
|
||||
"自定义 Provider",
|
||||
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
|
||||
),
|
||||
"primary_model": st.column_config.TextColumn("主模型名称"),
|
||||
"primary_model": st.column_config.TextColumn("自定义模型"),
|
||||
"ensemble_size": st.column_config.NumberColumn(
|
||||
"协作模型数量",
|
||||
disabled=True,
|
||||
help="在 config.json 中编辑 ensemble 详情",
|
||||
help="路由模式下自动维护",
|
||||
),
|
||||
},
|
||||
)
|
||||
@ -609,31 +660,32 @@ def render_settings() -> None:
|
||||
try:
|
||||
existing.weight = max(0.0, float(row.get("weight", existing.weight)))
|
||||
except (TypeError, ValueError):
|
||||
existing.weight = existing.weight
|
||||
pass
|
||||
|
||||
route_name = (row.get("llm_route") or "").strip() or None
|
||||
existing.llm_route = route_name
|
||||
if route_name and route_name in routes:
|
||||
existing.llm = routes[route_name].resolve(profiles)
|
||||
else:
|
||||
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
||||
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
||||
existing.llm.strategy = strategy_val
|
||||
|
||||
provider_before = existing.llm.primary.provider or ""
|
||||
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
|
||||
existing.llm.primary.provider = provider_val
|
||||
|
||||
model_val = (row.get("primary_model") or "").strip()
|
||||
if model_val:
|
||||
existing.llm.primary.model = model_val
|
||||
else:
|
||||
existing.llm.primary.model = DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
||||
|
||||
existing.llm.primary.model = (
|
||||
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
|
||||
)
|
||||
if provider_before != provider_val:
|
||||
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
|
||||
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
|
||||
|
||||
existing.llm.primary.__post_init__()
|
||||
updated_departments[code] = existing
|
||||
|
||||
if updated_departments:
|
||||
cfg.departments = updated_departments
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("部门配置已更新。")
|
||||
else:
|
||||
@ -643,10 +695,12 @@ def render_settings() -> None:
|
||||
from app.utils.config import _default_departments
|
||||
|
||||
cfg.departments = _default_departments()
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("已恢复默认部门配置。")
|
||||
st.experimental_rerun()
|
||||
|
||||
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
|
||||
st.caption("部门协作模型(ensemble)请在 config.json 中手动编辑,UI 将在后续版本补充。")
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
|
||||
|
||||
def _default_root() -> Path:
|
||||
@ -132,6 +132,135 @@ class LLMConfig:
|
||||
majority_threshold: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMProfile:
|
||||
"""Named LLM endpoint profile reusable across routes/departments."""
|
||||
|
||||
key: str
|
||||
provider: str = "ollama"
|
||||
model: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
temperature: float = 0.2
|
||||
timeout: float = 30.0
|
||||
title: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
def to_endpoint(self) -> LLMEndpoint:
|
||||
return LLMEndpoint(
|
||||
provider=self.provider,
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"temperature": self.temperature,
|
||||
"timeout": self.timeout,
|
||||
"title": self.title,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_endpoint(
|
||||
cls,
|
||||
key: str,
|
||||
endpoint: LLMEndpoint,
|
||||
*,
|
||||
title: str = "",
|
||||
enabled: bool = True,
|
||||
) -> "LLMProfile":
|
||||
return cls(
|
||||
key=key,
|
||||
provider=endpoint.provider,
|
||||
model=endpoint.model,
|
||||
base_url=endpoint.base_url,
|
||||
api_key=endpoint.api_key,
|
||||
temperature=endpoint.temperature,
|
||||
timeout=endpoint.timeout,
|
||||
title=title,
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMRoute:
|
||||
"""Declarative routing for selecting profiles and strategy."""
|
||||
|
||||
name: str
|
||||
title: str = ""
|
||||
strategy: str = "single"
|
||||
majority_threshold: int = 3
|
||||
primary: str = "ollama"
|
||||
ensemble: List[str] = field(default_factory=list)
|
||||
|
||||
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
|
||||
def _endpoint_from_key(key: str) -> LLMEndpoint:
|
||||
profile = profiles.get(key)
|
||||
if profile and profile.enabled:
|
||||
return profile.to_endpoint()
|
||||
fallback = profiles.get("ollama")
|
||||
if not fallback or not fallback.enabled:
|
||||
fallback = next(
|
||||
(item for item in profiles.values() if item.enabled),
|
||||
None,
|
||||
)
|
||||
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
|
||||
endpoint.provider = key or endpoint.provider
|
||||
return endpoint
|
||||
|
||||
primary_endpoint = _endpoint_from_key(self.primary)
|
||||
ensemble_endpoints = [
|
||||
_endpoint_from_key(key)
|
||||
for key in self.ensemble
|
||||
if key in profiles and profiles[key].enabled
|
||||
]
|
||||
config = LLMConfig(
|
||||
primary=primary_endpoint,
|
||||
ensemble=ensemble_endpoints,
|
||||
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
majority_threshold=max(1, self.majority_threshold or 1),
|
||||
)
|
||||
return config
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"strategy": self.strategy,
|
||||
"majority_threshold": self.majority_threshold,
|
||||
"primary": self.primary,
|
||||
"ensemble": list(self.ensemble),
|
||||
}
|
||||
|
||||
|
||||
def _default_llm_profiles() -> Dict[str, LLMProfile]:
|
||||
return {
|
||||
provider: LLMProfile(
|
||||
key=provider,
|
||||
provider=provider,
|
||||
model=DEFAULT_LLM_MODELS.get(provider),
|
||||
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
|
||||
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
|
||||
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
|
||||
title=f"默认 {provider}",
|
||||
)
|
||||
for provider in DEFAULT_LLM_MODEL_OPTIONS
|
||||
}
|
||||
|
||||
|
||||
def _default_llm_routes() -> Dict[str, LLMRoute]:
|
||||
return {
|
||||
"global": LLMRoute(name="global", title="全局默认路由"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepartmentSettings:
|
||||
"""Configuration for a single decision department."""
|
||||
@ -141,6 +270,7 @@ class DepartmentSettings:
|
||||
description: str = ""
|
||||
weight: float = 1.0
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
llm_route: Optional[str] = None
|
||||
|
||||
|
||||
def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||
@ -153,7 +283,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||
("risk", "风险控制部门"),
|
||||
]
|
||||
return {
|
||||
code: DepartmentSettings(code=code, title=title)
|
||||
code: DepartmentSettings(code=code, title=title, llm_route="global")
|
||||
for code, title in presets
|
||||
}
|
||||
|
||||
@ -169,8 +299,21 @@ class AppConfig:
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
force_refresh: bool = False
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
llm_route: str = "global"
|
||||
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
|
||||
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
|
||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||
|
||||
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
|
||||
route_key = route or self.llm_route
|
||||
route_cfg = self.llm_routes.get(route_key)
|
||||
if route_cfg:
|
||||
return route_cfg.resolve(self.llm_profiles)
|
||||
return self.llm
|
||||
|
||||
def sync_runtime_llm(self) -> None:
|
||||
self.llm = self.resolve_llm()
|
||||
|
||||
|
||||
CONFIG = AppConfig()
|
||||
|
||||
@ -213,11 +356,69 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if "decision_method" in payload:
|
||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||
|
||||
routes_defined = False
|
||||
inline_primary_loaded = False
|
||||
|
||||
profiles_payload = payload.get("llm_profiles")
|
||||
if isinstance(profiles_payload, dict):
|
||||
profiles: Dict[str, LLMProfile] = {}
|
||||
for key, data in profiles_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
provider = str(data.get("provider") or "ollama").lower()
|
||||
profile = LLMProfile(
|
||||
key=key,
|
||||
provider=provider,
|
||||
model=data.get("model"),
|
||||
base_url=data.get("base_url"),
|
||||
api_key=data.get("api_key"),
|
||||
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||
title=str(data.get("title") or ""),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
)
|
||||
profiles[key] = profile
|
||||
if profiles:
|
||||
cfg.llm_profiles = profiles
|
||||
|
||||
routes_payload = payload.get("llm_routes")
|
||||
if isinstance(routes_payload, dict):
|
||||
routes: Dict[str, LLMRoute] = {}
|
||||
for name, data in routes_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
strategy_raw = str(data.get("strategy") or "single").lower()
|
||||
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
|
||||
route = LLMRoute(
|
||||
name=name,
|
||||
title=str(data.get("title") or ""),
|
||||
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
|
||||
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
|
||||
primary=str(data.get("primary") or "global"),
|
||||
ensemble=[
|
||||
str(item)
|
||||
for item in data.get("ensemble", [])
|
||||
if isinstance(item, str)
|
||||
],
|
||||
)
|
||||
routes[name] = route
|
||||
if routes:
|
||||
cfg.llm_routes = routes
|
||||
routes_defined = True
|
||||
|
||||
route_key = payload.get("llm_route")
|
||||
if isinstance(route_key, str) and route_key:
|
||||
cfg.llm_route = route_key
|
||||
|
||||
llm_payload = payload.get("llm")
|
||||
if isinstance(llm_payload, dict):
|
||||
route_value = llm_payload.get("route")
|
||||
if isinstance(route_value, str) and route_value:
|
||||
cfg.llm_route = route_value
|
||||
primary_data = llm_payload.get("primary")
|
||||
if isinstance(primary_data, dict):
|
||||
cfg.llm.primary = _dict_to_endpoint(primary_data)
|
||||
inline_primary_loaded = True
|
||||
|
||||
ensemble_data = llm_payload.get("ensemble")
|
||||
if isinstance(ensemble_data, list):
|
||||
@ -237,6 +438,30 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if isinstance(majority, int) and majority > 0:
|
||||
cfg.llm.majority_threshold = majority
|
||||
|
||||
if inline_primary_loaded and not routes_defined:
|
||||
primary_key = "inline_global_primary"
|
||||
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
|
||||
primary_key,
|
||||
cfg.llm.primary,
|
||||
title="全局主模型",
|
||||
)
|
||||
ensemble_keys: List[str] = []
|
||||
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
|
||||
inline_key = f"inline_global_ensemble_{idx}"
|
||||
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
|
||||
inline_key,
|
||||
endpoint,
|
||||
title=f"全局协作#{idx}",
|
||||
)
|
||||
ensemble_keys.append(inline_key)
|
||||
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
|
||||
auto_route.strategy = cfg.llm.strategy
|
||||
auto_route.majority_threshold = cfg.llm.majority_threshold
|
||||
auto_route.primary = primary_key
|
||||
auto_route.ensemble = ensemble_keys
|
||||
cfg.llm_routes["global"] = auto_route
|
||||
cfg.llm_route = cfg.llm_route or "global"
|
||||
|
||||
departments_payload = payload.get("departments")
|
||||
if isinstance(departments_payload, dict):
|
||||
new_departments: Dict[str, DepartmentSettings] = {}
|
||||
@ -264,35 +489,55 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
majority_raw = llm_data.get("majority_threshold")
|
||||
if isinstance(majority_raw, int) and majority_raw > 0:
|
||||
llm_cfg.majority_threshold = majority_raw
|
||||
route = data.get("llm_route")
|
||||
route_name = str(route).strip() if isinstance(route, str) and route else None
|
||||
resolved = llm_cfg
|
||||
if route_name and route_name in cfg.llm_routes:
|
||||
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
|
||||
new_departments[code] = DepartmentSettings(
|
||||
code=code,
|
||||
title=title,
|
||||
description=description,
|
||||
weight=weight,
|
||||
llm=llm_cfg,
|
||||
llm=resolved,
|
||||
llm_route=route_name,
|
||||
)
|
||||
if new_departments:
|
||||
cfg.departments = new_departments
|
||||
|
||||
cfg.sync_runtime_llm()
|
||||
|
||||
|
||||
def save_config(cfg: AppConfig | None = None) -> None:
|
||||
cfg = cfg or CONFIG
|
||||
cfg.sync_runtime_llm()
|
||||
path = cfg.data_paths.config_file
|
||||
payload = {
|
||||
"tushare_token": cfg.tushare_token,
|
||||
"force_refresh": cfg.force_refresh,
|
||||
"decision_method": cfg.decision_method,
|
||||
"llm_route": cfg.llm_route,
|
||||
"llm": {
|
||||
"route": cfg.llm_route,
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": cfg.llm.majority_threshold,
|
||||
"primary": _endpoint_to_dict(cfg.llm.primary),
|
||||
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
||||
},
|
||||
"llm_profiles": {
|
||||
key: profile.to_dict()
|
||||
for key, profile in cfg.llm_profiles.items()
|
||||
},
|
||||
"llm_routes": {
|
||||
name: route.to_dict()
|
||||
for name, route in cfg.llm_routes.items()
|
||||
},
|
||||
"departments": {
|
||||
code: {
|
||||
"title": dept.title,
|
||||
"description": dept.description,
|
||||
"weight": dept.weight,
|
||||
"llm_route": dept.llm_route,
|
||||
"llm": {
|
||||
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": dept.llm.majority_threshold,
|
||||
@ -320,7 +565,15 @@ def _load_env_defaults(cfg: AppConfig) -> None:
|
||||
|
||||
api_key = os.getenv("LLM_API_KEY")
|
||||
if api_key:
|
||||
cfg.llm.primary.api_key = api_key.strip()
|
||||
sanitized = api_key.strip()
|
||||
cfg.llm.primary.api_key = sanitized
|
||||
route = cfg.llm_routes.get(cfg.llm_route)
|
||||
if route:
|
||||
profile = cfg.llm_profiles.get(route.primary)
|
||||
if profile:
|
||||
profile.api_key = sanitized
|
||||
|
||||
cfg.sync_runtime_llm()
|
||||
|
||||
|
||||
_load_from_file(CONFIG)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user