From 90fb2a9df62539547e57f30eb70ff7e1b0cbcb6e Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 11 Oct 2025 09:27:55 +0800 Subject: [PATCH] refactor backtest engine with trading rules and progress tracking --- app/agents/departments.py | 101 ++++++++++++++-- app/agents/value.py | 21 +++- app/backtest/engine.py | 103 ++++++++++++++-- app/features/factors.py | 94 +++++++++------ app/features/progress.py | 38 ++++++ app/ui/progress_state.py | 7 ++ app/ui/views/backtest.py | 43 ++++--- app/ui/views/pool.py | 36 ++++-- app/ui/views/today.py | 247 ++++++++++++++++++++------------------ app/utils/data_access.py | 44 +++++-- 10 files changed, 515 insertions(+), 219 deletions(-) create mode 100644 app/features/progress.py diff --git a/app/agents/departments.py b/app/agents/departments.py index c341afe..c866e3d 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -204,22 +204,29 @@ class DepartmentAgent: rounds_executed = round_idx + 1 - usage = response.get("usage") if isinstance(response, Mapping) else None - if isinstance(usage, Mapping): - usage_payload = {"round": round_idx + 1} - usage_payload.update(dict(usage)) - usage_records.append(usage_payload) + message, usage_payload, tool_calls = _normalize_llm_response(response) + if usage_payload: + payload_with_round = {"round": round_idx + 1} + payload_with_round.update(usage_payload) + usage_records.append(payload_with_round) - choice = (response.get("choices") or [{}])[0] - message = choice.get("message", {}) + if not message: + LOGGER.debug( + "部门 %s 第 %s 轮响应缺少 message 字段:%s", + self.settings.code, + round_idx + 1, + response, + extra=LOG_EXTRA, + ) + message = {"role": "assistant", "content": ""} transcript.append(_message_to_text(message)) assistant_record: Dict[str, Any] = { - "role": "assistant", + "role": message.get("role", "assistant"), "content": _extract_message_content(message), } - if message.get("tool_calls"): - assistant_record["tool_calls"] = message.get("tool_calls") + if tool_calls: + assistant_record["tool_calls"] = tool_calls messages.append(assistant_record) CONV_LOGGER.info( "dept=%s round=%s assistant=%s", @@ -228,7 +235,6 @@ class DepartmentAgent: assistant_record, ) - tool_calls = message.get("tool_calls") or [] if tool_calls: for call in tool_calls: function_block = call.get("function") or {} @@ -656,6 +662,8 @@ class DepartmentAgent: dialogue=[response], ) return decision + + def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: if not isinstance(context.features, dict): context.features = dict(context.features or {}) @@ -669,6 +677,77 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: return context +def _compose_usage_from_stats(payload: Mapping[str, Any]) -> Dict[str, Any]: + usage: Dict[str, Any] = {} + prompt_eval = payload.get("prompt_eval_count") + completion_eval = payload.get("eval_count") + if isinstance(prompt_eval, (int, float)): + usage["prompt_tokens"] = int(prompt_eval) + if isinstance(completion_eval, (int, float)): + usage["completion_tokens"] = int(completion_eval) + if usage: + total = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) + usage["total_tokens"] = total + return usage + + +def _normalize_llm_response( + response: Mapping[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, Any]]]: + message: Dict[str, Any] = {} + usage: Dict[str, Any] = {} + tool_calls: List[Dict[str, Any]] = [] + + if not isinstance(response, Mapping): + return message, usage, tool_calls + + choices = response.get("choices") + if isinstance(choices, list) and choices: + choice = choices[0] or {} + candidate = choice.get("message") + if isinstance(candidate, Mapping): + message = candidate + raw_calls = candidate.get("tool_calls") + if isinstance(raw_calls, list): + tool_calls = list(raw_calls) + raw_usage = response.get("usage") + if isinstance(raw_usage, Mapping): + usage = dict(raw_usage) + else: + raw_message = response.get("message") + if isinstance(raw_message, Mapping): + message = raw_message + raw_calls = raw_message.get("tool_calls") + if isinstance(raw_calls, list): + tool_calls = list(raw_calls) + elif isinstance(response.get("messages"), list): + messages_list = response.get("messages") or [] + if messages_list: + candidate = messages_list[-1] + if isinstance(candidate, Mapping): + message = candidate + raw_calls = candidate.get("tool_calls") + if isinstance(raw_calls, list): + tool_calls = list(raw_calls) + if not message: + content = response.get("content") + if isinstance(content, str): + message = {"role": "assistant", "content": content} + raw_usage = response.get("usage") + if isinstance(raw_usage, Mapping): + usage = dict(raw_usage) + else: + usage = _compose_usage_from_stats(response) + + if not tool_calls: + extra = message.get("additional_kwargs") + if isinstance(extra, Mapping): + extra_calls = extra.get("tool_calls") + if isinstance(extra_calls, list): + tool_calls = list(extra_calls) + return message or {"role": "assistant", "content": ""}, usage, tool_calls + + def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: if isinstance(payload, dict): return dict(payload) diff --git a/app/agents/value.py b/app/agents/value.py index ff135e3..d501314 100644 --- a/app/agents/value.py +++ b/app/agents/value.py @@ -1,6 +1,8 @@ """Value and quality filtering agent.""" from __future__ import annotations +from typing import Mapping + from .base import Agent, AgentAction, AgentContext @@ -9,12 +11,19 @@ class ValueAgent(Agent): super().__init__(name="A_val") def score(self, context: AgentContext, action: AgentAction) -> float: - pe = context.features.get("pe_percentile", 0.5) - pb = context.features.get("pb_percentile", 0.5) - roe = context.features.get("roe_percentile", 0.5) - # Lower valuation percentiles and higher quality percentiles add value. - raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3) - raw = min(raw, 1.0) + pe_score = context.features.get("valuation_pe_score", 0.0) + pb_score = context.features.get("valuation_pb_score", 0.0) + # 多因子组合尚未落地,这里兼容扩展因子(若存在则优先使用) + scope_values = {} + if isinstance(context.raw, Mapping): + scope_values = context.raw.get("scope_values", {}) or {} + multi_score = context.features.get("val_multiscore") + if multi_score is None: + multi_score = scope_values.get("factors.val_multiscore") + if multi_score is not None: + raw = float(max(0.0, min(1.0, multi_score))) + else: + raw = max(0.0, min(1.0, 0.6 * pe_score + 0.4 * pb_score)) if action is AgentAction.SELL: return 1 - raw if action is AgentAction.HOLD: diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 3fa240a..36ea420 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from collections import defaultdict from dataclasses import dataclass, field -from datetime import date +from datetime import date, datetime from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from app.agents.base import AgentAction, AgentContext @@ -16,7 +16,7 @@ from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents from app.data.schema import initialize_database from app.utils.data_access import DataBroker -from app.utils.config import get_config +from app.utils.config import PortfolioSettings, get_config from app.utils.db import db_session from app.utils.logging import get_logger from app.utils import alerts @@ -105,12 +105,26 @@ class BacktestEngine: ) self.data_broker = DataBroker() params = cfg.params or {} + portfolio_cfg = getattr(app_cfg, "portfolio", None) or PortfolioSettings() self.risk_params = { "max_position_weight": float(params.get("max_position_weight", 0.2)), "max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)), "fee_rate": float(params.get("fee_rate", 0.0005)), "slippage_bps": float(params.get("slippage_bps", 10.0)), } + self.initial_cash = max(0.0, float(params.get("initial_capital", portfolio_cfg.initial_capital))) + target_return = params.get("target", params.get("target_return", 0.0)) or 0.0 + stop_loss = params.get("stop", params.get("stop_loss", 0.0)) or 0.0 + hold_days_param = params.get("hold_days", params.get("max_hold_days", 0)) + try: + max_hold_days = int(hold_days_param) if hold_days_param is not None else 0 + except (TypeError, ValueError): + max_hold_days = 0 + self.trading_rules = { + "target_return": float(target_return), + "stop_loss": float(stop_loss), + "max_hold_days": max(0, max_hold_days), + } self._fee_rate = max(self.risk_params["fee_rate"], 0.0) self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0 self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0) @@ -314,9 +328,9 @@ class BacktestEngine: is_suspended = self.data_broker.fetch_flags( "suspend", ts_code, - trade_date, - "ts_code = ?", - [ts_code], + trade_date_str, + "", + [], auto_refresh=False, # 避免在回测中触发自动补数 ) @@ -650,6 +664,7 @@ class BacktestEngine: continue features = feature_cache.get(ts_code, {}) current_qty = state.holdings.get(ts_code, 0.0) + current_cost_basis = float(state.cost_basis.get(ts_code, 0.0) or 0.0) liquidity_score = float(features.get("liquidity_score") or 0.0) risk_penalty = float(features.get("risk_penalty") or 0.0) is_suspended = bool(features.get("is_suspended")) @@ -679,6 +694,76 @@ class BacktestEngine: if risk.status == "blocked": continue + rule_override_action: Optional[AgentAction] = None + rule_override_reason: Optional[str] = None + gain_ratio: Optional[float] = None + if current_qty > 0 and current_cost_basis: + try: + gain_ratio = (price / current_cost_basis) - 1.0 + except ZeroDivisionError: + gain_ratio = None + target_return = self.trading_rules.get("target_return", 0.0) + if ( + rule_override_action is None + and gain_ratio is not None + and target_return + and gain_ratio >= target_return + ): + rule_override_action = AgentAction.SELL + rule_override_reason = "target_reached" + stop_loss = self.trading_rules.get("stop_loss", 0.0) + if ( + rule_override_action is None + and gain_ratio is not None + and stop_loss + and gain_ratio <= stop_loss + ): + rule_override_action = AgentAction.SELL + rule_override_reason = "stop_loss" + max_hold_days = self.trading_rules.get("max_hold_days", 0) + if ( + rule_override_action is None + and max_hold_days + and max_hold_days > 0 + and current_qty > 0 + ): + opened_str = state.opened_dates.get(ts_code) + opened_dt: Optional[date] = None + if opened_str: + try: + opened_dt = date.fromisoformat(str(opened_str)) + except ValueError: + try: + opened_dt = datetime.strptime(str(opened_str), "%Y%m%d").date() + except ValueError: + opened_dt = None + LOGGER.debug( + "无法解析持仓日期 ts_code=%s value=%s", + ts_code, + opened_str, + extra=LOG_EXTRA, + ) + if opened_dt: + holding_days = (trade_date - opened_dt).days + if holding_days >= max_hold_days: + rule_override_action = AgentAction.SELL + rule_override_reason = "holding_period" + + if rule_override_action and rule_override_action is not effective_action: + effective_action = rule_override_action + effective_weight = target_weight_for_action(effective_action) + _record_risk( + ts_code, + rule_override_reason or "rule_override", + decision, + extra={ + "rule_trigger": rule_override_reason, + "gain_ratio": gain_ratio, + }, + action_override=effective_action, + target_weight_override=effective_weight, + ) + if is_suspended: _record_risk(ts_code, "suspended", decision) continue @@ -738,8 +823,7 @@ class BacktestEngine: if total_cash_needed <= 0: _record_risk(ts_code, "invalid_trade", decision) continue - - previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty + previous_cost = current_cost_basis * current_qty new_qty = current_qty + delta state.cost_basis[ts_code] = ( (previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0 @@ -777,8 +861,7 @@ class BacktestEngine: gross_value = sell_qty * trade_price fee = gross_value * self._fee_rate proceeds = gross_value - fee - cost_basis = state.cost_basis.get(ts_code, 0.0) - realized = (trade_price - cost_basis) * sell_qty - fee + realized = (trade_price - current_cost_basis) * sell_qty - fee state.cash += proceeds state.realized_pnl += realized new_qty = current_qty - sell_qty @@ -1023,7 +1106,7 @@ class BacktestEngine: """Initialise a new incremental backtest session.""" return BacktestSession( - state=PortfolioState(), + state=PortfolioState(cash=self.initial_cash), result=BacktestResult(), current_date=self.cfg.start_date, ) diff --git a/app/features/factors.py b/app/features/factors.py index e7f9389..e6ea4f0 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -19,7 +19,11 @@ from app.features.value_risk_factors import ValueRiskFactors # 导入因子验证功能 from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers # 导入UI进度状态管理 -from app.ui.progress_state import factor_progress +try: + from app.features.progress import get_progress_handler +except ImportError: # pragma: no cover - optional dependency + def get_progress_handler(): + return None LOGGER = get_logger(__name__) @@ -176,14 +180,16 @@ def compute_factors( broker = DataBroker() results: List[FactorResult] = [] rows_to_persist: List[tuple[str, Dict[str, float | None]]] = [] + total_batches = (len(universe) + batch_size - 1) // batch_size if universe else 0 + progress = get_progress_handler() + if progress and universe: + try: + progress.start_calculation(len(universe), total_batches) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler start_calculation 失败", extra=LOG_EXTRA) + progress = None try: - # 启动UI进度状态(在异步线程中不直接访问factor_progress) - # factor_progress.start_calculation( - # total_securities=len(universe), - # total_batches=(len(universe) + batch_size - 1) // batch_size - # ) - # 分批处理以优化性能 for i in range(0, len(universe), batch_size): batch = universe[i:i+batch_size] @@ -194,9 +200,10 @@ def compute_factors( specs, validation_stats, batch_index=i // batch_size, - total_batches=(len(universe) + batch_size - 1) // batch_size, + total_batches=total_batches or 1, processed_securities=i, - total_securities=len(universe) + total_securities=len(universe), + progress=progress, ) for ts_code, values in batch_results: @@ -222,9 +229,13 @@ def compute_factors( _persist_factor_rows(trade_date_str, rows_to_persist, specs) # 更新UI进度状态为完成 - factor_progress.complete_calculation( - message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}" - ) + if progress: + try: + progress.complete_calculation( + message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}" + ) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler complete_calculation 失败", extra=LOG_EXTRA) LOGGER.info( "因子计算完成 总数量:%s 成功:%s 失败:%s", @@ -239,7 +250,11 @@ def compute_factors( except Exception as exc: # 发生错误时更新UI状态 error_message = f"因子计算过程中发生错误: {exc}" - factor_progress.error_occurred(error_message) + if progress: + try: + progress.error_occurred(error_message) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler error_occurred 失败", extra=LOG_EXTRA) LOGGER.error(error_message, extra=LOG_EXTRA) raise @@ -380,6 +395,7 @@ def _compute_batch_factors( total_batches: int = 1, processed_securities: int = 0, total_securities: int = 0, + progress: Optional[object] = None, ) -> List[tuple[str, Dict[str, float | None]]]: """批量计算多个证券的因子值,提高计算效率""" batch_results = [] @@ -388,13 +404,16 @@ def _compute_batch_factors( available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) # 更新UI进度状态 - 开始处理批次 - if total_securities > 0: - from app.ui.progress_state import factor_progress - factor_progress.update_progress( - current_securities=processed_securities, - current_batch=batch_index + 1, - message=f"开始处理批次 {batch_index + 1}/{total_batches}" - ) + if progress and total_securities > 0: + try: + progress.update_progress( + current_securities=processed_securities, + current_batch=batch_index + 1, + message=f"开始处理批次 {batch_index + 1}/{total_batches}", + ) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA) + progress = None for i, ts_code in enumerate(ts_codes): try: @@ -434,15 +453,18 @@ def _compute_batch_factors( validation_stats["skipped"] += 1 # 每处理1个证券更新一次进度,确保实时性 - if total_securities > 0: + if progress and total_securities > 0: current_progress = processed_securities + i + 1 progress_percentage = (current_progress / total_securities) * 100 - from app.ui.progress_state import factor_progress - factor_progress.update_progress( - current_securities=current_progress, - current_batch=batch_index + 1, - message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)" - ) + try: + progress.update_progress( + current_securities=current_progress, + current_batch=batch_index + 1, + message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)", + ) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA) + progress = None except Exception as e: LOGGER.error( "计算因子失败 ts_code=%s err=%s", @@ -453,16 +475,18 @@ def _compute_batch_factors( validation_stats["skipped"] += 1 # 批次处理完成,更新最终进度 - if total_securities > 0: + if progress and total_securities > 0: final_progress = processed_securities + len(ts_codes) progress_percentage = (final_progress / total_securities) * 100 - from app.ui.progress_state import factor_progress - factor_progress.update_progress( - current_securities=final_progress, - current_batch=batch_index + 1, - message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)" - ) - + try: + progress.update_progress( + current_securities=final_progress, + current_batch=batch_index + 1, + message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)", + ) + except Exception: # noqa: BLE001 + LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA) + return batch_results diff --git a/app/features/progress.py b/app/features/progress.py new file mode 100644 index 0000000..e30eb17 --- /dev/null +++ b/app/features/progress.py @@ -0,0 +1,38 @@ +"""Optional progress reporting for factor computation.""" +from __future__ import annotations + +from typing import Optional, Protocol + + +class FactorProgressProtocol(Protocol): + """Protocol describing the optional UI progress handler.""" + + def start_calculation(self, total_securities: int, total_batches: int) -> None: ... + + def update_progress( + self, + current_securities: int, + current_batch: int, + message: str = "", + ) -> None: ... + + def complete_calculation(self, message: str = "") -> None: ... + + def error_occurred(self, error_message: str) -> None: ... + + +_progress_handler: Optional[FactorProgressProtocol] = None + + +def register_progress_handler(progress: FactorProgressProtocol | None) -> None: + """Register a progress handler (typically provided by the UI layer).""" + + global _progress_handler + _progress_handler = progress + + +def get_progress_handler() -> Optional[FactorProgressProtocol]: + """Return the currently registered progress handler if any.""" + + return _progress_handler + diff --git a/app/ui/progress_state.py b/app/ui/progress_state.py index 230512a..6b2155d 100644 --- a/app/ui/progress_state.py +++ b/app/ui/progress_state.py @@ -5,6 +5,12 @@ from typing import Optional, Dict, Any import time import streamlit as st +try: + from app.features.progress import register_progress_handler +except ImportError: # pragma: no cover - optional dependency + def register_progress_handler(handler: object) -> None: + return + class FactorProgressState: """因子计算进度状态管理类""" @@ -151,6 +157,7 @@ class FactorProgressState: # 全局进度状态实例 factor_progress = FactorProgressState() +register_progress_handler(factor_progress) def render_factor_progress() -> None: diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index fe474f2..9c38c42 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -44,13 +44,14 @@ def render_backtest_review() -> None: app_cfg = get_config() default_start, default_end = default_backtest_range(window_days=60) LOGGER.debug( - "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%s", default_start, default_end, "000001.SZ", 0.035, -0.015, 10, + get_config().portfolio.initial_capital, extra=LOG_EXTRA, ) @@ -59,10 +60,24 @@ def render_backtest_review() -> None: start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe") - col_target, col_stop, col_hold = st.columns(3) + col_target, col_stop, col_hold, col_cap = st.columns(4) target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target") stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop") hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days") + initial_capital = col_cap.number_input( + "组合初始资金", + value=float(app_cfg.portfolio.initial_capital), + step=100000.0, + format="%.0f", + key="bt_initial_capital", + ) + initial_capital = max(0.0, float(initial_capital)) + backtest_params = { + "target": float(target), + "stop": float(stop), + "hold_days": int(hold_days), + "initial_capital": initial_capital, + } structure_options = [item.value for item in GameStructure] selected_structure_values = st.multiselect( "选择博弈框架", @@ -74,13 +89,14 @@ def render_backtest_review() -> None: selected_structure_values = [GameStructure.REPEATED.value] selected_structures = [GameStructure(value) for value in selected_structure_values] LOGGER.debug( - "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", + "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s initial_capital=%.2f", start_date, end_date, universe_text, target, stop, hold_days, + initial_capital, extra=LOG_EXTRA, ) @@ -134,13 +150,14 @@ def render_backtest_review() -> None: try: universe = [code.strip() for code in universe_text.split(',') if code.strip()] LOGGER.info( - "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%.2f", start_date, end_date, universe, target, stop, hold_days, + initial_capital, extra=LOG_EXTRA, ) backtest_cfg = BtConfig( @@ -149,11 +166,7 @@ def render_backtest_review() -> None: start_date=start_date, end_date=end_date, universe=universe, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, + params=dict(backtest_params), game_structures=selected_structures, ) result = run_backtest(backtest_cfg, decision_callback=_decision_callback) @@ -665,11 +678,7 @@ def render_backtest_review() -> None: start_date=start_date, end_date=end_date, universe=universe_env, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, + params=dict(backtest_params), method=app_cfg.decision_method, game_structures=selected_structures, ) @@ -942,11 +951,7 @@ def render_backtest_review() -> None: start_date=start_date, end_date=end_date, universe=universe_env, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, + params=dict(backtest_params), method=app_cfg.decision_method, game_structures=selected_structures, ) diff --git a/app/ui/views/pool.py b/app/ui/views/pool.py index 5956704..675f7ba 100644 --- a/app/ui/views/pool.py +++ b/app/ui/views/pool.py @@ -1,6 +1,8 @@ """投资池与仓位概览页面。""" from __future__ import annotations +from datetime import date, timedelta + import pandas as pd import plotly.express as px import streamlit as st @@ -106,28 +108,44 @@ def render_pool_overview() -> None: st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") + default_compare_a = latest_date or date.today() + default_compare_b = default_compare_a - timedelta(days=1) + if default_compare_b > default_compare_a: + default_compare_b = default_compare_a + compare_col1, compare_col2 = st.columns(2) + compare_date1 = compare_col1.date_input( + "日志对比日期 A", + value=default_compare_a, + key="pool_compare_date_a", + ) + compare_date2 = compare_col2.date_input( + "日志对比日期 B", + value=default_compare_b, + key="pool_compare_date_b", + ) + if st.button("执行对比", type="secondary"): with st.spinner("执行日志对比分析中..."): try: with db_session(read_only=True) as conn: - query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined] - query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined] + query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" + query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" logs1 = conn.execute( "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", (query_date1, query_date2), ).fetchall() - query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined] - query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined] + query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" + query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" logs2 = conn.execute( "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", (query_date3, query_date4), ).fetchall() df1 = pd.DataFrame(logs1, columns=["level", "count"]) - df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined] + df1["date"] = compare_date1.strftime("%Y-%m-%d") df2 = pd.DataFrame(logs2, columns=["level", "count"]) - df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined] + df2["date"] = compare_date2.strftime("%Y-%m-%d") for df in (df1, df2): for col in df.columns: @@ -141,13 +159,13 @@ def render_pool_overview() -> None: y="count", color="date", barmode="group", - title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined] + title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", ) st.plotly_chart(fig, width='stretch') st.write("日志统计对比:") - date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined] - date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined] + date1_str = compare_date1.strftime("%Y%m%d") + date2_str = compare_date2.strftime("%Y%m%d") merged_df = df1.merge( df2, on="level", diff --git a/app/ui/views/today.py b/app/ui/views/today.py index 808a5a8..856eac7 100644 --- a/app/ui/views/today.py +++ b/app/ui/views/today.py @@ -23,6 +23,107 @@ from app.ui.shared import ( ) +def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]: + unique_symbols = list(dict.fromkeys(symbols)) + if not unique_symbols: + return {} + placeholder = ", ".join("?" for _ in unique_symbols) + sql = ( + f""" + SELECT ts_code, agent, action + FROM agent_utils + WHERE trade_date = ? + AND ts_code IN ({placeholder}) + """ + ) + params: List[object] = [trade_date, *unique_symbols] + with db_session(read_only=True) as conn: + rows = conn.execute(sql, params).fetchall() + mapping: Dict[str, Dict[str, Optional[str]]] = {} + for row in rows: + ts_code = row["ts_code"] + agent = row["agent"] + action = row["action"] + mapping.setdefault(ts_code, {})[agent] = action + return mapping + + +def _reevaluate_symbols( + trade_date_obj: date, + symbols: List[str], + cfg_id: str, + cfg_name: str, +) -> List[Dict[str, object]]: + unique_symbols = list(dict.fromkeys(symbols)) + if not unique_symbols: + return [] + trade_date_str = trade_date_obj.isoformat() + before_map = _fetch_agent_actions(trade_date_str, unique_symbols) + engine_params: Dict[str, object] = {} + target_val = st.session_state.get("bt_target") + if target_val is not None: + try: + engine_params["target"] = float(target_val) + except (TypeError, ValueError): + pass + stop_val = st.session_state.get("bt_stop") + if stop_val is not None: + try: + engine_params["stop"] = float(stop_val) + except (TypeError, ValueError): + pass + hold_val = st.session_state.get("bt_hold_days") + if hold_val is not None: + try: + engine_params["hold_days"] = int(hold_val) + except (TypeError, ValueError): + pass + capital_val = st.session_state.get("bt_initial_capital") + if capital_val is not None: + try: + engine_params["initial_capital"] = float(capital_val) + except (TypeError, ValueError): + pass + cfg = BtConfig( + id=cfg_id, + name=cfg_name, + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=unique_symbols, + params=engine_params, + ) + engine = BacktestEngine(cfg) + state = PortfolioState(cash=engine.initial_cash) + engine.simulate_day(trade_date_obj, state) + after_map = _fetch_agent_actions(trade_date_str, unique_symbols) + changes: List[Dict[str, object]] = [] + for code in unique_symbols: + before_agents = before_map.get(code, {}) + after_agents = after_map.get(code, {}) + for agent, new_action in after_agents.items(): + old_action = before_agents.get(agent) + if new_action != old_action: + changes.append( + { + "代码": code, + "代理": agent, + "原动作": old_action, + "新动作": new_action, + } + ) + for agent, old_action in before_agents.items(): + if agent not in after_agents: + changes.append( + { + "代码": code, + "代理": agent, + "原动作": old_action, + "新动作": None, + } + ) + return changes + + def render_today_plan() -> None: LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) st.header("今日计划") @@ -176,56 +277,15 @@ def _render_today_plan_symbol_view( raise ValueError(f"无法解析交易日:{trade_date}") progress = st.progress(0.0) - changes_all: List[Dict[str, object]] = [] - success_count = 0 - error_count = 0 - - for idx, code in enumerate(symbols, start=1): - try: - with db_session(read_only=True) as conn: - before_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - before_map = {row["agent"]: row["action"] for row in before_rows} - - cfg = BtConfig( - id="reeval_ui_all", - name="UI All Re-eval", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[code], - params={}, - ) - engine = BacktestEngine(cfg) - state = PortfolioState() - engine.simulate_day(trade_date_obj, state) - - with db_session(read_only=True) as conn: - after_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action = before_map.get(agent) - if new_action != old_action: - changes_all.append( - {"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action} - ) - success_count += 1 - except Exception: # noqa: BLE001 - LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA) - error_count += 1 - - progress.progress(idx / len(symbols)) - - if error_count > 0: - st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个") - else: - st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功") - + progress.progress(0.3 if symbols else 1.0) + changes_all = _reevaluate_symbols( + trade_date_obj, + symbols, + "reeval_ui_all", + "UI All Re-eval", + ) + progress.progress(1.0) + st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的") if changes_all: st.write("检测到以下动作变更:") st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') @@ -579,44 +639,20 @@ def _render_today_plan_symbol_view( pass if trade_date_obj is None: raise ValueError(f"无法解析交易日:{trade_date}") - with db_session(read_only=True) as conn: - before_rows = conn.execute( - """ - SELECT agent, action, utils FROM agent_utils - WHERE trade_date = ? AND ts_code = ? - """, - (trade_date, ts_code), - ).fetchall() - before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows} - cfg = BtConfig( - id="reeval_ui", - name="UI Re-evaluation", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[ts_code], - params={}, + changes = _reevaluate_symbols( + trade_date_obj, + [ts_code], + "reeval_ui", + "UI Re-evaluation", ) - engine = BacktestEngine(cfg) - state = PortfolioState() - engine.simulate_day(trade_date_obj, state) - with db_session(read_only=True) as conn: - after_rows = conn.execute( - """ - SELECT agent, action, utils FROM agent_utils - WHERE trade_date = ? AND ts_code = ? - """, - (trade_date, ts_code), - ).fetchall() - changes = [] - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action, _old_utils = before_map.get(agent, (None, None)) - if new_action != old_action: - changes.append({"代理": agent, "原动作": old_action, "新动作": new_action}) if changes: + for change in changes: + change.setdefault("代码", ts_code) + df_changes = pd.DataFrame(changes) + if "代码" in df_changes.columns: + df_changes = df_changes[["代码", "代理", "原动作", "新动作"]] st.success("重评估完成,检测到动作变更:") - st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') + st.dataframe(df_changes, hide_index=True, width='stretch') else: st.success("重评估完成,无动作变更。") st.rerun() @@ -637,38 +673,15 @@ def _render_today_plan_symbol_view( if trade_date_obj is None: raise ValueError(f"无法解析交易日:{trade_date}") progress = st.progress(0.0) - changes_all: List[Dict[str, object]] = [] - for idx, code in enumerate(batch_symbols, start=1): - with db_session(read_only=True) as conn: - before_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - before_map = {row["agent"]: row["action"] for row in before_rows} - cfg = BtConfig( - id="reeval_ui_batch", - name="UI Batch Re-eval", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[code], - params={}, - ) - engine = BacktestEngine(cfg) - state = PortfolioState() - engine.simulate_day(trade_date_obj, state) - with db_session(read_only=True) as conn: - after_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action = before_map.get(agent) - if new_action != old_action: - changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) - progress.progress(idx / max(1, len(batch_symbols))) - st.success("批量重评估完成。") + progress.progress(0.3 if batch_symbols else 1.0) + changes_all = _reevaluate_symbols( + trade_date_obj, + batch_symbols, + "reeval_ui_batch", + "UI Batch Re-eval", + ) + progress.progress(1.0) + st.success(f"批量重评估完成:共处理 {len(batch_symbols)} 个标的") if changes_all: st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') st.rerun() diff --git a/app/utils/data_access.py b/app/utils/data_access.py index c7102b5..1f750be 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -4,6 +4,7 @@ from __future__ import annotations import re import sqlite3 import threading +import time from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass, field @@ -577,8 +578,7 @@ class DataBroker: Returns: 新闻数据列表,包含sentiment、heat、entities等字段 """ - # 简化实现:返回模拟数据 - # 在实际应用中,这里应该查询新闻数据库 + # TODO: 使用真实新闻数据库替换随机生成的占位数据 return [ { "sentiment": np.random.uniform(-1, 1), @@ -597,8 +597,7 @@ class DataBroker: Returns: 行业代码或名称,找不到时返回None """ - # 简化实现:返回模拟行业 - # 在实际应用中,这里应该查询股票行业信息 + # TODO: 替换为真实行业映射逻辑(当前仅为占位数据) industry_mapping = { "000001.SZ": "银行", "000002.SZ": "房地产", @@ -617,8 +616,7 @@ class DataBroker: Returns: 行业情绪得分,找不到时返回None """ - # 简化实现:返回模拟情绪得分 - # 在实际应用中,这里应该基于行业新闻计算情绪 + # TODO: 接入行业情绪数据源,当前随机值仅用于占位显示 return np.random.uniform(-1, 1) def get_industry_stocks(self, industry: str) -> List[str]: @@ -630,8 +628,7 @@ class DataBroker: Returns: 同行业股票代码列表 """ - # 简化实现:返回模拟股票列表 - # 在实际应用中,这里应该查询行业股票列表 + # TODO: 使用实际行业成分数据替换占位列表 industry_stocks = { "银行": ["000001.SZ", "002142.SZ", "600036.SH"], "房地产": ["000002.SZ", "000402.SZ", "600048.SH"], @@ -653,10 +650,33 @@ class DataBroker: if not _is_safe_identifier(table): return False - query = ( - f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1" - ) - bind_params = (ts_code, *params) + parsed_date = _parse_trade_date(trade_date) + trade_key = parsed_date.strftime("%Y%m%d") if parsed_date else str(trade_date) + + if auto_refresh and parsed_date and self.check_data_availability(trade_key, {table}): + self._trigger_background_refresh(trade_key) + if hasattr(time, "sleep"): + time.sleep(0.5) + + if table == "suspend": + query = ( + "SELECT 1 FROM suspend " + "WHERE ts_code = ? " + "AND suspend_date <= ? " + "AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) " + "LIMIT 1" + ) + bind_params = (ts_code, trade_key, trade_key) + else: + clauses = ["ts_code = ?"] + bind_params_list: List[object] = [ts_code] + clause_text = (where_clause or "").strip() + if clause_text: + clauses.append(clause_text) + bind_params_list.extend(params) + query = f"SELECT 1 FROM {table} WHERE {' AND '.join(clauses)} LIMIT 1" + bind_params = tuple(bind_params_list) + try: with db_session(read_only=True) as conn: try: