update
This commit is contained in:
parent
91e8eb5cb3
commit
a6b28ff019
@ -1610,12 +1610,24 @@ def render_settings() -> None:
|
||||
st.header("数据与设置")
|
||||
cfg = get_config()
|
||||
LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA)
|
||||
|
||||
# 基础配置
|
||||
col1, col2 = st.columns([2, 1])
|
||||
with col1:
|
||||
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||
with col2:
|
||||
auto_update = st.checkbox(
|
||||
"自动更新数据",
|
||||
value=cfg.auto_update_data,
|
||||
help="勾选后,每次启动程序将自动执行Tushare和RSS数据拉取"
|
||||
)
|
||||
|
||||
if st.button("保存设置"):
|
||||
LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA)
|
||||
cfg.tushare_token = token.strip() or None
|
||||
cfg.auto_update_data = auto_update
|
||||
LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA)
|
||||
LOGGER.info("自动更新数据设置=%s", cfg.auto_update_data, extra=LOG_EXTRA)
|
||||
save_config()
|
||||
st.success("设置已保存,仅在当前会话生效。")
|
||||
|
||||
@ -1674,6 +1686,19 @@ def render_settings() -> None:
|
||||
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
|
||||
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com")
|
||||
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
|
||||
|
||||
# 添加缺失的表单字段
|
||||
default_model_val = st.selectbox(
|
||||
"默认模型",
|
||||
options=provider_cfg.models or [""],
|
||||
index=0 if not provider_cfg.models else (provider_cfg.models.index(provider_cfg.default_model) if provider_cfg.default_model in provider_cfg.models else 0),
|
||||
key=f"provider_default_model_{selected_provider}"
|
||||
)
|
||||
temp_val = st.number_input("默认温度", value=provider_cfg.default_temperature, min_value=0.0, max_value=2.0, step=0.1, key=temp_key)
|
||||
timeout_val = st.number_input("默认超时(秒)", value=provider_cfg.default_timeout, min_value=1, max_value=300, step=1, key=timeout_key)
|
||||
prompt_template_val = st.text_area("Prompt 模板", value=provider_cfg.prompt_template or "", key=prompt_key)
|
||||
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
|
||||
mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
|
||||
st.markdown("可用模型:")
|
||||
if provider_cfg.models:
|
||||
st.code("\n".join(provider_cfg.models), language="text")
|
||||
@ -2299,6 +2324,41 @@ def render_tests() -> None:
|
||||
def main() -> None:
|
||||
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
||||
st.set_page_config(page_title="多智能体个人投资助理", layout="wide")
|
||||
|
||||
# 检查是否需要自动更新数据
|
||||
cfg = get_config()
|
||||
if cfg.auto_update_data:
|
||||
LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA)
|
||||
try:
|
||||
# 初始化数据库
|
||||
from app.data.schema import initialize_database
|
||||
initialize_database()
|
||||
|
||||
# 执行开机检查(包含数据拉取)
|
||||
from app.ingest.checker import run_boot_check
|
||||
with st.spinner("正在自动更新数据..."):
|
||||
def progress_hook(message: str, progress: float) -> None:
|
||||
st.write(f"📊 {message} ({progress:.1%})")
|
||||
|
||||
report = run_boot_check(
|
||||
days=30, # 最近30天
|
||||
auto_fetch=True,
|
||||
progress_hook=progress_hook,
|
||||
force_refresh=False
|
||||
)
|
||||
|
||||
# 执行RSS新闻拉取
|
||||
from app.ingest.rss import ingest_configured_rss
|
||||
rss_count = ingest_configured_rss(hours_back=24, max_items_per_feed=50)
|
||||
|
||||
LOGGER.info("自动数据更新完成:日线数据覆盖%s-%s,RSS新闻%s条",
|
||||
report.start, report.end, rss_count, extra=LOG_EXTRA)
|
||||
st.success(f"✅ 自动数据更新完成:获取RSS新闻 {rss_count} 条")
|
||||
|
||||
except Exception as exc:
|
||||
LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA)
|
||||
st.error(f"❌ 自动数据更新失败:{exc}")
|
||||
|
||||
render_global_dashboard()
|
||||
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
||||
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)
|
||||
|
||||
@ -339,6 +339,7 @@ class AppConfig:
|
||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
force_refresh: bool = False
|
||||
auto_update_data: bool = False
|
||||
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||
@ -399,6 +400,8 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
cfg.tushare_token = payload.get("tushare_token") or None
|
||||
if "force_refresh" in payload:
|
||||
cfg.force_refresh = bool(payload.get("force_refresh"))
|
||||
if "auto_update_data" in payload:
|
||||
cfg.auto_update_data = bool(payload.get("auto_update_data"))
|
||||
if "decision_method" in payload:
|
||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||
|
||||
@ -579,6 +582,7 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
payload = {
|
||||
"tushare_token": cfg.tushare_token,
|
||||
"force_refresh": cfg.force_refresh,
|
||||
"auto_update_data": cfg.auto_update_data,
|
||||
"decision_method": cfg.decision_method,
|
||||
"rss_sources": cfg.rss_sources,
|
||||
"agent_weights": cfg.agent_weights.as_dict(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user