update
This commit is contained in:
parent
db3df13462
commit
478a9a64af
65
app/ui/shared.py
Normal file
65
app/ui/shared.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Shared utilities and constants for Streamlit UI views."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "ui"}
|
||||
|
||||
|
||||
def get_query_params() -> dict[str, list[str]]:
|
||||
"""Safely read URL query parameters from Streamlit."""
|
||||
try:
|
||||
return dict(st.query_params)
|
||||
except Exception: # noqa: BLE001
|
||||
return {}
|
||||
|
||||
|
||||
def set_query_params(**kwargs: object) -> None:
|
||||
"""Update URL query parameters, ignoring failures in unsupported contexts."""
|
||||
try:
|
||||
payload = {k: v for k, v in kwargs.items() if v is not None}
|
||||
if payload:
|
||||
st.query_params.update(payload)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
|
||||
def get_latest_trade_date() -> Optional[date]:
|
||||
"""Fetch the most recent trade date from the database."""
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT trade_date FROM daily ORDER BY trade_date DESC LIMIT 1"
|
||||
).fetchone()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("查询最新交易日失败", extra=LOG_EXTRA)
|
||||
return None
|
||||
if not row:
|
||||
return None
|
||||
raw_value = row["trade_date"]
|
||||
if not raw_value:
|
||||
return None
|
||||
try:
|
||||
return datetime.strptime(str(raw_value), "%Y%m%d").date()
|
||||
except ValueError:
|
||||
try:
|
||||
return datetime.fromisoformat(str(raw_value)).date()
|
||||
except ValueError:
|
||||
LOGGER.warning("无法解析交易日:%s", raw_value, extra=LOG_EXTRA)
|
||||
return None
|
||||
|
||||
|
||||
def default_backtest_range(window_days: int = 60) -> tuple[date, date]:
|
||||
"""Return a sensible (end, start) date range for backtests."""
|
||||
latest = get_latest_trade_date() or date.today()
|
||||
start = latest - timedelta(days=window_days)
|
||||
if start > latest:
|
||||
start = latest
|
||||
return start, latest
|
||||
File diff suppressed because it is too large
Load Diff
24
app/ui/views/__init__.py
Normal file
24
app/ui/views/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""View modules for Streamlit UI tabs."""
|
||||
|
||||
from .today import render_today_plan
|
||||
from .pool import render_pool_overview
|
||||
from .backtest import render_backtest_review
|
||||
from .market import render_market_visualization
|
||||
from .logs import render_log_viewer
|
||||
from .settings import render_config_overview, render_llm_settings, render_data_settings
|
||||
from .tests import render_tests
|
||||
from .dashboard import render_global_dashboard, update_dashboard_sidebar
|
||||
|
||||
__all__ = [
|
||||
"render_today_plan",
|
||||
"render_pool_overview",
|
||||
"render_backtest_review",
|
||||
"render_market_visualization",
|
||||
"render_log_viewer",
|
||||
"render_config_overview",
|
||||
"render_llm_settings",
|
||||
"render_data_settings",
|
||||
"render_tests",
|
||||
"render_global_dashboard",
|
||||
"update_dashboard_sidebar",
|
||||
]
|
||||
790
app/ui/views/backtest.py
Normal file
790
app/ui/views/backtest.py
Normal file
@ -0,0 +1,790 @@
|
||||
"""回测与复盘相关视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import date
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
import streamlit as st
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.agents.registry import default_agents
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import run_ingestion
|
||||
from app.llm.client import run_llm
|
||||
from app.llm.metrics import reset as reset_llm_metrics
|
||||
from app.llm.metrics import snapshot as snapshot_llm_metrics
|
||||
from app.utils import alerts
|
||||
from app.utils.config import get_config, save_config
|
||||
from app.utils.tuning import log_tuning_result
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA, default_backtest_range
|
||||
from app.ui.views.dashboard import update_dashboard_sidebar
|
||||
|
||||
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||||
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||
|
||||
def render_backtest_review() -> None:
|
||||
"""渲染回测执行、调参与结果复盘页面。"""
|
||||
st.header("回测与复盘")
|
||||
st.caption("1. 基于历史数据复盘当前策略;2. 借助强化学习/调参探索更优参数组合。")
|
||||
app_cfg = get_config()
|
||||
default_start, default_end = default_backtest_range(window_days=60)
|
||||
LOGGER.debug(
|
||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||
default_start,
|
||||
default_end,
|
||||
"000001.SZ",
|
||||
0.035,
|
||||
-0.015,
|
||||
10,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
st.markdown("### 回测参数")
|
||||
col1, col2 = st.columns(2)
|
||||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||||
col_target, col_stop, col_hold = st.columns(3)
|
||||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
||||
LOGGER.debug(
|
||||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
||||
start_date,
|
||||
end_date,
|
||||
universe_text,
|
||||
target,
|
||||
stop,
|
||||
hold_days,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"])
|
||||
|
||||
with tab_backtest:
|
||||
st.markdown("#### 回测执行")
|
||||
if st.button("运行回测", key="bt_run_button"):
|
||||
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||||
decision_log_container = st.container()
|
||||
status_box = st.status("准备执行回测...", expanded=True)
|
||||
llm_stats_placeholder = st.empty()
|
||||
decision_entries: List[str] = []
|
||||
|
||||
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
|
||||
ts_label = trade_dt.isoformat()
|
||||
summary = ""
|
||||
for dept_decision in decision.department_decisions.values():
|
||||
if getattr(dept_decision, "summary", ""):
|
||||
summary = str(dept_decision.summary)
|
||||
break
|
||||
entry_lines = [
|
||||
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
|
||||
]
|
||||
if summary:
|
||||
entry_lines.append(f"摘要:{summary}")
|
||||
dep_highlights = []
|
||||
for dept_code, dept_decision in decision.department_decisions.items():
|
||||
dep_highlights.append(
|
||||
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
|
||||
)
|
||||
if dep_highlights:
|
||||
entry_lines.append("部门意见:" + ";".join(dep_highlights))
|
||||
decision_entries.append(" \n".join(entry_lines))
|
||||
decision_log_container.markdown("\n\n".join(decision_entries[-200:]))
|
||||
status_box.write(f"{ts_label} {ts_code} → {decision.action.value} (信心 {decision.confidence:.2f})")
|
||||
stats = snapshot_llm_metrics()
|
||||
llm_stats_placeholder.json(
|
||||
{
|
||||
"LLM 调用次数": stats.get("total_calls", 0),
|
||||
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
|
||||
"Completion Tokens": stats.get("total_completion_tokens", 0),
|
||||
"按 Provider": stats.get("provider_calls", {}),
|
||||
"按模型": stats.get("model_calls", {}),
|
||||
}
|
||||
)
|
||||
update_dashboard_sidebar(stats)
|
||||
|
||||
reset_llm_metrics()
|
||||
status_box.update(label="执行回测中...", state="running")
|
||||
try:
|
||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
LOGGER.info(
|
||||
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||
start_date,
|
||||
end_date,
|
||||
universe,
|
||||
target,
|
||||
stop,
|
||||
hold_days,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
backtest_cfg = BtConfig(
|
||||
id="streamlit_demo",
|
||||
name="Streamlit Demo Strategy",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
)
|
||||
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
||||
LOGGER.info(
|
||||
"回测完成:nav_records=%s trades=%s",
|
||||
len(result.nav_series),
|
||||
len(result.trades),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
status_box.update(label="回测执行完成", state="complete")
|
||||
st.success("回测执行完成,详见下方结果与统计。")
|
||||
metrics = snapshot_llm_metrics()
|
||||
llm_stats_placeholder.json(
|
||||
{
|
||||
"LLM 调用次数": metrics.get("total_calls", 0),
|
||||
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
|
||||
"Completion Tokens": metrics.get("total_completion_tokens", 0),
|
||||
"按 Provider": metrics.get("provider_calls", {}),
|
||||
"按模型": metrics.get("model_calls", {}),
|
||||
}
|
||||
)
|
||||
update_dashboard_sidebar(metrics)
|
||||
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades}
|
||||
st.json(st.session_state["backtest_last_result"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||
status_box.update(label="回测执行失败", state="error")
|
||||
st.error(f"回测执行失败:{exc}")
|
||||
|
||||
last_result = st.session_state.get("backtest_last_result")
|
||||
if last_result:
|
||||
st.markdown("#### 最近回测输出")
|
||||
st.json(last_result)
|
||||
|
||||
st.divider()
|
||||
# ADD: Comparison view for multiple backtest configurations
|
||||
with st.expander("历史回测结果对比", expanded=False):
|
||||
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||||
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
|
||||
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
|
||||
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||||
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
cfg_rows = conn.execute(
|
||||
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||||
cfg_rows = []
|
||||
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||||
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs")
|
||||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
if selected_ids:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
nav_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
rpt_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||
st.error("读取回测结果失败")
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
if not nav_df.empty:
|
||||
try:
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||
# ADD: date window filter
|
||||
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
|
||||
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
|
||||
col_d1, col_d2 = st.columns(2)
|
||||
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
|
||||
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
|
||||
if start_filter > end_filter:
|
||||
start_filter, end_filter = end_filter, start_filter
|
||||
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
||||
nav_df = nav_df.loc[mask]
|
||||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||
if normalize_to_one:
|
||||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||
import plotly.graph_objects as go
|
||||
fig = go.Figure()
|
||||
for col in pivot.columns:
|
||||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||
if use_log_y:
|
||||
fig.update_yaxes(type="log")
|
||||
st.plotly_chart(fig, width='stretch')
|
||||
# ADD: export pivot
|
||||
try:
|
||||
csv_buf = pivot.reset_index()
|
||||
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
||||
st.download_button(
|
||||
"下载曲线(CSV)",
|
||||
data=csv_buf.to_csv(index=False),
|
||||
file_name="bt_nav_compare.csv",
|
||||
mime="text/csv",
|
||||
key="dl_nav_compare",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||||
if not rpt_df.empty:
|
||||
try:
|
||||
metrics_rows: List[Dict[str, object]] = []
|
||||
for _, row in rpt_df.iterrows():
|
||||
cfg_id = row["cfg_id"]
|
||||
try:
|
||||
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||||
except json.JSONDecodeError:
|
||||
summary = {}
|
||||
record = {
|
||||
"cfg_id": cfg_id,
|
||||
"总收益": summary.get("total_return"),
|
||||
"最大回撤": summary.get("max_drawdown"),
|
||||
"交易数": summary.get("trade_count"),
|
||||
"平均换手": summary.get("avg_turnover"),
|
||||
"风险事件": summary.get("risk_events"),
|
||||
}
|
||||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||
if metrics_rows:
|
||||
dfm = pd.DataFrame(metrics_rows)
|
||||
st.dataframe(dfm, hide_index=True, width='stretch')
|
||||
try:
|
||||
st.download_button(
|
||||
"下载指标(CSV)",
|
||||
data=dfm.to_csv(index=False),
|
||||
file_name="bt_metrics_compare.csv",
|
||||
mime="text/csv",
|
||||
key="dl_metrics_compare",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
st.info("请选择至少一个配置进行对比。")
|
||||
|
||||
|
||||
|
||||
with tab_rl:
|
||||
st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。")
|
||||
with st.expander("离线调参实验 (DecisionEnv)", expanded=False):
|
||||
st.caption(
|
||||
"使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围,"
|
||||
"系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。"
|
||||
)
|
||||
|
||||
disable_departments = st.checkbox(
|
||||
"禁用部门 LLM(仅规则代理,适合离线快速评估)",
|
||||
value=True,
|
||||
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
||||
)
|
||||
|
||||
default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
experiment_id = st.text_input(
|
||||
"实验 ID",
|
||||
value=default_experiment_id,
|
||||
help="用于在 tuning_results 表中区分不同实验。",
|
||||
)
|
||||
strategy_label = st.text_input(
|
||||
"策略说明",
|
||||
value="DecisionEnv",
|
||||
help="可选:为本次调参记录一个策略名称或备注。",
|
||||
)
|
||||
|
||||
agent_objects = default_agents()
|
||||
agent_names = [agent.name for agent in agent_objects]
|
||||
if not agent_names:
|
||||
st.info("暂无可调整的代理。")
|
||||
else:
|
||||
selected_agents = st.multiselect(
|
||||
"选择调参的代理权重",
|
||||
agent_names,
|
||||
default=agent_names[:2],
|
||||
key="decision_env_agents",
|
||||
)
|
||||
|
||||
specs: List[ParameterSpec] = []
|
||||
action_values: List[float] = []
|
||||
range_valid = True
|
||||
for idx, agent_name in enumerate(selected_agents):
|
||||
col_min, col_max, col_action = st.columns([1, 1, 2])
|
||||
min_key = f"decision_env_min_{agent_name}"
|
||||
max_key = f"decision_env_max_{agent_name}"
|
||||
action_key = f"decision_env_action_{agent_name}"
|
||||
default_min = 0.0
|
||||
default_max = 1.0
|
||||
min_val = col_min.number_input(
|
||||
f"{agent_name} 最小权重",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=default_min,
|
||||
step=0.05,
|
||||
key=min_key,
|
||||
)
|
||||
max_val = col_max.number_input(
|
||||
f"{agent_name} 最大权重",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=default_max,
|
||||
step=0.05,
|
||||
key=max_key,
|
||||
)
|
||||
if max_val <= min_val:
|
||||
range_valid = False
|
||||
action_val = col_action.slider(
|
||||
f"{agent_name} 动作 (0-1)",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.5,
|
||||
step=0.01,
|
||||
key=action_key,
|
||||
)
|
||||
specs.append(
|
||||
ParameterSpec(
|
||||
name=f"weight_{agent_name}",
|
||||
target=f"agent_weights.{agent_name}",
|
||||
minimum=min_val,
|
||||
maximum=max_val,
|
||||
)
|
||||
)
|
||||
action_values.append(action_val)
|
||||
|
||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||
just_finished_single = False
|
||||
if run_decision_env:
|
||||
if not selected_agents:
|
||||
st.warning("请至少选择一个代理进行调参。")
|
||||
elif not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||||
selected_agents,
|
||||
action_values,
|
||||
disable_departments,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
baseline_weights = cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit",
|
||||
name="DecisionEnv Streamlit",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
env.reset()
|
||||
LOGGER.debug(
|
||||
"离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s",
|
||||
bt_cfg_env,
|
||||
len(specs),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在执行离线调参……"):
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_values)
|
||||
LOGGER.info(
|
||||
"离线调参(单次)完成,obs=%s reward=%.4f done=%s",
|
||||
observation,
|
||||
reward,
|
||||
done,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
||||
st.error(f"离线调参失败:{exc}")
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
else:
|
||||
if observation.get("failure"):
|
||||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
||||
st.json(observation)
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
else:
|
||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||
resolved_strategy = strategy_label or "DecisionEnv"
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_values)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
log_success = False
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=resolved_experiment_id,
|
||||
strategy=resolved_strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=info.get("weights", {}),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
log_success = True
|
||||
LOGGER.info(
|
||||
"离线调参(单次)日志写入成功:experiment=%s strategy=%s",
|
||||
resolved_experiment_id,
|
||||
resolved_strategy,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
|
||||
"observation": dict(observation),
|
||||
"reward": float(reward),
|
||||
"weights": info.get("weights", {}),
|
||||
"nav_series": info.get("nav_series"),
|
||||
"trades": info.get("trades"),
|
||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||
"portfolio_trades": info.get("portfolio_trades"),
|
||||
"risk_breakdown": info.get("risk_breakdown"),
|
||||
"selected_agents": list(selected_agents),
|
||||
"action_values": list(action_values),
|
||||
"experiment_id": resolved_experiment_id,
|
||||
"strategy_label": resolved_strategy,
|
||||
"logged": log_success,
|
||||
}
|
||||
just_finished_single = True
|
||||
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
|
||||
if single_result:
|
||||
if just_finished_single:
|
||||
st.success("离线调参完成")
|
||||
else:
|
||||
st.success("离线调参结果(最近一次运行)")
|
||||
st.caption(
|
||||
f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
|
||||
)
|
||||
observation = single_result.get("observation", {})
|
||||
reward = float(single_result.get("reward", 0.0))
|
||||
col_metrics = st.columns(4)
|
||||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
||||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||
|
||||
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
||||
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
||||
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
||||
col_metrics_extra = st.columns(3)
|
||||
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
||||
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
||||
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
||||
|
||||
weights_dict = single_result.get("weights") or {}
|
||||
if weights_dict:
|
||||
st.write("调参后权重:")
|
||||
st.json(weights_dict)
|
||||
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
|
||||
try:
|
||||
cfg.agent_weights.update_from_dict(weights_dict)
|
||||
save_config(cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
st.error(f"写入配置失败:{exc}")
|
||||
else:
|
||||
st.success("代理权重已写入 config.json")
|
||||
|
||||
if single_result.get("logged"):
|
||||
st.caption("调参结果已写入 tuning_results 表。")
|
||||
|
||||
nav_series = single_result.get("nav_series") or []
|
||||
if nav_series:
|
||||
try:
|
||||
nav_df = pd.DataFrame(nav_series)
|
||||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
||||
nav_df = nav_df.sort_values("trade_date")
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
||||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
||||
|
||||
trades = single_result.get("trades") or []
|
||||
if trades:
|
||||
st.write("成交记录:")
|
||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||
|
||||
snapshots = single_result.get("portfolio_snapshots") or []
|
||||
if snapshots:
|
||||
with st.expander("投资组合快照", expanded=False):
|
||||
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
||||
|
||||
portfolio_trades = single_result.get("portfolio_trades") or []
|
||||
if portfolio_trades:
|
||||
with st.expander("组合成交明细", expanded=False):
|
||||
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
||||
|
||||
risk_breakdown = single_result.get("risk_breakdown") or {}
|
||||
if risk_breakdown:
|
||||
with st.expander("风险事件统计", expanded=False):
|
||||
st.json(risk_breakdown)
|
||||
|
||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
st.success("已清除单次调参结果缓存。")
|
||||
|
||||
st.divider()
|
||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||||
default_grid = "\n".join(
|
||||
[
|
||||
",".join(["0.2" for _ in specs]),
|
||||
",".join(["0.5" for _ in specs]),
|
||||
",".join(["0.8" for _ in specs]),
|
||||
]
|
||||
) if specs else ""
|
||||
action_grid_raw = st.text_area(
|
||||
"动作列表",
|
||||
value=default_grid,
|
||||
height=120,
|
||||
key="decision_env_batch_actions",
|
||||
)
|
||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||
batch_just_ran = False
|
||||
if run_batch:
|
||||
if not selected_agents:
|
||||
st.warning("请先选择调参代理。")
|
||||
elif not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||||
selected_agents,
|
||||
disable_departments,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
st.warning("请在文本框中输入至少一组动作。")
|
||||
else:
|
||||
LOGGER.debug(
|
||||
"离线调参(批量)原始输入=%s",
|
||||
lines,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
parsed_actions: List[List[float]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
values = [float(val.strip()) for val in line.split(',') if val.strip()]
|
||||
except ValueError:
|
||||
st.error(f"无法解析动作行:{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
if len(values) != len(specs):
|
||||
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
parsed_actions.append(values)
|
||||
if parsed_actions:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)解析动作成功,数量=%s",
|
||||
len(parsed_actions),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
baseline_weights = cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit_batch",
|
||||
name="DecisionEnv Batch",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
results: List[Dict[str, object]] = []
|
||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||
resolved_strategy = strategy_label or "DecisionEnv"
|
||||
LOGGER.debug(
|
||||
"离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s",
|
||||
bt_cfg_env,
|
||||
len(parsed_actions),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在批量执行调参……"):
|
||||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
||||
env.reset()
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_vals)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "error",
|
||||
"错误": str(exc),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if observation.get("failure"):
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "failure",
|
||||
"奖励": -1.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)第 %s 组完成,reward=%.4f obs=%s",
|
||||
idx,
|
||||
reward,
|
||||
observation,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_vals)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
weights_payload = info.get("weights", {})
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=resolved_experiment_id,
|
||||
strategy=resolved_strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=weights_payload,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "ok",
|
||||
"总收益": observation.get("total_return", 0.0),
|
||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
||||
"波动率": observation.get("volatility", 0.0),
|
||||
"奖励": reward,
|
||||
"权重": weights_payload,
|
||||
}
|
||||
)
|
||||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||||
"results": results,
|
||||
"selectable": [
|
||||
row
|
||||
for row in results
|
||||
if row.get("状态") == "ok" and row.get("权重")
|
||||
],
|
||||
"experiment_id": resolved_experiment_id,
|
||||
"strategy_label": resolved_strategy,
|
||||
}
|
||||
batch_just_ran = True
|
||||
LOGGER.info(
|
||||
"离线调参(批量)执行结束,总结果条数=%s",
|
||||
len(results),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
|
||||
if batch_state:
|
||||
results = batch_state.get("results") or []
|
||||
if results:
|
||||
if batch_just_ran:
|
||||
st.success("批量调参完成")
|
||||
else:
|
||||
st.success("批量调参结果(最近一次运行)")
|
||||
st.caption(
|
||||
f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
|
||||
)
|
||||
results_df = pd.DataFrame(results)
|
||||
st.write("批量调参结果:")
|
||||
st.dataframe(results_df, hide_index=True, width='stretch')
|
||||
selectable = batch_state.get("selectable") or []
|
||||
if selectable:
|
||||
option_labels = [
|
||||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
||||
for row in selectable
|
||||
]
|
||||
selected_label = st.selectbox(
|
||||
"选择要保存的记录",
|
||||
option_labels,
|
||||
key="decision_env_batch_select",
|
||||
)
|
||||
selected_row = None
|
||||
for label, row in zip(option_labels, selectable):
|
||||
if label == selected_label:
|
||||
selected_row = row
|
||||
break
|
||||
if selected_row and st.button(
|
||||
"保存所选权重为默认配置",
|
||||
key="save_decision_env_weights_batch",
|
||||
):
|
||||
try:
|
||||
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
||||
save_config(cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
st.error(f"写入配置失败:{exc}")
|
||||
else:
|
||||
st.success(
|
||||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
||||
)
|
||||
else:
|
||||
st.caption("暂无成功的结果可供保存。")
|
||||
else:
|
||||
st.caption("批量调参在最近一次执行中未产生结果。")
|
||||
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
|
||||
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
|
||||
st.session_state.pop("decision_env_batch_select", None)
|
||||
st.success("已清除批量调参结果缓存。")
|
||||
147
app/ui/views/dashboard.py
Normal file
147
app/ui/views/dashboard.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Sidebar dashboard for the Streamlit UI."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from app.llm.metrics import (
|
||||
recent_decisions as llm_recent_decisions,
|
||||
register_listener as register_llm_metrics_listener,
|
||||
snapshot as snapshot_llm_metrics,
|
||||
)
|
||||
from app.utils import alerts
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||
|
||||
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
|
||||
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
|
||||
_SIDEBAR_LISTENER_ATTACHED = False
|
||||
_WARNINGS_PLACEHOLDER = None
|
||||
|
||||
|
||||
def _ensure_dashboard_elements(metrics_container: object, decisions_container: object) -> Dict[str, object]:
|
||||
elements = {
|
||||
"metrics_calls": metrics_container.metric,
|
||||
"metrics_prompt": metrics_container.metric,
|
||||
"metrics_completion": metrics_container.metric,
|
||||
"provider_distribution": metrics_container.empty(),
|
||||
"model_distribution": metrics_container.empty(),
|
||||
"decisions_list": decisions_container.empty(),
|
||||
}
|
||||
return elements
|
||||
|
||||
|
||||
def update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None:
|
||||
"""Refresh sidebar metrics and warnings."""
|
||||
global _DASHBOARD_CONTAINERS
|
||||
global _DASHBOARD_ELEMENTS
|
||||
global _WARNINGS_PLACEHOLDER
|
||||
|
||||
containers = _DASHBOARD_CONTAINERS
|
||||
if not containers:
|
||||
return
|
||||
metrics_container, decisions_container = containers
|
||||
elements = _DASHBOARD_ELEMENTS
|
||||
if elements is None:
|
||||
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||
_DASHBOARD_ELEMENTS = elements
|
||||
|
||||
if metrics is None:
|
||||
metrics = snapshot_llm_metrics()
|
||||
|
||||
elements["metrics_calls"]("LLM 调用", metrics.get("total_calls", 0))
|
||||
elements["metrics_prompt"]("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
|
||||
elements["metrics_completion"]("Completion Tokens", metrics.get("total_completion_tokens", 0))
|
||||
|
||||
provider_calls = metrics.get("provider_calls", {})
|
||||
model_calls = metrics.get("model_calls", {})
|
||||
|
||||
provider_placeholder = elements["provider_distribution"]
|
||||
provider_placeholder.empty()
|
||||
if provider_calls:
|
||||
provider_placeholder.json(provider_calls)
|
||||
else:
|
||||
provider_placeholder.info("暂无 Provider 分布数据。")
|
||||
|
||||
model_placeholder = elements["model_distribution"]
|
||||
model_placeholder.empty()
|
||||
if model_calls:
|
||||
model_placeholder.json(model_calls)
|
||||
else:
|
||||
model_placeholder.info("暂无模型分布数据。")
|
||||
|
||||
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
|
||||
decisions_placeholder = elements["decisions_list"]
|
||||
decisions_placeholder.empty()
|
||||
if decisions:
|
||||
lines = []
|
||||
for record in reversed(decisions[-10:]):
|
||||
ts_code = record.get("ts_code")
|
||||
trade_date = record.get("trade_date")
|
||||
action = record.get("action")
|
||||
confidence = record.get("confidence", 0.0)
|
||||
summary = record.get("summary")
|
||||
line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
|
||||
if summary:
|
||||
line += f"\n<small>{summary}</small>"
|
||||
lines.append(line)
|
||||
decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True)
|
||||
else:
|
||||
decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。")
|
||||
|
||||
if _WARNINGS_PLACEHOLDER is not None:
|
||||
_WARNINGS_PLACEHOLDER.empty()
|
||||
with _WARNINGS_PLACEHOLDER.container():
|
||||
st.subheader("数据告警")
|
||||
warn_list = alerts.get_warnings()
|
||||
if warn_list:
|
||||
lines = []
|
||||
for warning in warn_list[-10:]:
|
||||
detail = warning.get("detail")
|
||||
source = warning.get("source")
|
||||
ts = warning.get("ts")
|
||||
label = warning.get("label")
|
||||
line = f"- **{source or '未知来源'}** {label or ''}"
|
||||
if detail:
|
||||
line += f":{detail}"
|
||||
if ts:
|
||||
line += f"({ts})"
|
||||
lines.append(line)
|
||||
st.markdown("\n".join(lines))
|
||||
else:
|
||||
st.caption("暂无数据告警。")
|
||||
|
||||
|
||||
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
||||
try:
|
||||
update_dashboard_sidebar(metrics)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("侧边栏监听器刷新失败", exc_info=True, extra=LOG_EXTRA)
|
||||
|
||||
|
||||
def render_global_dashboard() -> None:
|
||||
"""Render a persistent sidebar with realtime LLM stats and recent decisions."""
|
||||
|
||||
global _DASHBOARD_CONTAINERS
|
||||
global _DASHBOARD_ELEMENTS
|
||||
global _SIDEBAR_LISTENER_ATTACHED
|
||||
global _WARNINGS_PLACEHOLDER
|
||||
|
||||
warnings = alerts.get_warnings()
|
||||
badge = f" ({len(warnings)})" if warnings else ""
|
||||
st.sidebar.header(f"系统监控{badge}")
|
||||
|
||||
metrics_container = st.sidebar.container()
|
||||
decisions_container = st.sidebar.container()
|
||||
st.sidebar.container() # legacy placeholder for layout spacing
|
||||
warn_placeholder = st.sidebar.empty()
|
||||
|
||||
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
|
||||
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
|
||||
_WARNINGS_PLACEHOLDER = warn_placeholder
|
||||
|
||||
if not _SIDEBAR_LISTENER_ATTACHED:
|
||||
register_llm_metrics_listener(_sidebar_metrics_listener)
|
||||
_SIDEBAR_LISTENER_ATTACHED = True
|
||||
update_dashboard_sidebar()
|
||||
164
app/ui/views/logs.py
Normal file
164
app/ui/views/logs.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""日志钻取视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from app.utils.db import db_session
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||
|
||||
def render_log_viewer() -> None:
|
||||
"""渲染日志钻取与历史对比视图页面。"""
|
||||
LOGGER.info("渲染日志视图页面", extra=LOG_EXTRA)
|
||||
st.header("日志钻取与历史对比")
|
||||
st.write("查看系统运行日志,支持时间范围筛选、关键词搜索和历史对比功能。")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
start_date = st.date_input("开始日期", value=date.today() - timedelta(days=7))
|
||||
with col2:
|
||||
end_date = st.date_input("结束日期", value=date.today())
|
||||
|
||||
log_levels = ["ALL", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
selected_level = st.selectbox("日志级别", log_levels, index=1)
|
||||
|
||||
search_query = st.text_input("搜索关键词")
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
stages = [row["stage"] for row in conn.execute("SELECT DISTINCT stage FROM run_log").fetchall()]
|
||||
stages = [s for s in stages if s]
|
||||
stages.insert(0, "ALL")
|
||||
selected_stage = st.selectbox("执行阶段", stages)
|
||||
|
||||
with st.spinner("加载日志数据中..."):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
query_parts = ["SELECT ts, stage, level, msg FROM run_log WHERE 1=1"]
|
||||
params: list[object] = []
|
||||
|
||||
start_ts = f"{start_date.isoformat()}T00:00:00Z"
|
||||
end_ts = f"{end_date.isoformat()}T23:59:59Z"
|
||||
query_parts.append("AND ts BETWEEN ? AND ?")
|
||||
params.extend([start_ts, end_ts])
|
||||
|
||||
if selected_level != "ALL":
|
||||
query_parts.append("AND level = ?")
|
||||
params.append(selected_level)
|
||||
|
||||
if search_query:
|
||||
query_parts.append("AND msg LIKE ?")
|
||||
params.append(f"%{search_query}%")
|
||||
|
||||
if selected_stage != "ALL":
|
||||
query_parts.append("AND stage = ?")
|
||||
params.append(selected_stage)
|
||||
|
||||
query_parts.append("ORDER BY ts DESC")
|
||||
|
||||
query = " ".join(query_parts)
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
|
||||
if rows:
|
||||
rows_dict = [{key: row[key] for key in row.keys()} for row in rows]
|
||||
log_df = pd.DataFrame(rows_dict)
|
||||
log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
for col in log_df.columns:
|
||||
log_df[col] = log_df[col].astype(str)
|
||||
else:
|
||||
log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"])
|
||||
|
||||
st.dataframe(
|
||||
log_df,
|
||||
hide_index=True,
|
||||
width="stretch",
|
||||
column_config={
|
||||
"ts": st.column_config.TextColumn("时间"),
|
||||
"stage": st.column_config.TextColumn("执行阶段"),
|
||||
"level": st.column_config.TextColumn("日志级别"),
|
||||
"msg": st.column_config.TextColumn("日志消息", width="large"),
|
||||
},
|
||||
)
|
||||
|
||||
if not log_df.empty:
|
||||
csv_data = log_df.to_csv(index=False).encode("utf-8")
|
||||
st.download_button(
|
||||
label="下载日志CSV",
|
||||
data=csv_data,
|
||||
file_name=f"logs_{start_date}_{end_date}.csv",
|
||||
mime="text/csv",
|
||||
key="download_logs",
|
||||
)
|
||||
|
||||
json_data = log_df.to_json(orient="records", force_ascii=False, indent=2)
|
||||
st.download_button(
|
||||
label="下载日志JSON",
|
||||
data=json_data,
|
||||
file_name=f"logs_{start_date}_{end_date}.json",
|
||||
mime="application/json",
|
||||
key="download_logs_json",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("加载日志失败", extra=LOG_EXTRA)
|
||||
st.error(f"加载日志数据失败:{exc}")
|
||||
|
||||
st.subheader("历史对比")
|
||||
st.write("选择两个时间点的日志进行对比分析。")
|
||||
|
||||
col3, col4 = st.columns(2)
|
||||
with col3:
|
||||
compare_date1 = st.date_input("对比日期1", value=date.today() - timedelta(days=1))
|
||||
with col4:
|
||||
compare_date2 = st.date_input("对比日期2", value=date.today())
|
||||
|
||||
comparison_stage = st.selectbox("对比阶段", stages, key="compare_stage")
|
||||
st.write("选择需要比较的日志数量。")
|
||||
compare_limit = st.slider("对比日志数量", min_value=10, max_value=200, value=50, step=10)
|
||||
|
||||
if st.button("生成历史对比报告"):
|
||||
with st.spinner("生成对比报告中..."):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
def load_logs(d: date) -> pd.DataFrame:
|
||||
start_ts = f"{d.isoformat()}T00:00:00Z"
|
||||
end_ts = f"{d.isoformat()}T23:59:59Z"
|
||||
query = ["SELECT ts, level, msg FROM run_log WHERE ts BETWEEN ? AND ?"]
|
||||
params: list[object] = [start_ts, end_ts]
|
||||
if comparison_stage != "ALL":
|
||||
query.append("AND stage = ?")
|
||||
params.append(comparison_stage)
|
||||
query.append("ORDER BY ts DESC LIMIT ?")
|
||||
params.append(compare_limit)
|
||||
sql = " ".join(query)
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=["ts", "level", "msg"])
|
||||
df = pd.DataFrame([{k: row[k] for k in row.keys()} for row in rows])
|
||||
df["ts"] = pd.to_datetime(df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return df
|
||||
|
||||
df1 = load_logs(compare_date1)
|
||||
df2 = load_logs(compare_date2)
|
||||
|
||||
if df1.empty and df2.empty:
|
||||
st.info("选定日期暂无日志可对比。")
|
||||
else:
|
||||
st.write("### 对比结果")
|
||||
col_a, col_b = st.columns(2)
|
||||
with col_a:
|
||||
st.write(f"{compare_date1} 日日志")
|
||||
st.dataframe(df1, hide_index=True, width="stretch")
|
||||
with col_b:
|
||||
st.write(f"{compare_date2} 日日志")
|
||||
st.dataframe(df2, hide_index=True, width="stretch")
|
||||
|
||||
summary = {
|
||||
"日期1日志条数": int(len(df1)),
|
||||
"日期2日志条数": int(len(df2)),
|
||||
"新增日志条数": max(len(df2) - len(df1), 0),
|
||||
}
|
||||
st.write("摘要:", summary)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("历史对比生成失败", extra=LOG_EXTRA)
|
||||
st.error(f"生成历史对比失败:{exc}")
|
||||
115
app/ui/views/market.py
Normal file
115
app/ui/views/market.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""行情可视化页面。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from app.utils.db import db_session
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||
|
||||
|
||||
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM daily
|
||||
ORDER BY trade_date DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA)
|
||||
return []
|
||||
return [row["ts_code"] for row in rows]
|
||||
|
||||
|
||||
def _parse_ts_code(selection: str) -> str:
|
||||
return selection.split(" ", 1)[0]
|
||||
|
||||
|
||||
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
||||
with db_session(read_only=True) as conn:
|
||||
df = pd.read_sql_query(
|
||||
"""
|
||||
SELECT trade_date, open, high, low, close, vol, amount
|
||||
FROM daily
|
||||
WHERE ts_code = ? AND trade_date BETWEEN ? AND ?
|
||||
ORDER BY trade_date
|
||||
""",
|
||||
conn,
|
||||
params=(ts_code, start.strftime("%Y%m%d"), end.strftime("%Y%m%d")),
|
||||
)
|
||||
if df.empty:
|
||||
return df
|
||||
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
||||
return df
|
||||
|
||||
|
||||
def render_market_visualization() -> None:
|
||||
st.header("行情可视化")
|
||||
st.caption("按标的查看 K 线、成交量以及常用指标。")
|
||||
|
||||
options = _load_stock_options()
|
||||
if not options:
|
||||
st.warning("暂未加载到可用的行情标的,请先执行数据同步。")
|
||||
return
|
||||
|
||||
selection = st.selectbox("选择标的", options, index=0)
|
||||
ts_code = _parse_ts_code(selection)
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
start_date = st.date_input("开始日期", value=date.today() - timedelta(days=120))
|
||||
with col2:
|
||||
end_date = st.date_input("结束日期", value=date.today())
|
||||
|
||||
if start_date > end_date:
|
||||
st.error("开始日期不能晚于结束日期。")
|
||||
return
|
||||
|
||||
try:
|
||||
df = _load_daily_frame(ts_code, start_date, end_date)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA)
|
||||
st.error(f"加载行情数据失败:{exc}")
|
||||
return
|
||||
|
||||
if df.empty:
|
||||
st.info("所选区间内无行情数据。")
|
||||
return
|
||||
|
||||
st.metric("最新收盘价", f"{df['close'].iloc[-1]:.2f}")
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Candlestick(
|
||||
x=df["trade_date"],
|
||||
open=df["open"],
|
||||
high=df["high"],
|
||||
low=df["low"],
|
||||
close=df["close"],
|
||||
name="K线",
|
||||
)
|
||||
]
|
||||
)
|
||||
fig.update_layout(title=f"{ts_code} K线图", xaxis_title="日期", yaxis_title="价格")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
fig_vol = px.bar(df, x="trade_date", y="vol", title="成交量")
|
||||
st.plotly_chart(fig_vol, use_container_width=True)
|
||||
|
||||
df_ma = df.copy()
|
||||
df_ma["MA5"] = df_ma["close"].rolling(window=5).mean()
|
||||
df_ma["MA20"] = df_ma["close"].rolling(window=20).mean()
|
||||
df_ma["MA60"] = df_ma["close"].rolling(window=60).mean()
|
||||
|
||||
fig_ma = px.line(df_ma, x="trade_date", y=["close", "MA5", "MA20", "MA60"], title="均线对比")
|
||||
st.plotly_chart(fig_ma, use_container_width=True)
|
||||
|
||||
st.dataframe(df, hide_index=True, width='stretch')
|
||||
162
app/ui/views/pool.py
Normal file
162
app/ui/views/pool.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""投资池与仓位概览页面。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from app.utils.db import db_session
|
||||
from app.utils.portfolio import (
|
||||
get_latest_snapshot,
|
||||
list_investment_pool,
|
||||
list_positions,
|
||||
list_recent_trades,
|
||||
)
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA, get_latest_trade_date
|
||||
|
||||
|
||||
def render_pool_overview() -> None:
|
||||
"""单独的投资池与仓位概览页面(从今日计划中提取)。"""
|
||||
LOGGER.info("渲染投资池与仓位概览页面", extra=LOG_EXTRA)
|
||||
st.header("投资池与仓位概览")
|
||||
|
||||
snapshot = get_latest_snapshot()
|
||||
if snapshot:
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
if snapshot.total_value is not None:
|
||||
col_a.metric("组合净值", f"{snapshot.total_value:,.2f}")
|
||||
if snapshot.cash is not None:
|
||||
col_b.metric("现金余额", f"{snapshot.cash:,.2f}")
|
||||
if snapshot.invested_value is not None:
|
||||
col_c.metric("持仓市值", f"{snapshot.invested_value:,.2f}")
|
||||
detail_cols = st.columns(4)
|
||||
if snapshot.unrealized_pnl is not None:
|
||||
detail_cols[0].metric("浮盈", f"{snapshot.unrealized_pnl:,.2f}")
|
||||
if snapshot.realized_pnl is not None:
|
||||
detail_cols[1].metric("已实现盈亏", f"{snapshot.realized_pnl:,.2f}")
|
||||
if snapshot.net_flow is not None:
|
||||
detail_cols[2].metric("净流入", f"{snapshot.net_flow:,.2f}")
|
||||
if snapshot.exposure is not None:
|
||||
detail_cols[3].metric("风险敞口", f"{snapshot.exposure:.2%}")
|
||||
if snapshot.notes:
|
||||
st.caption(f"备注:{snapshot.notes}")
|
||||
else:
|
||||
st.info("暂无组合快照,请在执行回测或实盘同步后写入 portfolio_snapshots。")
|
||||
|
||||
try:
|
||||
latest_date = get_latest_trade_date()
|
||||
candidates = list_investment_pool(trade_date=latest_date)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("加载候选池失败", extra=LOG_EXTRA)
|
||||
candidates = []
|
||||
|
||||
if candidates:
|
||||
candidate_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"交易日": item.trade_date,
|
||||
"代码": item.ts_code,
|
||||
"评分": item.score,
|
||||
"状态": item.status,
|
||||
"标签": "、".join(item.tags) if item.tags else "-",
|
||||
"理由": item.rationale or "",
|
||||
}
|
||||
for item in candidates
|
||||
]
|
||||
)
|
||||
st.write("候选投资池:")
|
||||
st.dataframe(candidate_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("候选投资池暂无数据。")
|
||||
|
||||
positions = list_positions(active_only=False)
|
||||
if positions:
|
||||
position_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"ID": pos.id,
|
||||
"代码": pos.ts_code,
|
||||
"开仓日": pos.opened_date,
|
||||
"平仓日": pos.closed_date or "-",
|
||||
"状态": pos.status,
|
||||
"数量": pos.quantity,
|
||||
"成本": pos.cost_price,
|
||||
"现价": pos.market_price,
|
||||
"市值": pos.market_value,
|
||||
"浮盈": pos.unrealized_pnl,
|
||||
"已实现": pos.realized_pnl,
|
||||
"目标权重": pos.target_weight,
|
||||
}
|
||||
for pos in positions
|
||||
]
|
||||
)
|
||||
st.write("组合持仓:")
|
||||
st.dataframe(position_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("组合持仓暂无记录。")
|
||||
|
||||
trades = list_recent_trades(limit=20)
|
||||
if trades:
|
||||
trades_df = pd.DataFrame(trades)
|
||||
st.write("近期成交:")
|
||||
st.dataframe(trades_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("近期成交暂无记录。")
|
||||
|
||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||
|
||||
if st.button("执行对比", type="secondary"):
|
||||
with st.spinner("执行日志对比分析中..."):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
||||
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
||||
logs1 = conn.execute(
|
||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||
(query_date1, query_date2),
|
||||
).fetchall()
|
||||
|
||||
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
||||
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
||||
logs2 = conn.execute(
|
||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||
(query_date3, query_date4),
|
||||
).fetchall()
|
||||
|
||||
df1 = pd.DataFrame(logs1, columns=["level", "count"])
|
||||
df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
||||
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
||||
df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
||||
|
||||
for df in (df1, df2):
|
||||
for col in df.columns:
|
||||
if col != "level":
|
||||
df[col] = df[col].astype(object)
|
||||
|
||||
compare_df = pd.concat([df1, df2])
|
||||
fig = px.bar(
|
||||
compare_df,
|
||||
x="level",
|
||||
y="count",
|
||||
color="date",
|
||||
barmode="group",
|
||||
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined]
|
||||
)
|
||||
st.plotly_chart(fig, width='stretch')
|
||||
|
||||
st.write("日志统计对比:")
|
||||
date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined]
|
||||
date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined]
|
||||
merged_df = df1.merge(
|
||||
df2,
|
||||
on="level",
|
||||
suffixes=(f"_{date1_str}", f"_{date2_str}"),
|
||||
how="outer",
|
||||
).fillna(0)
|
||||
st.dataframe(merged_df, hide_index=True, width="stretch")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("日志对比失败", extra=LOG_EXTRA)
|
||||
st.error(f"日志对比分析失败:{exc}")
|
||||
|
||||
return
|
||||
603
app/ui/views/settings.py
Normal file
603
app/ui/views/settings.py
Normal file
@ -0,0 +1,603 @@
|
||||
"""系统设置相关视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
import streamlit as st
|
||||
|
||||
from app.llm.client import llm_config_snapshot
|
||||
from app.utils.config import (
|
||||
ALLOWED_LLM_STRATEGIES,
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
DepartmentSettings,
|
||||
LLMEndpoint,
|
||||
LLMProvider,
|
||||
get_config,
|
||||
save_config,
|
||||
)
|
||||
from app.utils.db import db_session
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||
|
||||
_MODEL_CACHE: Dict[str, Dict[str, object]] = {}
|
||||
_CACHE_TTL_SECONDS = 300
|
||||
|
||||
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
|
||||
"""Attempt to query provider API and return available model ids."""
|
||||
|
||||
base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip()
|
||||
if not base_url:
|
||||
return [], "请先填写 Base URL"
|
||||
timeout = float(provider.default_timeout or 30.0)
|
||||
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
||||
|
||||
# ADD: simple cache by provider+base URL
|
||||
cache_key = f"{provider.key}|{base_url}"
|
||||
now = datetime.now()
|
||||
cached = _MODEL_CACHE.get(cache_key)
|
||||
if cached:
|
||||
ts = cached.get("ts")
|
||||
if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS:
|
||||
models = list(cached.get("models") or [])
|
||||
return models, None
|
||||
|
||||
try:
|
||||
if mode == "ollama":
|
||||
url = base_url.rstrip('/') + "/api/tags"
|
||||
response = requests.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = []
|
||||
for item in data.get("models", []) or data.get("data", []):
|
||||
name = item.get("name") or item.get("model") or item.get("tag")
|
||||
if name:
|
||||
models.append(str(name).strip())
|
||||
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||
return sorted(set(models)), None
|
||||
|
||||
api_key = (api_override or provider.api_key or "").strip()
|
||||
if not api_key:
|
||||
return [], "缺少 API Key"
|
||||
url = base_url.rstrip('/') + "/v1/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
models = [
|
||||
str(item.get("id")).strip()
|
||||
for item in payload.get("data", [])
|
||||
if item.get("id")
|
||||
]
|
||||
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||
return sorted(set(models)), None
|
||||
except RequestException as exc: # noqa: BLE001
|
||||
return [], f"HTTP 错误:{exc}"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return [], f"解析失败:{exc}"
|
||||
|
||||
def render_config_overview() -> None:
|
||||
"""Render a concise overview of persisted configuration values."""
|
||||
|
||||
LOGGER.info("渲染配置概览页", extra=LOG_EXTRA)
|
||||
cfg = get_config()
|
||||
|
||||
st.subheader("核心配置概览")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1.metric("决策方式", cfg.decision_method.upper())
|
||||
col2.metric("自动更新数据", "启用" if cfg.auto_update_data else "关闭")
|
||||
col3.metric("数据更新间隔(天)", cfg.data_update_interval)
|
||||
|
||||
col4, col5, col6 = st.columns(3)
|
||||
col4.metric("强制刷新", "开启" if cfg.force_refresh else "关闭")
|
||||
col5.metric("TuShare Token", "已配置" if cfg.tushare_token else "未配置")
|
||||
col6.metric("配置文件", cfg.data_paths.config_file.name)
|
||||
st.caption(f"配置文件路径:{cfg.data_paths.config_file}")
|
||||
|
||||
st.divider()
|
||||
st.subheader("RSS 数据源状态")
|
||||
rss_sources = cfg.rss_sources or {}
|
||||
if rss_sources:
|
||||
rows: List[Dict[str, object]] = []
|
||||
for name, payload in rss_sources.items():
|
||||
if isinstance(payload, dict):
|
||||
rows.append(
|
||||
{
|
||||
"名称": name,
|
||||
"启用": "是" if payload.get("enabled", True) else "否",
|
||||
"URL": payload.get("url", "-"),
|
||||
"关键词数": len(payload.get("keywords", []) or []),
|
||||
}
|
||||
)
|
||||
elif isinstance(payload, bool):
|
||||
rows.append(
|
||||
{
|
||||
"名称": name,
|
||||
"启用": "是" if payload else "否",
|
||||
"URL": "-",
|
||||
"关键词数": 0,
|
||||
}
|
||||
)
|
||||
if rows:
|
||||
st.dataframe(pd.DataFrame(rows), hide_index=True, width="stretch")
|
||||
else:
|
||||
st.info("未配置 RSS 数据源。")
|
||||
else:
|
||||
st.info("未在配置文件中找到 RSS 数据源。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("部门配置")
|
||||
dept_rows: List[Dict[str, object]] = []
|
||||
for code, dept in cfg.departments.items():
|
||||
dept_rows.append(
|
||||
{
|
||||
"部门": dept.title or code,
|
||||
"代码": code,
|
||||
"权重": dept.weight,
|
||||
"LLM 策略": dept.llm.strategy,
|
||||
"模板": dept.prompt_template_id or f"{code}_dept",
|
||||
"模板版本": dept.prompt_template_version or "(激活版本)",
|
||||
}
|
||||
)
|
||||
if dept_rows:
|
||||
st.dataframe(pd.DataFrame(dept_rows), hide_index=True, width="stretch")
|
||||
else:
|
||||
st.info("尚未配置任何部门。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM 成本控制")
|
||||
cost = cfg.llm_cost
|
||||
col_a, col_b, col_c, col_d = st.columns(4)
|
||||
col_a.metric("成本控制", "启用" if cost.enabled else "关闭")
|
||||
col_b.metric("小时预算($)", f"{cost.hourly_budget:.2f}")
|
||||
col_c.metric("日预算($)", f"{cost.daily_budget:.2f}")
|
||||
col_d.metric("月预算($)", f"{cost.monthly_budget:.2f}")
|
||||
|
||||
if cost.model_weights:
|
||||
weight_rows = (
|
||||
pd.DataFrame(
|
||||
[
|
||||
{"模型": model, "占比上限": f"{limit * 100:.0f}%"}
|
||||
for model, limit in cost.model_weights.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
st.dataframe(weight_rows, hide_index=True, width="stretch")
|
||||
else:
|
||||
st.caption("未配置模型占比限制。")
|
||||
|
||||
st.divider()
|
||||
st.caption("提示:数据源、LLM 及投资组合设置可在对应标签页中调整。")
|
||||
|
||||
def render_llm_settings() -> None:
|
||||
cfg = get_config()
|
||||
st.subheader("LLM 设置")
|
||||
providers = cfg.llm_providers
|
||||
provider_keys = sorted(providers.keys())
|
||||
st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。")
|
||||
|
||||
provider_select_col, provider_manage_col = st.columns([3, 1])
|
||||
if provider_keys:
|
||||
try:
|
||||
default_provider = cfg.llm.primary.provider or provider_keys[0]
|
||||
provider_index = provider_keys.index(default_provider)
|
||||
except ValueError:
|
||||
provider_index = 0
|
||||
selected_provider = provider_select_col.selectbox(
|
||||
"选择 Provider",
|
||||
provider_keys,
|
||||
index=provider_index,
|
||||
key="llm_provider_select",
|
||||
)
|
||||
else:
|
||||
selected_provider = None
|
||||
provider_select_col.info("尚未配置 Provider,请先创建。")
|
||||
|
||||
new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name")
|
||||
if provider_manage_col.button("创建 Provider", key="create_provider_btn"):
|
||||
key = (new_provider_name or "").strip().lower()
|
||||
if not key:
|
||||
st.warning("请输入有效的 Provider 名称。")
|
||||
elif key in providers:
|
||||
st.warning(f"Provider {key} 已存在。")
|
||||
else:
|
||||
providers[key] = LLMProvider(key=key)
|
||||
cfg.llm_providers = providers
|
||||
save_config()
|
||||
st.success(f"已创建 Provider {key}。")
|
||||
st.rerun()
|
||||
|
||||
if selected_provider:
|
||||
provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider))
|
||||
title_key = f"provider_title_{selected_provider}"
|
||||
base_key = f"provider_base_{selected_provider}"
|
||||
api_key_key = f"provider_api_{selected_provider}"
|
||||
mode_key = f"provider_mode_{selected_provider}"
|
||||
enabled_key = f"provider_enabled_{selected_provider}"
|
||||
|
||||
title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key)
|
||||
base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com")
|
||||
api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password")
|
||||
|
||||
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
|
||||
mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
|
||||
st.markdown("可用模型:")
|
||||
if provider_cfg.models:
|
||||
st.code("\n".join(provider_cfg.models), language="text")
|
||||
else:
|
||||
st.info("尚未获取模型列表,可点击下方按钮自动拉取。")
|
||||
try:
|
||||
cache_key = f"{selected_provider}|{(base_val or '').strip()}"
|
||||
entry = _MODEL_CACHE.get(cache_key)
|
||||
if entry and isinstance(entry.get("ts"), float):
|
||||
ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S")
|
||||
st.caption(f"最近拉取时间:{ts}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
fetch_key = f"fetch_models_{selected_provider}"
|
||||
if st.button("获取模型列表", key=fetch_key):
|
||||
with st.spinner("正在获取模型列表..."):
|
||||
models, error = _discover_provider_models(provider_cfg, base_val, api_val)
|
||||
if error:
|
||||
st.error(error)
|
||||
else:
|
||||
provider_cfg.models = models
|
||||
if models and (not provider_cfg.default_model or provider_cfg.default_model not in models):
|
||||
provider_cfg.default_model = models[0]
|
||||
providers[selected_provider] = provider_cfg
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success(f"共获取 {len(models)} 个模型。")
|
||||
st.rerun()
|
||||
|
||||
if st.button("保存 Provider", key=f"save_provider_{selected_provider}"):
|
||||
provider_cfg.title = title_val.strip()
|
||||
provider_cfg.base_url = base_val.strip()
|
||||
provider_cfg.api_key = api_val.strip() or None
|
||||
provider_cfg.enabled = enabled_val
|
||||
provider_cfg.mode = mode_val
|
||||
providers[selected_provider] = provider_cfg
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Provider 已保存。")
|
||||
st.session_state[title_key] = provider_cfg.title or ""
|
||||
|
||||
provider_in_use = (cfg.llm.primary.provider == selected_provider) or any(
|
||||
ep.provider == selected_provider for ep in cfg.llm.ensemble
|
||||
)
|
||||
if not provider_in_use:
|
||||
for dept in cfg.departments.values():
|
||||
if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble):
|
||||
provider_in_use = True
|
||||
break
|
||||
if st.button(
|
||||
"删除 Provider",
|
||||
key=f"delete_provider_{selected_provider}",
|
||||
disabled=provider_in_use or len(providers) <= 1,
|
||||
):
|
||||
providers.pop(selected_provider, None)
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("Provider 已删除。")
|
||||
st.rerun()
|
||||
|
||||
st.markdown("##### 全局推理配置")
|
||||
if not provider_keys:
|
||||
st.warning("请先配置至少一个 Provider。")
|
||||
else:
|
||||
global_cfg = cfg.llm
|
||||
primary = global_cfg.primary
|
||||
try:
|
||||
provider_index = provider_keys.index(primary.provider or provider_keys[0])
|
||||
except ValueError:
|
||||
provider_index = 0
|
||||
selected_global_provider = st.selectbox(
|
||||
"主模型 Provider",
|
||||
provider_keys,
|
||||
index=provider_index,
|
||||
key="global_provider_select",
|
||||
)
|
||||
provider_cfg = providers.get(selected_global_provider)
|
||||
available_models = provider_cfg.models if provider_cfg else []
|
||||
default_model = primary.model or (provider_cfg.default_model if provider_cfg else None)
|
||||
if available_models:
|
||||
options = available_models + ["自定义"]
|
||||
try:
|
||||
model_index = available_models.index(default_model)
|
||||
model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice")
|
||||
except ValueError:
|
||||
model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice")
|
||||
if model_choice == "自定义":
|
||||
model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip()
|
||||
else:
|
||||
model_val = model_choice
|
||||
else:
|
||||
model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip()
|
||||
|
||||
temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2)
|
||||
temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp")
|
||||
timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0)
|
||||
timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout")
|
||||
prompt_template_val = st.text_area(
|
||||
"主模型 Prompt 模板(可选)",
|
||||
value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "",
|
||||
height=120,
|
||||
key="global_prompt_template",
|
||||
)
|
||||
|
||||
strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy")
|
||||
show_ensemble = strategy_val != "single"
|
||||
majority_threshold_val = st.number_input(
|
||||
"多数投票门槛",
|
||||
min_value=1,
|
||||
max_value=10,
|
||||
value=int(global_cfg.majority_threshold),
|
||||
step=1,
|
||||
key="global_majority",
|
||||
disabled=not show_ensemble,
|
||||
)
|
||||
if not show_ensemble:
|
||||
majority_threshold_val = 1
|
||||
|
||||
ensemble_rows: List[Dict[str, str]] = []
|
||||
if show_ensemble:
|
||||
ensemble_rows = [
|
||||
{
|
||||
"provider": ep.provider,
|
||||
"model": ep.model or "",
|
||||
"temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}",
|
||||
"timeout": "" if ep.timeout is None else str(int(ep.timeout)),
|
||||
"prompt_template": ep.prompt_template or "",
|
||||
}
|
||||
for ep in global_cfg.ensemble
|
||||
] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}]
|
||||
|
||||
ensemble_editor = st.data_editor(
|
||||
ensemble_rows,
|
||||
num_rows="dynamic",
|
||||
key="global_ensemble_editor",
|
||||
width='stretch',
|
||||
hide_index=True,
|
||||
column_config={
|
||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
||||
"model": st.column_config.TextColumn("模型"),
|
||||
"temperature": st.column_config.TextColumn("温度"),
|
||||
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||
},
|
||||
)
|
||||
if hasattr(ensemble_editor, "to_dict"):
|
||||
ensemble_rows = ensemble_editor.to_dict("records")
|
||||
else:
|
||||
ensemble_rows = ensemble_editor
|
||||
else:
|
||||
st.info("当前策略为单模型,未启用协作模型。")
|
||||
|
||||
if st.button("保存全局配置", key="save_global_llm"):
|
||||
primary.provider = selected_global_provider
|
||||
primary.model = model_val or None
|
||||
primary.temperature = float(temp_val)
|
||||
primary.timeout = float(timeout_val)
|
||||
primary.prompt_template = prompt_template_val.strip() or None
|
||||
primary.base_url = None
|
||||
primary.api_key = None
|
||||
|
||||
new_ensemble: List[LLMEndpoint] = []
|
||||
if show_ensemble:
|
||||
for row in ensemble_rows:
|
||||
provider_val = (row.get("provider") or "").strip().lower()
|
||||
if not provider_val:
|
||||
continue
|
||||
model_raw = (row.get("model") or "").strip() or None
|
||||
temp_raw = (row.get("temperature") or "").strip()
|
||||
timeout_raw = (row.get("timeout") or "").strip()
|
||||
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||
new_ensemble.append(
|
||||
LLMEndpoint(
|
||||
provider=provider_val,
|
||||
model=model_raw,
|
||||
temperature=float(temp_raw) if temp_raw else None,
|
||||
timeout=float(timeout_raw) if timeout_raw else None,
|
||||
prompt_template=prompt_raw or None,
|
||||
)
|
||||
)
|
||||
cfg.llm.ensemble = new_ensemble
|
||||
cfg.llm.strategy = strategy_val
|
||||
cfg.llm.majority_threshold = int(majority_threshold_val)
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("全局 LLM 配置已保存。")
|
||||
st.json(llm_config_snapshot())
|
||||
|
||||
st.markdown("##### 部门配置")
|
||||
dept_settings = cfg.departments or {}
|
||||
dept_rows = [
|
||||
{
|
||||
"code": code,
|
||||
"title": dept.title,
|
||||
"description": dept.description,
|
||||
"weight": float(dept.weight),
|
||||
"strategy": dept.llm.strategy,
|
||||
"majority_threshold": dept.llm.majority_threshold,
|
||||
"provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""),
|
||||
"model": dept.llm.primary.model or "",
|
||||
"temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}",
|
||||
"timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)),
|
||||
"prompt_template": dept.llm.primary.prompt_template or "",
|
||||
}
|
||||
for code, dept in sorted(dept_settings.items())
|
||||
]
|
||||
|
||||
if not dept_rows:
|
||||
st.info("当前未配置部门,可在 config.json 中添加。")
|
||||
dept_rows = []
|
||||
|
||||
dept_editor = st.data_editor(
|
||||
dept_rows,
|
||||
num_rows="fixed",
|
||||
key="department_editor",
|
||||
width='stretch',
|
||||
hide_index=True,
|
||||
column_config={
|
||||
"code": st.column_config.TextColumn("编码", disabled=True),
|
||||
"title": st.column_config.TextColumn("名称"),
|
||||
"description": st.column_config.TextColumn("说明"),
|
||||
"weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1),
|
||||
"strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)),
|
||||
"majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1),
|
||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]),
|
||||
"model": st.column_config.TextColumn("模型"),
|
||||
"temperature": st.column_config.TextColumn("温度"),
|
||||
"timeout": st.column_config.TextColumn("超时(秒)"),
|
||||
"prompt_template": st.column_config.TextColumn("Prompt 模板"),
|
||||
},
|
||||
)
|
||||
|
||||
if hasattr(dept_editor, "to_dict"):
|
||||
dept_rows = dept_editor.to_dict("records")
|
||||
else:
|
||||
dept_rows = dept_editor
|
||||
|
||||
col_reset, col_save = st.columns([1, 1])
|
||||
|
||||
if col_save.button("保存部门配置"):
|
||||
updated_departments: Dict[str, DepartmentSettings] = {}
|
||||
for row in dept_rows:
|
||||
code = row.get("code")
|
||||
if not code:
|
||||
continue
|
||||
existing = dept_settings.get(code) or DepartmentSettings(code=code, title=code)
|
||||
existing.title = row.get("title") or existing.title
|
||||
existing.description = row.get("description") or ""
|
||||
try:
|
||||
existing.weight = max(0.0, float(row.get("weight", existing.weight)))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
strategy_val = (row.get("strategy") or existing.llm.strategy).lower()
|
||||
if strategy_val in ALLOWED_LLM_STRATEGIES:
|
||||
existing.llm.strategy = strategy_val
|
||||
if existing.llm.strategy == "single":
|
||||
existing.llm.majority_threshold = 1
|
||||
existing.llm.ensemble = []
|
||||
else:
|
||||
majority_raw = row.get("majority_threshold")
|
||||
try:
|
||||
majority_val = int(majority_raw)
|
||||
if majority_val > 0:
|
||||
existing.llm.majority_threshold = majority_val
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower()
|
||||
model_val = (row.get("model") or "").strip() or None
|
||||
temp_raw = (row.get("temperature") or "").strip()
|
||||
timeout_raw = (row.get("timeout") or "").strip()
|
||||
prompt_raw = (row.get("prompt_template") or "").strip()
|
||||
|
||||
endpoint = existing.llm.primary or LLMEndpoint()
|
||||
endpoint.provider = provider_val
|
||||
endpoint.model = model_val
|
||||
endpoint.temperature = float(temp_raw) if temp_raw else None
|
||||
endpoint.timeout = float(timeout_raw) if timeout_raw else None
|
||||
endpoint.prompt_template = prompt_raw or None
|
||||
endpoint.base_url = None
|
||||
endpoint.api_key = None
|
||||
existing.llm.primary = endpoint
|
||||
if existing.llm.strategy != "single":
|
||||
existing.llm.ensemble = []
|
||||
|
||||
updated_departments[code] = existing
|
||||
|
||||
if updated_departments:
|
||||
cfg.departments = updated_departments
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("部门配置已更新。")
|
||||
else:
|
||||
st.warning("未能解析部门配置输入。")
|
||||
|
||||
if col_reset.button("恢复默认部门"):
|
||||
from app.utils.config import _default_departments
|
||||
|
||||
cfg.departments = _default_departments()
|
||||
cfg.sync_runtime_llm()
|
||||
save_config()
|
||||
st.success("已恢复默认部门配置。")
|
||||
st.rerun()
|
||||
|
||||
st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。")
|
||||
|
||||
def render_data_settings() -> None:
|
||||
"""渲染数据源配置界面."""
|
||||
st.subheader("Tushare 数据源")
|
||||
cfg = get_config()
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
tushare_token = st.text_input(
|
||||
"Tushare Token",
|
||||
value=cfg.tushare_token or "",
|
||||
type="password",
|
||||
help="从 tushare.pro 获取的 API token"
|
||||
)
|
||||
|
||||
with col2:
|
||||
auto_update = st.checkbox(
|
||||
"启动时自动更新数据",
|
||||
value=cfg.auto_update_data,
|
||||
help="启动应用时自动检查并更新数据"
|
||||
)
|
||||
|
||||
update_interval = st.slider(
|
||||
"数据更新间隔(天)",
|
||||
min_value=1,
|
||||
max_value=30,
|
||||
value=cfg.data_update_interval,
|
||||
help="自动更新时检查的数据时间范围"
|
||||
)
|
||||
|
||||
if st.button("保存数据源配置"):
|
||||
cfg.tushare_token = tushare_token
|
||||
cfg.auto_update_data = auto_update
|
||||
cfg.data_update_interval = update_interval
|
||||
save_config(cfg)
|
||||
st.success("数据源配置已更新!")
|
||||
|
||||
st.divider()
|
||||
st.subheader("数据更新记录")
|
||||
|
||||
with db_session() as session:
|
||||
df = pd.read_sql_query(
|
||||
"""
|
||||
SELECT job_type, status, created_at, updated_at, error_msg
|
||||
FROM fetch_jobs
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 50
|
||||
""",
|
||||
session
|
||||
)
|
||||
|
||||
if not df.empty:
|
||||
df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2)
|
||||
df = df.drop(columns=["updated_at"])
|
||||
df = df.rename(columns={
|
||||
"job_type": "数据类型",
|
||||
"status": "状态",
|
||||
"created_at": "开始时间",
|
||||
"error_msg": "错误信息",
|
||||
"duration": "耗时(秒)"
|
||||
})
|
||||
st.dataframe(df, width='stretch')
|
||||
else:
|
||||
st.info("暂无数据更新记录")
|
||||
194
app/ui/views/tests.py
Normal file
194
app/ui/views/tests.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""自检测试视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.llm.client import llm_config_snapshot, run_llm
|
||||
from app.utils import alerts
|
||||
from app.utils.config import get_config, save_config
|
||||
|
||||
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||
from app.ui.views.dashboard import update_dashboard_sidebar
|
||||
|
||||
def render_tests() -> None:
|
||||
LOGGER.info("渲染自检页面", extra=LOG_EXTRA)
|
||||
st.header("自检测试")
|
||||
st.write("用于快速检查数据库与数据拉取是否正常工作。")
|
||||
|
||||
if st.button("测试数据库初始化"):
|
||||
LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA)
|
||||
with st.spinner("正在检查数据库..."):
|
||||
result = initialize_database()
|
||||
if result.skipped:
|
||||
LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA)
|
||||
st.success("数据库已存在,检查通过。")
|
||||
else:
|
||||
LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA)
|
||||
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
|
||||
|
||||
st.divider()
|
||||
|
||||
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"):
|
||||
LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA)
|
||||
with st.spinner("正在调用 TuShare 接口..."):
|
||||
try:
|
||||
run_ingestion(
|
||||
FetchJob(
|
||||
name="streamlit_self_test",
|
||||
start=date(2024, 1, 1),
|
||||
end=date(2024, 1, 3),
|
||||
ts_codes=("000001.SZ",),
|
||||
),
|
||||
include_limits=False,
|
||||
)
|
||||
LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA)
|
||||
st.success("TuShare 示例拉取完成,数据已写入数据库。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA)
|
||||
st.error(f"拉取失败:{exc}")
|
||||
alerts.add_warning("TuShare", "示例拉取失败", str(exc))
|
||||
update_dashboard_sidebar()
|
||||
|
||||
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
||||
|
||||
st.divider()
|
||||
|
||||
st.subheader("RSS 数据测试")
|
||||
st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。")
|
||||
rss_url = st.text_input(
|
||||
"测试 RSS 地址",
|
||||
value="https://rsshub.app/cls/depth/1000",
|
||||
help="留空则使用默认配置的全部 RSS 来源。",
|
||||
).strip()
|
||||
rss_hours = int(
|
||||
st.number_input(
|
||||
"回溯窗口(小时)",
|
||||
min_value=1,
|
||||
max_value=168,
|
||||
value=24,
|
||||
step=6,
|
||||
help="仅抓取最近指定小时内的新闻。",
|
||||
)
|
||||
)
|
||||
rss_limit = int(
|
||||
st.number_input(
|
||||
"单源抓取条数",
|
||||
min_value=1,
|
||||
max_value=200,
|
||||
value=50,
|
||||
step=10,
|
||||
)
|
||||
)
|
||||
if st.button("运行 RSS 测试"):
|
||||
from app.ingest import rss as rss_ingest
|
||||
|
||||
LOGGER.info(
|
||||
"点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s",
|
||||
rss_url,
|
||||
rss_hours,
|
||||
rss_limit,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在抓取 RSS 新闻..."):
|
||||
try:
|
||||
if rss_url:
|
||||
items = rss_ingest.fetch_rss_feed(
|
||||
rss_url,
|
||||
hours_back=rss_hours,
|
||||
max_items=rss_limit,
|
||||
)
|
||||
count = rss_ingest.save_news_items(items)
|
||||
else:
|
||||
count = rss_ingest.ingest_configured_rss(
|
||||
hours_back=rss_hours,
|
||||
max_items_per_feed=rss_limit,
|
||||
)
|
||||
st.success(f"RSS 测试完成,新增 {count} 条新闻记录。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA)
|
||||
st.error(f"RSS 测试失败:{exc}")
|
||||
alerts.add_warning("RSS", "RSS 测试执行失败", str(exc))
|
||||
update_dashboard_sidebar()
|
||||
|
||||
st.divider()
|
||||
days = int(
|
||||
st.number_input(
|
||||
"检查窗口(天数)",
|
||||
min_value=30,
|
||||
max_value=10950,
|
||||
value=365,
|
||||
step=30,
|
||||
)
|
||||
)
|
||||
LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA)
|
||||
cfg = get_config()
|
||||
force_refresh = st.checkbox(
|
||||
"强制刷新数据(关闭增量跳过)",
|
||||
value=cfg.force_refresh,
|
||||
help="勾选后将重新拉取所选区间全部数据",
|
||||
)
|
||||
if force_refresh != cfg.force_refresh:
|
||||
cfg.force_refresh = force_refresh
|
||||
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
|
||||
save_config()
|
||||
|
||||
if st.button("执行手动数据同步"):
|
||||
LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA)
|
||||
progress_bar = st.progress(0.0)
|
||||
status_placeholder = st.empty()
|
||||
log_placeholder = st.empty()
|
||||
messages: list[str] = []
|
||||
|
||||
def hook(message: str, value: float) -> None:
|
||||
progress_bar.progress(min(max(value, 0.0), 1.0))
|
||||
status_placeholder.write(message)
|
||||
messages.append(message)
|
||||
LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
|
||||
|
||||
with st.spinner("正在执行手动数据同步..."):
|
||||
try:
|
||||
report = run_boot_check(
|
||||
days=days,
|
||||
progress_hook=hook,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
LOGGER.info("手动数据同步成功", extra=LOG_EXTRA)
|
||||
st.success("手动数据同步完成,以下为数据覆盖摘要。")
|
||||
st.json(report.to_dict())
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA)
|
||||
st.error(f"手动数据同步失败:{exc}")
|
||||
alerts.add_warning("数据同步", "手动数据同步失败", str(exc))
|
||||
update_dashboard_sidebar()
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
finally:
|
||||
progress_bar.progress(1.0)
|
||||
|
||||
st.divider()
|
||||
st.subheader("LLM 接口测试")
|
||||
st.json(llm_config_snapshot())
|
||||
llm_prompt = st.text_area("测试 Prompt", value="请概述今天的市场重点。", height=160)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt (可选)",
|
||||
value="你是一名量化策略研究助手,用简洁中文回答。",
|
||||
height=100,
|
||||
)
|
||||
if st.button("执行 LLM 测试"):
|
||||
with st.spinner("正在调用 LLM..."):
|
||||
try:
|
||||
response = run_llm(llm_prompt, system=system_prompt or None)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("LLM 测试失败", extra=LOG_EXTRA)
|
||||
st.error(f"LLM 调用失败:{exc}")
|
||||
else:
|
||||
LOGGER.info("LLM 测试成功", extra=LOG_EXTRA)
|
||||
st.success("LLM 调用成功,以下为返回内容:")
|
||||
st.write(response)
|
||||
677
app/ui/views/today.py
Normal file
677
app/ui/views/today.py
Normal file
@ -0,0 +1,677 @@
|
||||
"""今日计划页面视图。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import Counter
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
|
||||
from app.utils.portfolio import list_investment_pool
|
||||
from app.utils.db import db_session
|
||||
|
||||
from app.ui.shared import (
|
||||
LOGGER,
|
||||
LOG_EXTRA,
|
||||
get_latest_trade_date,
|
||||
get_query_params,
|
||||
set_query_params,
|
||||
)
|
||||
|
||||
|
||||
def render_today_plan() -> None:
|
||||
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||
st.header("今日计划")
|
||||
latest_trade_date = get_latest_trade_date()
|
||||
if latest_trade_date:
|
||||
st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)")
|
||||
else:
|
||||
st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。")
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
date_rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT trade_date
|
||||
FROM agent_utils
|
||||
ORDER BY trade_date DESC
|
||||
LIMIT 30
|
||||
"""
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("加载 agent_utils 失败", extra=LOG_EXTRA)
|
||||
st.warning("暂未写入部门/代理决策,请先运行回测或策略评估流程。")
|
||||
return
|
||||
|
||||
trade_dates = [row["trade_date"] for row in date_rows]
|
||||
if not trade_dates:
|
||||
st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。")
|
||||
return
|
||||
|
||||
query = get_query_params()
|
||||
default_trade_date = query.get("date", [trade_dates[0]])[0]
|
||||
try:
|
||||
default_idx = trade_dates.index(default_trade_date)
|
||||
except ValueError:
|
||||
default_idx = 0
|
||||
trade_date = st.selectbox("交易日", trade_dates, index=default_idx)
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
code_rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM agent_utils
|
||||
WHERE trade_date = ?
|
||||
ORDER BY ts_code
|
||||
""",
|
||||
(trade_date,),
|
||||
).fetchall()
|
||||
symbols = [row["ts_code"] for row in code_rows]
|
||||
|
||||
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
|
||||
with assistant_tab:
|
||||
_render_today_plan_assistant_view(trade_date)
|
||||
|
||||
with detail_tab:
|
||||
if not symbols:
|
||||
st.info("所选交易日暂无 agent_utils 记录。")
|
||||
else:
|
||||
_render_today_plan_symbol_view(trade_date, symbols, query)
|
||||
|
||||
|
||||
def _render_today_plan_assistant_view(trade_date: str | int | date) -> None:
|
||||
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
|
||||
try:
|
||||
candidates = list_investment_pool(trade_date=trade_date)
|
||||
if candidates:
|
||||
scores = [float(item.score or 0.0) for item in candidates]
|
||||
statuses = [item.status or "UNKNOWN" for item in candidates]
|
||||
tags: List[str] = []
|
||||
rationales: List[str] = []
|
||||
for item in candidates:
|
||||
if getattr(item, "tags", None):
|
||||
tags.extend(item.tags)
|
||||
if getattr(item, "rationale", None):
|
||||
rationales.append(str(item.rationale))
|
||||
cnt = Counter(statuses)
|
||||
tag_cnt = Counter(tags)
|
||||
st.subheader("候选池聚合概览(已匿名化)")
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
col_a.metric("候选数", f"{len(candidates)}")
|
||||
col_b.metric("平均评分", f"{np.mean(scores):.3f}" if scores else "-")
|
||||
col_c.metric("中位评分", f"{np.median(scores):.3f}" if scores else "-")
|
||||
|
||||
st.write("状态分布:")
|
||||
st.json(dict(cnt))
|
||||
|
||||
if tag_cnt:
|
||||
st.write("常见标签(示例):")
|
||||
st.json(dict(tag_cnt.most_common(10)))
|
||||
|
||||
if rationales:
|
||||
st.write("汇总理由(节选,不含代码):")
|
||||
seen = set()
|
||||
excerpts = []
|
||||
for rationale in rationales:
|
||||
text = rationale.strip()
|
||||
if text and text not in seen:
|
||||
seen.add(text)
|
||||
excerpts.append(text)
|
||||
if len(excerpts) >= 3:
|
||||
break
|
||||
for idx, excerpt in enumerate(excerpts, start=1):
|
||||
st.markdown(f"**理由 {idx}:** {excerpt}")
|
||||
|
||||
avg_score = float(np.mean(scores)) if scores else 0.0
|
||||
suggest_pct = max(0.0, min(0.3, 0.10 + (avg_score - 0.5) * 0.2))
|
||||
st.subheader("组合级建议(不指定标的)")
|
||||
st.write(
|
||||
f"基于候选池平均评分 {avg_score:.3f},建议今日用于新增买入的现金比例约为 {suggest_pct:.0%}。"
|
||||
)
|
||||
st.write(
|
||||
"建议分配思路:在候选池中挑选若干得分较高的标的按目标权重等比例分配,或以分批买入的方式分摊入场时点。"
|
||||
)
|
||||
if st.button("生成组合级操作建议(仅输出,不执行)"):
|
||||
st.success("已生成组合级建议(仅供参考)。")
|
||||
st.write({
|
||||
"候选数": len(candidates),
|
||||
"平均评分": avg_score,
|
||||
"建议新增买入比例": f"{suggest_pct:.0%}",
|
||||
})
|
||||
else:
|
||||
st.info("所选交易日暂无候选投资池数据。")
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("加载候选池聚合信息失败", extra=LOG_EXTRA)
|
||||
st.error("加载候选池数据时发生错误。")
|
||||
|
||||
|
||||
def _render_today_plan_symbol_view(
|
||||
trade_date: str | int | date,
|
||||
symbols: List[str],
|
||||
query_params: Dict[str, List[str]],
|
||||
) -> None:
|
||||
default_ts = query_params.get("code", [symbols[0]])[0]
|
||||
try:
|
||||
default_ts_idx = symbols.index(default_ts)
|
||||
except ValueError:
|
||||
default_ts_idx = 0
|
||||
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
|
||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
||||
|
||||
if st.button("一键重评估所有标的", type="primary", width='stretch'):
|
||||
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
||||
try:
|
||||
trade_date_obj: Optional[date] = None
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
|
||||
progress = st.progress(0.0)
|
||||
changes_all: List[Dict[str, object]] = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, code in enumerate(symbols, start=1):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui_all",
|
||||
name="UI All Re-eval",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action = before_map.get(agent)
|
||||
if new_action != old_action:
|
||||
changes_all.append(
|
||||
{"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}
|
||||
)
|
||||
success_count += 1
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA)
|
||||
error_count += 1
|
||||
|
||||
progress.progress(idx / len(symbols))
|
||||
|
||||
if error_count > 0:
|
||||
st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个")
|
||||
else:
|
||||
st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功")
|
||||
|
||||
if changes_all:
|
||||
st.write("检测到以下动作变更:")
|
||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("一键重评估失败", extra=LOG_EXTRA)
|
||||
st.error(f"一键重评估执行过程中发生错误:{exc}")
|
||||
|
||||
set_query_params(date=str(trade_date), code=str(ts_code))
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils, feasible, weight
|
||||
FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
ORDER BY CASE WHEN agent = 'global' THEN 1 ELSE 0 END, agent
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
st.info("未查询到详细决策记录,稍后再试。")
|
||||
return
|
||||
|
||||
try:
|
||||
feasible_actions = json.loads(rows[0]["feasible"] or "[]")
|
||||
except (KeyError, TypeError, json.JSONDecodeError):
|
||||
feasible_actions = []
|
||||
|
||||
global_info = None
|
||||
dept_records: List[Dict[str, object]] = []
|
||||
dept_details: Dict[str, Dict[str, object]] = {}
|
||||
agent_records: List[Dict[str, object]] = []
|
||||
|
||||
for item in rows:
|
||||
agent_name = item["agent"]
|
||||
action = item["action"]
|
||||
weight = float(item["weight"] or 0.0)
|
||||
try:
|
||||
utils = json.loads(item["utils"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
utils = {}
|
||||
|
||||
if agent_name == "global":
|
||||
global_info = {
|
||||
"action": action,
|
||||
"confidence": float(utils.get("_confidence", 0.0)),
|
||||
"target_weight": float(utils.get("_target_weight", 0.0)),
|
||||
"department_votes": utils.get("_department_votes", {}),
|
||||
"requires_review": bool(utils.get("_requires_review", False)),
|
||||
"scope_values": utils.get("_scope_values", {}),
|
||||
"close_series": utils.get("_close_series", []),
|
||||
"turnover_series": utils.get("_turnover_series", []),
|
||||
"department_supplements": utils.get("_department_supplements", {}),
|
||||
"department_dialogue": utils.get("_department_dialogue", {}),
|
||||
"department_telemetry": utils.get("_department_telemetry", {}),
|
||||
}
|
||||
continue
|
||||
|
||||
if agent_name.startswith("dept_"):
|
||||
code = agent_name.split("dept_", 1)[-1]
|
||||
signals = utils.get("_signals", [])
|
||||
risks = utils.get("_risks", [])
|
||||
supplements = utils.get("_supplements", [])
|
||||
dialogue = utils.get("_dialogue", [])
|
||||
telemetry = utils.get("_telemetry", {})
|
||||
dept_records.append(
|
||||
{
|
||||
"部门": code,
|
||||
"行动": action,
|
||||
"信心": float(utils.get("_confidence", 0.0)),
|
||||
"权重": weight,
|
||||
"摘要": utils.get("_summary", ""),
|
||||
"核心信号": ";".join(signals) if isinstance(signals, list) else signals,
|
||||
"风险提示": ";".join(risks) if isinstance(risks, list) else risks,
|
||||
"补充次数": len(supplements) if isinstance(supplements, list) else 0,
|
||||
}
|
||||
)
|
||||
dept_details[code] = {
|
||||
"supplements": supplements if isinstance(supplements, list) else [],
|
||||
"dialogue": dialogue if isinstance(dialogue, list) else [],
|
||||
"summary": utils.get("_summary", ""),
|
||||
"signals": signals,
|
||||
"risks": risks,
|
||||
"telemetry": telemetry if isinstance(telemetry, dict) else {},
|
||||
}
|
||||
else:
|
||||
score_map = {
|
||||
key: float(val)
|
||||
for key, val in utils.items()
|
||||
if not str(key).startswith("_")
|
||||
}
|
||||
agent_records.append(
|
||||
{
|
||||
"代理": agent_name,
|
||||
"建议动作": action,
|
||||
"权重": weight,
|
||||
"SELL": score_map.get("SELL", 0.0),
|
||||
"HOLD": score_map.get("HOLD", 0.0),
|
||||
"BUY_S": score_map.get("BUY_S", 0.0),
|
||||
"BUY_M": score_map.get("BUY_M", 0.0),
|
||||
"BUY_L": score_map.get("BUY_L", 0.0),
|
||||
}
|
||||
)
|
||||
|
||||
if feasible_actions:
|
||||
st.caption(f"可行操作集合:{', '.join(feasible_actions)}")
|
||||
|
||||
st.subheader("全局策略")
|
||||
if global_info:
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1.metric("最终行动", global_info["action"])
|
||||
col2.metric("信心", f"{global_info['confidence']:.2f}")
|
||||
col3.metric("目标权重", f"{global_info['target_weight']:+.2%}")
|
||||
if global_info["department_votes"]:
|
||||
st.json(global_info["department_votes"])
|
||||
if global_info["requires_review"]:
|
||||
st.warning("部门分歧较大,已标记为需人工复核。")
|
||||
with st.expander("基础上下文数据", expanded=False):
|
||||
scope = global_info.get("scope_values") or {}
|
||||
close_series = global_info.get("close_series") or []
|
||||
turnover_series = global_info.get("turnover_series") or []
|
||||
st.write("最新字段:")
|
||||
if scope:
|
||||
st.json(scope)
|
||||
st.download_button(
|
||||
"下载字段(JSON)",
|
||||
data=json.dumps(scope, ensure_ascii=False, indent=2),
|
||||
file_name=f"{ts_code}_{trade_date}_scope.json",
|
||||
mime="application/json",
|
||||
key="dl_scope_json",
|
||||
)
|
||||
if close_series:
|
||||
st.write("收盘价时间序列 (最近窗口):")
|
||||
st.json(close_series)
|
||||
try:
|
||||
import csv
|
||||
import io
|
||||
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf)
|
||||
writer.writerow(["trade_date", "close"])
|
||||
for dt_val, val in close_series:
|
||||
writer.writerow([dt_val, val])
|
||||
st.download_button(
|
||||
"下载收盘价(CSV)",
|
||||
data=buf.getvalue(),
|
||||
file_name=f"{ts_code}_{trade_date}_close_series.csv",
|
||||
mime="text/csv",
|
||||
key="dl_close_csv",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if turnover_series:
|
||||
st.write("换手率时间序列 (最近窗口):")
|
||||
st.json(turnover_series)
|
||||
dept_sup = global_info.get("department_supplements") or {}
|
||||
dept_dialogue = global_info.get("department_dialogue") or {}
|
||||
dept_telemetry = global_info.get("department_telemetry") or {}
|
||||
if dept_sup or dept_dialogue:
|
||||
with st.expander("部门补数与对话记录", expanded=False):
|
||||
if dept_sup:
|
||||
st.write("补充数据:")
|
||||
st.json(dept_sup)
|
||||
if dept_dialogue:
|
||||
st.write("对话片段:")
|
||||
st.json(dept_dialogue)
|
||||
if dept_telemetry:
|
||||
with st.expander("部门 LLM 元数据", expanded=False):
|
||||
st.json(dept_telemetry)
|
||||
else:
|
||||
st.info("暂未写入全局策略摘要。")
|
||||
|
||||
st.subheader("部门意见")
|
||||
if dept_records:
|
||||
keyword = st.text_input("筛选摘要/信号关键词", value="")
|
||||
filtered = dept_records
|
||||
if keyword.strip():
|
||||
kw = keyword.strip()
|
||||
filtered = [
|
||||
item
|
||||
for item in dept_records
|
||||
if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", ""))
|
||||
]
|
||||
min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05)
|
||||
sort_col = st.selectbox("排序列", ["信心", "权重"], index=0)
|
||||
filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf]
|
||||
filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True)
|
||||
dept_df = pd.DataFrame(filtered)
|
||||
st.dataframe(dept_df, width='stretch', hide_index=True)
|
||||
try:
|
||||
st.download_button(
|
||||
"下载部门意见(CSV)",
|
||||
data=dept_df.to_csv(index=False),
|
||||
file_name=f"{trade_date}_{ts_code}_departments.csv",
|
||||
mime="text/csv",
|
||||
key="dl_dept_csv",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
for code, details in dept_details.items():
|
||||
with st.expander(f"{code} 补充详情", expanded=False):
|
||||
supplements = details.get("supplements", [])
|
||||
dialogue = details.get("dialogue", [])
|
||||
if supplements:
|
||||
st.write("补充数据:")
|
||||
st.json(supplements)
|
||||
else:
|
||||
st.caption("无补充数据请求。")
|
||||
if dialogue:
|
||||
st.write("对话记录:")
|
||||
for idx, line in enumerate(dialogue, start=1):
|
||||
st.markdown(f"**回合 {idx}:** {line}")
|
||||
else:
|
||||
st.caption("无额外对话。")
|
||||
telemetry = details.get("telemetry") or {}
|
||||
if telemetry:
|
||||
st.write("LLM 元数据:")
|
||||
st.json(telemetry)
|
||||
else:
|
||||
st.info("暂无部门记录。")
|
||||
|
||||
st.subheader("代理评分")
|
||||
if agent_records:
|
||||
sort_agent_by = st.selectbox(
|
||||
"代理排序",
|
||||
["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"],
|
||||
index=1,
|
||||
)
|
||||
agent_df = pd.DataFrame(agent_records)
|
||||
if sort_agent_by in agent_df.columns:
|
||||
agent_df = agent_df.sort_values(sort_agent_by, ascending=False)
|
||||
st.dataframe(agent_df, width='stretch', hide_index=True)
|
||||
try:
|
||||
st.download_button(
|
||||
"下载代理评分(CSV)",
|
||||
data=agent_df.to_csv(index=False),
|
||||
file_name=f"{trade_date}_{ts_code}_agents.csv",
|
||||
mime="text/csv",
|
||||
key="dl_agent_csv",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
else:
|
||||
st.info("暂无基础代理评分。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("相关新闻")
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
trade_date_obj = date.today() - timedelta(days=7)
|
||||
|
||||
news_query = """
|
||||
SELECT id, title, source, pub_time, sentiment, heat, entities
|
||||
FROM news
|
||||
WHERE ts_code = ? AND pub_time >= ?
|
||||
ORDER BY pub_time DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
seven_days_ago = (trade_date_obj - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
news_rows = conn.execute(news_query, (ts_code, seven_days_ago)).fetchall()
|
||||
|
||||
if news_rows:
|
||||
news_data = []
|
||||
for row in news_rows:
|
||||
entities_info = {}
|
||||
try:
|
||||
if row["entities"]:
|
||||
entities_info = json.loads(row["entities"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
news_item = {
|
||||
"标题": row["title"],
|
||||
"来源": row["source"],
|
||||
"发布时间": row["pub_time"],
|
||||
"情感指数": f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-",
|
||||
"热度评分": f"{row['heat']:.2f}" if row["heat"] is not None else "-",
|
||||
}
|
||||
|
||||
industries = entities_info.get("industries", [])
|
||||
if industries:
|
||||
news_item["相关行业"] = "、".join(industries[:3])
|
||||
|
||||
news_data.append(news_item)
|
||||
|
||||
news_df = pd.DataFrame(news_data)
|
||||
for col in news_df.columns:
|
||||
news_df[col] = news_df[col].astype(str)
|
||||
st.dataframe(news_df, width='stretch', hide_index=True)
|
||||
|
||||
st.write("详细新闻内容:")
|
||||
for idx, row in enumerate(news_rows):
|
||||
with st.expander(f"{idx+1}. {row['title']}", expanded=False):
|
||||
st.write(f"**来源:** {row['source']}")
|
||||
st.write(f"**发布时间:** {row['pub_time']}")
|
||||
|
||||
entities_info = {}
|
||||
try:
|
||||
if row["entities"]:
|
||||
entities_info = json.loads(row["entities"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
sentiment_display = f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-"
|
||||
heat_display = f"{row['heat']:.2f}" if row["heat"] is not None else "-"
|
||||
st.write(f"**情感指数:** {sentiment_display} | **热度评分:** {heat_display}")
|
||||
|
||||
industries = entities_info.get("industries", [])
|
||||
if industries:
|
||||
st.write(f"**相关行业:** {'、'.join(industries)}")
|
||||
|
||||
important_keywords = entities_info.get("important_keywords", [])
|
||||
if important_keywords:
|
||||
st.write(f"**重要关键词:** {'、'.join(important_keywords)}")
|
||||
|
||||
url = entities_info.get("source_url", "")
|
||||
if url:
|
||||
st.markdown(f"[查看原文]({url})", unsafe_allow_html=True)
|
||||
else:
|
||||
st.info(f"近7天内暂无关于 {ts_code} 的新闻。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("获取新闻数据失败", extra=LOG_EXTRA)
|
||||
st.error(f"获取新闻数据时发生错误:{exc}")
|
||||
|
||||
st.divider()
|
||||
st.info("投资池与仓位概览已移至单独页面。请在侧边或页面导航中选择“投资池/仓位”以查看详细信息。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("策略重评估")
|
||||
st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。")
|
||||
cols_re = st.columns([1, 1])
|
||||
if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"):
|
||||
with st.spinner("正在重评估..."):
|
||||
try:
|
||||
trade_date_obj: Optional[date] = None
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui",
|
||||
name="UI Re-evaluation",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[ts_code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
changes = []
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action, _old_utils = before_map.get(agent, (None, None))
|
||||
if new_action != old_action:
|
||||
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
if changes:
|
||||
st.success("重评估完成,检测到动作变更:")
|
||||
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
|
||||
else:
|
||||
st.success("重评估完成,无动作变更。")
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("重评估失败", extra=LOG_EXTRA)
|
||||
st.error(f"重评估失败:{exc}")
|
||||
if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols):
|
||||
with st.spinner("批量重评估执行中..."):
|
||||
try:
|
||||
trade_date_obj: Optional[date] = None
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
progress = st.progress(0.0)
|
||||
changes_all: List[Dict[str, object]] = []
|
||||
for idx, code in enumerate(batch_symbols, start=1):
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui_batch",
|
||||
name="UI Batch Re-eval",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action = before_map.get(agent)
|
||||
if new_action != old_action:
|
||||
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
progress.progress(idx / max(1, len(batch_symbols)))
|
||||
st.success("批量重评估完成。")
|
||||
if changes_all:
|
||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量重评估失败", extra=LOG_EXTRA)
|
||||
st.error(f"批量重评估失败:{exc}")
|
||||
Loading…
Reference in New Issue
Block a user