This commit is contained in:
sam 2025-10-02 17:01:58 +08:00
parent 6a7c20db91
commit 4eb2b2d81e

View File

@ -58,6 +58,7 @@ from app.utils.portfolio import (
)
from app.agents.registry import default_agents
from app.utils.tuning import log_tuning_result
from app.backtest.engine import BacktestEngine, PortfolioState
LOGGER = get_logger(__name__)
@ -67,6 +68,9 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
_SIDEBAR_LISTENER_ATTACHED = False
# ADD: simple in-memory cache for provider model discovery
_MODEL_CACHE: Dict[str, Dict[str, object]] = {}
_CACHE_TTL_SECONDS = 300
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
@ -210,6 +214,16 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
timeout = float(provider.default_timeout or 30.0)
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
# ADD: simple cache by provider+base URL
cache_key = f"{provider.key}|{base_url}"
now = datetime.now()
cached = _MODEL_CACHE.get(cache_key)
if cached:
ts = cached.get("ts")
if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS:
models = list(cached.get("models") or [])
return models, None
try:
if mode == "ollama":
url = base_url.rstrip('/') + "/api/tags"
@ -221,6 +235,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
name = item.get("name") or item.get("model") or item.get("tag")
if name:
models.append(str(name).strip())
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
return sorted(set(models)), None
api_key = (api_override or provider.api_key or "").strip()
@ -239,6 +254,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
for item in payload.get("data", [])
if item.get("id")
]
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
return sorted(set(models)), None
except RequestException as exc: # noqa: BLE001
return [], f"HTTP 错误:{exc}"
@ -345,7 +361,7 @@ def render_today_plan() -> None:
if latest_trade_date:
st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)")
else:
st.caption("统计与决策概览现已移至左侧“系统监控”侧栏。")
st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。")
try:
with db_session(read_only=True) as conn:
date_rows = conn.execute(
@ -384,6 +400,8 @@ def render_today_plan() -> None:
return
ts_code = st.selectbox("标的", symbols, index=0)
# ADD: batch selection for re-evaluation
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
with db_session(read_only=True) as conn:
rows = conn.execute(
@ -523,7 +541,16 @@ def render_today_plan() -> None:
st.subheader("部门意见")
if dept_records:
dept_df = pd.DataFrame(dept_records)
# ADD: keyword filter for department summaries
keyword = st.text_input("筛选摘要/信号关键词", value="")
filtered = dept_records
if keyword.strip():
kw = keyword.strip()
filtered = [
item for item in dept_records
if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", ""))
]
dept_df = pd.DataFrame(filtered)
st.dataframe(dept_df, width='stretch', hide_index=True)
for code, details in dept_details.items():
with st.expander(f"{code} 补充详情", expanded=False):
@ -636,6 +663,122 @@ def render_today_plan() -> None:
st.caption("数据来源agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
st.divider()
st.subheader("策略重评估")
st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。")
cols_re = st.columns([1,1])
if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"):
with st.spinner("正在重评估..."):
try:
trade_date_obj = None
try:
trade_date_obj = date.fromisoformat(str(trade_date))
except Exception:
try:
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
pass
if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}")
# snapshot before
with db_session(read_only=True) as conn:
before_rows = conn.execute(
"""
SELECT agent, action, utils FROM agent_utils
WHERE trade_date = ? AND ts_code = ?
""",
(trade_date, ts_code),
).fetchall()
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
cfg = BtConfig(
id="reeval_ui",
name="UI Re-evaluation",
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=[ts_code],
params={},
)
engine = BacktestEngine(cfg)
state = PortfolioState()
_ = engine.simulate_day(trade_date_obj, state)
# compare after
with db_session(read_only=True) as conn:
after_rows = conn.execute(
"""
SELECT agent, action, utils FROM agent_utils
WHERE trade_date = ? AND ts_code = ?
""",
(trade_date, ts_code),
).fetchall()
changes = []
for row in after_rows:
agent = row["agent"]
new_action = row["action"]
old_action, _old_utils = before_map.get(agent, (None, None))
if new_action != old_action:
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
if changes:
st.success("重评估完成,检测到动作变更:")
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
else:
st.success("重评估完成,无动作变更。")
st.rerun()
except Exception as exc: # noqa: BLE001
LOGGER.exception("重评估失败", extra=LOG_EXTRA)
st.error(f"重评估失败:{exc}")
if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols):
with st.spinner("批量重评估执行中..."):
try:
trade_date_obj = None
try:
trade_date_obj = date.fromisoformat(str(trade_date))
except Exception:
try:
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
except Exception:
pass
if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}")
progress = st.progress(0.0)
changes_all: List[Dict[str, object]] = []
for idx, code in enumerate(batch_symbols, start=1):
with db_session(read_only=True) as conn:
before_rows = conn.execute(
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
(trade_date, code),
).fetchall()
before_map = {row["agent"]: row["action"] for row in before_rows}
cfg = BtConfig(
id="reeval_ui_batch",
name="UI Batch Re-eval",
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=[code],
params={},
)
engine = BacktestEngine(cfg)
state = PortfolioState()
_ = engine.simulate_day(trade_date_obj, state)
with db_session(read_only=True) as conn:
after_rows = conn.execute(
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
(trade_date, code),
).fetchall()
for row in after_rows:
agent = row["agent"]
new_action = row["action"]
old_action = before_map.get(agent)
if new_action != old_action:
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
progress.progress(idx / max(1, len(batch_symbols)))
st.success("批量重评估完成。")
if changes_all:
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
st.rerun()
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量重评估失败", extra=LOG_EXTRA)
st.error(f"批量重评估失败:{exc}")
def render_backtest() -> None:
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
@ -1237,6 +1380,84 @@ def render_backtest() -> None:
st.session_state.pop("decision_env_batch_select", None)
st.success("已清除批量调参结果缓存。")
# ADD: Comparison view for multiple backtest configurations
with st.expander("回测结果对比", expanded=False):
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
normalize_to_one = st.checkbox("归一化到 1 起点", value=True)
use_log_y = st.checkbox("对数坐标", value=False)
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options)
try:
with db_session(read_only=True) as conn:
cfg_rows = conn.execute(
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
cfg_rows = []
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2])
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
if selected_ids:
try:
with db_session(read_only=True) as conn:
nav_df = pd.read_sql_query(
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
conn,
params=tuple(selected_ids),
)
rpt_df = pd.read_sql_query(
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
conn,
params=tuple(selected_ids),
)
except Exception: # noqa: BLE001
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
st.error("读取回测结果失败")
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
if not nav_df.empty:
try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
if normalize_to_one:
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
import plotly.graph_objects as go
fig = go.Figure()
for col in pivot.columns:
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
if use_log_y:
fig.update_yaxes(type="log")
st.plotly_chart(fig, use_container_width=True)
except Exception: # noqa: BLE001
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
if not rpt_df.empty:
try:
metrics_rows: List[Dict[str, object]] = []
for _, row in rpt_df.iterrows():
cfg_id = row["cfg_id"]
try:
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
except json.JSONDecodeError:
summary = {}
record = {
"cfg_id": cfg_id,
"总收益": summary.get("total_return"),
"最大回撤": summary.get("max_drawdown"),
"交易数": summary.get("trade_count"),
"平均换手": summary.get("avg_turnover"),
"风险事件": summary.get("risk_events"),
}
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
if metrics_rows:
dfm = pd.DataFrame(metrics_rows)
st.dataframe(dfm, hide_index=True, width='stretch')
except Exception: # noqa: BLE001
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
else:
st.info("请选择至少一个配置进行对比。")
def render_settings() -> None:
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)