This commit is contained in:
sam 2025-10-03 20:41:07 +08:00
parent 91e8eb5cb3
commit a6b28ff019
2 changed files with 65 additions and 1 deletions

View File

@ -1610,12 +1610,24 @@ def render_settings() -> None:
st.header("数据与设置")
cfg = get_config()
LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA)
# 基础配置
col1, col2 = st.columns([2, 1])
with col1:
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
with col2:
auto_update = st.checkbox(
"自动更新数据",
value=cfg.auto_update_data,
help="勾选后每次启动程序将自动执行Tushare和RSS数据拉取"
)
if st.button("保存设置"):
LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA)
cfg.tushare_token = token.strip() or None
cfg.auto_update_data = auto_update
LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA)
LOGGER.info("自动更新数据设置=%s", cfg.auto_update_data, extra=LOG_EXTRA)
save_config()
st.success("设置已保存,仅在当前会话生效。")
@ -1674,6 +1686,19 @@ def render_settings() -> None:
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址例如https://api.openai.com")
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
# 添加缺失的表单字段
default_model_val = st.selectbox(
"默认模型",
options=provider_cfg.models or [""],
index=0 if not provider_cfg.models else (provider_cfg.models.index(provider_cfg.default_model) if provider_cfg.default_model in provider_cfg.models else 0),
key=f"provider_default_model_{selected_provider}"
)
temp_val = st.number_input("默认温度", value=provider_cfg.default_temperature, min_value=0.0, max_value=2.0, step=0.1, key=temp_key)
timeout_val = st.number_input("默认超时(秒)", value=provider_cfg.default_timeout, min_value=1, max_value=300, step=1, key=timeout_key)
prompt_template_val = st.text_area("Prompt 模板", value=provider_cfg.prompt_template or "", key=prompt_key)
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
st.markdown("可用模型:")
if provider_cfg.models:
st.code("\n".join(provider_cfg.models), language="text")
@ -2299,6 +2324,41 @@ def render_tests() -> None:
def main() -> None:
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
# 检查是否需要自动更新数据
cfg = get_config()
if cfg.auto_update_data:
LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA)
try:
# 初始化数据库
from app.data.schema import initialize_database
initialize_database()
# 执行开机检查(包含数据拉取)
from app.ingest.checker import run_boot_check
with st.spinner("正在自动更新数据..."):
def progress_hook(message: str, progress: float) -> None:
st.write(f"📊 {message} ({progress:.1%})")
report = run_boot_check(
days=30, # 最近30天
auto_fetch=True,
progress_hook=progress_hook,
force_refresh=False
)
# 执行RSS新闻拉取
from app.ingest.rss import ingest_configured_rss
rss_count = ingest_configured_rss(hours_back=24, max_items_per_feed=50)
LOGGER.info("自动数据更新完成:日线数据覆盖%s-%sRSS新闻%s",
report.start, report.end, rss_count, extra=LOG_EXTRA)
st.success(f"✅ 自动数据更新完成获取RSS新闻 {rss_count}")
except Exception as exc:
LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA)
st.error(f"❌ 自动数据更新失败:{exc}")
render_global_dashboard()
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)

View File

@ -339,6 +339,7 @@ class AppConfig:
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
auto_update_data: bool = False
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
llm: LLMConfig = field(default_factory=LLMConfig)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
@ -399,6 +400,8 @@ def _load_from_file(cfg: AppConfig) -> None:
cfg.tushare_token = payload.get("tushare_token") or None
if "force_refresh" in payload:
cfg.force_refresh = bool(payload.get("force_refresh"))
if "auto_update_data" in payload:
cfg.auto_update_data = bool(payload.get("auto_update_data"))
if "decision_method" in payload:
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
@ -579,6 +582,7 @@ def save_config(cfg: AppConfig | None = None) -> None:
payload = {
"tushare_token": cfg.tushare_token,
"force_refresh": cfg.force_refresh,
"auto_update_data": cfg.auto_update_data,
"decision_method": cfg.decision_method,
"rss_sources": cfg.rss_sources,
"agent_weights": cfg.agent_weights.as_dict(),