This commit is contained in:
sam 2025-09-27 20:30:01 +08:00
parent c8e7955786
commit 7c51831615
4 changed files with 248 additions and 48 deletions

View File

@ -58,9 +58,10 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试 ### LLM 配置与测试
- 默认使用本地 Ollama`http://localhost:11434`),可在 Streamlit 的 “数据与设置” 页签切换到 OpenAI 兼容接口 - 支持本地 Ollama`http://localhost:11434`)与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 Streamlit 的 “数据与设置” 页签切换 Provider 并配置模型、Base URL、API Key。不同 Provider 默认映射的模型示例Ollama → `llama3`OpenAI → `gpt-4o-mini`DeepSeek → `deepseek-chat`,文心一言 → `ERNIE-Speed`
- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。 - 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。
- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。 - 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。
- 未来可对同一功能的智能体并行调用多个 LLM采用多数投票等策略增强鲁棒性当前代码结构已为此预留扩展空间。
## 快速开始 ## 快速开始

View File

@ -2,18 +2,18 @@
from __future__ import annotations from __future__ import annotations
import json import json
from collections import Counter
from dataclasses import asdict from dataclasses import asdict
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
import requests import requests
from app.utils.config import get_config from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "llm"} LOG_EXTRA = {"stage": "llm"}
class LLMError(RuntimeError): class LLMError(RuntimeError):
"""Raised when LLM provider returns an error response.""" """Raised when LLM provider returns an error response."""
@ -21,9 +21,18 @@ class LLMError(RuntimeError):
def _default_base_url(provider: str) -> str: def _default_base_url(provider: str) -> str:
if provider == "ollama": if provider == "ollama":
return "http://localhost:11434" return "http://localhost:11434"
if provider == "deepseek":
return "https://api.deepseek.com"
if provider == "wenxin":
return "https://aip.baidubce.com"
return "https://api.openai.com" return "https://api.openai.com"
def _default_model(provider: str) -> str:
provider = (provider or "").lower()
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]: def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = [] messages: List[Dict[str, str]] = []
if system: if system:
@ -32,7 +41,15 @@ def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str,
return messages return messages
def _request_ollama(model: str, prompt: str, *, base_url: str, temperature: float, timeout: float, system: Optional[str]) -> str: def _request_ollama(
model: str,
prompt: str,
*,
base_url: str,
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
url = f"{base_url.rstrip('/')}/api/chat" url = f"{base_url.rstrip('/')}/api/chat"
payload = { payload = {
"model": model, "model": model,
@ -52,7 +69,16 @@ def _request_ollama(model: str, prompt: str, *, base_url: str, temperature: floa
return str(content) return str(content)
def _request_openai(model: str, prompt: str, *, base_url: str, api_key: str, temperature: float, timeout: float, system: Optional[str]) -> str: def _request_openai(
model: str,
prompt: str,
*,
base_url: str,
api_key: str,
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
url = f"{base_url.rstrip('/')}/v1/chat/completions" url = f"{base_url.rstrip('/')}/v1/chat/completions"
headers = { headers = {
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
@ -74,28 +100,30 @@ def _request_openai(model: str, prompt: str, *, base_url: str, api_key: str, tem
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
def run_llm(prompt: str, *, system: Optional[str] = None) -> str: def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
"""Execute the configured LLM provider with the given prompt.""" provider = (endpoint.provider or "ollama").lower()
base_url = endpoint.base_url or _default_base_url(provider)
cfg = get_config().llm model = endpoint.model or _default_model(provider)
provider = (cfg.provider or "ollama").lower() temperature = max(0.0, min(endpoint.temperature, 2.0))
base_url = cfg.base_url or _default_base_url(provider) timeout = max(5.0, endpoint.timeout or 30.0)
model = cfg.model
temperature = max(0.0, min(cfg.temperature, 2.0))
timeout = max(5.0, cfg.timeout or 30.0)
LOGGER.info( LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s", provider, model, base_url, extra=LOG_EXTRA "触发 LLM 请求provider=%s model=%s base=%s",
provider,
model,
base_url,
extra=LOG_EXTRA,
) )
if provider == "openai": if provider in {"openai", "deepseek", "wenxin"}:
if not cfg.api_key: api_key = endpoint.api_key
raise LLMError("缺少 OpenAI 兼容 API Key") if not api_key:
raise LLMError(f"缺少 {provider} API Key (model={model})")
return _request_openai( return _request_openai(
model, model,
prompt, prompt,
base_url=base_url, base_url=base_url,
api_key=cfg.api_key, api_key=api_key,
temperature=temperature, temperature=temperature,
timeout=timeout, timeout=timeout,
system=system, system=system,
@ -109,14 +137,89 @@ def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
timeout=timeout, timeout=timeout,
system=system, system=system,
) )
raise LLMError(f"不支持的 LLM provider: {cfg.provider}") raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
def _normalize_response(text: str) -> str:
return " ".join(text.strip().split())
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
"""Execute the configured LLM strategy with the given prompt."""
settings = get_config().llm
if settings.strategy == "majority":
return _run_majority_vote(settings, prompt, system)
return _call_endpoint(settings.primary, prompt, system)
def _run_majority_vote(config, prompt: str, system: Optional[str]) -> str:
endpoints: List[LLMEndpoint] = [config.primary] + list(config.ensemble)
responses: List[Dict[str, str]] = []
failures: List[str] = []
for idx, endpoint in enumerate(endpoints, start=1):
try:
result = _call_endpoint(endpoint, prompt, system)
responses.append({
"provider": endpoint.provider,
"model": endpoint.model,
"raw": result,
"normalized": _normalize_response(result),
})
except Exception as exc: # noqa: BLE001
summary = f"{endpoint.provider}:{endpoint.model} -> {exc}"
failures.append(summary)
LOGGER.warning("LLM 调用失败:%s", summary, extra=LOG_EXTRA)
if not responses:
raise LLMError("所有 LLM 调用均失败,无法返回结果。")
threshold = max(1, config.majority_threshold)
threshold = min(threshold, len(responses))
counter = Counter(item["normalized"] for item in responses)
top_value, top_count = counter.most_common(1)[0]
if top_count >= threshold:
chosen_raw = next(item["raw"] for item in responses if item["normalized"] == top_value)
LOGGER.info(
"LLM 多模型投票通过value=%s votes=%s/%s threshold=%s",
top_value[:80],
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
return chosen_raw
LOGGER.info(
"LLM 多模型投票未达门槛votes=%s/%s threshold=%s,返回首个结果",
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
if failures:
LOGGER.warning("LLM 调用失败列表:%s", failures, extra=LOG_EXTRA)
return responses[0]["raw"]
def llm_config_snapshot() -> Dict[str, object]: def llm_config_snapshot() -> Dict[str, object]:
"""Return a sanitized snapshot of current LLM configuration for debugging.""" """Return a sanitized snapshot of current LLM configuration for debugging."""
cfg = get_config().llm settings = get_config().llm
data = asdict(cfg) primary = asdict(settings.primary)
if data.get("api_key"): if primary.get("api_key"):
data["api_key"] = "***" primary["api_key"] = "***"
return data ensemble = []
for endpoint in settings.ensemble:
record = asdict(endpoint)
if record.get("api_key"):
record["api_key"] = "***"
ensemble.append(record)
return {
"strategy": settings.strategy,
"majority_threshold": settings.majority_threshold,
"primary": primary,
"ensemble": ensemble,
}

View File

@ -1,9 +1,12 @@
"""Streamlit UI scaffold for the investment assistant.""" """Streamlit UI scaffold for the investment assistant."""
from __future__ import annotations from __future__ import annotations
import json
import sys import sys
from dataclasses import asdict
from datetime import date, timedelta from datetime import date, timedelta
from pathlib import Path from pathlib import Path
from typing import List
ROOT = Path(__file__).resolve().parents[2] ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path: if str(ROOT) not in sys.path:
@ -20,7 +23,7 @@ from app.ingest.checker import run_boot_check
from app.ingest.tushare import FetchJob, run_ingestion from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.client import llm_config_snapshot, run_llm from app.llm.client import llm_config_snapshot, run_llm
from app.llm.explain import make_human_card from app.llm.explain import make_human_card
from app.utils.config import get_config from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
@ -194,28 +197,98 @@ def render_settings() -> None:
st.divider() st.divider()
st.subheader("LLM 设置") st.subheader("LLM 设置")
llm_cfg = cfg.llm llm_cfg = cfg.llm
primary = llm_cfg.primary
providers = ["ollama", "openai"] providers = ["ollama", "openai"]
try: try:
provider_index = providers.index((llm_cfg.provider or "ollama").lower()) provider_index = providers.index((primary.provider or "ollama").lower())
except ValueError: except ValueError:
provider_index = 0 provider_index = 0
selected_provider = st.selectbox("LLM Provider", providers, index=provider_index) selected_provider = st.selectbox("LLM Provider", providers, index=provider_index)
llm_model = st.text_input("LLM 模型", value=llm_cfg.model) default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
llm_base = st.text_input("LLM Base URL (可选)", value=llm_cfg.base_url or "") llm_model = st.text_input("LLM 模型", value=primary.model, help=f"默认推荐:{default_model_hint}")
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=llm_cfg.api_key or "", type="password") base_hints = {
llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(llm_cfg.temperature), step=0.05) "ollama": "http://localhost:11434",
llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(llm_cfg.timeout), step=5.0) "openai": "https://api.openai.com",
"deepseek": "https://api.deepseek.com",
"wenxin": "https://aip.baidubce.com",
}
default_base_hint = base_hints.get(selected_provider, "")
llm_base = st.text_input("LLM Base URL (可选)", value=primary.base_url or "", help=f"默认推荐:{default_base_hint or '按供应商要求填写'}")
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password")
llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(primary.temperature), step=0.05)
llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(primary.timeout), step=5.0, format="%d")
strategy_options = ["single", "majority"]
try:
strategy_index = strategy_options.index(llm_cfg.strategy)
except ValueError:
strategy_index = 0
selected_strategy = st.selectbox("LLM 推理策略", strategy_options, index=strategy_index)
majority_threshold = st.number_input(
"多数投票门槛",
min_value=1,
max_value=10,
value=int(llm_cfg.majority_threshold),
step=1,
format="%d",
)
ensemble_display = []
for endpoint in llm_cfg.ensemble:
data = asdict(endpoint)
if data.get("api_key"):
data["api_key"] = ""
ensemble_display.append(data)
ensemble_text = st.text_area(
"LLM 集群配置 (JSON 数组)",
value=json.dumps(ensemble_display or [], ensure_ascii=False, indent=2),
height=220,
)
if st.button("保存 LLM 设置"): if st.button("保存 LLM 设置"):
llm_cfg.provider = selected_provider original_provider = primary.provider
llm_cfg.model = llm_model.strip() or llm_cfg.model original_model = primary.model
llm_cfg.base_url = llm_base.strip() or None primary.provider = selected_provider
llm_cfg.api_key = llm_api_key.strip() or None model_input = llm_model.strip()
llm_cfg.temperature = llm_temperature if not model_input:
llm_cfg.timeout = llm_timeout primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) elif selected_provider != original_provider and model_input == original_model:
st.success("LLM 设置已保存,仅在当前会话生效。") primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
st.json(llm_config_snapshot()) else:
primary.model = model_input
primary.base_url = llm_base.strip() or None
primary.temperature = llm_temperature
primary.timeout = llm_timeout
api_key_value = llm_api_key.strip()
primary.api_key = api_key_value or None
try:
parsed = json.loads(ensemble_text or "[]")
if not isinstance(parsed, list):
raise ValueError("ensemble 配置必须是数组")
except Exception as exc: # noqa: BLE001
LOGGER.exception("解析 LLM 集群配置失败", extra=LOG_EXTRA)
st.error(f"LLM 集群配置解析失败:{exc}")
else:
new_ensemble: List[LLMEndpoint] = []
invalid = False
for item in parsed:
if not isinstance(item, dict):
st.error("LLM 集群配置中的每个元素都必须是对象")
invalid = True
break
fields = {key: item.get(key) for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")}
endpoint = LLMEndpoint(**{k: v for k, v in fields.items() if v not in (None, "")})
if not endpoint.provider:
endpoint.provider = "ollama"
new_ensemble.append(endpoint)
if not invalid:
llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy
llm_cfg.majority_threshold = int(majority_threshold)
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
st.success("LLM 设置已保存,仅在当前会话生效。")
st.json(llm_config_snapshot())
def render_tests() -> None: def render_tests() -> None:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, List, Optional
def _default_root() -> Path: def _default_root() -> Path:
@ -44,17 +44,40 @@ class AgentWeights:
"A_macro": self.macro, "A_macro": self.macro,
} }
@dataclass DEFAULT_LLM_MODELS: Dict[str, str] = {
class LLMConfig: "ollama": "llama3",
"""Configuration for LLM providers (Ollama / OpenAI-compatible).""" "openai": "gpt-4o-mini",
"deepseek": "deepseek-chat",
"wenxin": "ERNIE-Speed",
}
provider: str = "ollama" # Options: "ollama", "openai"
model: str = "llama3" @dataclass
base_url: Optional[str] = None # Defaults resolved per provider class LLMEndpoint:
"""Single LLM endpoint configuration."""
provider: str = "ollama"
model: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
temperature: float = 0.2 temperature: float = 0.2
timeout: float = 30.0 timeout: float = 30.0
def __post_init__(self) -> None:
self.provider = (self.provider or "ollama").lower()
if not self.model:
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
@dataclass
class LLMConfig:
"""LLM configuration allowing single or ensemble strategies."""
primary: LLMEndpoint = field(default_factory=LLMEndpoint)
ensemble: List[LLMEndpoint] = field(default_factory=list)
strategy: str = "single" # Options: single, majority
majority_threshold: int = 3
@dataclass @dataclass
class AppConfig: class AppConfig: