From 5b4bd51199a33e034bd9361ee6039111e094a416 Mon Sep 17 00:00:00 2001 From: sam Date: Sat, 27 Sep 2025 21:03:04 +0800 Subject: [PATCH] update --- .gitignore | 2 + README.md | 5 +- app/llm/client.py | 11 +-- app/ui/streamlit_app.py | 104 +++++++++++++++++++++----- app/utils/config.py | 157 +++++++++++++++++++++++++++++++++++++++- 5 files changed, 247 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 94fc8fe..a2cc304 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,8 @@ env/ app/data/*.db* app/data/backups/ app/data/logs/ +app/data/*.json +.json *.log # Streamlit temporary files diff --git a/README.md b/README.md index ac949e1..082d751 100644 --- a/README.md +++ b/README.md @@ -58,10 +58,13 @@ export TUSHARE_TOKEN="" ### LLM 配置与测试 -- 支持本地 Ollama(`http://localhost:11434`)与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 Streamlit 的 “数据与设置” 页签切换 Provider 并配置模型、Base URL、API Key。不同 Provider 默认映射的模型示例:Ollama → `llama3`,OpenAI → `gpt-4o-mini`,DeepSeek → `deepseek-chat`,文心一言 → `ERNIE-Speed`。 +- 支持本地 Ollama 与多家 OpenAI 兼容云端供应商(如 DeepSeek、文心一言、OpenAI 等),可在 “数据与设置” 页签切换 Provider 并自动加载该 Provider 的候选模型、推荐 Base URL、默认温度与超时时间,亦可切换为自定义值。所有修改会持久化到 `app/data/config.json`,下次启动自动加载。 - 修改 Provider/模型/Base URL/API Key 后点击 “保存 LLM 设置”,更新内容仅在当前会话生效。 - 在 “自检测试” 页新增 “LLM 接口测试”,可输入 Prompt 快速验证调用结果,日志会记录限频与错误信息便于排查。 - 未来可对同一功能的智能体并行调用多个 LLM,采用多数投票等策略增强鲁棒性,当前代码结构已为此预留扩展空间。 +- 若使用环境变量自动注入配置,可设置: + - `TUSHARE_TOKEN` + - `LLM_API_KEY` ## 快速开始 diff --git a/app/llm/client.py b/app/llm/client.py index fef7277..130e4f1 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -8,7 +8,7 @@ from typing import Dict, Iterable, List, Optional import requests -from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config +from app.utils.config import DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_MODELS, LLMEndpoint, get_config from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -19,13 +19,8 @@ class LLMError(RuntimeError): def _default_base_url(provider: str) -> str: - if provider == "ollama": - return "http://localhost:11434" - if provider == "deepseek": - return "https://api.deepseek.com" - if provider == "wenxin": - return "https://aip.baidubce.com" - return "https://api.openai.com" + provider = (provider or "openai").lower() + return DEFAULT_LLM_BASE_URLS.get(provider, DEFAULT_LLM_BASE_URLS["openai"]) def _default_model(provider: str) -> str: diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index a4edec4..d756358 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -23,7 +23,14 @@ from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion from app.llm.client import llm_config_snapshot, run_llm from app.llm.explain import make_human_card -from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config +from app.utils.config import ( + DEFAULT_LLM_BASE_URLS, + DEFAULT_LLM_MODEL_OPTIONS, + DEFAULT_LLM_MODELS, + LLMEndpoint, + get_config, + save_config, +) from app.utils.db import db_session from app.utils.logging import get_logger @@ -190,6 +197,7 @@ def render_settings() -> None: LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA) cfg.tushare_token = token.strip() or None LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA) + save_config() st.success("设置已保存,仅在当前会话生效。") st.write("新闻源开关与数据库备份将在此配置。") @@ -198,25 +206,76 @@ def render_settings() -> None: st.subheader("LLM 设置") llm_cfg = cfg.llm primary = llm_cfg.primary - providers = ["ollama", "openai"] + providers = sorted(DEFAULT_LLM_MODELS.keys()) try: provider_index = providers.index((primary.provider or "ollama").lower()) except ValueError: provider_index = 0 selected_provider = st.selectbox("LLM Provider", providers, index=provider_index) + provider_info = DEFAULT_LLM_MODEL_OPTIONS.get(selected_provider, {}) + model_options = provider_info.get("models", []) + custom_model_label = "自定义模型" default_model_hint = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) - llm_model = st.text_input("LLM 模型", value=primary.model, help=f"默认推荐:{default_model_hint}") - base_hints = { - "ollama": "http://localhost:11434", - "openai": "https://api.openai.com", - "deepseek": "https://api.deepseek.com", - "wenxin": "https://aip.baidubce.com", - } - default_base_hint = base_hints.get(selected_provider, "") - llm_base = st.text_input("LLM Base URL (可选)", value=primary.base_url or "", help=f"默认推荐:{default_base_hint or '按供应商要求填写'}") + + if model_options: + options_with_custom = model_options + [custom_model_label] + if primary.model in model_options: + model_index = options_with_custom.index(primary.model) + else: + model_index = len(options_with_custom) - 1 + selected_model_option = st.selectbox( + "LLM 模型", + options_with_custom, + index=model_index, + help=f"可选模型:{', '.join(model_options)}", + ) + if selected_model_option == custom_model_label: + custom_model_value = st.text_input( + "自定义模型名称", + value=primary.model if primary.model not in model_options else "", + ) + else: + custom_model_value = selected_model_option + else: + custom_model_value = st.text_input( + "LLM 模型", + value=primary.model or default_model_hint, + help="未预设该 Provider 的模型列表,请手动填写", + ) + selected_model_option = custom_model_label + default_base_hint = DEFAULT_LLM_BASE_URLS.get(selected_provider, "") + provider_default_temp = float(provider_info.get("temperature", 0.2)) + provider_default_timeout = int(provider_info.get("timeout", 30.0)) + + if primary.provider == selected_provider: + base_value = primary.base_url or default_base_hint or "" + temp_value = float(primary.temperature) + timeout_value = int(primary.timeout) + else: + base_value = default_base_hint or "" + temp_value = provider_default_temp + timeout_value = provider_default_timeout + + llm_base = st.text_input( + "LLM Base URL (可选)", + value=base_value, + help=f"默认推荐:{default_base_hint or '按供应商要求填写'}", + ) llm_api_key = st.text_input("LLM API Key (OpenAI 类需要)", value=primary.api_key or "", type="password") - llm_temperature = st.slider("LLM 温度", min_value=0.0, max_value=2.0, value=float(primary.temperature), step=0.05) - llm_timeout = st.number_input("请求超时时间 (秒)", min_value=5.0, max_value=120.0, value=float(primary.timeout), step=5.0, format="%d") + llm_temperature = st.slider( + "LLM 温度", + min_value=0.0, + max_value=2.0, + value=temp_value, + step=0.05, + ) + llm_timeout = st.number_input( + "请求超时时间 (秒)", + min_value=5, + max_value=120, + value=timeout_value, + step=5, + ) strategy_options = ["single", "majority"] try: @@ -249,13 +308,18 @@ def render_settings() -> None: original_provider = primary.provider original_model = primary.model primary.provider = selected_provider - model_input = llm_model.strip() - if not model_input: - primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) - elif selected_provider != original_provider and model_input == original_model: - primary.model = DEFAULT_LLM_MODELS.get(selected_provider, DEFAULT_LLM_MODELS["ollama"]) + if model_options: + if selected_model_option == custom_model_label: + model_input = custom_model_value.strip() + primary.model = model_input or DEFAULT_LLM_MODELS.get( + selected_provider, DEFAULT_LLM_MODELS["ollama"] + ) + else: + primary.model = selected_model_option else: - primary.model = model_input + primary.model = custom_model_value.strip() or DEFAULT_LLM_MODELS.get( + selected_provider, DEFAULT_LLM_MODELS["ollama"] + ) primary.base_url = llm_base.strip() or None primary.temperature = llm_temperature primary.timeout = llm_timeout @@ -286,6 +350,7 @@ def render_settings() -> None: llm_cfg.ensemble = new_ensemble llm_cfg.strategy = selected_strategy llm_cfg.majority_threshold = int(majority_threshold) + save_config() LOGGER.info("LLM 配置已更新:%s", llm_config_snapshot(), extra=LOG_EXTRA) st.success("LLM 设置已保存,仅在当前会话生效。") st.json(llm_config_snapshot()) @@ -342,6 +407,7 @@ def render_tests() -> None: if force_refresh != cfg.force_refresh: cfg.force_refresh = force_refresh LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) + save_config() if st.button("执行开机检查"): LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA) diff --git a/app/utils/config.py b/app/utils/config.py index 804cef0..989e12e 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -2,6 +2,8 @@ from __future__ import annotations from dataclasses import dataclass, field +import json +import os from pathlib import Path from typing import Dict, List, Optional @@ -17,12 +19,14 @@ class DataPaths: root: Path = field(default_factory=_default_root) database: Path = field(init=False) backups: Path = field(init=False) + config_file: Path = field(init=False) def __post_init__(self) -> None: self.root.mkdir(parents=True, exist_ok=True) self.database = self.root / "llm_quant.db" self.backups = self.root / "backups" self.backups.mkdir(parents=True, exist_ok=True) + self.config_file = self.root / "config.json" @dataclass @@ -44,11 +48,51 @@ class AgentWeights: "A_macro": self.macro, } +DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = { + "ollama": { + "models": ["llama3", "phi3", "qwen2"], + "base_url": "http://localhost:11434", + "temperature": 0.2, + "timeout": 30.0, + }, + "openai": { + "models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"], + "base_url": "https://api.openai.com", + "temperature": 0.2, + "timeout": 30.0, + }, + "deepseek": { + "models": ["deepseek-chat", "deepseek-coder"], + "base_url": "https://api.deepseek.com", + "temperature": 0.2, + "timeout": 45.0, + }, + "wenxin": { + "models": ["ERNIE-Speed", "ERNIE-Bot"], + "base_url": "https://aip.baidubce.com", + "temperature": 0.2, + "timeout": 60.0, + }, +} + DEFAULT_LLM_MODELS: Dict[str, str] = { - "ollama": "llama3", - "openai": "gpt-4o-mini", - "deepseek": "deepseek-chat", - "wenxin": "ERNIE-Speed", + provider: info["models"][0] + for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() +} + +DEFAULT_LLM_BASE_URLS: Dict[str, str] = { + provider: info["base_url"] + for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() +} + +DEFAULT_LLM_TEMPERATURES: Dict[str, float] = { + provider: float(info.get("temperature", 0.2)) + for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() +} + +DEFAULT_LLM_TIMEOUTS: Dict[str, float] = { + provider: float(info.get("timeout", 30.0)) + for provider, info in DEFAULT_LLM_MODEL_OPTIONS.items() } @@ -67,6 +111,12 @@ class LLMEndpoint: self.provider = (self.provider or "ollama").lower() if not self.model: self.model = DEFAULT_LLM_MODELS.get(self.provider, DEFAULT_LLM_MODELS["ollama"]) + if not self.base_url: + self.base_url = DEFAULT_LLM_BASE_URLS.get(self.provider) + if self.temperature == 0.2 or self.temperature is None: + self.temperature = DEFAULT_LLM_TEMPERATURES.get(self.provider, 0.2) + if self.timeout == 30.0 or self.timeout is None: + self.timeout = DEFAULT_LLM_TIMEOUTS.get(self.provider, 30.0) @dataclass @@ -95,6 +145,105 @@ class AppConfig: CONFIG = AppConfig() +def _endpoint_to_dict(endpoint: LLMEndpoint) -> Dict[str, object]: + return { + "provider": endpoint.provider, + "model": endpoint.model, + "base_url": endpoint.base_url, + "api_key": endpoint.api_key, + "temperature": endpoint.temperature, + "timeout": endpoint.timeout, + } + + +def _dict_to_endpoint(data: Dict[str, object]) -> LLMEndpoint: + payload = { + key: data.get(key) + for key in ("provider", "model", "base_url", "api_key", "temperature", "timeout") + if data.get(key) is not None + } + return LLMEndpoint(**payload) + + +def _load_from_file(cfg: AppConfig) -> None: + path = cfg.data_paths.config_file + if not path.exists(): + return + try: + with path.open("r", encoding="utf-8") as fh: + payload = json.load(fh) + except (json.JSONDecodeError, OSError): + return + + if isinstance(payload, dict): + if "tushare_token" in payload: + cfg.tushare_token = payload.get("tushare_token") or None + if "force_refresh" in payload: + cfg.force_refresh = bool(payload.get("force_refresh")) + if "decision_method" in payload: + cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) + + llm_payload = payload.get("llm") + if isinstance(llm_payload, dict): + primary_data = llm_payload.get("primary") + if isinstance(primary_data, dict): + cfg.llm.primary = _dict_to_endpoint(primary_data) + + ensemble_data = llm_payload.get("ensemble") + if isinstance(ensemble_data, list): + cfg.llm.ensemble = [ + _dict_to_endpoint(item) + for item in ensemble_data + if isinstance(item, dict) + ] + + strategy = llm_payload.get("strategy") + if strategy in {"single", "majority"}: + cfg.llm.strategy = strategy + + majority = llm_payload.get("majority_threshold") + if isinstance(majority, int) and majority > 0: + cfg.llm.majority_threshold = majority + + +def save_config(cfg: AppConfig | None = None) -> None: + cfg = cfg or CONFIG + path = cfg.data_paths.config_file + payload = { + "tushare_token": cfg.tushare_token, + "force_refresh": cfg.force_refresh, + "decision_method": cfg.decision_method, + "llm": { + "strategy": cfg.llm.strategy, + "majority_threshold": cfg.llm.majority_threshold, + "primary": _endpoint_to_dict(cfg.llm.primary), + "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], + }, + } + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as fh: + json.dump(payload, fh, ensure_ascii=False, indent=2) + except OSError: + pass + + +def _load_env_defaults(cfg: AppConfig) -> None: + """Populate sensitive fields from environment variables if present.""" + + token = os.getenv("TUSHARE_TOKEN") + if token: + cfg.tushare_token = token.strip() + + api_key = os.getenv("LLM_API_KEY") + if api_key: + cfg.llm.primary.api_key = api_key.strip() + + +_load_from_file(CONFIG) +_load_env_defaults(CONFIG) + + def get_config() -> AppConfig: """Return a mutable global configuration instance."""