refactor backtest engine with trading rules and progress tracking
This commit is contained in:
parent
8aa8efb651
commit
90fb2a9df6
@ -204,22 +204,29 @@ class DepartmentAgent:
|
||||
|
||||
rounds_executed = round_idx + 1
|
||||
|
||||
usage = response.get("usage") if isinstance(response, Mapping) else None
|
||||
if isinstance(usage, Mapping):
|
||||
usage_payload = {"round": round_idx + 1}
|
||||
usage_payload.update(dict(usage))
|
||||
usage_records.append(usage_payload)
|
||||
message, usage_payload, tool_calls = _normalize_llm_response(response)
|
||||
if usage_payload:
|
||||
payload_with_round = {"round": round_idx + 1}
|
||||
payload_with_round.update(usage_payload)
|
||||
usage_records.append(payload_with_round)
|
||||
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
if not message:
|
||||
LOGGER.debug(
|
||||
"部门 %s 第 %s 轮响应缺少 message 字段:%s",
|
||||
self.settings.code,
|
||||
round_idx + 1,
|
||||
response,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
message = {"role": "assistant", "content": ""}
|
||||
transcript.append(_message_to_text(message))
|
||||
|
||||
assistant_record: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"role": message.get("role", "assistant"),
|
||||
"content": _extract_message_content(message),
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
assistant_record["tool_calls"] = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
assistant_record["tool_calls"] = tool_calls
|
||||
messages.append(assistant_record)
|
||||
CONV_LOGGER.info(
|
||||
"dept=%s round=%s assistant=%s",
|
||||
@ -228,7 +235,6 @@ class DepartmentAgent:
|
||||
assistant_record,
|
||||
)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
function_block = call.get("function") or {}
|
||||
@ -656,6 +662,8 @@ class DepartmentAgent:
|
||||
dialogue=[response],
|
||||
)
|
||||
return decision
|
||||
|
||||
|
||||
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||
if not isinstance(context.features, dict):
|
||||
context.features = dict(context.features or {})
|
||||
@ -669,6 +677,77 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||
return context
|
||||
|
||||
|
||||
def _compose_usage_from_stats(payload: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
usage: Dict[str, Any] = {}
|
||||
prompt_eval = payload.get("prompt_eval_count")
|
||||
completion_eval = payload.get("eval_count")
|
||||
if isinstance(prompt_eval, (int, float)):
|
||||
usage["prompt_tokens"] = int(prompt_eval)
|
||||
if isinstance(completion_eval, (int, float)):
|
||||
usage["completion_tokens"] = int(completion_eval)
|
||||
if usage:
|
||||
total = usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
|
||||
usage["total_tokens"] = total
|
||||
return usage
|
||||
|
||||
|
||||
def _normalize_llm_response(
|
||||
response: Mapping[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, Any]]]:
|
||||
message: Dict[str, Any] = {}
|
||||
usage: Dict[str, Any] = {}
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
|
||||
if not isinstance(response, Mapping):
|
||||
return message, usage, tool_calls
|
||||
|
||||
choices = response.get("choices")
|
||||
if isinstance(choices, list) and choices:
|
||||
choice = choices[0] or {}
|
||||
candidate = choice.get("message")
|
||||
if isinstance(candidate, Mapping):
|
||||
message = candidate
|
||||
raw_calls = candidate.get("tool_calls")
|
||||
if isinstance(raw_calls, list):
|
||||
tool_calls = list(raw_calls)
|
||||
raw_usage = response.get("usage")
|
||||
if isinstance(raw_usage, Mapping):
|
||||
usage = dict(raw_usage)
|
||||
else:
|
||||
raw_message = response.get("message")
|
||||
if isinstance(raw_message, Mapping):
|
||||
message = raw_message
|
||||
raw_calls = raw_message.get("tool_calls")
|
||||
if isinstance(raw_calls, list):
|
||||
tool_calls = list(raw_calls)
|
||||
elif isinstance(response.get("messages"), list):
|
||||
messages_list = response.get("messages") or []
|
||||
if messages_list:
|
||||
candidate = messages_list[-1]
|
||||
if isinstance(candidate, Mapping):
|
||||
message = candidate
|
||||
raw_calls = candidate.get("tool_calls")
|
||||
if isinstance(raw_calls, list):
|
||||
tool_calls = list(raw_calls)
|
||||
if not message:
|
||||
content = response.get("content")
|
||||
if isinstance(content, str):
|
||||
message = {"role": "assistant", "content": content}
|
||||
raw_usage = response.get("usage")
|
||||
if isinstance(raw_usage, Mapping):
|
||||
usage = dict(raw_usage)
|
||||
else:
|
||||
usage = _compose_usage_from_stats(response)
|
||||
|
||||
if not tool_calls:
|
||||
extra = message.get("additional_kwargs")
|
||||
if isinstance(extra, Mapping):
|
||||
extra_calls = extra.get("tool_calls")
|
||||
if isinstance(extra_calls, list):
|
||||
tool_calls = list(extra_calls)
|
||||
return message or {"role": "assistant", "content": ""}, usage, tool_calls
|
||||
|
||||
|
||||
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Value and quality filtering agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
|
||||
@ -9,12 +11,19 @@ class ValueAgent(Agent):
|
||||
super().__init__(name="A_val")
|
||||
|
||||
def score(self, context: AgentContext, action: AgentAction) -> float:
|
||||
pe = context.features.get("pe_percentile", 0.5)
|
||||
pb = context.features.get("pb_percentile", 0.5)
|
||||
roe = context.features.get("roe_percentile", 0.5)
|
||||
# Lower valuation percentiles and higher quality percentiles add value.
|
||||
raw = max(0.0, (1 - pe) * 0.4 + (1 - pb) * 0.3 + roe * 0.3)
|
||||
raw = min(raw, 1.0)
|
||||
pe_score = context.features.get("valuation_pe_score", 0.0)
|
||||
pb_score = context.features.get("valuation_pb_score", 0.0)
|
||||
# 多因子组合尚未落地,这里兼容扩展因子(若存在则优先使用)
|
||||
scope_values = {}
|
||||
if isinstance(context.raw, Mapping):
|
||||
scope_values = context.raw.get("scope_values", {}) or {}
|
||||
multi_score = context.features.get("val_multiscore")
|
||||
if multi_score is None:
|
||||
multi_score = scope_values.get("factors.val_multiscore")
|
||||
if multi_score is not None:
|
||||
raw = float(max(0.0, min(1.0, multi_score)))
|
||||
else:
|
||||
raw = max(0.0, min(1.0, 0.6 * pe_score + 0.4 * pb_score))
|
||||
if action is AgentAction.SELL:
|
||||
return 1 - raw
|
||||
if action is AgentAction.HOLD:
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
@ -16,7 +16,7 @@ from app.llm.metrics import record_decision as metrics_record_decision
|
||||
from app.agents.registry import default_agents
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.config import get_config
|
||||
from app.utils.config import PortfolioSettings, get_config
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils import alerts
|
||||
@ -105,12 +105,26 @@ class BacktestEngine:
|
||||
)
|
||||
self.data_broker = DataBroker()
|
||||
params = cfg.params or {}
|
||||
portfolio_cfg = getattr(app_cfg, "portfolio", None) or PortfolioSettings()
|
||||
self.risk_params = {
|
||||
"max_position_weight": float(params.get("max_position_weight", 0.2)),
|
||||
"max_daily_turnover_ratio": float(params.get("max_daily_turnover_ratio", 0.25)),
|
||||
"fee_rate": float(params.get("fee_rate", 0.0005)),
|
||||
"slippage_bps": float(params.get("slippage_bps", 10.0)),
|
||||
}
|
||||
self.initial_cash = max(0.0, float(params.get("initial_capital", portfolio_cfg.initial_capital)))
|
||||
target_return = params.get("target", params.get("target_return", 0.0)) or 0.0
|
||||
stop_loss = params.get("stop", params.get("stop_loss", 0.0)) or 0.0
|
||||
hold_days_param = params.get("hold_days", params.get("max_hold_days", 0))
|
||||
try:
|
||||
max_hold_days = int(hold_days_param) if hold_days_param is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
max_hold_days = 0
|
||||
self.trading_rules = {
|
||||
"target_return": float(target_return),
|
||||
"stop_loss": float(stop_loss),
|
||||
"max_hold_days": max(0, max_hold_days),
|
||||
}
|
||||
self._fee_rate = max(self.risk_params["fee_rate"], 0.0)
|
||||
self._slippage_rate = max(self.risk_params["slippage_bps"], 0.0) / 10_000.0
|
||||
self._turnover_cap = max(self.risk_params["max_daily_turnover_ratio"], 0.0)
|
||||
@ -314,9 +328,9 @@ class BacktestEngine:
|
||||
is_suspended = self.data_broker.fetch_flags(
|
||||
"suspend",
|
||||
ts_code,
|
||||
trade_date,
|
||||
"ts_code = ?",
|
||||
[ts_code],
|
||||
trade_date_str,
|
||||
"",
|
||||
[],
|
||||
auto_refresh=False, # 避免在回测中触发自动补数
|
||||
)
|
||||
|
||||
@ -650,6 +664,7 @@ class BacktestEngine:
|
||||
continue
|
||||
features = feature_cache.get(ts_code, {})
|
||||
current_qty = state.holdings.get(ts_code, 0.0)
|
||||
current_cost_basis = float(state.cost_basis.get(ts_code, 0.0) or 0.0)
|
||||
liquidity_score = float(features.get("liquidity_score") or 0.0)
|
||||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||||
is_suspended = bool(features.get("is_suspended"))
|
||||
@ -679,6 +694,76 @@ class BacktestEngine:
|
||||
if risk.status == "blocked":
|
||||
continue
|
||||
|
||||
rule_override_action: Optional[AgentAction] = None
|
||||
rule_override_reason: Optional[str] = None
|
||||
gain_ratio: Optional[float] = None
|
||||
if current_qty > 0 and current_cost_basis:
|
||||
try:
|
||||
gain_ratio = (price / current_cost_basis) - 1.0
|
||||
except ZeroDivisionError:
|
||||
gain_ratio = None
|
||||
target_return = self.trading_rules.get("target_return", 0.0)
|
||||
if (
|
||||
rule_override_action is None
|
||||
and gain_ratio is not None
|
||||
and target_return
|
||||
and gain_ratio >= target_return
|
||||
):
|
||||
rule_override_action = AgentAction.SELL
|
||||
rule_override_reason = "target_reached"
|
||||
stop_loss = self.trading_rules.get("stop_loss", 0.0)
|
||||
if (
|
||||
rule_override_action is None
|
||||
and gain_ratio is not None
|
||||
and stop_loss
|
||||
and gain_ratio <= stop_loss
|
||||
):
|
||||
rule_override_action = AgentAction.SELL
|
||||
rule_override_reason = "stop_loss"
|
||||
max_hold_days = self.trading_rules.get("max_hold_days", 0)
|
||||
if (
|
||||
rule_override_action is None
|
||||
and max_hold_days
|
||||
and max_hold_days > 0
|
||||
and current_qty > 0
|
||||
):
|
||||
opened_str = state.opened_dates.get(ts_code)
|
||||
opened_dt: Optional[date] = None
|
||||
if opened_str:
|
||||
try:
|
||||
opened_dt = date.fromisoformat(str(opened_str))
|
||||
except ValueError:
|
||||
try:
|
||||
opened_dt = datetime.strptime(str(opened_str), "%Y%m%d").date()
|
||||
except ValueError:
|
||||
opened_dt = None
|
||||
LOGGER.debug(
|
||||
"无法解析持仓日期 ts_code=%s value=%s",
|
||||
ts_code,
|
||||
opened_str,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if opened_dt:
|
||||
holding_days = (trade_date - opened_dt).days
|
||||
if holding_days >= max_hold_days:
|
||||
rule_override_action = AgentAction.SELL
|
||||
rule_override_reason = "holding_period"
|
||||
|
||||
if rule_override_action and rule_override_action is not effective_action:
|
||||
effective_action = rule_override_action
|
||||
effective_weight = target_weight_for_action(effective_action)
|
||||
_record_risk(
|
||||
ts_code,
|
||||
rule_override_reason or "rule_override",
|
||||
decision,
|
||||
extra={
|
||||
"rule_trigger": rule_override_reason,
|
||||
"gain_ratio": gain_ratio,
|
||||
},
|
||||
action_override=effective_action,
|
||||
target_weight_override=effective_weight,
|
||||
)
|
||||
|
||||
if is_suspended:
|
||||
_record_risk(ts_code, "suspended", decision)
|
||||
continue
|
||||
@ -738,8 +823,7 @@ class BacktestEngine:
|
||||
if total_cash_needed <= 0:
|
||||
_record_risk(ts_code, "invalid_trade", decision)
|
||||
continue
|
||||
|
||||
previous_cost = state.cost_basis.get(ts_code, 0.0) * current_qty
|
||||
previous_cost = current_cost_basis * current_qty
|
||||
new_qty = current_qty + delta
|
||||
state.cost_basis[ts_code] = (
|
||||
(previous_cost + trade_value + fee) / new_qty if new_qty > 0 else 0.0
|
||||
@ -777,8 +861,7 @@ class BacktestEngine:
|
||||
gross_value = sell_qty * trade_price
|
||||
fee = gross_value * self._fee_rate
|
||||
proceeds = gross_value - fee
|
||||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||||
realized = (trade_price - cost_basis) * sell_qty - fee
|
||||
realized = (trade_price - current_cost_basis) * sell_qty - fee
|
||||
state.cash += proceeds
|
||||
state.realized_pnl += realized
|
||||
new_qty = current_qty - sell_qty
|
||||
@ -1023,7 +1106,7 @@ class BacktestEngine:
|
||||
"""Initialise a new incremental backtest session."""
|
||||
|
||||
return BacktestSession(
|
||||
state=PortfolioState(),
|
||||
state=PortfolioState(cash=self.initial_cash),
|
||||
result=BacktestResult(),
|
||||
current_date=self.cfg.start_date,
|
||||
)
|
||||
|
||||
@ -19,7 +19,11 @@ from app.features.value_risk_factors import ValueRiskFactors
|
||||
# 导入因子验证功能
|
||||
from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers
|
||||
# 导入UI进度状态管理
|
||||
from app.ui.progress_state import factor_progress
|
||||
try:
|
||||
from app.features.progress import get_progress_handler
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
def get_progress_handler():
|
||||
return None
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -176,14 +180,16 @@ def compute_factors(
|
||||
broker = DataBroker()
|
||||
results: List[FactorResult] = []
|
||||
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
||||
total_batches = (len(universe) + batch_size - 1) // batch_size if universe else 0
|
||||
progress = get_progress_handler()
|
||||
if progress and universe:
|
||||
try:
|
||||
progress.start_calculation(len(universe), total_batches)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler start_calculation 失败", extra=LOG_EXTRA)
|
||||
progress = None
|
||||
|
||||
try:
|
||||
# 启动UI进度状态(在异步线程中不直接访问factor_progress)
|
||||
# factor_progress.start_calculation(
|
||||
# total_securities=len(universe),
|
||||
# total_batches=(len(universe) + batch_size - 1) // batch_size
|
||||
# )
|
||||
|
||||
# 分批处理以优化性能
|
||||
for i in range(0, len(universe), batch_size):
|
||||
batch = universe[i:i+batch_size]
|
||||
@ -194,9 +200,10 @@ def compute_factors(
|
||||
specs,
|
||||
validation_stats,
|
||||
batch_index=i // batch_size,
|
||||
total_batches=(len(universe) + batch_size - 1) // batch_size,
|
||||
total_batches=total_batches or 1,
|
||||
processed_securities=i,
|
||||
total_securities=len(universe)
|
||||
total_securities=len(universe),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
for ts_code, values in batch_results:
|
||||
@ -222,9 +229,13 @@ def compute_factors(
|
||||
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
|
||||
|
||||
# 更新UI进度状态为完成
|
||||
factor_progress.complete_calculation(
|
||||
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}"
|
||||
)
|
||||
if progress:
|
||||
try:
|
||||
progress.complete_calculation(
|
||||
message=f"因子计算完成: 总数量={len(universe)}, 成功={validation_stats['success']}, 失败={len(universe) - validation_stats['success']}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler complete_calculation 失败", extra=LOG_EXTRA)
|
||||
|
||||
LOGGER.info(
|
||||
"因子计算完成 总数量:%s 成功:%s 失败:%s",
|
||||
@ -239,7 +250,11 @@ def compute_factors(
|
||||
except Exception as exc:
|
||||
# 发生错误时更新UI状态
|
||||
error_message = f"因子计算过程中发生错误: {exc}"
|
||||
factor_progress.error_occurred(error_message)
|
||||
if progress:
|
||||
try:
|
||||
progress.error_occurred(error_message)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler error_occurred 失败", extra=LOG_EXTRA)
|
||||
LOGGER.error(error_message, extra=LOG_EXTRA)
|
||||
raise
|
||||
|
||||
@ -380,6 +395,7 @@ def _compute_batch_factors(
|
||||
total_batches: int = 1,
|
||||
processed_securities: int = 0,
|
||||
total_securities: int = 0,
|
||||
progress: Optional[object] = None,
|
||||
) -> List[tuple[str, Dict[str, float | None]]]:
|
||||
"""批量计算多个证券的因子值,提高计算效率"""
|
||||
batch_results = []
|
||||
@ -388,13 +404,16 @@ def _compute_batch_factors(
|
||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||
|
||||
# 更新UI进度状态 - 开始处理批次
|
||||
if total_securities > 0:
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=processed_securities,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||
)
|
||||
if progress and total_securities > 0:
|
||||
try:
|
||||
progress.update_progress(
|
||||
current_securities=processed_securities,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||
progress = None
|
||||
|
||||
for i, ts_code in enumerate(ts_codes):
|
||||
try:
|
||||
@ -434,15 +453,18 @@ def _compute_batch_factors(
|
||||
validation_stats["skipped"] += 1
|
||||
|
||||
# 每处理1个证券更新一次进度,确保实时性
|
||||
if total_securities > 0:
|
||||
if progress and total_securities > 0:
|
||||
current_progress = processed_securities + i + 1
|
||||
progress_percentage = (current_progress / total_securities) * 100
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=current_progress,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
||||
)
|
||||
try:
|
||||
progress.update_progress(
|
||||
current_securities=current_progress,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||
progress = None
|
||||
except Exception as e:
|
||||
LOGGER.error(
|
||||
"计算因子失败 ts_code=%s err=%s",
|
||||
@ -453,15 +475,17 @@ def _compute_batch_factors(
|
||||
validation_stats["skipped"] += 1
|
||||
|
||||
# 批次处理完成,更新最终进度
|
||||
if total_securities > 0:
|
||||
if progress and total_securities > 0:
|
||||
final_progress = processed_securities + len(ts_codes)
|
||||
progress_percentage = (final_progress / total_securities) * 100
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=final_progress,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
||||
)
|
||||
try:
|
||||
progress.update_progress(
|
||||
current_securities=final_progress,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("Progress handler update_progress 失败", extra=LOG_EXTRA)
|
||||
|
||||
return batch_results
|
||||
|
||||
|
||||
38
app/features/progress.py
Normal file
38
app/features/progress.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Optional progress reporting for factor computation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Protocol
|
||||
|
||||
|
||||
class FactorProgressProtocol(Protocol):
|
||||
"""Protocol describing the optional UI progress handler."""
|
||||
|
||||
def start_calculation(self, total_securities: int, total_batches: int) -> None: ...
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
current_securities: int,
|
||||
current_batch: int,
|
||||
message: str = "",
|
||||
) -> None: ...
|
||||
|
||||
def complete_calculation(self, message: str = "") -> None: ...
|
||||
|
||||
def error_occurred(self, error_message: str) -> None: ...
|
||||
|
||||
|
||||
_progress_handler: Optional[FactorProgressProtocol] = None
|
||||
|
||||
|
||||
def register_progress_handler(progress: FactorProgressProtocol | None) -> None:
|
||||
"""Register a progress handler (typically provided by the UI layer)."""
|
||||
|
||||
global _progress_handler
|
||||
_progress_handler = progress
|
||||
|
||||
|
||||
def get_progress_handler() -> Optional[FactorProgressProtocol]:
|
||||
"""Return the currently registered progress handler if any."""
|
||||
|
||||
return _progress_handler
|
||||
|
||||
@ -5,6 +5,12 @@ from typing import Optional, Dict, Any
|
||||
import time
|
||||
import streamlit as st
|
||||
|
||||
try:
|
||||
from app.features.progress import register_progress_handler
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
def register_progress_handler(handler: object) -> None:
|
||||
return
|
||||
|
||||
|
||||
class FactorProgressState:
|
||||
"""因子计算进度状态管理类"""
|
||||
@ -151,6 +157,7 @@ class FactorProgressState:
|
||||
|
||||
# 全局进度状态实例
|
||||
factor_progress = FactorProgressState()
|
||||
register_progress_handler(factor_progress)
|
||||
|
||||
|
||||
def render_factor_progress() -> None:
|
||||
|
||||
@ -44,13 +44,14 @@ def render_backtest_review() -> None:
|
||||
app_cfg = get_config()
|
||||
default_start, default_end = default_backtest_range(window_days=60)
|
||||
LOGGER.debug(
|
||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%s",
|
||||
default_start,
|
||||
default_end,
|
||||
"000001.SZ",
|
||||
0.035,
|
||||
-0.015,
|
||||
10,
|
||||
get_config().portfolio.initial_capital,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
@ -59,10 +60,24 @@ def render_backtest_review() -> None:
|
||||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||||
col_target, col_stop, col_hold = st.columns(3)
|
||||
col_target, col_stop, col_hold, col_cap = st.columns(4)
|
||||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
||||
initial_capital = col_cap.number_input(
|
||||
"组合初始资金",
|
||||
value=float(app_cfg.portfolio.initial_capital),
|
||||
step=100000.0,
|
||||
format="%.0f",
|
||||
key="bt_initial_capital",
|
||||
)
|
||||
initial_capital = max(0.0, float(initial_capital))
|
||||
backtest_params = {
|
||||
"target": float(target),
|
||||
"stop": float(stop),
|
||||
"hold_days": int(hold_days),
|
||||
"initial_capital": initial_capital,
|
||||
}
|
||||
structure_options = [item.value for item in GameStructure]
|
||||
selected_structure_values = st.multiselect(
|
||||
"选择博弈框架",
|
||||
@ -74,13 +89,14 @@ def render_backtest_review() -> None:
|
||||
selected_structure_values = [GameStructure.REPEATED.value]
|
||||
selected_structures = [GameStructure(value) for value in selected_structure_values]
|
||||
LOGGER.debug(
|
||||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
||||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s initial_capital=%.2f",
|
||||
start_date,
|
||||
end_date,
|
||||
universe_text,
|
||||
target,
|
||||
stop,
|
||||
hold_days,
|
||||
initial_capital,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
@ -134,13 +150,14 @@ def render_backtest_review() -> None:
|
||||
try:
|
||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
LOGGER.info(
|
||||
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s initial_capital=%.2f",
|
||||
start_date,
|
||||
end_date,
|
||||
universe,
|
||||
target,
|
||||
stop,
|
||||
hold_days,
|
||||
initial_capital,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
backtest_cfg = BtConfig(
|
||||
@ -149,11 +166,7 @@ def render_backtest_review() -> None:
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
params=dict(backtest_params),
|
||||
game_structures=selected_structures,
|
||||
)
|
||||
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
||||
@ -665,11 +678,7 @@ def render_backtest_review() -> None:
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
params=dict(backtest_params),
|
||||
method=app_cfg.decision_method,
|
||||
game_structures=selected_structures,
|
||||
)
|
||||
@ -942,11 +951,7 @@ def render_backtest_review() -> None:
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
params=dict(backtest_params),
|
||||
method=app_cfg.decision_method,
|
||||
game_structures=selected_structures,
|
||||
)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""投资池与仓位概览页面。"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
@ -106,28 +108,44 @@ def render_pool_overview() -> None:
|
||||
|
||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||
|
||||
default_compare_a = latest_date or date.today()
|
||||
default_compare_b = default_compare_a - timedelta(days=1)
|
||||
if default_compare_b > default_compare_a:
|
||||
default_compare_b = default_compare_a
|
||||
compare_col1, compare_col2 = st.columns(2)
|
||||
compare_date1 = compare_col1.date_input(
|
||||
"日志对比日期 A",
|
||||
value=default_compare_a,
|
||||
key="pool_compare_date_a",
|
||||
)
|
||||
compare_date2 = compare_col2.date_input(
|
||||
"日志对比日期 B",
|
||||
value=default_compare_b,
|
||||
key="pool_compare_date_b",
|
||||
)
|
||||
|
||||
if st.button("执行对比", type="secondary"):
|
||||
with st.spinner("执行日志对比分析中..."):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
||||
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
||||
query_date1 = f"{compare_date1.isoformat()}T00:00:00Z"
|
||||
query_date2 = f"{compare_date1.isoformat()}T23:59:59Z"
|
||||
logs1 = conn.execute(
|
||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||
(query_date1, query_date2),
|
||||
).fetchall()
|
||||
|
||||
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined]
|
||||
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined]
|
||||
query_date3 = f"{compare_date2.isoformat()}T00:00:00Z"
|
||||
query_date4 = f"{compare_date2.isoformat()}T23:59:59Z"
|
||||
logs2 = conn.execute(
|
||||
"SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level",
|
||||
(query_date3, query_date4),
|
||||
).fetchall()
|
||||
|
||||
df1 = pd.DataFrame(logs1, columns=["level", "count"])
|
||||
df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
||||
df1["date"] = compare_date1.strftime("%Y-%m-%d")
|
||||
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
||||
df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined]
|
||||
df2["date"] = compare_date2.strftime("%Y-%m-%d")
|
||||
|
||||
for df in (df1, df2):
|
||||
for col in df.columns:
|
||||
@ -141,13 +159,13 @@ def render_pool_overview() -> None:
|
||||
y="count",
|
||||
color="date",
|
||||
barmode="group",
|
||||
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined]
|
||||
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})",
|
||||
)
|
||||
st.plotly_chart(fig, width='stretch')
|
||||
|
||||
st.write("日志统计对比:")
|
||||
date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined]
|
||||
date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined]
|
||||
date1_str = compare_date1.strftime("%Y%m%d")
|
||||
date2_str = compare_date2.strftime("%Y%m%d")
|
||||
merged_df = df1.merge(
|
||||
df2,
|
||||
on="level",
|
||||
|
||||
@ -23,6 +23,107 @@ from app.ui.shared import (
|
||||
)
|
||||
|
||||
|
||||
def _fetch_agent_actions(trade_date: str, symbols: List[str]) -> Dict[str, Dict[str, Optional[str]]]:
|
||||
unique_symbols = list(dict.fromkeys(symbols))
|
||||
if not unique_symbols:
|
||||
return {}
|
||||
placeholder = ", ".join("?" for _ in unique_symbols)
|
||||
sql = (
|
||||
f"""
|
||||
SELECT ts_code, agent, action
|
||||
FROM agent_utils
|
||||
WHERE trade_date = ?
|
||||
AND ts_code IN ({placeholder})
|
||||
"""
|
||||
)
|
||||
params: List[object] = [trade_date, *unique_symbols]
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
mapping: Dict[str, Dict[str, Optional[str]]] = {}
|
||||
for row in rows:
|
||||
ts_code = row["ts_code"]
|
||||
agent = row["agent"]
|
||||
action = row["action"]
|
||||
mapping.setdefault(ts_code, {})[agent] = action
|
||||
return mapping
|
||||
|
||||
|
||||
def _reevaluate_symbols(
|
||||
trade_date_obj: date,
|
||||
symbols: List[str],
|
||||
cfg_id: str,
|
||||
cfg_name: str,
|
||||
) -> List[Dict[str, object]]:
|
||||
unique_symbols = list(dict.fromkeys(symbols))
|
||||
if not unique_symbols:
|
||||
return []
|
||||
trade_date_str = trade_date_obj.isoformat()
|
||||
before_map = _fetch_agent_actions(trade_date_str, unique_symbols)
|
||||
engine_params: Dict[str, object] = {}
|
||||
target_val = st.session_state.get("bt_target")
|
||||
if target_val is not None:
|
||||
try:
|
||||
engine_params["target"] = float(target_val)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
stop_val = st.session_state.get("bt_stop")
|
||||
if stop_val is not None:
|
||||
try:
|
||||
engine_params["stop"] = float(stop_val)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
hold_val = st.session_state.get("bt_hold_days")
|
||||
if hold_val is not None:
|
||||
try:
|
||||
engine_params["hold_days"] = int(hold_val)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
capital_val = st.session_state.get("bt_initial_capital")
|
||||
if capital_val is not None:
|
||||
try:
|
||||
engine_params["initial_capital"] = float(capital_val)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
cfg = BtConfig(
|
||||
id=cfg_id,
|
||||
name=cfg_name,
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=unique_symbols,
|
||||
params=engine_params,
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState(cash=engine.initial_cash)
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
after_map = _fetch_agent_actions(trade_date_str, unique_symbols)
|
||||
changes: List[Dict[str, object]] = []
|
||||
for code in unique_symbols:
|
||||
before_agents = before_map.get(code, {})
|
||||
after_agents = after_map.get(code, {})
|
||||
for agent, new_action in after_agents.items():
|
||||
old_action = before_agents.get(agent)
|
||||
if new_action != old_action:
|
||||
changes.append(
|
||||
{
|
||||
"代码": code,
|
||||
"代理": agent,
|
||||
"原动作": old_action,
|
||||
"新动作": new_action,
|
||||
}
|
||||
)
|
||||
for agent, old_action in before_agents.items():
|
||||
if agent not in after_agents:
|
||||
changes.append(
|
||||
{
|
||||
"代码": code,
|
||||
"代理": agent,
|
||||
"原动作": old_action,
|
||||
"新动作": None,
|
||||
}
|
||||
)
|
||||
return changes
|
||||
|
||||
|
||||
def render_today_plan() -> None:
|
||||
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||
st.header("今日计划")
|
||||
@ -176,56 +277,15 @@ def _render_today_plan_symbol_view(
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
|
||||
progress = st.progress(0.0)
|
||||
changes_all: List[Dict[str, object]] = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for idx, code in enumerate(symbols, start=1):
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui_all",
|
||||
name="UI All Re-eval",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action = before_map.get(agent)
|
||||
if new_action != old_action:
|
||||
changes_all.append(
|
||||
{"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}
|
||||
)
|
||||
success_count += 1
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA)
|
||||
error_count += 1
|
||||
|
||||
progress.progress(idx / len(symbols))
|
||||
|
||||
if error_count > 0:
|
||||
st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个")
|
||||
else:
|
||||
st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功")
|
||||
|
||||
progress.progress(0.3 if symbols else 1.0)
|
||||
changes_all = _reevaluate_symbols(
|
||||
trade_date_obj,
|
||||
symbols,
|
||||
"reeval_ui_all",
|
||||
"UI All Re-eval",
|
||||
)
|
||||
progress.progress(1.0)
|
||||
st.success(f"一键重评估完成:共处理 {len(symbols)} 个标的")
|
||||
if changes_all:
|
||||
st.write("检测到以下动作变更:")
|
||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||
@ -579,44 +639,20 @@ def _render_today_plan_symbol_view(
|
||||
pass
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui",
|
||||
name="UI Re-evaluation",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[ts_code],
|
||||
params={},
|
||||
changes = _reevaluate_symbols(
|
||||
trade_date_obj,
|
||||
[ts_code],
|
||||
"reeval_ui",
|
||||
"UI Re-evaluation",
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"""
|
||||
SELECT agent, action, utils FROM agent_utils
|
||||
WHERE trade_date = ? AND ts_code = ?
|
||||
""",
|
||||
(trade_date, ts_code),
|
||||
).fetchall()
|
||||
changes = []
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action, _old_utils = before_map.get(agent, (None, None))
|
||||
if new_action != old_action:
|
||||
changes.append({"代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
if changes:
|
||||
for change in changes:
|
||||
change.setdefault("代码", ts_code)
|
||||
df_changes = pd.DataFrame(changes)
|
||||
if "代码" in df_changes.columns:
|
||||
df_changes = df_changes[["代码", "代理", "原动作", "新动作"]]
|
||||
st.success("重评估完成,检测到动作变更:")
|
||||
st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch')
|
||||
st.dataframe(df_changes, hide_index=True, width='stretch')
|
||||
else:
|
||||
st.success("重评估完成,无动作变更。")
|
||||
st.rerun()
|
||||
@ -637,38 +673,15 @@ def _render_today_plan_symbol_view(
|
||||
if trade_date_obj is None:
|
||||
raise ValueError(f"无法解析交易日:{trade_date}")
|
||||
progress = st.progress(0.0)
|
||||
changes_all: List[Dict[str, object]] = []
|
||||
for idx, code in enumerate(batch_symbols, start=1):
|
||||
with db_session(read_only=True) as conn:
|
||||
before_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
before_map = {row["agent"]: row["action"] for row in before_rows}
|
||||
cfg = BtConfig(
|
||||
id="reeval_ui_batch",
|
||||
name="UI Batch Re-eval",
|
||||
start_date=trade_date_obj,
|
||||
end_date=trade_date_obj,
|
||||
universe=[code],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
state = PortfolioState()
|
||||
engine.simulate_day(trade_date_obj, state)
|
||||
with db_session(read_only=True) as conn:
|
||||
after_rows = conn.execute(
|
||||
"SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?",
|
||||
(trade_date, code),
|
||||
).fetchall()
|
||||
for row in after_rows:
|
||||
agent = row["agent"]
|
||||
new_action = row["action"]
|
||||
old_action = before_map.get(agent)
|
||||
if new_action != old_action:
|
||||
changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action})
|
||||
progress.progress(idx / max(1, len(batch_symbols)))
|
||||
st.success("批量重评估完成。")
|
||||
progress.progress(0.3 if batch_symbols else 1.0)
|
||||
changes_all = _reevaluate_symbols(
|
||||
trade_date_obj,
|
||||
batch_symbols,
|
||||
"reeval_ui_batch",
|
||||
"UI Batch Re-eval",
|
||||
)
|
||||
progress.progress(1.0)
|
||||
st.success(f"批量重评估完成:共处理 {len(batch_symbols)} 个标的")
|
||||
if changes_all:
|
||||
st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch')
|
||||
st.rerun()
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
@ -577,8 +578,7 @@ class DataBroker:
|
||||
Returns:
|
||||
新闻数据列表,包含sentiment、heat、entities等字段
|
||||
"""
|
||||
# 简化实现:返回模拟数据
|
||||
# 在实际应用中,这里应该查询新闻数据库
|
||||
# TODO: 使用真实新闻数据库替换随机生成的占位数据
|
||||
return [
|
||||
{
|
||||
"sentiment": np.random.uniform(-1, 1),
|
||||
@ -597,8 +597,7 @@ class DataBroker:
|
||||
Returns:
|
||||
行业代码或名称,找不到时返回None
|
||||
"""
|
||||
# 简化实现:返回模拟行业
|
||||
# 在实际应用中,这里应该查询股票行业信息
|
||||
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
|
||||
industry_mapping = {
|
||||
"000001.SZ": "银行",
|
||||
"000002.SZ": "房地产",
|
||||
@ -617,8 +616,7 @@ class DataBroker:
|
||||
Returns:
|
||||
行业情绪得分,找不到时返回None
|
||||
"""
|
||||
# 简化实现:返回模拟情绪得分
|
||||
# 在实际应用中,这里应该基于行业新闻计算情绪
|
||||
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
|
||||
return np.random.uniform(-1, 1)
|
||||
|
||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||
@ -630,8 +628,7 @@ class DataBroker:
|
||||
Returns:
|
||||
同行业股票代码列表
|
||||
"""
|
||||
# 简化实现:返回模拟股票列表
|
||||
# 在实际应用中,这里应该查询行业股票列表
|
||||
# TODO: 使用实际行业成分数据替换占位列表
|
||||
industry_stocks = {
|
||||
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
||||
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
||||
@ -653,10 +650,33 @@ class DataBroker:
|
||||
|
||||
if not _is_safe_identifier(table):
|
||||
return False
|
||||
query = (
|
||||
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
||||
)
|
||||
bind_params = (ts_code, *params)
|
||||
parsed_date = _parse_trade_date(trade_date)
|
||||
trade_key = parsed_date.strftime("%Y%m%d") if parsed_date else str(trade_date)
|
||||
|
||||
if auto_refresh and parsed_date and self.check_data_availability(trade_key, {table}):
|
||||
self._trigger_background_refresh(trade_key)
|
||||
if hasattr(time, "sleep"):
|
||||
time.sleep(0.5)
|
||||
|
||||
if table == "suspend":
|
||||
query = (
|
||||
"SELECT 1 FROM suspend "
|
||||
"WHERE ts_code = ? "
|
||||
"AND suspend_date <= ? "
|
||||
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
|
||||
"LIMIT 1"
|
||||
)
|
||||
bind_params = (ts_code, trade_key, trade_key)
|
||||
else:
|
||||
clauses = ["ts_code = ?"]
|
||||
bind_params_list: List[object] = [ts_code]
|
||||
clause_text = (where_clause or "").strip()
|
||||
if clause_text:
|
||||
clauses.append(clause_text)
|
||||
bind_params_list.extend(params)
|
||||
query = f"SELECT 1 FROM {table} WHERE {' AND '.join(clauses)} LIMIT 1"
|
||||
bind_params = tuple(bind_params_list)
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user