diff --git a/README.md b/README.md index bf809ff..ac949e1 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,10 @@ export TUSHARE_TOKEN="" ### LLM 配置与测试 -- 默认使用本地 Ollama(`http://localhost:11434`),可在 Streamlit 的 “数据与设置” 页签切换到 OpenAI 兼容接口。 +- 支持本地 Ollama(`http://localhost:11434`)与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 Streamlit 的 “数据与设置” 页签切换 Provider 并配置模型、Base URL、API Key。不同 Provider 默认映射的模型示例:Ollama → `llama3`,OpenAI → `gpt-4o-mini`,DeepSeek → `deepseek-chat`,文心一言 → `ERNIE-Speed`。 - 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。 - 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。 +- 未来可对同一功能的智能体并行调用多个 LLM,采用多数投票等策略增强鲁棒性,当前代码结构已为此预留扩展空间。 ## 快速开始 diff --git a/app/llm/client.py b/app/llm/client.py index 0114fdb..fef7277 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -2,18 +2,18 @@ from __future__ import annotations import json +from collections import Counter from dataclasses import asdict from typing import Dict, Iterable, List, Optional import requests -from app.utils.config import get_config +from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config from app.utils.logging import get_logger LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "llm"} - class LLMError(RuntimeError): """Raised when LLM provider returns an error response.""" @@ -21,9 +21,18 @@ class LLMError(RuntimeError): def _default_base_url(provider: str) -> str: if provider == "ollama": return "http://localhost:11434" + if provider == "deepseek": + return "https://api.deepseek.com" + if provider == "wenxin": + return "https://aip.baidubce.com" return "https://api.openai.com" +def _default_model(provider: str) -> str: + provider = (provider or "").lower() + return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"]) + + def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]: messages: List[Dict[str, str]] = [] if system: @@ -32,7 +41,15 @@ def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, return messages -def _request_ollama(model: str, prompt: str, *, base_url: str, temperature: float, timeout: float, system: Optional[str]) -> str: +def _request_ollama( + model: str, + prompt: str, + *, + base_url: str, + temperature: float, + timeout: float, + system: Optional[str], +) -> str: url = f"{base_url.rstrip('/')}/api/chat" payload = { "model": model, @@ -52,7 +69,16 @@ def _request_ollama(model: str, prompt: str, *, base_url: str, temperature: floa return str(content) -def _request_openai(model: str, prompt: str, *, base_url: str, api_key: str, temperature: float, timeout: float, system: Optional[str]) -> str: +def _request_openai( + model: str, + prompt: str, + *, + base_url: str, + api_key: str, + temperature: float, + timeout: float, + system: Optional[str], +) -> str: url = f"{base_url.rstrip('/')}/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", @@ -74,28 +100,30 @@ def _request_openai(model: str, prompt: str, *, base_url: str, api_key: str, tem raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc -def run_llm(prompt: str, *, system: Optional[str] = None) -> str: - """Execute the configured LLM provider with the given prompt.""" - - cfg = get_config().llm - provider = (cfg.provider or "ollama").lower() - base_url = cfg.base_url or _default_base_url(provider) - model = cfg.model - temperature = max(0.0, min(cfg.temperature, 2.0)) - timeout = max(5.0, cfg.timeout or 30.0) +def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: + provider = (endpoint.provider or "ollama").lower() + base_url = endpoint.base_url or _default_base_url(provider) + model = endpoint.model or _default_model(provider) + temperature = max(0.0, min(endpoint.temperature, 2.0)) + timeout = max(5.0, endpoint.timeout or 30.0) LOGGER.info( - "触发 LLM 请求:provider=%s model=%s base=%s", provider, model, base_url, extra=LOG_EXTRA + "触发 LLM 请求:provider=%s model=%s base=%s", + provider, + model, + base_url, + extra=LOG_EXTRA, ) - if provider == "openai": - if not cfg.api_key: - raise LLMError("缺少 OpenAI 兼容 API Key") + if provider in {"openai", "deepseek", "wenxin"}: + api_key = endpoint.api_key + if not api_key: + raise LLMError(f"缺少 {provider} API Key (model={model})") return _request_openai( model, prompt, base_url=base_url, - api_key=cfg.api_key, + api_key=api_key, temperature=temperature, timeout=timeout, system=system, @@ -109,14 +137,89 @@ def run_llm(prompt: str, *, system: Optional[str] = None) -> str: timeout=timeout, system=system, ) - raise LLMError(f"不支持的 LLM provider: {cfg.provider}") + raise LLMError(f"不支持的 LLM provider: {endpoint.provider}") + + +def _normalize_response(text: str) -> str: + return " ".join(text.strip().split()) + + +def run_llm(prompt: str, *, system: Optional[str] = None) -> str: + """Execute the configured LLM strategy with the given prompt.""" + + settings = get_config().llm + if settings.strategy == "majority": + return _run_majority_vote(settings, prompt, system) + return _call_endpoint(settings.primary, prompt, system) + + +def _run_majority_vote(config, prompt: str, system: Optional[str]) -> str: + endpoints: List[LLMEndpoint] = [config.primary] + list(config.ensemble) + responses: List[Dict[str, str]] = [] + failures: List[str] = [] + + for idx, endpoint in enumerate(endpoints, start=1): + try: + result = _call_endpoint(endpoint, prompt, system) + responses.append({ + "provider": endpoint.provider, + "model": endpoint.model, + "raw": result, + "normalized": _normalize_response(result), + }) + except Exception as exc: # noqa: BLE001 + summary = f"{endpoint.provider}:{endpoint.model} -> {exc}" + failures.append(summary) + LOGGER.warning("LLM 调用失败:%s", summary, extra=LOG_EXTRA) + + if not responses: + raise LLMError("所有 LLM 调用均失败,无法返回结果。") + + threshold = max(1, config.majority_threshold) + threshold = min(threshold, len(responses)) + + counter = Counter(item["normalized"] for item in responses) + top_value, top_count = counter.most_common(1)[0] + if top_count >= threshold: + chosen_raw = next(item["raw"] for item in responses if item["normalized"] == top_value) + LOGGER.info( + "LLM 多模型投票通过:value=%s votes=%s/%s threshold=%s", + top_value[:80], + top_count, + len(responses), + threshold, + extra=LOG_EXTRA, + ) + return chosen_raw + + LOGGER.info( + "LLM 多模型投票未达门槛:votes=%s/%s threshold=%s,返回首个结果", + top_count, + len(responses), + threshold, + extra=LOG_EXTRA, + ) + if failures: + LOGGER.warning("LLM 调用失败列表:%s", failures, extra=LOG_EXTRA) + return responses[0]["raw"] def llm_config_snapshot() -> Dict[str, object]: """Return a sanitized snapshot of current LLM configuration for debugging.""" - cfg = get_config().llm - data = asdict(cfg) - if data.get("api_key"): - data["api_key"] = "***" - return data + settings = get_config().llm + primary = asdict(settings.primary) + if primary.get("api_key"): + primary["api_key"] = "***" + ensemble = [] + for endpoint in settings.ensemble: + record = asdict(endpoint) + if record.get("api_key"): + record["api_key"] = "***" + ensemble.append(record) + return { + "strategy": settings.strategy, + "majority_threshold": settings.majority_threshold, + "primary": primary, + "ensemble": ensemble, + } diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index ed2451e..a4edec4 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -1,9 +1,12 @@ """Streamlit UI scaffold for the investment assistant.""" from __future__ import annotations +import json import sys +from dataclasses import asdict from datetime import date, timedelta from pathlib import Path +from typing import List ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: @@ -20,7 +23,7 @@ from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion from app.llm.client import llm_config_snapshot, run_llm from app.llm.explain import make_human_card -from app.utils.config import get_config +from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config from app.utils.db import db_session from app.utils.logging import get_logger @@ -194,28 +197,98 @@ def render_settings() -> None: st.divider() st.subheader("LLM 设置") llm_cfg = cfg.llm + primary = llm_cfg.primary providers = ["ollama", "openai"] try: - provider_index = providers.index((llm_cfg.provider or "ollama").lower()) + provider_index = providers.index((primary.provider or "ollama").lower()) except ValueError: provider_index = 0 selected_provider = st.selectbox("LLM Provider", providers, index=provider_index) - llm_model = st.text_input("LLM 模型", value=llm_cfg.model) - llm_base = st.text_input("LLM Base URL (可选)", value=llm_cfg.base_url or "") - llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=llm_cfg.api_key or "", type="password") - llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(llm_cfg.temperature), step=0.05) - llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(llm_cfg.timeout), step=5.0) + default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) + llm_model = st.text_input("LLM 模型", value=primary.model, help=f"默认推荐:{default_model_hint}") + base_hints = { + "ollama": "http://localhost:11434", + "openai": "https://api.openai.com", + "deepseek": "https://api.deepseek.com", + "wenxin": "https://aip.baidubce.com", + } + default_base_hint = base_hints.get(selected_provider, "") + llm_base = st.text_input("LLM Base URL (可选)", value=primary.base_url or "", help=f"默认推荐:{default_base_hint or '按供应商要求填写'}") + llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password") + llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(primary.temperature), step=0.05) + llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(primary.timeout), step=5.0, format="%d") + + strategy_options = ["single", "majority"] + try: + strategy_index = strategy_options.index(llm_cfg.strategy) + except ValueError: + strategy_index = 0 + selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index) + majority_threshold = st.number_input( + "多数投票门槛", + min_value=1, + max_value=10, + value=int(llm_cfg.majority_threshold), + step=1, + format="%d", + ) + + ensemble_display = [] + for endpoint in llm_cfg.ensemble: + data = asdict(endpoint) + if data.get("api_key"): + data["api_key"] = "" + ensemble_display.append(data) + ensemble_text = st.text_area( + "LLM 集群配置 (JSON 数组)", + value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2), + height=220, + ) if st.button("保存 LLM 设置"): - llm_cfg.provider = selected_provider - llm_cfg.model = llm_model.strip() or llm_cfg.model - llm_cfg.base_url = llm_base.strip() or None - llm_cfg.api_key = llm_api_key.strip() or None - llm_cfg.temperature = llm_temperature - llm_cfg.timeout = llm_timeout - LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) - st.success("LLM 设置已保存,仅在当前会话生效。") - st.json(llm_config_snapshot()) + original_provider = primary.provider + original_model = primary.model + primary.provider = selected_provider + model_input = llm_model.strip() + if not model_input: + primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) + elif selected_provider != original_provider and model_input == original_model: + primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) + else: + primary.model = model_input + primary.base_url = llm_base.strip() or None + primary.temperature = llm_temperature + primary.timeout = llm_timeout + api_key_value = llm_api_key.strip() + primary.api_key = api_key_value or None + + try: + parsed = json.loads(ensemble_text or "[]") + if not isinstance(parsed, list): + raise ValueError("ensemble 配置必须是数组") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA) + st.error(f"LLM 集群配置解析失败:{exc}") + else: + new_ensemble: List[LLMEndpoint] = [] + invalid = False + for item in parsed: + if not isinstance(item, dict): + st.error("LLM 集群配置中的每个元素都必须是对象") + invalid = True + break + fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")} + endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")}) + if not endpoint.provider: + endpoint.provider = "ollama" + new_ensemble.append(endpoint) + if not invalid: + llm_cfg.ensemble = new_ensemble + llm_cfg.strategy = selected_strategy + llm_cfg.majority_threshold = int(majority_threshold) + LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) + st.success("LLM 设置已保存,仅在当前会话生效。") + st.json(llm_config_snapshot()) def render_tests() -> None: diff --git a/app/utils/config.py b/app/utils/config.py index c603733..804cef0 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, Optional +from typing import Dict, List, Optional def _default_root() -> Path: @@ -44,17 +44,40 @@ class AgentWeights: "A_macro": self.macro, } -@dataclass -class LLMConfig: - """Configuration for LLM providers (Ollama / OpenAI-compatible).""" +DEFAULT_LLM_MODELS: Dict[str, str] = { + "ollama": "llama3", + "openai": "gpt-4o-mini", + "deepseek": "deepseek-chat", + "wenxin": "ERNIE-Speed", +} - provider: str = "ollama" # Options: "ollama", "openai" - model: str = "llama3" - base_url: Optional[str] = None # Defaults resolved per provider + +@dataclass +class LLMEndpoint: + """Single LLM endpoint configuration.""" + + provider: str = "ollama" + model: Optional[str] = None + base_url: Optional[str] = None api_key: Optional[str] = None temperature: float = 0.2 timeout: float = 30.0 + def __post_init__(self) -> None: + self.provider = (self.provider or "ollama").lower() + if not self.model: + self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"]) + + +@dataclass +class LLMConfig: + """LLM configuration allowing single or ensemble strategies.""" + + primary: LLMEndpoint = field(default_factory=LLMEndpoint) + ensemble: List[LLMEndpoint] = field(default_factory=list) + strategy: str = "single" # Options: single, majority + majority_threshold: int = 3 + @dataclass class AppConfig: