This commit is contained in:
sam 2025-09-27 21:03:04 +08:00
parent 7c51831615
commit 5b4bd51199
5 changed files with 247 additions and 32 deletions

2
.gitignore vendored
View File

@ -19,6 +19,8 @@ env/
app/data/*.db* app/data/*.db*
app/data/backups/ app/data/backups/
app/data/logs/ app/data/logs/
app/data/*.json
.json
*.log *.log
# Streamlit temporary files # Streamlit temporary files

View File

@ -58,10 +58,13 @@ export TUSHARE_TOKEN="<your-token>"
### LLM 配置与测试 ### LLM 配置与测试
- 支持本地 Ollama`http://localhost:11434`)与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 Streamlit 的 “数据与设置” 页签切换 Provider 并配置模型、Base URL、API Key。不同 Provider 默认映射的模型示例Ollama → `llama3`OpenAI → `gpt-4o-mini`DeepSeek → `deepseek-chat`,文心一言 → `ERNIE-Speed` - 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载
- 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。 - 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。
- 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。 - 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。
- 未来可对同一功能的智能体并行调用多个 LLM采用多数投票等策略增强鲁棒性当前代码结构已为此预留扩展空间。 - 未来可对同一功能的智能体并行调用多个 LLM采用多数投票等策略增强鲁棒性当前代码结构已为此预留扩展空间。
- 若使用环境变量自动注入配置,可设置:
- `TUSHARE_TOKEN`
- `LLM_API_KEY`
## 快速开始 ## 快速开始

View File

@ -8,7 +8,7 @@ from typing import Dict, Iterable, List, Optional
import requests import requests
from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config from app.utils.config import DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_MODELS, LLMEndpoint, get_config
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -19,13 +19,8 @@ class LLMError(RuntimeError):
def _default_base_url(provider: str) -> str: def _default_base_url(provider: str) -> str:
if provider == "ollama": provider = (provider or "openai").lower()
return "http://localhost:11434" return DEFAULT_LLM_BASE_URLS.get(provider, DEFAULT_LLM_BASE_URLS["openai"])
if provider == "deepseek":
return "https://api.deepseek.com"
if provider == "wenxin":
return "https://aip.baidubce.com"
return "https://api.openai.com"
def _default_model(provider: str) -> str: def _default_model(provider: str) -> str:

View File

@ -23,7 +23,14 @@ from app.ingest.checker import run_boot_check
from app.ingest.tushare import FetchJob, run_ingestion from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.client import llm_config_snapshot, run_llm from app.llm.client import llm_config_snapshot, run_llm
from app.llm.explain import make_human_card from app.llm.explain import make_human_card
from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config from app.utils.config import (
DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODEL_OPTIONS,
DEFAULT_LLM_MODELS,
LLMEndpoint,
get_config,
save_config,
)
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
@ -190,6 +197,7 @@ def render_settings() -> None:
LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA) LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA)
cfg.tushare_token = token.strip() or None cfg.tushare_token = token.strip() or None
LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA) LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA)
save_config()
st.success("设置已保存,仅在当前会话生效。") st.success("设置已保存,仅在当前会话生效。")
st.write("新闻源开关与数据库备份将在此配置。") st.write("新闻源开关与数据库备份将在此配置。")
@ -198,25 +206,76 @@ def render_settings() -> None:
st.subheader("LLM 设置") st.subheader("LLM 设置")
llm_cfg = cfg.llm llm_cfg = cfg.llm
primary = llm_cfg.primary primary = llm_cfg.primary
providers = ["ollama", "openai"] providers = sorted(DEFAULT_LLM_MODELS.keys())
try: try:
provider_index = providers.index((primary.provider or "ollama").lower()) provider_index = providers.index((primary.provider or "ollama").lower())
except ValueError: except ValueError:
provider_index = 0 provider_index = 0
selected_provider = st.selectbox("LLM Provider", providers, index=provider_index) selected_provider = st.selectbox("LLM Provider", providers, index=provider_index)
provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {})
model_options = provider_info.get("models", [])
custom_model_label = "自定义模型"
default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"])
llm_model = st.text_input("LLM 模型", value=primary.model, help=f"默认推荐:{default_model_hint}")
base_hints = { if model_options:
"ollama": "http://localhost:11434", options_with_custom = model_options + [custom_model_label]
"openai": "https://api.openai.com", if primary.model in model_options:
"deepseek": "https://api.deepseek.com", model_index = options_with_custom.index(primary.model)
"wenxin": "https://aip.baidubce.com", else:
} model_index = len(options_with_custom) - 1
default_base_hint = base_hints.get(selected_provider, "") selected_model_option = st.selectbox(
llm_base = st.text_input("LLM Base URL (可选)", value=primary.base_url or "", help=f"默认推荐:{default_base_hint or '按供应商要求填写'}") "LLM 模型",
options_with_custom,
index=model_index,
help=f"可选模型:{', '.join(model_options)}",
)
if selected_model_option == custom_model_label:
custom_model_value = st.text_input(
"自定义模型名称",
value=primary.model if primary.model not in model_options else "",
)
else:
custom_model_value = selected_model_option
else:
custom_model_value = st.text_input(
"LLM 模型",
value=primary.model or default_model_hint,
help="未预设该 Provider 的模型列表,请手动填写",
)
selected_model_option = custom_model_label
default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "")
provider_default_temp = float(provider_info.get("temperature", 0.2))
provider_default_timeout = int(provider_info.get("timeout", 30.0))
if primary.provider == selected_provider:
base_value = primary.base_url or default_base_hint or ""
temp_value = float(primary.temperature)
timeout_value = int(primary.timeout)
else:
base_value = default_base_hint or ""
temp_value = provider_default_temp
timeout_value = provider_default_timeout
llm_base = st.text_input(
"LLM Base URL (可选)",
value=base_value,
help=f"默认推荐:{default_base_hint or '按供应商要求填写'}",
)
llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password") llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password")
llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(primary.temperature), step=0.05) llm_temperature = st.slider(
llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(primary.timeout), step=5.0, format="%d") "LLM 温度",
min_value=0.0,
max_value=2.0,
value=temp_value,
step=0.05,
)
llm_timeout = st.number_input(
"请求超时时间 (秒)",
min_value=5,
max_value=120,
value=timeout_value,
step=5,
)
strategy_options = ["single", "majority"] strategy_options = ["single", "majority"]
try: try:
@ -249,13 +308,18 @@ def render_settings() -> None:
original_provider = primary.provider original_provider = primary.provider
original_model = primary.model original_model = primary.model
primary.provider = selected_provider primary.provider = selected_provider
model_input = llm_model.strip() if model_options:
if not model_input: if selected_model_option == custom_model_label:
primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) model_input = custom_model_value.strip()
elif selected_provider != original_provider and model_input == original_model: primary.model = model_input or DEFAULT_LLM_MODELS.get(
primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
else: else:
primary.model = model_input primary.model = selected_model_option
else:
primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get(
selected_provider, DEFAULT_LLM_MODELS["ollama"]
)
primary.base_url = llm_base.strip() or None primary.base_url = llm_base.strip() or None
primary.temperature = llm_temperature primary.temperature = llm_temperature
primary.timeout = llm_timeout primary.timeout = llm_timeout
@ -286,6 +350,7 @@ def render_settings() -> None:
llm_cfg.ensemble = new_ensemble llm_cfg.ensemble = new_ensemble
llm_cfg.strategy = selected_strategy llm_cfg.strategy = selected_strategy
llm_cfg.majority_threshold = int(majority_threshold) llm_cfg.majority_threshold = int(majority_threshold)
save_config()
LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA)
st.success("LLM 设置已保存,仅在当前会话生效。") st.success("LLM 设置已保存,仅在当前会话生效。")
st.json(llm_config_snapshot()) st.json(llm_config_snapshot())
@ -342,6 +407,7 @@ def render_tests() -> None:
if force_refresh != cfg.force_refresh: if force_refresh != cfg.force_refresh:
cfg.force_refresh = force_refresh cfg.force_refresh = force_refresh
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
save_config()
if st.button("执行开机检查"): if st.button("执行开机检查"):
LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA) LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA)

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
import json
import os
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -17,12 +19,14 @@ class DataPaths:
root: Path = field(default_factory=_default_root) root: Path = field(default_factory=_default_root)
database: Path = field(init=False) database: Path = field(init=False)
backups: Path = field(init=False) backups: Path = field(init=False)
config_file: Path = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.root.mkdir(parents=True, exist_ok=True) self.root.mkdir(parents=True, exist_ok=True)
self.database = self.root / "llm_quant.db" self.database = self.root / "llm_quant.db"
self.backups = self.root / "backups" self.backups = self.root / "backups"
self.backups.mkdir(parents=True, exist_ok=True) self.backups.mkdir(parents=True, exist_ok=True)
self.config_file = self.root / "config.json"
@dataclass @dataclass
@ -44,11 +48,51 @@ class AgentWeights:
"A_macro": self.macro, "A_macro": self.macro,
} }
DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = {
"ollama": {
"models": ["llama3", "phi3", "qwen2"],
"base_url": "http://localhost:11434",
"temperature": 0.2,
"timeout": 30.0,
},
"openai": {
"models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"],
"base_url": "https://api.openai.com",
"temperature": 0.2,
"timeout": 30.0,
},
"deepseek": {
"models": ["deepseek-chat", "deepseek-coder"],
"base_url": "https://api.deepseek.com",
"temperature": 0.2,
"timeout": 45.0,
},
"wenxin": {
"models": ["ERNIE-Speed", "ERNIE-Bot"],
"base_url": "https://aip.baidubce.com",
"temperature": 0.2,
"timeout": 60.0,
},
}
DEFAULT_LLM_MODELS: Dict[str, str] = { DEFAULT_LLM_MODELS: Dict[str, str] = {
"ollama": "llama3", provider: info["models"][0]
"openai": "gpt-4o-mini", for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
"deepseek": "deepseek-chat", }
"wenxin": "ERNIE-Speed",
DEFAULT_LLM_BASE_URLS: Dict[str, str] = {
provider: info["base_url"]
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
DEFAULT_LLM_TEMPERATURES: Dict[str, float] = {
provider: float(info.get("temperature", 0.2))
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
}
DEFAULT_LLM_TIMEOUTS: Dict[str, float] = {
provider: float(info.get("timeout", 30.0))
for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items()
} }
@ -67,6 +111,12 @@ class LLMEndpoint:
self.provider = (self.provider or "ollama").lower() self.provider = (self.provider or "ollama").lower()
if not self.model: if not self.model:
self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"]) self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"])
if not self.base_url:
self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider)
if self.temperature == 0.2 or self.temperature is None:
self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2)
if self.timeout == 30.0 or self.timeout is None:
self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0)
@dataclass @dataclass
@ -95,6 +145,105 @@ class AppConfig:
CONFIG = AppConfig() CONFIG = AppConfig()
def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]:
return {
"provider": endpoint.provider,
"model": endpoint.model,
"base_url": endpoint.base_url,
"api_key": endpoint.api_key,
"temperature": endpoint.temperature,
"timeout": endpoint.timeout,
}
def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint:
payload = {
key: data.get(key)
for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout")
if data.get(key) is not None
}
return LLMEndpoint(**payload)
def _load_from_file(cfg: AppConfig) -> None:
path = cfg.data_paths.config_file
if not path.exists():
return
try:
with path.open("r", encoding="utf-8") as fh:
payload = json.load(fh)
except (json.JSONDecodeError, OSError):
return
if isinstance(payload, dict):
if "tushare_token" in payload:
cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
cfg.force_refresh = bool(payload.get("force_refresh"))
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
llm_payload = payload.get("llm")
if isinstance(llm_payload, dict):
primary_data = llm_payload.get("primary")
if isinstance(primary_data, dict):
cfg.llm.primary = _dict_to_endpoint(primary_data)
ensemble_data = llm_payload.get("ensemble")
if isinstance(ensemble_data, list):
cfg.llm.ensemble = [
_dict_to_endpoint(item)
for item in ensemble_data
if isinstance(item, dict)
]
strategy = llm_payload.get("strategy")
if strategy in {"single", "majority"}:
cfg.llm.strategy = strategy
majority = llm_payload.get("majority_threshold")
if isinstance(majority, int) and majority > 0:
cfg.llm.majority_threshold = majority
def save_config(cfg: AppConfig | None = None) -> None:
cfg = cfg or CONFIG
path = cfg.data_paths.config_file
payload = {
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"decision_method": cfg.decision_method,
"llm": {
"strategy": cfg.llm.strategy,
"majority_threshold": cfg.llm.majority_threshold,
"primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
},
}
try:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(payload, fh, ensure_ascii=False, indent=2)
except OSError:
pass
def _load_env_defaults(cfg: AppConfig) -> None:
"""Populate sensitive fields from environment variables if present."""
token = os.getenv("TUSHARE_TOKEN")
if token:
cfg.tushare_token = token.strip()
api_key = os.getenv("LLM_API_KEY")
if api_key:
cfg.llm.primary.api_key = api_key.strip()
_load_from_file(CONFIG)
_load_env_defaults(CONFIG)
def get_config() -> AppConfig: def get_config() -> AppConfig:
"""Return a mutable global configuration instance.""" """Return a mutable global configuration instance."""