This commit is contained in:
sam 2025-09-28 10:10:52 +08:00
parent 1a99b72c60
commit 6aece20816
6 changed files with 558 additions and 230 deletions

View File

@ -24,6 +24,7 @@
- **统一日志与持久化**SQLite 统一存储行情、回测与日志,配合 `DatabaseLogHandler` 在 UI/抓数流程中输出结构化运行轨迹,支持快速追踪与复盘。
- **跨市场数据扩展**`app/ingest/tushare.py` 追加指数、ETF/公募基金、期货、外汇、港股与美股的增量拉取逻辑,确保多资产因子与宏观代理所需的行情基础数据齐备。
- **部门化多模型协作**`app/agents/departments.py` 封装部门级 LLM 调度,`app/llm/client.py` 支持 single/majority/leader 策略,部门结论在 `app/agents/game.py` 与六类基础代理共同博弈,并持久化至 `agent_utils` 供 UI 展示。
- **LLM Profile/Route 管理**`app/utils/config.py` 引入可复用的 Profile终端定义与 Route推理策略组合Streamlit UI 支持可视化维护,全局与部门均可复用命名路由提升配置一致性。
## LLM + 多智能体最佳实践
@ -59,13 +60,11 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试
- 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载。
- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。
- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。
- 未来可对同一功能的智能体并行调用多个 LLM采用多数投票等策略增强鲁棒性当前代码结构已为此预留扩展空间。
- 若使用环境变量自动注入配置,可设置:
- `TUSHARE_TOKEN`
- `LLM_API_KEY`
- 新增 Profile/Route 双层配置Profile 定义单个端点(含 Provider/模型/域名/API KeyRoute 组合 Profile 并指定推理策略single/majority/leader。全局路由可一键切换部门可复用命名路由或保留自定义设置。
- Streamlit “数据与设置” 页通过表单管理 Profile、Route、全局路由保存即写入 `app/data/config.json`Route 预览会同步展示经 `llm_config_snapshot()` 脱敏后的实时配置。
- 支持本地 Ollama 与多家 OpenAI 兼容供应商DeepSeek、文心一言、OpenAI 等),可为不同 Profile 设置默认模型、温度、超时与启用状态。
- UI 保留 TuShare Token 维护,以及路由/Profile 新增、删除、禁用等操作;所有更新即时生效并记入日志。
- 使用环境变量注入敏感信息时,可配置:`TUSHARE_TOKEN`、`LLM_API_KEY`,加载后会同步至当前路由的主 Profile。
## 快速开始
@ -105,7 +104,7 @@ Streamlit `自检测试` 页签提供:
## 实施步骤
1. **配置扩展** (`app/utils/config.py` + `config.json`) ✅
- 部门支持 primary/ensemble、策略single/majority/leader、权重并可在 Streamlit 中编辑主要字段
- 引入 `llm_profiles`/`llm_routes` 统一管理终端与策略部门可复用路由或使用自定义配置Streamlit 提供可视化维护表单
2. **部门管控器**
- `app/agents/departments.py` 提供 `DepartmentAgent`/`DepartmentManager`,封装 Prompt 构建、多模型协商及异常回退。

View File

