refactor backtest engine with trading rules and progress tracking

This commit is contained in:
Your Name 2025-10-11 09:27:55 +08:00
parent 8aa8efb651
commit 90fb2a9df6
10 changed files with 515 additions and 219 deletions

View File

@ -204,22 +204,29 @@ class DepartmentAgent:
rounds_executed = round_idx + 1 rounds_executed = round_idx + 1
usage = response.get("usage") if isinstance(response, Mapping) else None message, usage_payload, tool_calls = _normalize_llm_response(response)
if isinstance(usage, Mapping): if usage_payload:
usage_payload = {"round": round_idx + 1} payload_with_round = {"round": round_idx + 1}
usage_payload.update(dict(usage)) payload_with_round.update(usage_payload)
usage_records.append(usage_payload) usage_records.append(payload_with_round)
choice = (response.get("choices") or [{}])[0] if not message:
message = choice.get("message", {}) LOGGER.debug(
"部门 %s%s 轮响应缺少 message 字段:%s",
self.settings.code,
round_idx + 1,
response,
extra=LOG_EXTRA,
)
message = {"role": "assistant", "content": ""}
transcript.append(_message_to_text(message)) transcript.append(_message_to_text(message))
assistant_record: Dict[str, Any] = { assistant_record: Dict[str, Any] = {
"role": "assistant", "role": message.get("role", "assistant"),
"content": _extract_message_content(message), "content": _extract_message_content(message),
} }
if message.get("tool_calls"): if tool_calls:
assistant_record["tool_calls"] = message.get("tool_calls") assistant_record["tool_calls"] = tool_calls
messages.append(assistant_record) messages.append(assistant_record)
CONV_LOGGER.info( CONV_LOGGER.info(
"dept=%s round=%s assistant=%s", "dept=%s round=%s assistant=%s",
@ -228,7 +235,6 @@ class DepartmentAgent:
assistant_record, assistant_record,
) )
tool_calls = message.get("tool_calls") or []
if tool_calls: if tool_calls:
for call in tool_calls: for call in tool_calls:
function_block = call.get("function") or {} function_block = call.get("function") or {}
@ -656,6 +662,8 @@ class DepartmentAgent:
dialogue=[response], dialogue=[response],
) )
return decision return decision
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
if not isinstance(context.features, dict): if not isinstance(context.features, dict):
context.features = dict(context.features or {}) context.features = dict(context.features or {})
@ -669,6 +677,77 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
return context return context
def _compose_usage_from_stats(payload: Mapping[str, Any]) -> Dict[str, Any]:
usage: Dict[str, Any] = {}
prompt_eval = payload.get("prompt_eval_count")
completion_eval = payload.get("eval_count")
if isinstance(prompt_eval, (int, float)):
usage["prompt_tokens"] = int(prompt_eval)
if isinstance(completion_eval, (int, float)):
usage["completion_tokens"] = int(completion_eval)
if usage:
total = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
usage["total_tokens"] = total
return usage
def _normalize_llm_response(
response: Mapping[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, Any]]]:
message: Dict[str, Any] = {}
usage: Dict[str, Any] = {}
tool_calls: List[Dict[str, Any]] = []
if not isinstance(response, Mapping):
return message, usage, tool_calls
choices = response.get("choices")
if isinstance(choices, list) and choices:
choice = choices[0] or {}
candidate = choice.get("message")
if isinstance(candidate, Mapping):
message = candidate
raw_calls = candidate.get("tool_calls")
if isinstance(raw_calls, list):
tool_calls = list(raw_calls)
raw_usage = response.get("usage")
if isinstance(raw_usage, Mapping):
usage = dict(raw_usage)
else:
raw_message = response.get("message")
if isinstance(raw_message, Mapping):
message = raw_message
raw_calls = raw_message.get("tool_calls")
if isinstance(raw_calls, list):
tool_calls = list(raw_calls)
elif isinstance(response.get("messages"), list):
messages_list = response.get("messages") or []
if messages_list:
candidate = messages_list[-1]
if isinstance(candidate, Mapping):
message = candidate
raw_calls = candidate.get("tool_calls")
if isinstance(raw_calls, list):
tool_calls = list(raw_calls)
if not message:
content = response.get("content")
if isinstance(content, str):
message = {"role": "assistant", "content": content}
raw_usage = response.get("usage")
if isinstance(raw_usage, Mapping):
usage = dict(raw_usage)
else:
usage = _compose_usage_from_stats(response)
if not tool_calls:
extra = message.get("additional_kwargs")
if isinstance(extra, Mapping):
extra_calls = extra.get("tool_calls")
if isinstance(extra_calls, list):
tool_calls = list(extra_calls)
return message or {"role": "assistant", "content": ""}, usage, tool_calls
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict): if isinstance(payload, dict):
return dict(payload) return dict(payload)

