refactor backtest engine with trading rules and progress tracking
This commit is contained in:
parent
8aa8efb651
commit
90fb2a9df6
@ -204,22 +204,29 @@ class DepartmentAgent:
|
|||||||
|
|
||||||
rounds_executed = round_idx + 1
|
rounds_executed = round_idx + 1
|
||||||
|
|
||||||
usage = response.get("usage") if isinstance(response, Mapping) else None
|
message, usage_payload, tool_calls = _normalize_llm_response(response)
|
||||||
if isinstance(usage, Mapping):
|
if usage_payload:
|
||||||
usage_payload = {"round": round_idx + 1}
|
payload_with_round = {"round": round_idx + 1}
|
||||||
usage_payload.update(dict(usage))
|
payload_with_round.update(usage_payload)
|
||||||
usage_records.append(usage_payload)
|
usage_records.append(payload_with_round)
|
||||||
|
|
||||||
choice = (response.get("choices") or [{}])[0]
|
if not message:
|
||||||
message = choice.get("message", {})
|
LOGGER.debug(
|
||||||
|
"部门 %s 第 %s 轮响应缺少 message 字段:%s",
|
||||||
|
self.settings.code,
|
||||||
|
round_idx + 1,
|
||||||
|
response,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
message = {"role": "assistant", "content": ""}
|
||||||
transcript.append(_message_to_text(message))
|
transcript.append(_message_to_text(message))
|
||||||
|
|
||||||
assistant_record: Dict[str, Any] = {
|
assistant_record: Dict[str, Any] = {
|
||||||
"role": "assistant",
|
"role": message.get("role", "assistant"),
|
||||||
"content": _extract_message_content(message),
|
"content": _extract_message_content(message),
|
||||||
}
|
}
|
||||||
if message.get("tool_calls"):
|
if tool_calls:
|
||||||
assistant_record["tool_calls"] = message.get("tool_calls")
|
assistant_record["tool_calls"] = tool_calls
|
||||||
messages.append(assistant_record)
|
messages.append(assistant_record)
|
||||||
CONV_LOGGER.info(
|
CONV_LOGGER.info(
|
||||||
"dept=%s round=%s assistant=%s",
|
"dept=%s round=%s assistant=%s",
|
||||||
@ -228,7 +235,6 @@ class DepartmentAgent:
|
|||||||
assistant_record,
|
assistant_record,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_calls = message.get("tool_calls") or []
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
for call in tool_calls:
|
for call in tool_calls:
|
||||||
function_block = call.get("function") or {}
|
function_block = call.get("function") or {}
|
||||||
@ -656,6 +662,8 @@ class DepartmentAgent:
|
|||||||
dialogue=[response],
|
dialogue=[response],
|
||||||
)
|
)
|
||||||
return decision
|
return decision
|
||||||
|
|
||||||
|
|
||||||
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||||
if not isinstance(context.features, dict):
|
if not isinstance(context.features, dict):
|
||||||
context.features = dict(context.features or {})
|
context.features = dict(context.features or {})
|
||||||
@ -669,6 +677,77 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_usage_from_stats(payload: Mapping[str, Any]) -> Dict[str, Any]:
|
||||||
|
usage: Dict[str, Any] = {}
|
||||||
|
prompt_eval = payload.get("prompt_eval_count")
|
||||||
|
completion_eval = payload.get("eval_count")
|
||||||
|
if isinstance(prompt_eval, (int, float)):
|
||||||
|
usage["prompt_tokens"] = int(prompt_eval)
|
||||||
|
if isinstance(completion_eval, (int, float)):
|
||||||
|
usage["completion_tokens"] = int(completion_eval)
|
||||||
|
if usage:
|
||||||
|
total = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
|
||||||
|
usage["total_tokens"] = total
|
||||||
|
return usage
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_llm_response(
|
||||||
|
response: Mapping[str, Any]
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, Any]]]:
|
||||||
|
message: Dict[str, Any] = {}
|
||||||
|
usage: Dict[str, Any] = {}
|
||||||
|
tool_calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
if not isinstance(response, Mapping):
|
||||||
|
return message, usage, tool_calls
|
||||||
|
|
||||||
|
choices = response.get("choices")
|
||||||
|
if isinstance(choices, list) and choices:
|
||||||
|
choice = choices[0] or {}
|
||||||
|
candidate = choice.get("message")
|
||||||
|
if isinstance(candidate, Mapping):
|
||||||
|
message = candidate
|
||||||
|
raw_calls = candidate.get("tool_calls")
|
||||||
|
if isinstance(raw_calls, list):
|
||||||
|
tool_calls = list(raw_calls)
|
||||||
|
raw_usage = response.get("usage")
|
||||||
|
if isinstance(raw_usage, Mapping):
|
||||||
|
usage = dict(raw_usage)
|
||||||
|
else:
|
||||||
|
raw_message = response.get("message")
|
||||||
|
if isinstance(raw_message, Mapping):
|
||||||
|
message = raw_message
|
||||||
|
raw_calls = raw_message.get("tool_calls")
|
||||||
|
if isinstance(raw_calls, list):
|
||||||
|
tool_calls = list(raw_calls)
|
||||||
|
elif isinstance(response.get("messages"), list):
|
||||||
|
messages_list = response.get("messages") or []
|
||||||
|
if messages_list:
|
||||||
|
candidate = messages_list[-1]
|
||||||
|
if isinstance(candidate, Mapping):
|
||||||
|
message = candidate
|
||||||
|
raw_calls = candidate.get("tool_calls")
|
||||||
|
if isinstance(raw_calls, list):
|
||||||
|
tool_calls = list(raw_calls)
|
||||||
|
if not message:
|
||||||
|
content = response.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
message = {"role": "assistant", "content": content}
|
||||||
|
raw_usage = response.get("usage")
|
||||||
|
if isinstance(raw_usage, Mapping):
|
||||||
|
usage = dict(raw_usage)
|
||||||
|
else:
|
||||||
|
usage = _compose_usage_from_stats(response)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
extra = message.get("additional_kwargs")
|
||||||
|
if isinstance(extra, Mapping):
|
||||||
|
extra_calls = extra.get("tool_calls")
|
||||||
|
if isinstance(extra_calls, list):
|
||||||
|
tool_calls = list(extra_calls)
|
||||||
|
return message or {"role": "assistant", "content": ""}, usage, tool_calls
|
||||||
|
|
||||||
|
|
||||||
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||||
if isinstance(payload, dict):
|
if isinstance(payload, dict):
|
||||||
return dict(payload)
|
return dict(payload)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
"""Value and quality filtering agent."""
|
"""Value and quality filtering agent."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from .base import Agent, AgentAction, AgentContext
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
|
||||||
@ -9,12 +11,19 @@ class ValueAgent(Agent):
|
|||||||
super().__init__(name="A_val")
|
super().__init__(name="A_val")
|
||||||
|
|
||||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||||
pe = context.features.get("pe_percentile", 0.5)
|
pe_score = context.features.get("valuation_pe_score", 0.0)
|
||||||
pb = context.features.get("pb_percentile", 0.5)
|
pb_score = context.features.get("valuation_pb_score", 0.0)
|
||||||
roe = context.features.get("roe_percentile", 0.5)
|
# 多因子组合尚未落地,这里兼容扩展因子(若存在则优先使用)
|
||||||
# Lower valuation percentiles and higher quality percentiles add value.
|
scope_values = {}
|
||||||
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3)
|
if isinstance(context.raw, Mapping):
|
||||||
raw = min(raw, 1.0)
|
scope_values = context.raw.get("scope_values", {}) or {}
|
||||||
|
multi_score = context.features.get("val_multiscore")
|
||||||
|
if multi_score is None:
|
||||||
|
multi_score = scope_values.get("factors.val_multiscore")
|
||||||
|
if multi_score is not None:
|
||||||
|
raw = float(max(0.0, min(1.0, multi_score)))
|
||||||
|
else:
|
||||||
|
raw = max(0.0, min(1.0, 0.6 * pe_score + 0.4 * pb_score))
|
||||||
if action is AgentAction.SELL:
|
if action is AgentAction.SELL:
|
||||||
return 1 - raw
|
return 1 - raw
|
||||||
if action is AgentAction.HOLD:
|
if action is AgentAction.HOLD:
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import date
|
from datetime import date, datetime
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from app.agents.base import AgentAction, AgentContext
|
from app.agents.base import AgentAction, AgentContext
|
||||||
@ -16,7 +16,7 @@ from app.llm.metrics import record_decision as metrics_record_decision
|
|||||||
from app.agents.registry import default_agents
|
from app.agents.registry import default_agents
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.utils.data_access import DataBroker
|
from app.utils.data_access import DataBroker
|
||||||
from app.utils.config import get_config
|
from app.utils.config import PortfolioSettings, get_config
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
from app.utils import alerts
|
from app.utils import alerts
|
||||||
@ -105,12 +105,26 @@ class BacktestEngine:
|
|||||||
)
|
)
|
||||||
self.data_broker = DataBroker()
|
self.data_broker = DataBroker()
|
||||||
params = cfg.params or {}
|
params = cfg.params or {}
|
||||||
|
portfolio_cfg = getattr(app_cfg, "portfolio", None) or PortfolioSettings()
|
||||||
self.risk_params = {
|
self.risk_params = {
|
||||||
"max_position_weight": float(params.get("max_position_weight", 0.2)),
|
"max_position_weight": float(params.get("max_position_weight", 0.2)),
|
||||||
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
|
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
|
||||||
"fee_rate": float(params.get("fee_rate", 0.0005)),
|
"fee_rate": float(params.get("fee_rate", 0.0005)),
|
||||||
"slippage_bps": float(params.get("slippage_bps", 10.0)),
|
"slippage_bps": float(params.get("slippage_bps", 10.0)),
|
||||||
}
|
}
|
||||||
|
self.initial_cash = max(0.0, float(params.get("initial_capital", portfolio_cfg.initial_capital)))
|
||||||
|
target_return = params.get("target", params.get("target_return", 0.0)) or 0.0
|
||||||
|
stop_loss = params.get("stop", params.get("stop_loss", 0.0)) or 0.0
|
||||||
|
hold_days_param = params.get("hold_days", params.get("max_hold_days", 0))
|
||||||
|
try:
|
||||||
|
max_hold_days = int(hold_days_param) if hold_days_param is not None else 0
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
max_hold_days = 0
|
||||||
|
self.trading_rules = {
|
||||||
|
"target_return": float(target_return),
|
||||||
|
"stop_loss": float(stop_loss),
|
||||||
|
"max_hold_days": max(0, max_hold_days),
|
||||||
|
}
|
||||||
self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
|
self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
|
||||||
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
|
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
|
||||||
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
|
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
|
||||||
@ -314,9 +328,9 @@ class BacktestEngine:
|
|||||||
is_suspended = self.data_broker.fetch_flags(
|
is_suspended = self.data_broker.fetch_flags(
|
||||||
"suspend",
|
"suspend",
|
||||||
ts_code,
|
ts_code,
|
||||||
trade_date,
|
trade_date_str,
|
||||||
"ts_code = ?",
|
"",
|
||||||
[ts_code],
|
[],
|
||||||
auto_refresh=False, # 避免在回测中触发自动补数
|
auto_refresh=False, # 避免在回测中触发自动补数
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -650,6 +664,7 @@ class BacktestEngine:
|
|||||||
continue
|
continue
|
||||||
features = feature_cache.get(ts_code, {})
|
features = feature_cache.get(ts_code, {})
|
||||||
current_qty = state.holdings.get(ts_code, 0.0)
|
current_qty = state.holdings.get(ts_code, 0.0)
|
||||||
|
current_cost_basis = float(state.cost_basis.get(ts_code, 0.0) or 0.0)
|
||||||
liquidity_score = float(features.get("liquidity_score") or 0.0)
|
liquidity_score = float(features.get("liquidity_score") or 0.0)
|
||||||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||||||
is_suspended = bool(features.get("is_suspended"))
|
is_suspended = bool(features.get("is_suspended"))
|
||||||
@ -679,6 +694,76 @@ class BacktestEngine:
|
|||||||
if risk.status == "blocked":
|
if risk.status == "blocked":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
rule_override_action: Optional[AgentAction] = None
|
||||||
|
rule_override_reason: Optional[str] = None
|
||||||
|
gain_ratio: Optional[float] = None
|
||||||
|
if current_qty > 0 and current_cost_basis:
|
||||||
|
try:
|
||||||
|
gain_ratio = (price / current_cost_basis) - 1.0
|
||||||
|
except ZeroDivisionError:
|
||||||
|
gain_ratio = None
|
||||||
|
target_return = self.trading_rules.get("target_return", 0.0)
|
||||||
|
if (
|
||||||
|
rule_override_action is None
|
||||||
|
and gain_ratio is not None
|
||||||
|
and target_return
|
||||||
|
and gain_ratio >= target_return
|
||||||
|
):
|
||||||
|
rule_override_action = AgentAction.SELL
|
||||||
|
rule_override_reason = "target_reached"
|
||||||
|
stop_loss = self.trading_rules.get("stop_loss", 0.0)
|
||||||
|
if (
|
||||||
|
rule_override_action is None
|
||||||
|
and gain_ratio is not None
|
||||||
|
and stop_loss
|
||||||
|
and gain_ratio <= stop_loss
|
||||||
|
):
|
||||||
|
rule_override_action = AgentAction.SELL
|
||||||
|
rule_override_reason = "stop_loss"
|
||||||
|
max_hold_days = self.trading_rules.get("max_hold_days", 0)
|
||||||
|
if (
|
||||||
|
rule_override_action is None
|
||||||
|
and max_hold_days
|
||||||
|
and max_hold_days > 0
|
||||||
|
and current_qty > 0
|
||||||
|
):
|
||||||
|
opened_str = state.opened_dates.get(ts_code)
|
||||||
|
opened_dt: Optional[date] = None
|
||||||
|
if opened_str:
|
||||||
|
try:
|
||||||
|
opened_dt = date.fromisoformat(str(opened_str))
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
opened_dt = datetime.strptime(str(opened_str), "%Y%m%d").date()
|
||||||
|
except ValueError:
|
||||||
|
opened_dt = None
|
||||||
|
LOGGER.debug(
|
||||||
|
"无法解析持仓日期 ts_code=%s value=%s",
|
||||||
|
ts_code,
|
||||||
|
opened_str,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
if opened_dt:
|
||||||
|
holding_days = (trade_date - opened_dt).days
|
||||||
|
if holding_days >= max_hold_days:
|
||||||
|
rule_override_action = AgentAction.SELL
|
||||||
|
rule_override_reason = "holding_period"
|
||||||
|
|
||||||
|
if rule_override_action and rule_override_action is not effective_action:
|
||||||
|
effective_action = rule_override_action
|
||||||
|
effective_weight = target_weight_for_action(effective_action)
|
||||||
|
_record_risk(
|
||||||
|
ts_code,
|
||||||
|
rule_override_reason or "rule_override",
|
||||||
|
decision,
|
||||||
|
extra={
|
||||||
|
"rule_trigger": rule_override_reason,
|
||||||
|
"gain_ratio": gain_ratio,
|
||||||
|
},
|
||||||
|
action_override=effective_action,
|
||||||
|
target_weight_override=effective_weight,
|
||||||
|
)
|
||||||
|
|
||||||
if is_suspended:
|
if is_suspended:
|
||||||
_record_risk(ts_code, "suspended", decision)
|
_record_risk(ts_code, "suspended", decision)
|
||||||
continue
|
continue
|
||||||
@ -738,8 +823,7 @@ class BacktestEngine:
|
|||||||
if total_cash_needed <= 0:
|
if total_cash_needed <= 0:
|
||||||
_record_risk(ts_code, "invalid_trade", decision)
|
_record_risk(ts_code, "invalid_trade", decision)
|
||||||
continue
|
continue
|
||||||
|
previous_cost = current_cost_basis * current_qty
|
||||||
previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty
|
|
||||||
new_qty = current_qty + delta
|
new_qty = current_qty + delta
|
||||||
state.cost_basis[ts_code] = (
|
state.cost_basis[ts_code] = (
|
||||||
(previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
|
(previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
|
||||||
@ -777,8 +861,7 @@ class BacktestEngine:
|
|||||||
gross_value = sell_qty * trade_price
|
gross_value = sell_qty * trade_price
|
||||||
fee = gross_value * self._fee_rate
|
fee = gross_value * self._fee_rate
|
||||||
proceeds = gross_value - fee
|
proceeds = gross_value - fee
|
||||||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
realized = (trade_price - current_cost_basis) * sell_qty - fee
|
||||||
realized = (trade_price - cost_basis) * sell_qty - fee
|
|
||||||
state.cash += proceeds
|
state.cash += proceeds
|
||||||
state.realized_pnl += realized
|
state.realized_pnl += realized
|
||||||
new_qty = current_qty - sell_qty
|
new_qty = current_qty - sell_qty
|
||||||
@ -1023,7 +1106,7 @@ class BacktestEngine:
|
|||||||
"""Initialise a new incremental backtest session."""
|
"""Initialise a new incremental backtest session."""
|
||||||
|
|
||||||
return BacktestSession(
|
return BacktestSession(
|
||||||
state=PortfolioState(),
|
state=PortfolioState(cash=self.initial_cash),
|
||||||
result=BacktestResult(),
|
result=BacktestResult(),
|
||||||
current_date=self.cfg.start_date,
|
current_date=self.cfg.start_date,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,7 +19,11 @@ from app.features.value_risk_factors import ValueRiskFactors
|
|||||||
# 导入因子验证功能
|
# 导入因子验证功能
|
||||||
from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers
|
from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers
|
||||||
# 导入UI进度状态管理
|
# 导入UI进度状态管理
|
||||||
from app.ui.progress_state import factor_progress
|
try:
|
||||||
|
from app.features.progress import get_progress_handler
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
def get_progress_handler():
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -176,14 +180,16 @@ def compute_factors(
|
|||||||
broker = DataBroker()
|
broker = DataBroker()
|
||||||
results: List[FactorResult] = []
|
results: List[FactorResult] = []
|
||||||
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
||||||
|
total_batches = (len(universe) + batch_size - 1) // batch_size if universe else 0
|
||||||
|
progress = get_progress_handler()
|
||||||
|
if progress and universe:
|
||||||
|
try:
|
||||||
|
progress.start_calculation(len(universe), total_batches)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler start_calculation 失败", extra=LOG_EXTRA)
|
||||||
|
progress = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 启动UI进度状态(在异步线程中不直接访问factor_progress)
|
|
||||||
# factor_progress.start_calculation(
|
|
||||||
# total_securities=len(universe),
|
|
||||||
# total_batches=(len(universe) + batch_size - 1) // batch_size
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 分批处理以优化性能
|
# 分批处理以优化性能
|
||||||
for i in range(0, len(universe), batch_size):
|
for i in range(0, len(universe), batch_size):
|
||||||
batch = universe[i:i+batch_size]
|
batch = universe[i:i+batch_size]
|
||||||
@ -194,9 +200,10 @@ def compute_factors(
|
|||||||
specs,
|
specs,
|
||||||
validation_stats,
|
validation_stats,
|
||||||
batch_index=i // batch_size,
|
batch_index=i // batch_size,
|
||||||
total_batches=(len(universe) + batch_size - 1) // batch_size,
|
total_batches=total_batches or 1,
|
||||||
processed_securities=i,
|
processed_securities=i,
|
||||||
total_securities=len(universe)
|
total_securities=len(universe),
|
||||||
|
progress=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ts_code, values in batch_results:
|
for ts_code, values in batch_results:
|
||||||
@ -222,9 +229,13 @@ def compute_factors(
|
|||||||
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
|
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
|
||||||
|
|
||||||
# 更新UI进度状态为完成
|
# 更新UI进度状态为完成
|
||||||
factor_progress.complete_calculation(
|
if progress:
|
||||||
|
try:
|
||||||
|
progress.complete_calculation(
|
||||||
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}"
|
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}"
|
||||||
)
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler complete_calculation 失败", extra=LOG_EXTRA)
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"因子计算完成 总数量:%s 成功:%s 失败:%s",
|
"因子计算完成 总数量:%s 成功:%s 失败:%s",
|
||||||
@ -239,7 +250,11 @@ def compute_factors(
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# 发生错误时更新UI状态
|
# 发生错误时更新UI状态
|
||||||
error_message = f"因子计算过程中发生错误: {exc}"
|
error_message = f"因子计算过程中发生错误: {exc}"
|
||||||
factor_progress.error_occurred(error_message)
|
if progress:
|
||||||
|
try:
|
||||||
|
progress.error_occurred(error_message)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler error_occurred 失败", extra=LOG_EXTRA)
|
||||||
LOGGER.error(error_message, extra=LOG_EXTRA)
|
LOGGER.error(error_message, extra=LOG_EXTRA)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -380,6 +395,7 @@ def _compute_batch_factors(
|
|||||||
total_batches: int = 1,
|
total_batches: int = 1,
|
||||||
processed_securities: int = 0,
|
processed_securities: int = 0,
|
||||||
total_securities: int = 0,
|
total_securities: int = 0,
|
||||||
|
progress: Optional[object] = None,
|
||||||
) -> List[tuple[str, Dict[str, float | None]]]:
|
) -> List[tuple[str, Dict[str, float | None]]]:
|
||||||
"""批量计算多个证券的因子值,提高计算效率"""
|
"""批量计算多个证券的因子值,提高计算效率"""
|
||||||
batch_results = []
|
batch_results = []
|
||||||
@ -388,13 +404,16 @@ def _compute_batch_factors(
|
|||||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||||
|
|
||||||
# 更新UI进度状态 - 开始处理批次
|
# 更新UI进度状态 - 开始处理批次
|
||||||
if total_securities > 0:
|
if progress and total_securities > 0:
|
||||||
from app.ui.progress_state import factor_progress
|
try:
|
||||||
factor_progress.update_progress(
|
progress.update_progress(
|
||||||
current_securities=processed_securities,
|
current_securities=processed_securities,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
message=f"开始处理批次 {batch_index + 1}/{total_batches}",
|
||||||
)
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||||
|
progress = None
|
||||||
|
|
||||||
for i, ts_code in enumerate(ts_codes):
|
for i, ts_code in enumerate(ts_codes):
|
||||||
try:
|
try:
|
||||||
@ -434,15 +453,18 @@ def _compute_batch_factors(
|
|||||||
validation_stats["skipped"] += 1
|
validation_stats["skipped"] += 1
|
||||||
|
|
||||||
# 每处理1个证券更新一次进度,确保实时性
|
# 每处理1个证券更新一次进度,确保实时性
|
||||||
if total_securities > 0:
|
if progress and total_securities > 0:
|
||||||
current_progress = processed_securities + i + 1
|
current_progress = processed_securities + i + 1
|
||||||
progress_percentage = (current_progress / total_securities) * 100
|
progress_percentage = (current_progress / total_securities) * 100
|
||||||
from app.ui.progress_state import factor_progress
|
try:
|
||||||
factor_progress.update_progress(
|
progress.update_progress(
|
||||||
current_securities=current_progress,
|
current_securities=current_progress,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)",
|
||||||
)
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||||
|
progress = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
"计算因子失败 ts_code=%s err=%s",
|
"计算因子失败 ts_code=%s err=%s",
|
||||||
@ -453,15 +475,17 @@ def _compute_batch_factors(
|
|||||||
validation_stats["skipped"] += 1
|
validation_stats["skipped"] += 1
|
||||||
|
|
||||||
# 批次处理完成,更新最终进度
|
# 批次处理完成,更新最终进度
|
||||||
if total_securities > 0:
|
if progress and total_securities > 0:
|
||||||
final_progress = processed_securities + len(ts_codes)
|
final_progress = processed_securities + len(ts_codes)
|
||||||
progress_percentage = (final_progress / total_securities) * 100
|
progress_percentage = (final_progress / total_securities) * 100
|
||||||
from app.ui.progress_state import factor_progress
|
try:
|
||||||
factor_progress.update_progress(
|
progress.update_progress(
|
||||||
current_securities=final_progress,
|
current_securities=final_progress,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)",
|
||||||
)
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||||
|
|
||||||
return batch_results
|
return batch_results
|
||||||
|
|
||||||
|
|||||||
38
app/features/progress.py
Normal file
38
app/features/progress.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Optional progress reporting for factor computation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class FactorProgressProtocol(Protocol):
|
||||||
|
"""Protocol describing the optional UI progress handler."""
|
||||||
|
|
||||||
|
def start_calculation(self, total_securities: int, total_batches: int) -> None: ...
|
||||||
|
|
||||||
|
def update_progress(
|
||||||
|
self,
|
||||||
|
current_securities: int,
|
||||||
|
current_batch: int,
|
||||||
|
message: str = "",
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def complete_calculation(self, message: str = "") -> None: ...
|
||||||
|
|
||||||
|
def error_occurred(self, error_message: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
_progress_handler: Optional[FactorProgressProtocol] = None
|
||||||
|
|
||||||
|
|
||||||
|
def register_progress_handler(progress: FactorProgressProtocol | None) -> None:
|
||||||
|
"""Register a progress handler (typically provided by the UI layer)."""
|
||||||
|
|
||||||
|
global _progress_handler
|
||||||
|
_progress_handler = progress
|
||||||
|
|
||||||
|
|
||||||
|
def get_progress_handler() -> Optional[FactorProgressProtocol]:
|
||||||
|
"""Return the currently registered progress handler if any."""
|
||||||
|
|
||||||
|
return _progress_handler
|
||||||
|
|
||||||
@ -5,6 +5,12 @@ from typing import Optional, Dict, Any
|
|||||||
import time
|
import time
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.features.progress import register_progress_handler
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
def register_progress_handler(handler: object) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class FactorProgressState:
|
class FactorProgressState:
|
||||||
"""因子计算进度状态管理类"""
|
"""因子计算进度状态管理类"""
|
||||||
@ -151,6 +157,7 @@ class FactorProgressState:
|
|||||||
|
|
||||||
# 全局进度状态实例
|
# 全局进度状态实例
|
||||||
factor_progress = FactorProgressState()
|
factor_progress = FactorProgressState()
|
||||||
|
register_progress_handler(factor_progress)
|
||||||
|
|
||||||
|
|
||||||
def render_factor_progress() -> None:
|
def render_factor_progress() -> None:
|
||||||
|
|||||||
@ -44,13 +44,14 @@ def render_backtest_review() -> None:
|
|||||||
app_cfg = get_config()
|
app_cfg = get_config()
|
||||||
default_start, default_end = default_backtest_range(window_days=60)
|
default_start, default_end = default_backtest_range(window_days=60)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%s",
|
||||||
default_start,
|
default_start,
|
||||||
default_end,
|
default_end,
|
||||||
"000001.SZ",
|
"000001.SZ",
|
||||||
0.035,
|
0.035,
|
||||||
-0.015,
|
-0.015,
|
||||||
10,
|
10,
|
||||||
|
get_config().portfolio.initial_capital,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,10 +60,24 @@ def render_backtest_review() -> None:
|
|||||||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||||||
col_target, col_stop, col_hold = st.columns(3)
|
col_target, col_stop, col_hold, col_cap = st.columns(4)
|
||||||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||||
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
||||||
|
initial_capital = col_cap.number_input(
|
||||||
|
"组合初始资金",
|
||||||
|
value=float(app_cfg.portfolio.initial_capital),
|
||||||
|
step=100000.0,
|
||||||
|
format="%.0f",
|
||||||
|
key="bt_initial_capital",
|
||||||
|
)
|
||||||
|
initial_capital = max(0.0, float(initial_capital))
|
||||||
|
backtest_params = {
|
||||||
|
"target": float(target),
|
||||||
|
"stop": float(stop),
|
||||||
|
"hold_days": int(hold_days),
|
||||||
|
"initial_capital": initial_capital,
|
||||||
|
}
|
||||||
structure_options = [item.value for item in GameStructure]
|
structure_options = [item.value for item in GameStructure]
|
||||||
selected_structure_values = st.multiselect(
|
selected_structure_values = st.multiselect(
|
||||||
"选择博弈框架",
|
"选择博弈框架",
|
||||||
@ -74,13 +89,14 @@ def render_backtest_review() -> None:
|
|||||||
selected_structure_values = [GameStructure.REPEATED.value]
|
selected_structure_values = [GameStructure.REPEATED.value]
|
||||||
selected_structures = [GameStructure(value) for value in selected_structure_values]
|
selected_structures = [GameStructure(value) for value in selected_structure_values]
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s initial_capital=%.2f",
|
||||||
start_date,
|
start_date,
|
||||||
end_date,
|
end_date,
|
||||||
universe_text,
|
universe_text,
|
||||||
target,
|
target,
|
||||||
stop,
|
stop,
|
||||||
hold_days,
|
hold_days,
|
||||||
|
initial_capital,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,13 +150,14 @@ def render_backtest_review() -> None:
|
|||||||
try:
|
try:
|
||||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%.2f",
|
||||||
start_date,
|
start_date,
|
||||||
end_date,
|
end_date,
|
||||||
universe,
|
universe,
|
||||||
target,
|
target,
|
||||||
stop,
|
stop,
|
||||||
hold_days,
|
hold_days,
|
||||||
|
initial_capital,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
backtest_cfg = BtConfig(
|
backtest_cfg = BtConfig(
|
||||||
@ -149,11 +166,7 @@ def render_backtest_review() -> None:
|
|||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
universe=universe,
|
universe=universe,
|
||||||
params={
|
params=dict(backtest_params),
|
||||||
"target": target,
|
|
||||||
"stop": stop,
|
|
||||||
"hold_days": int(hold_days),
|
|
||||||
},
|
|
||||||
game_structures=selected_structures,
|
game_structures=selected_structures,
|
||||||
)
|
)
|
||||||
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
||||||
@ -665,11 +678,7 @@ def render_backtest_review() -> None:
|
|||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
universe=universe_env,
|
universe=universe_env,
|
||||||
params={
|
params=dict(backtest_params),
|
||||||
"target": target,
|
|
||||||
"stop": stop,
|
|
||||||
"hold_days": int(hold_days),
|
|
||||||
},
|
|
||||||
method=app_cfg.decision_method,
|
method=app_cfg.decision_method,
|
||||||
game_structures=selected_structures,
|
game_structures=selected_structures,
|
||||||
)
|
)
|
||||||
@ -942,11 +951,7 @@ def render_backtest_review() -> None:
|
|||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
universe=universe_env,
|
universe=universe_env,
|
||||||
params={
|
params=dict(backtest_params),
|
||||||
"target": target,
|
|
||||||
"stop": stop,
|
|
||||||
"hold_days": int(hold_days),
|
|
||||||
},
|
|
||||||
method=app_cfg.decision_method,
|
method=app_cfg.decision_method,
|
||||||
game_structures=selected_structures,
|
game_structures=selected_structures,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
"""投资池与仓位概览页面。"""
|
"""投资池与仓位概览页面。"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@ -106,28 +108,44 @@ def render_pool_overview() -> None:
|
|||||||
|
|
||||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||||
|
|
||||||
|
default_compare_a = latest_date or date.today()
|
||||||
|
default_compare_b = default_compare_a - timedelta(days=1)
|
||||||
|
if default_compare_b > default_compare_a:
|
||||||
|
default_compare_b = default_compare_a
|
||||||
|
compare_col1, compare_col2 = st.columns(2)
|
||||||
|
compare_date1 = compare_col1.date_input(
|
||||||
|
"日志对比日期 A",
|
||||||
|
value=default_compare_a,
|
||||||
|
key="pool_compare_date_a",
|
||||||
|
)
|
||||||
|
compare_date2 = compare_col2.date_input(
|
||||||
|
"日志对比日期 B",
|
||||||
|
value=default_compare_b,
|
||||||
|
key="pool_compare_date_b",
|
||||||
|
)
|
||||||
|
|
||||||
if st.button("执行对比", type="secondary"):
|
if st.button("执行对比", type="secondary"):
|
||||||
with st.spinner("执行日志对比分析中..."):
|
with st.spinner("执行日志对比分析中..."):
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z"
|
||||||
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z"
|
||||||
logs1 = conn.execute(
|
logs1 = conn.execute(
|
||||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||||
(query_date1, query_date2),
|
(query_date1, query_date2),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z"
|
||||||
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z"
|
||||||
logs2 = conn.execute(
|
logs2 = conn.execute(
|
||||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||||
(query_date3, query_date4),
|
(query_date3, query_date4),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
df1 = pd.DataFrame(logs1, columns=["level", "count"])
|
df1 = pd.DataFrame(logs1, columns=["level", "count"])
|
||||||
df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
df1["date"] = compare_date1.strftime("%Y-%m-%d")
|
||||||
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
||||||
df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
df2["date"] = compare_date2.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
for df in (df1, df2):
|
for df in (df1, df2):
|
||||||
for col in df.columns:
|
for col in df.columns:
|
||||||
@ -141,13 +159,13 @@ def render_pool_overview() -> None:
|
|||||||
y="count",
|
y="count",
|
||||||
color="date",
|
color="date",
|
||||||
barmode="group",
|
barmode="group",
|
||||||
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined]
|
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})",
|
||||||
)
|
)
|
||||||
st.plotly_chart(fig, width='stretch')
|
st.plotly_chart(fig, width='stretch')
|
||||||
|
|
||||||
st.write("日志统计对比:")
|
st.write("日志统计对比:")
|
||||||
date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined]
|
date1_str = compare_date1.strftime("%Y%m%d")
|
||||||
date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined]
|
date2_str = compare_date2.strftime("%Y%m%d")
|
||||||
merged_df = df1.merge(
|
merged_df = df1.merge(
|
||||||
df2,
|
df2,
|
||||||
on="level",
|
on="level",
|
||||||
|
|||||||
@ -23,6 +23,107 @@ from app.ui.shared import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]:
|
||||||
|
unique_symbols = list(dict.fromkeys(symbols))
|
||||||
|
if not unique_symbols:
|
||||||
|
return {}
|
||||||
|
placeholder = ", ".join("?" for _ in unique_symbols)
|
||||||
|
sql = (
|
||||||
|
f"""
|
||||||
|
SELECT ts_code, agent, action
|
||||||
|
FROM agent_utils
|
||||||
|
WHERE trade_date = ?
|
||||||
|
AND ts_code IN ({placeholder})
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
params: List[object] = [trade_date, *unique_symbols]
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
mapping: Dict[str, Dict[str, Optional[str]]] = {}
|
||||||
|
for row in rows:
|
||||||
|
ts_code = row["ts_code"]
|
||||||
|
agent = row["agent"]
|
||||||
|
action = row["action"]
|
||||||
|
mapping.setdefault(ts_code, {})[agent] = action
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def _reevaluate_symbols(
|
||||||
|
trade_date_obj: date,
|
||||||
|
symbols: List[str],
|
||||||
|
cfg_id: str,
|
||||||
|
cfg_name: str,
|
||||||
|
) -> List[Dict[str, object]]:
|
||||||
|
unique_symbols = list(dict.fromkeys(symbols))
|
||||||
|
if not unique_symbols:
|
||||||
|
return []
|
||||||
|
trade_date_str = trade_date_obj.isoformat()
|
||||||
|
before_map = _fetch_agent_actions(trade_date_str, unique_symbols)
|
||||||
|
engine_params: Dict[str, object] = {}
|
||||||
|
target_val = st.session_state.get("bt_target")
|
||||||
|
if target_val is not None:
|
||||||
|
try:
|
||||||
|
engine_params["target"] = float(target_val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
stop_val = st.session_state.get("bt_stop")
|
||||||
|
if stop_val is not None:
|
||||||
|
try:
|
||||||
|
engine_params["stop"] = float(stop_val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
hold_val = st.session_state.get("bt_hold_days")
|
||||||
|
if hold_val is not None:
|
||||||
|
try:
|
||||||
|
engine_params["hold_days"] = int(hold_val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
capital_val = st.session_state.get("bt_initial_capital")
|
||||||
|
if capital_val is not None:
|
||||||
|
try:
|
||||||
|
engine_params["initial_capital"] = float(capital_val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
cfg = BtConfig(
|
||||||
|
id=cfg_id,
|
||||||
|
name=cfg_name,
|
||||||
|
start_date=trade_date_obj,
|
||||||
|
end_date=trade_date_obj,
|
||||||
|
universe=unique_symbols,
|
||||||
|
params=engine_params,
|
||||||
|
)
|
||||||
|
engine = BacktestEngine(cfg)
|
||||||
|
state = PortfolioState(cash=engine.initial_cash)
|
||||||
|
engine.simulate_day(trade_date_obj, state)
|
||||||
|
after_map = _fetch_agent_actions(trade_date_str, unique_symbols)
|
||||||
|
changes: List[Dict[str, object]] = []
|
||||||
|
for code in unique_symbols:
|
||||||
|
before_agents = before_map.get(code, {})
|
||||||
|
after_agents = after_map.get(code, {})
|
||||||
|
for agent, new_action in after_agents.items():
|
||||||
|
old_action = before_agents.get(agent)
|
||||||
|
if new_action != old_action:
|
||||||
|
changes.append(
|
||||||
|
{
|
||||||
|
"代码": code,
|
||||||
|
"代理": agent,
|
||||||
|
"原动作": old_action,
|
||||||
|
"新动作": new_action,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for agent, old_action in before_agents.items():
|
||||||
|
if agent not in after_agents:
|
||||||
|
changes.append(
|
||||||
|
{
|
||||||
|
"代码": code,
|
||||||
|
"代理": agent,
|
||||||
|
"原动作": old_action,
|
||||||
|
"新动作": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return changes
|
||||||
|
|
||||||
|
|
||||||
def render_today_plan() -> None:
|
def render_today_plan() -> None:
|
||||||
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||||
st.header("今日计划")
|
st.header("今日计划")
|
||||||
@ -176,56 +277,15 @@ def _render_today_plan_symbol_view(
|
|||||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||||
|
|
||||||
progress = st.progress(0.0)
|
progress = st.progress(0.0)
|
||||||
changes_all: List[Dict[str, object]] = []
|
progress.progress(0.3 if symbols else 1.0)
|
||||||
success_count = 0
|
changes_all = _reevaluate_symbols(
|
||||||
error_count = 0
|
trade_date_obj,
|
||||||
|
symbols,
|
||||||
for idx, code in enumerate(symbols, start=1):
|
"reeval_ui_all",
|
||||||
try:
|
"UI All Re-eval",
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
before_rows = conn.execute(
|
|
||||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
|
||||||
(trade_date, code),
|
|
||||||
).fetchall()
|
|
||||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
|
||||||
|
|
||||||
cfg = BtConfig(
|
|
||||||
id="reeval_ui_all",
|
|
||||||
name="UI All Re-eval",
|
|
||||||
start_date=trade_date_obj,
|
|
||||||
end_date=trade_date_obj,
|
|
||||||
universe=[code],
|
|
||||||
params={},
|
|
||||||
)
|
)
|
||||||
engine = BacktestEngine(cfg)
|
progress.progress(1.0)
|
||||||
state = PortfolioState()
|
st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的")
|
||||||
engine.simulate_day(trade_date_obj, state)
|
|
||||||
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
after_rows = conn.execute(
|
|
||||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
|
||||||
(trade_date, code),
|
|
||||||
).fetchall()
|
|
||||||
for row in after_rows:
|
|
||||||
agent = row["agent"]
|
|
||||||
new_action = row["action"]
|
|
||||||
old_action = before_map.get(agent)
|
|
||||||
if new_action != old_action:
|
|
||||||
changes_all.append(
|
|
||||||
{"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}
|
|
||||||
)
|
|
||||||
success_count += 1
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA)
|
|
||||||
error_count += 1
|
|
||||||
|
|
||||||
progress.progress(idx / len(symbols))
|
|
||||||
|
|
||||||
if error_count > 0:
|
|
||||||
st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个")
|
|
||||||
else:
|
|
||||||
st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功")
|
|
||||||
|
|
||||||
if changes_all:
|
if changes_all:
|
||||||
st.write("检测到以下动作变更:")
|
st.write("检测到以下动作变更:")
|
||||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||||
@ -579,44 +639,20 @@ def _render_today_plan_symbol_view(
|
|||||||
pass
|
pass
|
||||||
if trade_date_obj is None:
|
if trade_date_obj is None:
|
||||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||||
with db_session(read_only=True) as conn:
|
changes = _reevaluate_symbols(
|
||||||
before_rows = conn.execute(
|
trade_date_obj,
|
||||||
"""
|
[ts_code],
|
||||||
SELECT agent, action, utils FROM agent_utils
|
"reeval_ui",
|
||||||
WHERE trade_date = ? AND ts_code = ?
|
"UI Re-evaluation",
|
||||||
""",
|
|
||||||
(trade_date, ts_code),
|
|
||||||
).fetchall()
|
|
||||||
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
|
|
||||||
cfg = BtConfig(
|
|
||||||
id="reeval_ui",
|
|
||||||
name="UI Re-evaluation",
|
|
||||||
start_date=trade_date_obj,
|
|
||||||
end_date=trade_date_obj,
|
|
||||||
universe=[ts_code],
|
|
||||||
params={},
|
|
||||||
)
|
)
|
||||||
engine = BacktestEngine(cfg)
|
|
||||||
state = PortfolioState()
|
|
||||||
engine.simulate_day(trade_date_obj, state)
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
after_rows = conn.execute(
|
|
||||||
"""
|
|
||||||
SELECT agent, action, utils FROM agent_utils
|
|
||||||
WHERE trade_date = ? AND ts_code = ?
|
|
||||||
""",
|
|
||||||
(trade_date, ts_code),
|
|
||||||
).fetchall()
|
|
||||||
changes = []
|
|
||||||
for row in after_rows:
|
|
||||||
agent = row["agent"]
|
|
||||||
new_action = row["action"]
|
|
||||||
old_action, _old_utils = before_map.get(agent, (None, None))
|
|
||||||
if new_action != old_action:
|
|
||||||
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
|
|
||||||
if changes:
|
if changes:
|
||||||
|
for change in changes:
|
||||||
|
change.setdefault("代码", ts_code)
|
||||||
|
df_changes = pd.DataFrame(changes)
|
||||||
|
if "代码" in df_changes.columns:
|
||||||
|
df_changes = df_changes[["代码", "代理", "原动作", "新动作"]]
|
||||||
st.success("重评估完成,检测到动作变更:")
|
st.success("重评估完成,检测到动作变更:")
|
||||||
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
|
st.dataframe(df_changes, hide_index=True, width='stretch')
|
||||||
else:
|
else:
|
||||||
st.success("重评估完成,无动作变更。")
|
st.success("重评估完成,无动作变更。")
|
||||||
st.rerun()
|
st.rerun()
|
||||||
@ -637,38 +673,15 @@ def _render_today_plan_symbol_view(
|
|||||||
if trade_date_obj is None:
|
if trade_date_obj is None:
|
||||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||||
progress = st.progress(0.0)
|
progress = st.progress(0.0)
|
||||||
changes_all: List[Dict[str, object]] = []
|
progress.progress(0.3 if batch_symbols else 1.0)
|
||||||
for idx, code in enumerate(batch_symbols, start=1):
|
changes_all = _reevaluate_symbols(
|
||||||
with db_session(read_only=True) as conn:
|
trade_date_obj,
|
||||||
before_rows = conn.execute(
|
batch_symbols,
|
||||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
"reeval_ui_batch",
|
||||||
(trade_date, code),
|
"UI Batch Re-eval",
|
||||||
).fetchall()
|
|
||||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
|
||||||
cfg = BtConfig(
|
|
||||||
id="reeval_ui_batch",
|
|
||||||
name="UI Batch Re-eval",
|
|
||||||
start_date=trade_date_obj,
|
|
||||||
end_date=trade_date_obj,
|
|
||||||
universe=[code],
|
|
||||||
params={},
|
|
||||||
)
|
)
|
||||||
engine = BacktestEngine(cfg)
|
progress.progress(1.0)
|
||||||
state = PortfolioState()
|
st.success(f"批量重评估完成:共处理 {len(batch_symbols)} 个标的")
|
||||||
engine.simulate_day(trade_date_obj, state)
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
after_rows = conn.execute(
|
|
||||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
|
||||||
(trade_date, code),
|
|
||||||
).fetchall()
|
|
||||||
for row in after_rows:
|
|
||||||
agent = row["agent"]
|
|
||||||
new_action = row["action"]
|
|
||||||
old_action = before_map.get(agent)
|
|
||||||
if new_action != old_action:
|
|
||||||
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
|
|
||||||
progress.progress(idx / max(1, len(batch_symbols)))
|
|
||||||
st.success("批量重评估完成。")
|
|
||||||
if changes_all:
|
if changes_all:
|
||||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -577,8 +578,7 @@ class DataBroker:
|
|||||||
Returns:
|
Returns:
|
||||||
新闻数据列表,包含sentiment、heat、entities等字段
|
新闻数据列表,包含sentiment、heat、entities等字段
|
||||||
"""
|
"""
|
||||||
# 简化实现:返回模拟数据
|
# TODO: 使用真实新闻数据库替换随机生成的占位数据
|
||||||
# 在实际应用中,这里应该查询新闻数据库
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"sentiment": np.random.uniform(-1, 1),
|
"sentiment": np.random.uniform(-1, 1),
|
||||||
@ -597,8 +597,7 @@ class DataBroker:
|
|||||||
Returns:
|
Returns:
|
||||||
行业代码或名称,找不到时返回None
|
行业代码或名称,找不到时返回None
|
||||||
"""
|
"""
|
||||||
# 简化实现:返回模拟行业
|
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
|
||||||
# 在实际应用中,这里应该查询股票行业信息
|
|
||||||
industry_mapping = {
|
industry_mapping = {
|
||||||
"000001.SZ": "银行",
|
"000001.SZ": "银行",
|
||||||
"000002.SZ": "房地产",
|
"000002.SZ": "房地产",
|
||||||
@ -617,8 +616,7 @@ class DataBroker:
|
|||||||
Returns:
|
Returns:
|
||||||
行业情绪得分,找不到时返回None
|
行业情绪得分,找不到时返回None
|
||||||
"""
|
"""
|
||||||
# 简化实现:返回模拟情绪得分
|
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
|
||||||
# 在实际应用中,这里应该基于行业新闻计算情绪
|
|
||||||
return np.random.uniform(-1, 1)
|
return np.random.uniform(-1, 1)
|
||||||
|
|
||||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||||
@ -630,8 +628,7 @@ class DataBroker:
|
|||||||
Returns:
|
Returns:
|
||||||
同行业股票代码列表
|
同行业股票代码列表
|
||||||
"""
|
"""
|
||||||
# 简化实现:返回模拟股票列表
|
# TODO: 使用实际行业成分数据替换占位列表
|
||||||
# 在实际应用中,这里应该查询行业股票列表
|
|
||||||
industry_stocks = {
|
industry_stocks = {
|
||||||
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
||||||
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
||||||
@ -653,10 +650,33 @@ class DataBroker:
|
|||||||
|
|
||||||
if not _is_safe_identifier(table):
|
if not _is_safe_identifier(table):
|
||||||
return False
|
return False
|
||||||
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
|
trade_key = parsed_date.strftime("%Y%m%d") if parsed_date else str(trade_date)
|
||||||
|
|
||||||
|
if auto_refresh and parsed_date and self.check_data_availability(trade_key, {table}):
|
||||||
|
self._trigger_background_refresh(trade_key)
|
||||||
|
if hasattr(time, "sleep"):
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
if table == "suspend":
|
||||||
query = (
|
query = (
|
||||||
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
"SELECT 1 FROM suspend "
|
||||||
|
"WHERE ts_code = ? "
|
||||||
|
"AND suspend_date <= ? "
|
||||||
|
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
|
||||||
|
"LIMIT 1"
|
||||||
)
|
)
|
||||||
bind_params = (ts_code, *params)
|
bind_params = (ts_code, trade_key, trade_key)
|
||||||
|
else:
|
||||||
|
clauses = ["ts_code = ?"]
|
||||||
|
bind_params_list: List[object] = [ts_code]
|
||||||
|
clause_text = (where_clause or "").strip()
|
||||||
|
if clause_text:
|
||||||
|
clauses.append(clause_text)
|
||||||
|
bind_params_list.extend(params)
|
||||||
|
query = f"SELECT 1 FROM {table} WHERE {' AND '.join(clauses)} LIMIT 1"
|
||||||
|
bind_params = tuple(bind_params_list)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user