@ -3,12 +3,12 @@ from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping
from typing import Any, Callable, Dict, List, Mapping, Optional
from app.agents.base import AgentAction
from app.llm.client import run_llm_with_config
from app.llm.prompts import department_prompt
from app.utils.config import DepartmentSettings
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
@ -53,16 +53,27 @@ class DepartmentDecision:
class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department."""
def __init__(self, settings: DepartmentSettings) -> None:
def __init__(
self,
settings: DepartmentSettings,
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
) -> None:
self.settings = settings
self._resolver = resolver
def _get_llm_config(self) -> LLMConfig:
if self._resolver:
return self._resolver(self.settings)
return self.settings.llm
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
system_prompt = (
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(self.settings.llm, prompt, system=system_prompt)
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
return DepartmentDecision(
@ -106,10 +117,11 @@ class DepartmentAgent:
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""
def __init__(self, departments: Mapping[str, DepartmentSettings]) -> None:
def __init__(self, config: AppConfig) -> None:
self.config = config
self.agents: Dict[str, DepartmentAgent] = {
code: DepartmentAgent(settings)
for code, settings in departments.items()
code: DepartmentAgent(settings, self._resolve_llm)
for code, settings in config.departments.items()
}
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
@ -118,6 +130,11 @@ class DepartmentManager:
results[code] = agent.analyze(context)
return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
if settings.llm_route and settings.llm_route in self.config.llm_routes:
return self.config.llm_routes[settings.llm_route].resolve(self.config.llm_profiles)
return settings.llm
def _parse_department_response(text: str) -> Dict[str, Any]:
"""Extract a JSON object from the LLM response if possible."""

View File

@ -55,7 +55,7 @@ class BacktestEngine:
else:
self.weights = {agent.name: 1.0 for agent in self.agents}
self.department_manager = (
DepartmentManager(app_cfg.departments) if app_cfg.departments else None
DepartmentManager(app_cfg) if app_cfg.departments else None
)
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, float]]:

View File

@ -272,7 +272,8 @@ def run_llm_with_config(
def llm_config_snapshot() -> Dict[str, object]:
"""Return a sanitized snapshot of current LLM configuration for debugging."""
settings = get_config().llm
cfg = get_config()
settings = cfg.llm
primary = asdict(settings.primary)
if primary.get("api_key"):
primary["api_key"] = "***"
@ -282,7 +283,11 @@ def llm_config_snapshot() -> Dict[str, object]:
if record.get("api_key"):
record["api_key"] = "***"
ensemble.append(record)
route_name = cfg.llm_route
route_obj = cfg.llm_routes.get(route_name)
return {
"route": route_name,
"route_detail": route_obj.to_dict() if route_obj else None,
"strategy": settings.strategy,
"majority_threshold": settings.majority_threshold,
"primary": primary,

View File

@ -29,7 +29,8 @@ from app.utils.config import (
DEFAULT_LLM_MODEL_OPTIONS,
DEFAULT_LLM_MODELS,
DepartmentSettings,
LLMEndpoint,
LLMProfile,
LLMRoute,
get_config,
save_config,
)
@ -349,208 +350,253 @@ def render_settings() -> None:
st.divider()
st.subheader("LLM 设置")
llm_cfg = cfg.llm
primary = llm_cfg.primary
providers = sorted(DEFAULT_LLM_MODELS.keys())
profiles = cfg.llm_profiles or {}
routes = cfg.llm_routes or {}
profile_keys = sorted(profiles.keys())
route_keys = sorted(routes.keys())
used_routes = {
dept.llm_route for dept in cfg.departments.values() if dept.llm_route
}
st.caption("Profile 定义单个模型终端Route 负责组合 Profile 与推理策略。")
route_select_col, route_manage_col = st.columns([3, 1])
if route_keys:
try:
provider_index = providers.index((primary.provider or "ollama").lower())
active_index = route_keys.index(cfg.llm_route)
except ValueError:
provider_index = 0
selected_provider = st.selectbox("LLM Provider", providers, index=provider_index)
provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {})
model_options = provider_info.get("models", [])
custom_model_label = "自定义模型"
default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
if model_options:
options_with_custom = model_options + [custom_model_label]
if primary.provider == selected_provider and primary.model in model_options:
model_index = options_with_custom.index(primary.model)
active_index = 0
selected_route = route_select_col.selectbox(
"全局路由",
route_keys,
index=active_index,
key="llm_route_select",
)
else:
model_index = 0
selected_model_option = st.selectbox(
"LLM 模型",
options_with_custom,
index=model_index,
help=f"可选模型:{', '.join(model_options)}",
)
if selected_model_option == custom_model_label:
custom_model_value = st.text_input(
"自定义模型名称",
value="" if primary.provider != selected_provider or primary.model in model_options else primary.model,
)
chosen_model = custom_model_value.strip() or default_model_hint
else:
chosen_model = selected_model_option
else:
chosen_model = st.text_input(
"LLM 模型",
value=primary.model or default_model_hint,
help="未预设该 Provider 的模型列表,请手动填写",
).strip() or default_model_hint
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
provider_default_temp = float(provider_info.get("temperature", 0.2))
provider_default_timeout = int(provider_info.get("timeout", 30.0))
selected_route = None
route_select_col.info("尚未配置路由,请先创建。")
if primary.provider == selected_provider:
base_value = primary.base_url or default_base_hint or ""
temp_value = float(primary.temperature)
timeout_value = int(primary.timeout)
new_route_name = route_manage_col.text_input("新增路由", key="new_route_name")
if route_manage_col.button("添加路由"):
key = (new_route_name or "").strip()
if not key:
st.warning("请输入有效的路由名称。")
elif key in routes:
st.warning(f"路由 {key} 已存在。")
else:
base_value = default_base_hint or ""
temp_value = provider_default_temp
timeout_value = provider_default_timeout
routes[key] = LLMRoute(name=key)
if not selected_route:
selected_route = key
cfg.llm_route = key
save_config()
st.success(f"已添加路由 {key},请继续配置。")
st.experimental_rerun()
llm_base = st.text_input(
"LLM Base URL",
value=base_value,
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
)
llm_api_key = st.text_input(
"LLM API Key",
value=primary.api_key or "",
type="password",
help="点击右侧小图标可查看当前 Key该值会写入 config.json已被 gitignore 排除)",
)
llm_temperature = st.slider(
"LLM 温度",
min_value=0.0,
max_value=2.0,
value=temp_value,
step=0.05,
)
llm_timeout = st.number_input(
"请求超时时间 (秒)",
min_value=5,
max_value=120,
value=timeout_value,
step=5,
)
strategy_options = ["single", "majority", "leader"]
if selected_route:
route_obj = routes.get(selected_route)
if route_obj is None:
route_obj = LLMRoute(name=selected_route)
routes[selected_route] = route_obj
strategy_choices = sorted(ALLOWED_LLM_STRATEGIES)
try:
strategy_index = strategy_options.index(llm_cfg.strategy)
strategy_index = strategy_choices.index(route_obj.strategy)
except ValueError:
strategy_index = 0
selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index)
majority_threshold = st.number_input(
route_title = st.text_input(
"路由说明",
value=route_obj.title or "",
key=f"route_title_{selected_route}",
)
route_strategy = st.selectbox(
"推理策略",
strategy_choices,
index=strategy_index,
key=f"route_strategy_{selected_route}",
)
route_majority = st.number_input(
"多数投票门槛",
min_value=1,
max_value=10,
value=int(llm_cfg.majority_threshold),
value=int(route_obj.majority_threshold or 1),
step=1,
format="%d",
key=f"route_majority_{selected_route}",
)
existing_api_keys = {ep.provider: ep.api_key or None for ep in llm_cfg.ensemble}
available_providers = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
ensemble_rows = [
{
"provider": ep.provider or "",
"model": ep.model or DEFAULT_LLM_MODELS.get(ep.provider, DEFAULT_LLM_MODELS["ollama"]),
"base_url": ep.base_url or DEFAULT_LLM_BASE_URLS.get(ep.provider, ""),
"api_key": "***" if ep.api_key else "",
"temperature": float(ep.temperature),
"timeout": float(ep.timeout),
}
for ep in llm_cfg.ensemble
] or [
{
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": provider_default_temp,
"timeout": provider_default_timeout,
}
if not profile_keys:
st.warning("暂无可用 Profile请先在下方创建。")
else:
try:
primary_index = profile_keys.index(route_obj.primary)
except ValueError:
primary_index = 0
primary_key = st.selectbox(
"主用 Profile",
profile_keys,
index=primary_index,
key=f"route_primary_{selected_route}",
)
default_ensemble = [
key for key in route_obj.ensemble if key in profile_keys and key != primary_key
]
edited = st.data_editor(
ensemble_rows,
num_rows="dynamic",
key="llm_ensemble_editor",
column_config={
"provider": st.column_config.SelectboxColumn(
"Provider",
options=available_providers,
help="选择 LLM 供应商"
),
"model": st.column_config.TextColumn("模型", help="留空时使用该 Provider 的默认模型"),
"base_url": st.column_config.TextColumn("Base URL", help="留空时使用默认地址"),
"api_key": st.column_config.TextColumn("API Key", help="留空表示使用环境变量或不配置"),
"temperature": st.column_config.NumberColumn("温度", min_value=0.0, max_value=2.0, step=0.05),
"timeout": st.column_config.NumberColumn("超时(秒)", min_value=5.0, max_value=120.0, step=5.0),
},
hide_index=True,
use_container_width=True,
ensemble_keys = st.multiselect(
"协作 Profile (可多选)",
profile_keys,
default=default_ensemble,
key=f"route_ensemble_{selected_route}",
)
if hasattr(edited, "to_dict"):
ensemble_rows = edited.to_dict("records")
else:
ensemble_rows = edited
if st.button("保存 LLM 设置"):
primary.provider = selected_provider
primary.model = chosen_model
primary.base_url = llm_base.strip() or DEFAULT_LLM_BASE_URLS.get(selected_provider)
primary.temperature = llm_temperature
primary.timeout = llm_timeout
api_key_value = llm_api_key.strip()
if api_key_value:
primary.api_key = api_key_value
new_ensemble: List[LLMEndpoint] = []
for row in ensemble_rows:
provider = (row.get("provider") or "").strip().lower()
if not provider:
continue
provider_defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider, {})
default_model = DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
default_base = DEFAULT_LLM_BASE_URLS.get(provider)
temp_default = float(provider_defaults.get("temperature", 0.2))
timeout_default = float(provider_defaults.get("timeout", 30.0))
model_val = (row.get("model") or "").strip() or default_model
base_val = (row.get("base_url") or "").strip() or default_base
api_raw = (row.get("api_key") or "").strip()
if api_raw == "***":
api_value = existing_api_keys.get(provider)
else:
api_value = api_raw or None
temp_val = row.get("temperature")
timeout_val = row.get("timeout")
endpoint = LLMEndpoint(
provider=provider,
model=model_val,
base_url=base_val,
api_key=api_value,
temperature=float(temp_val) if temp_val is not None else temp_default,
timeout=float(timeout_val) if timeout_val is not None else timeout_default,
)
new_ensemble.append(endpoint)
llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy
llm_cfg.majority_threshold = int(majority_threshold)
if st.button("保存路由设置", key=f"save_route_{selected_route}"):
route_obj.title = route_title.strip()
route_obj.strategy = route_strategy
route_obj.majority_threshold = int(route_majority)
route_obj.primary = primary_key
route_obj.ensemble = [key for key in ensemble_keys if key != primary_key]
cfg.llm_route = selected_route
cfg.sync_runtime_llm()
save_config()
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
st.success("LLM 设置已保存,仅在当前会话生效。")
st.json(llm_config_snapshot())
LOGGER.info(
"路由 %s 配置更新:%s",
selected_route,
route_obj.to_dict(),
extra=LOG_EXTRA,
)
st.success("路由配置已保存。")
st.json({
"route": selected_route,
"route_detail": route_obj.to_dict(),
"resolved": llm_config_snapshot(),
})
route_in_use = selected_route in used_routes or selected_route == cfg.llm_route
if st.button(
"删除当前路由",
key=f"delete_route_{selected_route}",
disabled=route_in_use or len(routes) <= 1,
):
routes.pop(selected_route, None)
if cfg.llm_route == selected_route:
cfg.llm_route = next((key for key in routes.keys()), "global")
cfg.sync_runtime_llm()
save_config()
st.success("路由已删除。")
st.experimental_rerun()
st.divider()
st.subheader("LLM Profile 管理")
profile_select_col, profile_manage_col = st.columns([3, 1])
if profile_keys:
selected_profile = profile_select_col.selectbox(
"选择 Profile",
profile_keys,
index=0,
key="profile_select",
)
else:
selected_profile = None
profile_select_col.info("尚未配置 Profile请先创建。")
new_profile_name = profile_manage_col.text_input("新增 Profile", key="new_profile_name")
if profile_manage_col.button("创建 Profile"):
key = (new_profile_name or "").strip()
if not key:
st.warning("请输入有效的 Profile 名称。")
elif key in profiles:
st.warning(f"Profile {key} 已存在。")
else:
profiles[key] = LLMProfile(key=key)
save_config()
st.success(f"已创建 Profile {key}")
st.experimental_rerun()
if selected_profile:
profile = profiles[selected_profile]
provider_choices = sorted(DEFAULT_LLM_MODEL_OPTIONS.keys())
try:
provider_index = provider_choices.index(profile.provider)
except ValueError:
provider_index = 0
with st.form(f"profile_form_{selected_profile}"):
provider_val = st.selectbox(
"Provider",
provider_choices,
index=provider_index,
)
model_default = DEFAULT_LLM_MODELS.get(provider_val, profile.model or "")
model_val = st.text_input(
"模型",
value=profile.model or model_default,
)
base_default = DEFAULT_LLM_BASE_URLS.get(provider_val, profile.base_url or "")
base_val = st.text_input(
"Base URL",
value=profile.base_url or base_default,
)
api_val = st.text_input(
"API Key",
value=profile.api_key or "",
type="password",
)
temp_val = st.slider(
"温度",
min_value=0.0,
max_value=2.0,
value=float(profile.temperature),
step=0.05,
)
timeout_val = st.number_input(
"超时(秒)",
min_value=5,
max_value=180,
value=int(profile.timeout or 30),
step=5,
)
title_val = st.text_input("备注", value=profile.title or "")
enabled_val = st.checkbox("启用", value=profile.enabled)
submitted = st.form_submit_button("保存 Profile")
if submitted:
profile.provider = provider_val
profile.model = model_val.strip() or DEFAULT_LLM_MODELS.get(provider_val)
profile.base_url = base_val.strip() or DEFAULT_LLM_BASE_URLS.get(provider_val)
profile.api_key = api_val.strip() or None
profile.temperature = temp_val
profile.timeout = timeout_val
profile.title = title_val.strip()
profile.enabled = enabled_val
profiles[selected_profile] = profile
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已保存。")
profile_in_use = any(
selected_profile == route.primary or selected_profile in route.ensemble
for route in routes.values()
)
if st.button(
"删除该 Profile",
key=f"delete_profile_{selected_profile}",
disabled=profile_in_use or len(profiles) <= 1,
):
profiles.pop(selected_profile, None)
fallback_key = next((key for key in profiles.keys()), None)
for route in routes.values():
if route.primary == selected_profile:
route.primary = fallback_key or route.primary
route.ensemble = [key for key in route.ensemble if key != selected_profile]
cfg.sync_runtime_llm()
save_config()
st.success("Profile 已删除。")
st.experimental_rerun()
st.divider()
st.subheader("部门配置")
dept_settings = cfg.departments or {}
route_options_display = [""] + route_keys
dept_rows = [
{
"code": code,
"title": dept.title,
"description": dept.description,
"weight": float(dept.weight),
"llm_route": dept.llm_route or "",
"strategy": dept.llm.strategy,
"primary_provider": (dept.llm.primary.provider or "ollama"),
"primary_provider": (dept.llm.primary.provider or ""),
"primary_model": dept.llm.primary.model or "",
"ensemble_size": len(dept.llm.ensemble),
}
@ -572,20 +618,25 @@ def render_settings() -> None:
"title": st.column_config.TextColumn("名称"),
"description": st.column_config.TextColumn("说明"),
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
"llm_route": st.column_config.SelectboxColumn(
"路由",
options=route_options_display,
help="选择预定义路由;留空表示使用自定义配置",
),
"strategy": st.column_config.SelectboxColumn(
"策略",
"自定义策略",
options=sorted(ALLOWED_LLM_STRATEGIES),
help="single=单模型, majority=多数投票, leader=顾问-决策者模式",
help="仅当未选择路由时生效",
),
"primary_provider": st.column_config.SelectboxColumn(
"主模型 Provider",
"自定义 Provider",
options=sorted(DEFAULT_LLM_MODEL_OPTIONS.keys()),
),
"primary_model": st.column_config.TextColumn("主模型名称"),
"primary_model": st.column_config.TextColumn("自定义模型"),
"ensemble_size": st.column_config.NumberColumn(
"协作模型数量",
disabled=True,
help="在 config.json 中编辑 ensemble 详情",
help="路由模式下自动维护",
),
},
)
@ -609,31 +660,32 @@ def render_settings() -> None:
try:
existing.weight = max(0.0, float(row.get("weight", existing.weight)))
except (TypeError, ValueError):
existing.weight = existing.weight
pass
route_name = (row.get("llm_route") or "").strip() or None
existing.llm_route = route_name
if route_name and route_name in routes:
existing.llm = routes[route_name].resolve(profiles)
else:
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
if strategy_val in ALLOWED_LLM_STRATEGIES:
existing.llm.strategy = strategy_val
provider_before = existing.llm.primary.provider or ""
provider_val = (row.get("primary_provider") or provider_before or "ollama").lower()
existing.llm.primary.provider = provider_val
model_val = (row.get("primary_model") or "").strip()
if model_val:
existing.llm.primary.model = model_val
else:
existing.llm.primary.model = DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
existing.llm.primary.model = (
model_val or DEFAULT_LLM_MODELS.get(provider_val, existing.llm.primary.model)
)
if provider_before != provider_val:
default_base = DEFAULT_LLM_BASE_URLS.get(provider_val)
existing.llm.primary.base_url = default_base or existing.llm.primary.base_url
existing.llm.primary.__post_init__()
updated_departments[code] = existing
if updated_departments:
cfg.departments = updated_departments
cfg.sync_runtime_llm()
save_config()
st.success("部门配置已更新。")
else:
@ -643,10 +695,12 @@ def render_settings() -> None:
from app.utils.config import _default_departments
cfg.departments = _default_departments()
cfg.sync_runtime_llm()
save_config()
st.success("已恢复默认部门配置。")
st.experimental_rerun()
st.caption("选择路由可统一部门模型调用,自定义模式仍支持逐项配置。")
st.caption("部门协作模型ensemble请在 config.json 中手动编辑UI 将在后续版本补充。")

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Mapping, Optional
def _default_root() -> Path:
@ -132,6 +132,135 @@ class LLMConfig:
majority_threshold: int = 3
@dataclass
class LLMProfile:
"""Named LLM endpoint profile reusable across routes/departments."""
key: str
provider: str = "ollama"
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
temperature: float = 0.2
timeout: float = 30.0
title: str = ""
enabled: bool = True
def to_endpoint(self) -> LLMEndpoint:
return LLMEndpoint(
provider=self.provider,
model=self.model,
base_url=self.base_url,
api_key=self.api_key,
temperature=self.temperature,
timeout=self.timeout,
)
def to_dict(self) -> Dict[str, object]:
return {
"provider": self.provider,
"model": self.model,
"base_url": self.base_url,
"api_key": self.api_key,
"temperature": self.temperature,
"timeout": self.timeout,
"title": self.title,
"enabled": self.enabled,
}
@classmethod
def from_endpoint(
cls,
key: str,
endpoint: LLMEndpoint,
*,
title: str = "",
enabled: bool = True,
) -> "LLMProfile":
return cls(
key=key,
provider=endpoint.provider,
model=endpoint.model,
base_url=endpoint.base_url,
api_key=endpoint.api_key,
temperature=endpoint.temperature,
timeout=endpoint.timeout,
title=title,
enabled=enabled,
)
@dataclass
class LLMRoute:
"""Declarative routing for selecting profiles and strategy."""
name: str
title: str = ""
strategy: str = "single"
majority_threshold: int = 3
primary: str = "ollama"
ensemble: List[str] = field(default_factory=list)
def resolve(self, profiles: Mapping[str, LLMProfile]) -> LLMConfig:
def _endpoint_from_key(key: str) -> LLMEndpoint:
profile = profiles.get(key)
if profile and profile.enabled:
return profile.to_endpoint()
fallback = profiles.get("ollama")
if not fallback or not fallback.enabled:
fallback = next(
(item for item in profiles.values() if item.enabled),
None,
)
endpoint = fallback.to_endpoint() if fallback else LLMEndpoint()
endpoint.provider = key or endpoint.provider
return endpoint
primary_endpoint = _endpoint_from_key(self.primary)
ensemble_endpoints = [
_endpoint_from_key(key)
for key in self.ensemble
if key in profiles and profiles[key].enabled
]
config = LLMConfig(
primary=primary_endpoint,
ensemble=ensemble_endpoints,
strategy=self.strategy if self.strategy in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, self.majority_threshold or 1),
)
return config
def to_dict(self) -> Dict[str, object]:
return {
"title": self.title,
"strategy": self.strategy,
"majority_threshold": self.majority_threshold,
"primary": self.primary,
"ensemble": list(self.ensemble),
}
def _default_llm_profiles() -> Dict[str, LLMProfile]:
return {
provider: LLMProfile(
key=provider,
provider=provider,
model=DEFAULT_LLM_MODELS.get(provider),
base_url=DEFAULT_LLM_BASE_URLS.get(provider),
temperature=DEFAULT_LLM_TEMPERATURES.get(provider, 0.2),
timeout=DEFAULT_LLM_TIMEOUTS.get(provider, 30.0),
title=f"默认 {provider}",
)
for provider in DEFAULT_LLM_MODEL_OPTIONS
}
def _default_llm_routes() -> Dict[str, LLMRoute]:
return {
"global": LLMRoute(name="global", title="全局默认路由"),
}
@dataclass
class DepartmentSettings:
"""Configuration for a single decision department."""
@ -141,6 +270,7 @@ class DepartmentSettings:
description: str = ""
weight: float = 1.0
llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: Optional[str] = None
def _default_departments() -> Dict[str, DepartmentSettings]:
@ -153,7 +283,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
("risk", "风险控制部门"),
]
return {
code: DepartmentSettings(code=code, title=title)
code: DepartmentSettings(code=code, title=title, llm_route="global")
for code, title in presets
}
@ -169,8 +299,21 @@ class AppConfig:
agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
llm: LLMConfig = field(default_factory=LLMConfig)
llm_route: str = "global"
llm_profiles: Dict[str, LLMProfile] = field(default_factory=_default_llm_profiles)
llm_routes: Dict[str, LLMRoute] = field(default_factory=_default_llm_routes)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
def resolve_llm(self, route: Optional[str] = None) -> LLMConfig:
route_key = route or self.llm_route
route_cfg = self.llm_routes.get(route_key)
if route_cfg:
return route_cfg.resolve(self.llm_profiles)
return self.llm
def sync_runtime_llm(self) -> None:
self.llm = self.resolve_llm()
CONFIG = AppConfig()
@ -213,11 +356,69 @@ def _load_from_file(cfg: AppConfig) -> None:
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
routes_defined = False
inline_primary_loaded = False
profiles_payload = payload.get("llm_profiles")
if isinstance(profiles_payload, dict):
profiles: Dict[str, LLMProfile] = {}
for key, data in profiles_payload.items():
if not isinstance(data, dict):
continue
provider = str(data.get("provider") or "ollama").lower()
profile = LLMProfile(
key=key,
provider=provider,
model=data.get("model"),
base_url=data.get("base_url"),
api_key=data.get("api_key"),
temperature=float(data.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
timeout=float(data.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
title=str(data.get("title") or ""),
enabled=bool(data.get("enabled", True)),
)
profiles[key] = profile
if profiles:
cfg.llm_profiles = profiles
routes_payload = payload.get("llm_routes")
if isinstance(routes_payload, dict):
routes: Dict[str, LLMRoute] = {}
for name, data in routes_payload.items():
if not isinstance(data, dict):
continue
strategy_raw = str(data.get("strategy") or "single").lower()
normalized = LLM_STRATEGY_ALIASES.get(strategy_raw, strategy_raw)
route = LLMRoute(
name=name,
title=str(data.get("title") or ""),
strategy=normalized if normalized in ALLOWED_LLM_STRATEGIES else "single",
majority_threshold=max(1, int(data.get("majority_threshold", 3) or 3)),
primary=str(data.get("primary") or "global"),
ensemble=[
str(item)
for item in data.get("ensemble", [])
if isinstance(item, str)
],
)
routes[name] = route
if routes:
cfg.llm_routes = routes
routes_defined = True
route_key = payload.get("llm_route")
if isinstance(route_key, str) and route_key:
cfg.llm_route = route_key
llm_payload = payload.get("llm")
if isinstance(llm_payload, dict):
route_value = llm_payload.get("route")
if isinstance(route_value, str) and route_value:
cfg.llm_route = route_value
primary_data = llm_payload.get("primary")
if isinstance(primary_data, dict):
cfg.llm.primary = _dict_to_endpoint(primary_data)
inline_primary_loaded = True
ensemble_data = llm_payload.get("ensemble")
if isinstance(ensemble_data, list):
@ -237,6 +438,30 @@ def _load_from_file(cfg: AppConfig) -> None:
if isinstance(majority, int) and majority > 0:
cfg.llm.majority_threshold = majority
if inline_primary_loaded and not routes_defined:
primary_key = "inline_global_primary"
cfg.llm_profiles[primary_key] = LLMProfile.from_endpoint(
primary_key,
cfg.llm.primary,
title="全局主模型",
)
ensemble_keys: List[str] = []
for idx, endpoint in enumerate(cfg.llm.ensemble, start=1):
inline_key = f"inline_global_ensemble_{idx}"
cfg.llm_profiles[inline_key] = LLMProfile.from_endpoint(
inline_key,
endpoint,
title=f"全局协作#{idx}",
)
ensemble_keys.append(inline_key)
auto_route = cfg.llm_routes.get("global") or LLMRoute(name="global", title="全局默认路由")
auto_route.strategy = cfg.llm.strategy
auto_route.majority_threshold = cfg.llm.majority_threshold
auto_route.primary = primary_key
auto_route.ensemble = ensemble_keys
cfg.llm_routes["global"] = auto_route
cfg.llm_route = cfg.llm_route or "global"
departments_payload = payload.get("departments")
if isinstance(departments_payload, dict):
new_departments: Dict[str, DepartmentSettings] = {}
@ -264,35 +489,55 @@ def _load_from_file(cfg: AppConfig) -> None:
majority_raw = llm_data.get("majority_threshold")
if isinstance(majority_raw, int) and majority_raw > 0:
llm_cfg.majority_threshold = majority_raw
route = data.get("llm_route")
route_name = str(route).strip() if isinstance(route, str) and route else None
resolved = llm_cfg
if route_name and route_name in cfg.llm_routes:
resolved = cfg.llm_routes[route_name].resolve(cfg.llm_profiles)
new_departments[code] = DepartmentSettings(
code=code,
title=title,
description=description,
weight=weight,
llm=llm_cfg,
llm=resolved,
llm_route=route_name,
)
if new_departments:
cfg.departments = new_departments
cfg.sync_runtime_llm()
def save_config(cfg: AppConfig | None = None) -> None:
cfg = cfg or CONFIG
cfg.sync_runtime_llm()
path = cfg.data_paths.config_file
payload = {
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method,
"llm_route": cfg.llm_route,
"llm": {
"route": cfg.llm_route,
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": cfg.llm.majority_threshold,
"primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
},
"llm_profiles": {
key: profile.to_dict()
for key, profile in cfg.llm_profiles.items()
},
"llm_routes": {
name: route.to_dict()
for name, route in cfg.llm_routes.items()
},
"departments": {
code: {
"title": dept.title,
"description": dept.description,
"weight": dept.weight,
"llm_route": dept.llm_route,
"llm": {
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": dept.llm.majority_threshold,
@ -320,7 +565,15 @@ def _load_env_defaults(cfg: AppConfig) -> None:
api_key = os.getenv("LLM_API_KEY")
if api_key:
cfg.llm.primary.api_key = api_key.strip()
sanitized = api_key.strip()
cfg.llm.primary.api_key = sanitized
route = cfg.llm_routes.get(cfg.llm_route)
if route:
profile = cfg.llm_profiles.get(route.primary)
if profile:
profile.api_key = sanitized
cfg.sync_runtime_llm()
_load_from_file(CONFIG)