View File

@ -1,6 +1,8 @@
"""Value and quality filtering agent.""" """Value and quality filtering agent."""
from __future__ import annotations from __future__ import annotations
from typing import Mapping
from .base import Agent, AgentAction, AgentContext from .base import Agent, AgentAction, AgentContext
@ -9,12 +11,19 @@ class ValueAgent(Agent):
super().__init__(name="A_val") super().__init__(name="A_val")
def score(self, context: AgentContext, action: AgentAction) -> float: def score(self, context: AgentContext, action: AgentAction) -> float:
pe = context.features.get("pe_percentile", 0.5) pe_score = context.features.get("valuation_pe_score", 0.0)
pb = context.features.get("pb_percentile", 0.5) pb_score = context.features.get("valuation_pb_score", 0.0)
roe = context.features.get("roe_percentile", 0.5) # 多因子组合尚未落地,这里兼容扩展因子(若存在则优先使用)
# Lower valuation percentiles and higher quality percentiles add value. scope_values = {}
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3) if isinstance(context.raw, Mapping):
raw = min(raw, 1.0) scope_values = context.raw.get("scope_values", {}) or {}
multi_score = context.features.get("val_multiscore")
if multi_score is None:
multi_score = scope_values.get("factors.val_multiscore")
if multi_score is not None:
raw = float(max(0.0, min(1.0, multi_score)))
else:
raw = max(0.0, min(1.0, 0.6 * pe_score + 0.4 * pb_score))
if action is AgentAction.SELL: if action is AgentAction.SELL:
return 1 - raw return 1 - raw
if action is AgentAction.HOLD: if action is AgentAction.HOLD:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date from datetime import date, datetime
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from app.agents.base import AgentAction, AgentContext from app.agents.base import AgentAction, AgentContext
@ -16,7 +16,7 @@ from app.llm.metrics import record_decision as metrics_record_decision
from app.agents.registry import default_agents from app.agents.registry import default_agents
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.utils.data_access import DataBroker from app.utils.data_access import DataBroker
from app.utils.config import get_config from app.utils.config import PortfolioSettings, get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
from app.utils import alerts from app.utils import alerts
@ -105,12 +105,26 @@ class BacktestEngine:
) )
self.data_broker = DataBroker() self.data_broker = DataBroker()
params = cfg.params or {} params = cfg.params or {}
portfolio_cfg = getattr(app_cfg, "portfolio", None) or PortfolioSettings()
self.risk_params = { self.risk_params = {
"max_position_weight": float(params.get("max_position_weight", 0.2)), "max_position_weight": float(params.get("max_position_weight", 0.2)),
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)), "max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
"fee_rate": float(params.get("fee_rate", 0.0005)), "fee_rate": float(params.get("fee_rate", 0.0005)),
"slippage_bps": float(params.get("slippage_bps", 10.0)), "slippage_bps": float(params.get("slippage_bps", 10.0)),
} }
self.initial_cash = max(0.0, float(params.get("initial_capital", portfolio_cfg.initial_capital)))
target_return = params.get("target", params.get("target_return", 0.0)) or 0.0
stop_loss = params.get("stop", params.get("stop_loss", 0.0)) or 0.0
hold_days_param = params.get("hold_days", params.get("max_hold_days", 0))
try:
max_hold_days = int(hold_days_param) if hold_days_param is not None else 0
except (TypeError, ValueError):
max_hold_days = 0
self.trading_rules = {
"target_return": float(target_return),
"stop_loss": float(stop_loss),
"max_hold_days": max(0, max_hold_days),
}
self._fee_rate = max(self.risk_params["fee_rate"], 0.0) self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0 self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0) self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
@ -314,9 +328,9 @@ class BacktestEngine:
is_suspended = self.data_broker.fetch_flags( is_suspended = self.data_broker.fetch_flags(
"suspend", "suspend",
ts_code, ts_code,
trade_date, trade_date_str,
"ts_code = ?", "",
[ts_code], [],
auto_refresh=False, # 避免在回测中触发自动补数 auto_refresh=False, # 避免在回测中触发自动补数
) )
@ -650,6 +664,7 @@ class BacktestEngine:
continue continue
features = feature_cache.get(ts_code, {}) features = feature_cache.get(ts_code, {})
current_qty = state.holdings.get(ts_code, 0.0) current_qty = state.holdings.get(ts_code, 0.0)
current_cost_basis = float(state.cost_basis.get(ts_code, 0.0) or 0.0)
liquidity_score = float(features.get("liquidity_score") or 0.0) liquidity_score = float(features.get("liquidity_score") or 0.0)
risk_penalty = float(features.get("risk_penalty") or 0.0) risk_penalty = float(features.get("risk_penalty") or 0.0)
is_suspended = bool(features.get("is_suspended")) is_suspended = bool(features.get("is_suspended"))
@ -679,6 +694,76 @@ class BacktestEngine:
if risk.status == "blocked": if risk.status == "blocked":
continue continue
rule_override_action: Optional[AgentAction] = None
rule_override_reason: Optional[str] = None
gain_ratio: Optional[float] = None
if current_qty > 0 and current_cost_basis:
try:
gain_ratio = (price / current_cost_basis) - 1.0
except ZeroDivisionError:
gain_ratio = None
target_return = self.trading_rules.get("target_return", 0.0)
if (
rule_override_action is None
and gain_ratio is not None
and target_return
and gain_ratio >= target_return
):
rule_override_action = AgentAction.SELL
rule_override_reason = "target_reached"
stop_loss = self.trading_rules.get("stop_loss", 0.0)
if (
rule_override_action is None
and gain_ratio is not None
and stop_loss
and gain_ratio <= stop_loss
):
rule_override_action = AgentAction.SELL
rule_override_reason = "stop_loss"
max_hold_days = self.trading_rules.get("max_hold_days", 0)
if (
rule_override_action is None
and max_hold_days
and max_hold_days > 0
and current_qty > 0
):
opened_str = state.opened_dates.get(ts_code)
opened_dt: Optional[date] = None
if opened_str:
try:
opened_dt = date.fromisoformat(str(opened_str))
except ValueError:
try:
opened_dt = datetime.strptime(str(opened_str), "%Y%m%d").date()
except ValueError:
opened_dt = None
LOGGER.debug(
"无法解析持仓日期 ts_code=%s value=%s",
ts_code,
opened_str,
extra=LOG_EXTRA,
)
if opened_dt:
holding_days = (trade_date - opened_dt).days
if holding_days >= max_hold_days:
rule_override_action = AgentAction.SELL
rule_override_reason = "holding_period"
if rule_override_action and rule_override_action is not effective_action:
effective_action = rule_override_action
effective_weight = target_weight_for_action(effective_action)
_record_risk(
ts_code,
rule_override_reason or "rule_override",
decision,
extra={
"rule_trigger": rule_override_reason,
"gain_ratio": gain_ratio,
},
action_override=effective_action,
target_weight_override=effective_weight,
)
if is_suspended: if is_suspended:
_record_risk(ts_code, "suspended", decision) _record_risk(ts_code, "suspended", decision)
continue continue
@ -738,8 +823,7 @@ class BacktestEngine:
if total_cash_needed <= 0: if total_cash_needed <= 0:
_record_risk(ts_code, "invalid_trade", decision) _record_risk(ts_code, "invalid_trade", decision)
continue continue
previous_cost = current_cost_basis * current_qty
previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty
new_qty = current_qty + delta new_qty = current_qty + delta
state.cost_basis[ts_code] = ( state.cost_basis[ts_code] = (
(previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0 (previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
@ -777,8 +861,7 @@ class BacktestEngine:
gross_value = sell_qty * trade_price gross_value = sell_qty * trade_price
fee = gross_value * self._fee_rate fee = gross_value * self._fee_rate
proceeds = gross_value - fee proceeds = gross_value - fee
cost_basis = state.cost_basis.get(ts_code, 0.0) realized = (trade_price - current_cost_basis) * sell_qty - fee
realized = (trade_price - cost_basis) * sell_qty - fee
state.cash += proceeds state.cash += proceeds
state.realized_pnl += realized state.realized_pnl += realized
new_qty = current_qty - sell_qty new_qty = current_qty - sell_qty
@ -1023,7 +1106,7 @@ class BacktestEngine:
"""Initialise a new incremental backtest session.""" """Initialise a new incremental backtest session."""
return BacktestSession( return BacktestSession(
state=PortfolioState(), state=PortfolioState(cash=self.initial_cash),
result=BacktestResult(), result=BacktestResult(),
current_date=self.cfg.start_date, current_date=self.cfg.start_date,
) )

View File

@ -19,7 +19,11 @@ from app.features.value_risk_factors import ValueRiskFactors
# 导入因子验证功能 # 导入因子验证功能
from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers
# 导入UI进度状态管理 # 导入UI进度状态管理
from app.ui.progress_state import factor_progress try:
from app.features.progress import get_progress_handler
except ImportError: # pragma: no cover - optional dependency
def get_progress_handler():
return None
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -176,14 +180,16 @@ def compute_factors(
broker = DataBroker() broker = DataBroker()
results: List[FactorResult] = [] results: List[FactorResult] = []
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = [] rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
total_batches = (len(universe) + batch_size - 1) // batch_size if universe else 0
progress = get_progress_handler()
if progress and universe:
try:
progress.start_calculation(len(universe), total_batches)
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler start_calculation 失败", extra=LOG_EXTRA)
progress = None
try: try:
# 启动UI进度状态在异步线程中不直接访问factor_progress
# factor_progress.start_calculation(
# total_securities=len(universe),
# total_batches=(len(universe) + batch_size - 1) // batch_size
# )
# 分批处理以优化性能 # 分批处理以优化性能
for i in range(0, len(universe), batch_size): for i in range(0, len(universe), batch_size):
batch = universe[i:i+batch_size] batch = universe[i:i+batch_size]
@ -194,9 +200,10 @@ def compute_factors(
specs, specs,
validation_stats, validation_stats,
batch_index=i // batch_size, batch_index=i // batch_size,
total_batches=(len(universe) + batch_size - 1) // batch_size, total_batches=total_batches or 1,
processed_securities=i, processed_securities=i,
total_securities=len(universe) total_securities=len(universe),
progress=progress,
) )
for ts_code, values in batch_results: for ts_code, values in batch_results:
@ -222,9 +229,13 @@ def compute_factors(
_persist_factor_rows(trade_date_str, rows_to_persist, specs) _persist_factor_rows(trade_date_str, rows_to_persist, specs)
# 更新UI进度状态为完成 # 更新UI进度状态为完成
factor_progress.complete_calculation( if progress:
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}" try:
) progress.complete_calculation(
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}"
)
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler complete_calculation 失败", extra=LOG_EXTRA)
LOGGER.info( LOGGER.info(
"因子计算完成 总数量:%s 成功:%s 失败:%s", "因子计算完成 总数量:%s 成功:%s 失败:%s",
@ -239,7 +250,11 @@ def compute_factors(
except Exception as exc: except Exception as exc:
# 发生错误时更新UI状态 # 发生错误时更新UI状态
error_message = f"因子计算过程中发生错误: {exc}" error_message = f"因子计算过程中发生错误: {exc}"
factor_progress.error_occurred(error_message) if progress:
try:
progress.error_occurred(error_message)
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler error_occurred 失败", extra=LOG_EXTRA)
LOGGER.error(error_message, extra=LOG_EXTRA) LOGGER.error(error_message, extra=LOG_EXTRA)
raise raise
@ -380,6 +395,7 @@ def _compute_batch_factors(
total_batches: int = 1, total_batches: int = 1,
processed_securities: int = 0, processed_securities: int = 0,
total_securities: int = 0, total_securities: int = 0,
progress: Optional[object] = None,
) -> List[tuple[str, Dict[str, float | None]]]: ) -> List[tuple[str, Dict[str, float | None]]]:
"""批量计算多个证券的因子值,提高计算效率""" """批量计算多个证券的因子值,提高计算效率"""
batch_results = [] batch_results = []
@ -388,13 +404,16 @@ def _compute_batch_factors(
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
# 更新UI进度状态 - 开始处理批次 # 更新UI进度状态 - 开始处理批次
if total_securities > 0: if progress and total_securities > 0:
from app.ui.progress_state import factor_progress try:
factor_progress.update_progress( progress.update_progress(
current_securities=processed_securities, current_securities=processed_securities,
current_batch=batch_index + 1, current_batch=batch_index + 1,
message=f"开始处理批次 {batch_index + 1}/{total_batches}" message=f"开始处理批次 {batch_index + 1}/{total_batches}",
) )
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
progress = None
for i, ts_code in enumerate(ts_codes): for i, ts_code in enumerate(ts_codes):
try: try:
@ -434,15 +453,18 @@ def _compute_batch_factors(
validation_stats["skipped"] += 1 validation_stats["skipped"] += 1
# 每处理1个证券更新一次进度确保实时性 # 每处理1个证券更新一次进度确保实时性
if total_securities > 0: if progress and total_securities > 0:
current_progress = processed_securities + i + 1 current_progress = processed_securities + i + 1
progress_percentage = (current_progress / total_securities) * 100 progress_percentage = (current_progress / total_securities) * 100
from app.ui.progress_state import factor_progress try:
factor_progress.update_progress( progress.update_progress(
current_securities=current_progress, current_securities=current_progress,
current_batch=batch_index + 1, current_batch=batch_index + 1,
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)" message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)",
) )
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
progress = None
except Exception as e: except Exception as e:
LOGGER.error( LOGGER.error(
"计算因子失败 ts_code=%s err=%s", "计算因子失败 ts_code=%s err=%s",
@ -453,16 +475,18 @@ def _compute_batch_factors(
validation_stats["skipped"] += 1 validation_stats["skipped"] += 1
# 批次处理完成,更新最终进度 # 批次处理完成,更新最终进度
if total_securities > 0: if progress and total_securities > 0:
final_progress = processed_securities + len(ts_codes) final_progress = processed_securities + len(ts_codes)
progress_percentage = (final_progress / total_securities) * 100 progress_percentage = (final_progress / total_securities) * 100
from app.ui.progress_state import factor_progress try:
factor_progress.update_progress( progress.update_progress(
current_securities=final_progress, current_securities=final_progress,
current_batch=batch_index + 1, current_batch=batch_index + 1,
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)" message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)",
) )
except Exception: # noqa: BLE001
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
return batch_results return batch_results

38
app/features/progress.py Normal file
View File

@ -0,0 +1,38 @@
"""Optional progress reporting for factor computation."""
from __future__ import annotations
from typing import Optional, Protocol
class FactorProgressProtocol(Protocol):
"""Protocol describing the optional UI progress handler."""
def start_calculation(self, total_securities: int, total_batches: int) -> None: ...
def update_progress(
self,
current_securities: int,
current_batch: int,
message: str = "",
) -> None: ...
def complete_calculation(self, message: str = "") -> None: ...
def error_occurred(self, error_message: str) -> None: ...
_progress_handler: Optional[FactorProgressProtocol] = None
def register_progress_handler(progress: FactorProgressProtocol | None) -> None:
"""Register a progress handler (typically provided by the UI layer)."""
global _progress_handler
_progress_handler = progress
def get_progress_handler() -> Optional[FactorProgressProtocol]:
"""Return the currently registered progress handler if any."""
return _progress_handler

View File

@ -5,6 +5,12 @@ from typing import Optional, Dict, Any
import time import time
import streamlit as st import streamlit as st
try:
from app.features.progress import register_progress_handler
except ImportError: # pragma: no cover - optional dependency
def register_progress_handler(handler: object) -> None:
return
class FactorProgressState: class FactorProgressState:
"""因子计算进度状态管理类""" """因子计算进度状态管理类"""
@ -151,6 +157,7 @@ class FactorProgressState:
# 全局进度状态实例 # 全局进度状态实例
factor_progress = FactorProgressState() factor_progress = FactorProgressState()
register_progress_handler(factor_progress)
def render_factor_progress() -> None: def render_factor_progress() -> None:

View File

@ -44,13 +44,14 @@ def render_backtest_review() -> None:
app_cfg = get_config() app_cfg = get_config()
default_start, default_end = default_backtest_range(window_days=60) default_start, default_end = default_backtest_range(window_days=60)
LOGGER.debug( LOGGER.debug(
"回测默认参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", "回测默认参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%s",
default_start, default_start,
default_end, default_end,
"000001.SZ", "000001.SZ",
0.035, 0.035,
-0.015, -0.015,
10, 10,
get_config().portfolio.initial_capital,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
@ -59,10 +60,24 @@ def render_backtest_review() -> None:
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe") universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
col_target, col_stop, col_hold = st.columns(3) col_target, col_stop, col_hold, col_cap = st.columns(4)
target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target") target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target")
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop") stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days") hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
initial_capital = col_cap.number_input(
"组合初始资金",
value=float(app_cfg.portfolio.initial_capital),
step=100000.0,
format="%.0f",
key="bt_initial_capital",
)
initial_capital = max(0.0, float(initial_capital))
backtest_params = {
"target": float(target),
"stop": float(stop),
"hold_days": int(hold_days),
"initial_capital": initial_capital,
}
structure_options = [item.value for item in GameStructure] structure_options = [item.value for item in GameStructure]
selected_structure_values = st.multiselect( selected_structure_values = st.multiselect(
"选择博弈框架", "选择博弈框架",
@ -74,13 +89,14 @@ def render_backtest_review() -> None:
selected_structure_values = [GameStructure.REPEATED.value] selected_structure_values = [GameStructure.REPEATED.value]
selected_structures = [GameStructure(value) for value in selected_structure_values] selected_structures = [GameStructure(value) for value in selected_structure_values]
LOGGER.debug( LOGGER.debug(
"当前回测表单输入start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", "当前回测表单输入start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s initial_capital=%.2f",
start_date, start_date,
end_date, end_date,
universe_text, universe_text,
target, target,
stop, stop,
hold_days, hold_days,
initial_capital,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
@ -134,13 +150,14 @@ def render_backtest_review() -> None:
try: try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()] universe = [code.strip() for code in universe_text.split(',') if code.strip()]
LOGGER.info( LOGGER.info(
"回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", "回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%.2f",
start_date, start_date,
end_date, end_date,
universe, universe,
target, target,
stop, stop,
hold_days, hold_days,
initial_capital,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
backtest_cfg = BtConfig( backtest_cfg = BtConfig(
@ -149,11 +166,7 @@ def render_backtest_review() -> None:
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
universe=universe, universe=universe,
params={ params=dict(backtest_params),
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
game_structures=selected_structures, game_structures=selected_structures,
) )
result = run_backtest(backtest_cfg, decision_callback=_decision_callback) result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
@ -665,11 +678,7 @@ def render_backtest_review() -> None:
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
universe=universe_env, universe=universe_env,
params={ params=dict(backtest_params),
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=app_cfg.decision_method, method=app_cfg.decision_method,
game_structures=selected_structures, game_structures=selected_structures,
) )
@ -942,11 +951,7 @@ def render_backtest_review() -> None:
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
universe=universe_env, universe=universe_env,
params={ params=dict(backtest_params),
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=app_cfg.decision_method, method=app_cfg.decision_method,
game_structures=selected_structures, game_structures=selected_structures,
) )

View File

@ -1,6 +1,8 @@
"""投资池与仓位概览页面。""" """投资池与仓位概览页面。"""
from __future__ import annotations from __future__ import annotations
from datetime import date, timedelta
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import streamlit as st import streamlit as st
@ -106,28 +108,44 @@ def render_pool_overview() -> None:
st.caption("数据来源agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") st.caption("数据来源agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
default_compare_a = latest_date or date.today()
default_compare_b = default_compare_a - timedelta(days=1)
if default_compare_b > default_compare_a:
default_compare_b = default_compare_a
compare_col1, compare_col2 = st.columns(2)
compare_date1 = compare_col1.date_input(
"日志对比日期 A",
value=default_compare_a,
key="pool_compare_date_a",
)
compare_date2 = compare_col2.date_input(
"日志对比日期 B",
value=default_compare_b,
key="pool_compare_date_b",
)
if st.button("执行对比", type="secondary"): if st.button("执行对比", type="secondary"):
with st.spinner("执行日志对比分析中..."): with st.spinner("执行日志对比分析中..."):
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined] query_date1 = f"{compare_date1.isoformat()}T00:00:00Z"
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined] query_date2 = f"{compare_date1.isoformat()}T23:59:59Z"
logs1 = conn.execute( logs1 = conn.execute(
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
(query_date1, query_date2), (query_date1, query_date2),
).fetchall() ).fetchall()
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined] query_date3 = f"{compare_date2.isoformat()}T00:00:00Z"
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined] query_date4 = f"{compare_date2.isoformat()}T23:59:59Z"
logs2 = conn.execute( logs2 = conn.execute(
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
(query_date3, query_date4), (query_date3, query_date4),
).fetchall() ).fetchall()
df1 = pd.DataFrame(logs1, columns=["level", "count"]) df1 = pd.DataFrame(logs1, columns=["level", "count"])
df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined] df1["date"] = compare_date1.strftime("%Y-%m-%d")
df2 = pd.DataFrame(logs2, columns=["level", "count"]) df2 = pd.DataFrame(logs2, columns=["level", "count"])
df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined] df2["date"] = compare_date2.strftime("%Y-%m-%d")
for df in (df1, df2): for df in (df1, df2):
for col in df.columns: for col in df.columns:
@ -141,13 +159,13 @@ def render_pool_overview() -> None:
y="count", y="count",
color="date", color="date",
barmode="group", barmode="group",
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined] title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})",
) )
st.plotly_chart(fig, width='stretch') st.plotly_chart(fig, width='stretch')
st.write("日志统计对比:") st.write("日志统计对比:")
date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined] date1_str = compare_date1.strftime("%Y%m%d")
date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined] date2_str = compare_date2.strftime("%Y%m%d")
merged_df = df1.merge( merged_df = df1.merge(
df2, df2,
on="level", on="level",

View File

@ -23,6 +23,107 @@ from app.ui.shared import (
) )
def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]:
unique_symbols = list(dict.fromkeys(symbols))
if not unique_symbols:
return {}
placeholder = ", ".join("?" for _ in unique_symbols)
sql = (
f"""
SELECT ts_code, agent, action
FROM agent_utils
WHERE trade_date = ?
AND ts_code IN ({placeholder})
"""
)
params: List[object] = [trade_date, *unique_symbols]
with db_session(read_only=True) as conn:
rows = conn.execute(sql, params).fetchall()
mapping: Dict[str, Dict[str, Optional[str]]] = {}
for row in rows:
ts_code = row["ts_code"]
agent = row["agent"]
action = row["action"]
mapping.setdefault(ts_code, {})[agent] = action
return mapping
def _reevaluate_symbols(
trade_date_obj: date,
symbols: List[str],
cfg_id: str,
cfg_name: str,
) -> List[Dict[str, object]]:
unique_symbols = list(dict.fromkeys(symbols))
if not unique_symbols:
return []
trade_date_str = trade_date_obj.isoformat()
before_map = _fetch_agent_actions(trade_date_str, unique_symbols)
engine_params: Dict[str, object] = {}
target_val = st.session_state.get("bt_target")
if target_val is not None:
try:
engine_params["target"] = float(target_val)
except (TypeError, ValueError):
pass
stop_val = st.session_state.get("bt_stop")
if stop_val is not None:
try:
engine_params["stop"] = float(stop_val)
except (TypeError, ValueError):
pass
hold_val = st.session_state.get("bt_hold_days")
if hold_val is not None:
try:
engine_params["hold_days"] = int(hold_val)
except (TypeError, ValueError):
pass
capital_val = st.session_state.get("bt_initial_capital")
if capital_val is not None:
try:
engine_params["initial_capital"] = float(capital_val)
except (TypeError, ValueError):
pass
cfg = BtConfig(
id=cfg_id,
name=cfg_name,
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=unique_symbols,
params=engine_params,
)
engine = BacktestEngine(cfg)
state = PortfolioState(cash=engine.initial_cash)
engine.simulate_day(trade_date_obj, state)
after_map = _fetch_agent_actions(trade_date_str, unique_symbols)
changes: List[Dict[str, object]] = []
for code in unique_symbols:
before_agents = before_map.get(code, {})
after_agents = after_map.get(code, {})
for agent, new_action in after_agents.items():
old_action = before_agents.get(agent)
if new_action != old_action:
changes.append(
{
"代码": code,
"代理": agent,
"原动作": old_action,
"新动作": new_action,
}
)
for agent, old_action in before_agents.items():
if agent not in after_agents:
changes.append(
{
"代码": code,
"代理": agent,
"原动作": old_action,
"新动作": None,
}
)
return changes
def render_today_plan() -> None: def render_today_plan() -> None:
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
st.header("今日计划") st.header("今日计划")
@ -176,56 +277,15 @@ def _render_today_plan_symbol_view(
raise ValueError(f"无法解析交易日:{trade_date}") raise ValueError(f"无法解析交易日:{trade_date}")
progress = st.progress(0.0) progress = st.progress(0.0)
changes_all: List[Dict[str, object]] = [] progress.progress(0.3 if symbols else 1.0)
success_count = 0 changes_all = _reevaluate_symbols(
error_count = 0 trade_date_obj,
symbols,
for idx, code in enumerate(symbols, start=1): "reeval_ui_all",
try: "UI All Re-eval",
with db_session(read_only=True) as conn: )
before_rows = conn.execute( progress.progress(1.0)
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的")
(trade_date, code),
).fetchall()
before_map = {row["agent"]: row["action"] for row in before_rows}
cfg = BtConfig(
id="reeval_ui_all",
name="UI All Re-eval",
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=[code],
params={},
)
engine = BacktestEngine(cfg)
state = PortfolioState()
engine.simulate_day(trade_date_obj, state)
with db_session(read_only=True) as conn:
after_rows = conn.execute(
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
(trade_date, code),
).fetchall()
for row in after_rows:
agent = row["agent"]
new_action = row["action"]
old_action = before_map.get(agent)
if new_action != old_action:
changes_all.append(
{"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}
)
success_count += 1
except Exception: # noqa: BLE001
LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA)
error_count += 1
progress.progress(idx / len(symbols))
if error_count > 0:
st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count}")
else:
st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功")
if changes_all: if changes_all:
st.write("检测到以下动作变更:") st.write("检测到以下动作变更:")
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
@ -579,44 +639,20 @@ def _render_today_plan_symbol_view(
pass pass
if trade_date_obj is None: if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}") raise ValueError(f"无法解析交易日:{trade_date}")
with db_session(read_only=True) as conn: changes = _reevaluate_symbols(
before_rows = conn.execute( trade_date_obj,
""" [ts_code],
SELECT agent, action, utils FROM agent_utils "reeval_ui",
WHERE trade_date = ? AND ts_code = ? "UI Re-evaluation",
""",
(trade_date, ts_code),
).fetchall()
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
cfg = BtConfig(
id="reeval_ui",
name="UI Re-evaluation",
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=[ts_code],
params={},
) )
engine = BacktestEngine(cfg)
state = PortfolioState()
engine.simulate_day(trade_date_obj, state)
with db_session(read_only=True) as conn:
after_rows = conn.execute(
"""
SELECT agent, action, utils FROM agent_utils
WHERE trade_date = ? AND ts_code = ?
""",
(trade_date, ts_code),
).fetchall()
changes = []
for row in after_rows:
agent = row["agent"]
new_action = row["action"]
old_action, _old_utils = before_map.get(agent, (None, None))
if new_action != old_action:
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
if changes: if changes:
for change in changes:
change.setdefault("代码", ts_code)
df_changes = pd.DataFrame(changes)
if "代码" in df_changes.columns:
df_changes = df_changes[["代码", "代理", "原动作", "新动作"]]
st.success("重评估完成,检测到动作变更:") st.success("重评估完成,检测到动作变更:")
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') st.dataframe(df_changes, hide_index=True, width='stretch')
else: else:
st.success("重评估完成,无动作变更。") st.success("重评估完成,无动作变更。")
st.rerun() st.rerun()
@ -637,38 +673,15 @@ def _render_today_plan_symbol_view(
if trade_date_obj is None: if trade_date_obj is None:
raise ValueError(f"无法解析交易日:{trade_date}") raise ValueError(f"无法解析交易日:{trade_date}")
progress = st.progress(0.0) progress = st.progress(0.0)
changes_all: List[Dict[str, object]] = [] progress.progress(0.3 if batch_symbols else 1.0)
for idx, code in enumerate(batch_symbols, start=1): changes_all = _reevaluate_symbols(
with db_session(read_only=True) as conn: trade_date_obj,
before_rows = conn.execute( batch_symbols,
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", "reeval_ui_batch",
(trade_date, code), "UI Batch Re-eval",
).fetchall() )
before_map = {row["agent"]: row["action"] for row in before_rows} progress.progress(1.0)
cfg = BtConfig( st.success(f"批量重评估完成:共处理 {len(batch_symbols)} 个标的")
id="reeval_ui_batch",
name="UI Batch Re-eval",
start_date=trade_date_obj,
end_date=trade_date_obj,
universe=[code],
params={},
)
engine = BacktestEngine(cfg)
state = PortfolioState()
engine.simulate_day(trade_date_obj, state)
with db_session(read_only=True) as conn:
after_rows = conn.execute(
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
(trade_date, code),
).fetchall()
for row in after_rows:
agent = row["agent"]
new_action = row["action"]
old_action = before_map.get(agent)
if new_action != old_action:
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
progress.progress(idx / max(1, len(batch_symbols)))
st.success("批量重评估完成。")
if changes_all: if changes_all:
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
st.rerun() st.rerun()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import re import re
import sqlite3 import sqlite3
import threading import threading
import time
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -577,8 +578,7 @@ class DataBroker:
Returns: Returns:
新闻数据列表包含sentimentheatentities等字段 新闻数据列表包含sentimentheatentities等字段
""" """
# 简化实现:返回模拟数据 # TODO: 使用真实新闻数据库替换随机生成的占位数据
# 在实际应用中,这里应该查询新闻数据库
return [ return [
{ {
"sentiment": np.random.uniform(-1, 1), "sentiment": np.random.uniform(-1, 1),
@ -597,8 +597,7 @@ class DataBroker:
Returns: Returns:
行业代码或名称找不到时返回None 行业代码或名称找不到时返回None
""" """
# 简化实现:返回模拟行业 # TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
# 在实际应用中,这里应该查询股票行业信息
industry_mapping = { industry_mapping = {
"000001.SZ": "银行", "000001.SZ": "银行",
"000002.SZ": "房地产", "000002.SZ": "房地产",
@ -617,8 +616,7 @@ class DataBroker:
Returns: Returns:
行业情绪得分找不到时返回None 行业情绪得分找不到时返回None
""" """
# 简化实现:返回模拟情绪得分 # TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
# 在实际应用中,这里应该基于行业新闻计算情绪
return np.random.uniform(-1, 1) return np.random.uniform(-1, 1)
def get_industry_stocks(self, industry: str) -> List[str]: def get_industry_stocks(self, industry: str) -> List[str]:
@ -630,8 +628,7 @@ class DataBroker:
Returns: Returns:
同行业股票代码列表 同行业股票代码列表
""" """
# 简化实现:返回模拟股票列表 # TODO: 使用实际行业成分数据替换占位列表
# 在实际应用中,这里应该查询行业股票列表
industry_stocks = { industry_stocks = {
"银行": ["000001.SZ", "002142.SZ", "600036.SH"], "银行": ["000001.SZ", "002142.SZ", "600036.SH"],
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"], "房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
@ -653,10 +650,33 @@ class DataBroker:
if not _is_safe_identifier(table): if not _is_safe_identifier(table):
return False return False
query = ( parsed_date = _parse_trade_date(trade_date)
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1" trade_key = parsed_date.strftime("%Y%m%d") if parsed_date else str(trade_date)
)
bind_params = (ts_code, *params) if auto_refresh and parsed_date and self.check_data_availability(trade_key, {table}):
self._trigger_background_refresh(trade_key)
if hasattr(time, "sleep"):
time.sleep(0.5)
if table == "suspend":
query = (
"SELECT 1 FROM suspend "
"WHERE ts_code = ? "
"AND suspend_date <= ? "
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
"LIMIT 1"
)
bind_params = (ts_code, trade_key, trade_key)
else:
clauses = ["ts_code = ?"]
bind_params_list: List[object] = [ts_code]
clause_text = (where_clause or "").strip()
if clause_text:
clauses.append(clause_text)
bind_params_list.extend(params)
query = f"SELECT 1 FROM {table} WHERE {' AND '.join(clauses)} LIMIT 1"
bind_params = tuple(bind_params_list)
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
try: try: