update
This commit is contained in:
parent
6a7c20db91
commit
4eb2b2d81e
@ -58,6 +58,7 @@ from app.utils.portfolio import (
|
|||||||
)
|
)
|
||||||
from app.agents.registry import default_agents
|
from app.agents.registry import default_agents
|
||||||
from app.utils.tuning import log_tuning_result
|
from app.utils.tuning import log_tuning_result
|
||||||
|
from app.backtest.engine import BacktestEngine, PortfolioState
|
||||||
|
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -67,6 +68,9 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
|||||||
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
|
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
|
||||||
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
|
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
|
||||||
_SIDEBAR_LISTENER_ATTACHED = False
|
_SIDEBAR_LISTENER_ATTACHED = False
|
||||||
|
# ADD: simple in-memory cache for provider model discovery
|
||||||
|
_MODEL_CACHE: Dict[str, Dict[str, object]] = {}
|
||||||
|
_CACHE_TTL_SECONDS = 300
|
||||||
|
|
||||||
|
|
||||||
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
|
||||||
@ -210,6 +214,16 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
|||||||
timeout = float(provider.default_timeout or 30.0)
|
timeout = float(provider.default_timeout or 30.0)
|
||||||
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai")
|
||||||
|
|
||||||
|
# ADD: simple cache by provider+base URL
|
||||||
|
cache_key = f"{provider.key}|{base_url}"
|
||||||
|
now = datetime.now()
|
||||||
|
cached = _MODEL_CACHE.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
ts = cached.get("ts")
|
||||||
|
if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS:
|
||||||
|
models = list(cached.get("models") or [])
|
||||||
|
return models, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if mode == "ollama":
|
if mode == "ollama":
|
||||||
url = base_url.rstrip('/') + "/api/tags"
|
url = base_url.rstrip('/') + "/api/tags"
|
||||||
@ -221,6 +235,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
|||||||
name = item.get("name") or item.get("model") or item.get("tag")
|
name = item.get("name") or item.get("model") or item.get("tag")
|
||||||
if name:
|
if name:
|
||||||
models.append(str(name).strip())
|
models.append(str(name).strip())
|
||||||
|
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||||
return sorted(set(models)), None
|
return sorted(set(models)), None
|
||||||
|
|
||||||
api_key = (api_override or provider.api_key or "").strip()
|
api_key = (api_override or provider.api_key or "").strip()
|
||||||
@ -239,6 +254,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap
|
|||||||
for item in payload.get("data", [])
|
for item in payload.get("data", [])
|
||||||
if item.get("id")
|
if item.get("id")
|
||||||
]
|
]
|
||||||
|
_MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))}
|
||||||
return sorted(set(models)), None
|
return sorted(set(models)), None
|
||||||
except RequestException as exc: # noqa: BLE001
|
except RequestException as exc: # noqa: BLE001
|
||||||
return [], f"HTTP 错误:{exc}"
|
return [], f"HTTP 错误:{exc}"
|
||||||
@ -345,7 +361,7 @@ def render_today_plan() -> None:
|
|||||||
if latest_trade_date:
|
if latest_trade_date:
|
||||||
st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)")
|
st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)")
|
||||||
else:
|
else:
|
||||||
st.caption("统计与决策概览现已移至左侧“系统监控”侧栏。")
|
st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。")
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
date_rows = conn.execute(
|
date_rows = conn.execute(
|
||||||
@ -384,6 +400,8 @@ def render_today_plan() -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
ts_code = st.selectbox("标的", symbols, index=0)
|
ts_code = st.selectbox("标的", symbols, index=0)
|
||||||
|
# ADD: batch selection for re-evaluation
|
||||||
|
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
||||||
|
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
@ -523,7 +541,16 @@ def render_today_plan() -> None:
|
|||||||
|
|
||||||
st.subheader("部门意见")
|
st.subheader("部门意见")
|
||||||
if dept_records:
|
if dept_records:
|
||||||
dept_df = pd.DataFrame(dept_records)
|
# ADD: keyword filter for department summaries
|
||||||
|
keyword = st.text_input("筛选摘要/信号关键词", value="")
|
||||||
|
filtered = dept_records
|
||||||
|
if keyword.strip():
|
||||||
|
kw = keyword.strip()
|
||||||
|
filtered = [
|
||||||
|
item for item in dept_records
|
||||||
|
if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", ""))
|
||||||
|
]
|
||||||
|
dept_df = pd.DataFrame(filtered)
|
||||||
st.dataframe(dept_df, width='stretch', hide_index=True)
|
st.dataframe(dept_df, width='stretch', hide_index=True)
|
||||||
for code, details in dept_details.items():
|
for code, details in dept_details.items():
|
||||||
with st.expander(f"{code} 补充详情", expanded=False):
|
with st.expander(f"{code} 补充详情", expanded=False):
|
||||||
@ -636,6 +663,122 @@ def render_today_plan() -> None:
|
|||||||
|
|
||||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
st.subheader("策略重评估")
|
||||||
|
st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。")
|
||||||
|
cols_re = st.columns([1,1])
|
||||||
|
if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"):
|
||||||
|
with st.spinner("正在重评估..."):
|
||||||
|
try:
|
||||||
|
trade_date_obj = None
|
||||||
|
try:
|
||||||
|
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if trade_date_obj is None:
|
||||||
|
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||||
|
# snapshot before
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
before_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT agent, action, utils FROM agent_utils
|
||||||
|
WHERE trade_date = ? AND ts_code = ?
|
||||||
|
""",
|
||||||
|
(trade_date, ts_code),
|
||||||
|
).fetchall()
|
||||||
|
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="reeval_ui",
|
||||||
|
name="UI Re-evaluation",
|
||||||
|
start_date=trade_date_obj,
|
||||||
|
end_date=trade_date_obj,
|
||||||
|
universe=[ts_code],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
engine = BacktestEngine(cfg)
|
||||||
|
state = PortfolioState()
|
||||||
|
_ = engine.simulate_day(trade_date_obj, state)
|
||||||
|
# compare after
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
after_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT agent, action, utils FROM agent_utils
|
||||||
|
WHERE trade_date = ? AND ts_code = ?
|
||||||
|
""",
|
||||||
|
(trade_date, ts_code),
|
||||||
|
).fetchall()
|
||||||
|
changes = []
|
||||||
|
for row in after_rows:
|
||||||
|
agent = row["agent"]
|
||||||
|
new_action = row["action"]
|
||||||
|
old_action, _old_utils = before_map.get(agent, (None, None))
|
||||||
|
if new_action != old_action:
|
||||||
|
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
|
||||||
|
if changes:
|
||||||
|
st.success("重评估完成,检测到动作变更:")
|
||||||
|
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
|
||||||
|
else:
|
||||||
|
st.success("重评估完成,无动作变更。")
|
||||||
|
st.rerun()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("重评估失败", extra=LOG_EXTRA)
|
||||||
|
st.error(f"重评估失败:{exc}")
|
||||||
|
if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols):
|
||||||
|
with st.spinner("批量重评估执行中..."):
|
||||||
|
try:
|
||||||
|
trade_date_obj = None
|
||||||
|
try:
|
||||||
|
trade_date_obj = date.fromisoformat(str(trade_date))
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if trade_date_obj is None:
|
||||||
|
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||||
|
progress = st.progress(0.0)
|
||||||
|
changes_all: List[Dict[str, object]] = []
|
||||||
|
for idx, code in enumerate(batch_symbols, start=1):
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
before_rows = conn.execute(
|
||||||
|
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||||
|
(trade_date, code),
|
||||||
|
).fetchall()
|
||||||
|
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="reeval_ui_batch",
|
||||||
|
name="UI Batch Re-eval",
|
||||||
|
start_date=trade_date_obj,
|
||||||
|
end_date=trade_date_obj,
|
||||||
|
universe=[code],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
engine = BacktestEngine(cfg)
|
||||||
|
state = PortfolioState()
|
||||||
|
_ = engine.simulate_day(trade_date_obj, state)
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
after_rows = conn.execute(
|
||||||
|
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||||
|
(trade_date, code),
|
||||||
|
).fetchall()
|
||||||
|
for row in after_rows:
|
||||||
|
agent = row["agent"]
|
||||||
|
new_action = row["action"]
|
||||||
|
old_action = before_map.get(agent)
|
||||||
|
if new_action != old_action:
|
||||||
|
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
|
||||||
|
progress.progress(idx / max(1, len(batch_symbols)))
|
||||||
|
st.success("批量重评估完成。")
|
||||||
|
if changes_all:
|
||||||
|
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||||
|
st.rerun()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("批量重评估失败", extra=LOG_EXTRA)
|
||||||
|
st.error(f"批量重评估失败:{exc}")
|
||||||
|
|
||||||
|
|
||||||
def render_backtest() -> None:
|
def render_backtest() -> None:
|
||||||
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
|
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
|
||||||
@ -1237,6 +1380,84 @@ def render_backtest() -> None:
|
|||||||
st.session_state.pop("decision_env_batch_select", None)
|
st.session_state.pop("decision_env_batch_select", None)
|
||||||
st.success("已清除批量调参结果缓存。")
|
st.success("已清除批量调参结果缓存。")
|
||||||
|
|
||||||
|
# ADD: Comparison view for multiple backtest configurations
|
||||||
|
with st.expander("回测结果对比", expanded=False):
|
||||||
|
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||||||
|
normalize_to_one = st.checkbox("归一化到 1 起点", value=True)
|
||||||
|
use_log_y = st.checkbox("对数坐标", value=False)
|
||||||
|
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||||||
|
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options)
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
cfg_rows = conn.execute(
|
||||||
|
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
||||||
|
).fetchall()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||||||
|
cfg_rows = []
|
||||||
|
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||||||
|
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2])
|
||||||
|
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||||
|
if selected_ids:
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
nav_df = pd.read_sql_query(
|
||||||
|
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||||
|
conn,
|
||||||
|
params=tuple(selected_ids),
|
||||||
|
)
|
||||||
|
rpt_df = pd.read_sql_query(
|
||||||
|
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||||
|
conn,
|
||||||
|
params=tuple(selected_ids),
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||||
|
st.error("读取回测结果失败")
|
||||||
|
nav_df = pd.DataFrame()
|
||||||
|
rpt_df = pd.DataFrame()
|
||||||
|
if not nav_df.empty:
|
||||||
|
try:
|
||||||
|
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||||
|
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||||
|
if normalize_to_one:
|
||||||
|
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
fig = go.Figure()
|
||||||
|
for col in pivot.columns:
|
||||||
|
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||||
|
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||||
|
if use_log_y:
|
||||||
|
fig.update_yaxes(type="log")
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||||||
|
if not rpt_df.empty:
|
||||||
|
try:
|
||||||
|
metrics_rows: List[Dict[str, object]] = []
|
||||||
|
for _, row in rpt_df.iterrows():
|
||||||
|
cfg_id = row["cfg_id"]
|
||||||
|
try:
|
||||||
|
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
summary = {}
|
||||||
|
record = {
|
||||||
|
"cfg_id": cfg_id,
|
||||||
|
"总收益": summary.get("total_return"),
|
||||||
|
"最大回撤": summary.get("max_drawdown"),
|
||||||
|
"交易数": summary.get("trade_count"),
|
||||||
|
"平均换手": summary.get("avg_turnover"),
|
||||||
|
"风险事件": summary.get("risk_events"),
|
||||||
|
}
|
||||||
|
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||||
|
if metrics_rows:
|
||||||
|
dfm = pd.DataFrame(metrics_rows)
|
||||||
|
st.dataframe(dfm, hide_index=True, width='stretch')
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||||
|
else:
|
||||||
|
st.info("请选择至少一个配置进行对比。")
|
||||||
|
|
||||||
|
|
||||||
def render_settings() -> None:
|
def render_settings() -> None:
|
||||||
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
|
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user