update
This commit is contained in:
parent
6a7c20db91
commit
4eb2b2d81e
@ -58,6 +58,7 @@ from app.utils.portfolio import (
|
||||
)
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.tuning import log_tuning_result
|
||||
from app.backtest.engine import BacktestEngine, PortfolioState
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -67,6 +68,9 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
|
||||
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
|
||||
_SIDEBAR_LISTENER_ATTACHED = False
|
||||
# ADD: simple in-memory cache for provider model discovery
|
||||
_MODEL_CACHE: Dict[str, Dict[str, object]] = {}
|
||||
_CACHE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
||||
@ -210,6 +214,16 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
||||
timeout = float(provider.default_timeout or 30.0)
|
||||
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
||||
|
||||
# ADD: simple cache by provider+base URL
|
||||
cache_key = f"{provider.key}|{base_url}"
|
||||
now = datetime.now()
|
||||
cached = _MODEL_CACHE.get(cache_key)
|
||||
if cached:
|
||||
ts = cached.get("ts")
|
||||
if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS:
|
||||
models = list(cached.get("models") or [])
|
||||
return models, None
|
||||
|
||||
try:
|
||||
if mode == "ollama":
|
||||
url = base_url.rstrip('/') + "/api/tags"
|
||||
@ -221,6 +235,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
||||
name = item.get("name") or item.get("model") or item.get("tag")
|
||||
if name:
|
||||
models.append(str(name).strip())
|
||||
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||
return sorted(set(models)), None
|
||||
|
||||
api_key = (api_override or provider.api_key or "").strip()
|
||||
@ -239,6 +254,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
||||
for item in payload.get("data", [])
|
||||
if item.get("id")
|
||||
]
|
||||
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||
return sorted(set(models)), None
|
||||
except RequestException as exc: # noqa: BLE001
|
||||
return [], f"HTTP 错误:{exc}"
|
||||
@ -345,7 +361,7 @@ def render_today_plan() -> None:
|
||||
if latest_trade_date:
|
||||
st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)")
|
||||
else:
|
||||
st.caption("统计与决策概览现已移至左侧“系统监控”侧栏。")
|
||||
st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。")
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
date_rows = conn.execute(
|
||||
@ -384,6 +400,8 @@ def render_today_plan() -> None:
|
||||
return
|
||||
|
||||
ts_code = st.selectbox("标的", symbols, index=0)
|
||||
# ADD: batch selection for re-evaluation
|
||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
@ -523,7 +541,16 @@ def render_today_plan() -> None:
|
||||
|
||||
st.subheader("部门意见")
|
||||
if dept_records:
|
||||
dept_df = pd.DataFrame(dept_records)
|
||||
# ADD: keyword filter for department summaries
|
||||
keyword = st.text_input("筛选摘要/信号关键词", value="")
|
||||
filtered = dept_records
|
||||
if keyword.strip():
|
||||
kw = keyword.strip()
|
||||
filtered = [
|
||||
item for item in dept_records
|
||||
if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", ""))
|
||||
]
|
||||
dept_df = pd.DataFrame(filtered)
|
||||
st.dataframe(dept_df, width='stretch', hide_index=True)
|
||||
for code, details in dept_details.items():
|
||||
with st.expander(f"{code} 补充详情", expanded=False):
|
||||
@ -636,6 +663,122 @@ def render_today_plan() -> None:
|
||||
|
||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("策略重评估")
|
||||
st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。")
|
||||
cols_re = st.columns([1,1])
|
||||
if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"):
|
||||
with st.spinner("正在重评估..."):
|
||||
try:
|
||||
trade_date_obj = None
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
# snapshot before
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui",
|
||||
name="UI Re-evaluation",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[ts_code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
_ = engine.simulate_day(trade_date_obj, state)
|
||||
# compare after
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
changes = []
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action, _old_utils = before_map.get(agent, (None, None))
|
||||
if new_action != old_action:
|
||||
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
if changes:
|
||||
st.success("重评估完成,检测到动作变更:")
|
||||
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
|
||||
else:
|
||||
st.success("重评估完成,无动作变更。")
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("重评估失败", extra=LOG_EXTRA)
|
||||
st.error(f"重评估失败:{exc}")
|
||||
if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols):
|
||||
with st.spinner("批量重评估执行中..."):
|
||||
try:
|
||||
trade_date_obj = None
|
||||
try:
|
||||
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||
except Exception:
|
||||
try:
|
||||
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||
except Exception:
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
progress = st.progress(0.0)
|
||||
changes_all: List[Dict[str, object]] = []
|
||||
for idx, code in enumerate(batch_symbols, start=1):
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui_batch",
|
||||
name="UI Batch Re-eval",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
_ = engine.simulate_day(trade_date_obj, state)
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action = before_map.get(agent)
|
||||
if new_action != old_action:
|
||||
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
progress.progress(idx / max(1, len(batch_symbols)))
|
||||
st.success("批量重评估完成。")
|
||||
if changes_all:
|
||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||
st.rerun()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量重评估失败", extra=LOG_EXTRA)
|
||||
st.error(f"批量重评估失败:{exc}")
|
||||
|
||||
|
||||
def render_backtest() -> None:
|
||||
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
|
||||
@ -1237,6 +1380,84 @@ def render_backtest() -> None:
|
||||
st.session_state.pop("decision_env_batch_select", None)
|
||||
st.success("已清除批量调参结果缓存。")
|
||||
|
||||
# ADD: Comparison view for multiple backtest configurations
|
||||
with st.expander("回测结果对比", expanded=False):
|
||||
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||||
normalize_to_one = st.checkbox("归一化到 1 起点", value=True)
|
||||
use_log_y = st.checkbox("对数坐标", value=False)
|
||||
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||||
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options)
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
cfg_rows = conn.execute(
|
||||
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
||||
).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||||
cfg_rows = []
|
||||
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||||
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2])
|
||||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||
if selected_ids:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
nav_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
rpt_df = pd.read_sql_query(
|
||||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||
conn,
|
||||
params=tuple(selected_ids),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||
st.error("读取回测结果失败")
|
||||
nav_df = pd.DataFrame()
|
||||
rpt_df = pd.DataFrame()
|
||||
if not nav_df.empty:
|
||||
try:
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||
if normalize_to_one:
|
||||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||
import plotly.graph_objects as go
|
||||
fig = go.Figure()
|
||||
for col in pivot.columns:
|
||||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||
if use_log_y:
|
||||
fig.update_yaxes(type="log")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||||
if not rpt_df.empty:
|
||||
try:
|
||||
metrics_rows: List[Dict[str, object]] = []
|
||||
for _, row in rpt_df.iterrows():
|
||||
cfg_id = row["cfg_id"]
|
||||
try:
|
||||
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||||
except json.JSONDecodeError:
|
||||
summary = {}
|
||||
record = {
|
||||
"cfg_id": cfg_id,
|
||||
"总收益": summary.get("total_return"),
|
||||
"最大回撤": summary.get("max_drawdown"),
|
||||
"交易数": summary.get("trade_count"),
|
||||
"平均换手": summary.get("avg_turnover"),
|
||||
"风险事件": summary.get("risk_events"),
|
||||
}
|
||||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||
if metrics_rows:
|
||||
dfm = pd.DataFrame(metrics_rows)
|
||||
st.dataframe(dfm, hide_index=True, width='stretch')
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
st.info("请选择至少一个配置进行对比。")
|
||||
|
||||
|
||||
def render_settings() -> None:
|
||||
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user