add factor optimization and portfolio evaluation features
This commit is contained in:
parent
74d98bf4e0
commit
59ffd86f82
@ -153,10 +153,20 @@ class BacktestEngine:
|
||||
"daily_basic.volume_ratio",
|
||||
"stk_limit.up_limit",
|
||||
"stk_limit.down_limit",
|
||||
"factors.mom_5",
|
||||
"factors.mom_20",
|
||||
"factors.mom_60",
|
||||
"factors.volat_20",
|
||||
"factors.turn_20",
|
||||
"factors.turn_5",
|
||||
"factors.val_pe_score",
|
||||
"factors.val_pb_score",
|
||||
"factors.volume_ratio_score",
|
||||
"factors.val_multiscore",
|
||||
"factors.risk_penalty",
|
||||
"factors.sent_momentum",
|
||||
"factors.sent_market",
|
||||
"factors.sent_divergence",
|
||||
}
|
||||
selected_structures = (
|
||||
cfg.game_structures
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Factor performance evaluation utilities."""
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
@ -62,6 +63,20 @@ class FactorPerformance:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorPortfolioReport:
|
||||
weights: Dict[str, float]
|
||||
combined: FactorPerformance
|
||||
components: Dict[str, FactorPerformance]
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"weights": dict(self.weights),
|
||||
"combined": self.combined.to_dict(),
|
||||
"components": {name: perf.to_dict() for name, perf in self.components.items()},
|
||||
}
|
||||
|
||||
|
||||
def evaluate_factor(
|
||||
factor_name: str,
|
||||
start_date: date,
|
||||
@ -182,6 +197,81 @@ def evaluate_factor(
|
||||
return performance
|
||||
|
||||
|
||||
def optimize_factor_weights(
|
||||
factor_names: Sequence[str],
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
*,
|
||||
universe: Optional[Sequence[str]] = None,
|
||||
method: str = "ic_mean",
|
||||
) -> Tuple[Dict[str, float], Dict[str, FactorPerformance]]:
|
||||
"""Derive factor weights based on historical performance metrics."""
|
||||
|
||||
if not factor_names:
|
||||
raise ValueError("factor_names must not be empty")
|
||||
|
||||
normalized_universe = list(universe) if universe else None
|
||||
performances: Dict[str, FactorPerformance] = {}
|
||||
scores: Dict[str, float] = {}
|
||||
|
||||
for name in factor_names:
|
||||
perf = evaluate_factor(name, start_date, end_date, normalized_universe)
|
||||
performances[name] = perf
|
||||
if method == "ic_ir":
|
||||
metric = perf.ic_ir
|
||||
elif method == "rank_ic":
|
||||
metric = perf.rank_ic_mean
|
||||
else:
|
||||
metric = perf.ic_mean
|
||||
scores[name] = max(0.0, float(metric))
|
||||
|
||||
weights = _normalize_weight_map(factor_names, scores)
|
||||
return weights, performances
|
||||
|
||||
|
||||
def evaluate_factor_portfolio(
|
||||
factor_names: Sequence[str],
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
*,
|
||||
universe: Optional[Sequence[str]] = None,
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
method: str = "ic_mean",
|
||||
) -> FactorPortfolioReport:
|
||||
"""Evaluate a weighted combination of factors."""
|
||||
|
||||
if not factor_names:
|
||||
raise ValueError("factor_names must not be empty")
|
||||
|
||||
normalized_universe = _normalize_universe(universe)
|
||||
|
||||
if weights is None:
|
||||
weights, performances = optimize_factor_weights(
|
||||
factor_names,
|
||||
start_date,
|
||||
end_date,
|
||||
universe=universe,
|
||||
method=method,
|
||||
)
|
||||
else:
|
||||
weights = _normalize_weight_map(factor_names, weights)
|
||||
performances = {
|
||||
name: evaluate_factor(name, start_date, end_date, universe)
|
||||
for name in factor_names
|
||||
}
|
||||
|
||||
weight_vector = [weights[name] for name in factor_names]
|
||||
combined = _evaluate_combined_factor(
|
||||
factor_names,
|
||||
weight_vector,
|
||||
start_date,
|
||||
end_date,
|
||||
normalized_universe,
|
||||
)
|
||||
|
||||
return FactorPortfolioReport(weights=weights, combined=combined, components=performances)
|
||||
|
||||
|
||||
def combine_factors(
|
||||
factor_names: Sequence[str],
|
||||
weights: Optional[Sequence[float]] = None
|
||||
@ -208,6 +298,34 @@ def combine_factors(
|
||||
return FactorSpec(name, window)
|
||||
|
||||
|
||||
def _normalize_weight_map(
|
||||
factor_names: Sequence[str],
|
||||
weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
normalized: Dict[str, float] = {}
|
||||
for name in factor_names:
|
||||
if name not in weights:
|
||||
continue
|
||||
try:
|
||||
value = float(weights[name])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if np.isnan(value) or value <= 0.0:
|
||||
continue
|
||||
normalized[name] = value
|
||||
|
||||
if len(normalized) != len(factor_names):
|
||||
weight = 1.0 / len(factor_names)
|
||||
return {name: weight for name in factor_names}
|
||||
|
||||
total = sum(normalized.values())
|
||||
if total <= 0.0:
|
||||
weight = 1.0 / len(factor_names)
|
||||
return {name: weight for name in factor_names}
|
||||
|
||||
return {name: value / total for name, value in normalized.items()}
|
||||
|
||||
|
||||
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
|
||||
if not universe:
|
||||
return None
|
||||
@ -269,6 +387,49 @@ def _fetch_factor_cross_section(
|
||||
return result
|
||||
|
||||
|
||||
def _fetch_factor_matrix(
|
||||
conn,
|
||||
columns: Sequence[str],
|
||||
trade_date: str,
|
||||
universe: Optional[Tuple[str, ...]],
|
||||
) -> Dict[str, List[float]]:
|
||||
if not columns:
|
||||
return {}
|
||||
params: List[object] = [trade_date]
|
||||
column_clause = ", ".join(columns)
|
||||
query = f"SELECT ts_code, {column_clause} FROM factors WHERE trade_date = ?"
|
||||
if universe:
|
||||
placeholders = ",".join("?" for _ in universe)
|
||||
query += f" AND ts_code IN ({placeholders})"
|
||||
params.extend(universe)
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
matrix: Dict[str, List[float]] = {}
|
||||
for row in rows:
|
||||
ts_code = row["ts_code"]
|
||||
if not ts_code:
|
||||
continue
|
||||
vector: List[float] = []
|
||||
valid = True
|
||||
for column in columns:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
valid = False
|
||||
break
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
valid = False
|
||||
break
|
||||
if not np.isfinite(numeric):
|
||||
valid = False
|
||||
break
|
||||
vector.append(numeric)
|
||||
if not valid:
|
||||
continue
|
||||
matrix[ts_code] = vector
|
||||
return matrix
|
||||
|
||||
|
||||
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
|
||||
row = conn.execute(
|
||||
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
|
||||
@ -306,6 +467,31 @@ def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, f
|
||||
return result
|
||||
|
||||
|
||||
def _estimate_turnover_from_maps(
|
||||
series: Sequence[Tuple[str, Dict[str, float]]],
|
||||
) -> Optional[float]:
|
||||
if len(series) < 2:
|
||||
return None
|
||||
turnovers: List[float] = []
|
||||
for idx in range(1, len(series)):
|
||||
_, prev_map = series[idx - 1]
|
||||
_, curr_map = series[idx]
|
||||
if not prev_map or not curr_map:
|
||||
continue
|
||||
prev_threshold = np.percentile(list(prev_map.values()), 80)
|
||||
curr_threshold = np.percentile(list(curr_map.values()), 80)
|
||||
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
|
||||
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
|
||||
union = prev_top | curr_top
|
||||
if not union:
|
||||
continue
|
||||
turnover = len(prev_top ^ curr_top) / len(union)
|
||||
turnovers.append(turnover)
|
||||
if turnovers:
|
||||
return float(np.mean(turnovers))
|
||||
return None
|
||||
|
||||
|
||||
def _estimate_turnover_rate(
|
||||
factor_name: str,
|
||||
trade_dates: Sequence[str],
|
||||
@ -339,3 +525,86 @@ def _estimate_turnover_rate(
|
||||
if turnovers:
|
||||
return float(np.mean(turnovers))
|
||||
return None
|
||||
|
||||
|
||||
def _evaluate_combined_factor(
|
||||
factor_names: Sequence[str],
|
||||
weights: Sequence[float],
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
universe: Optional[Tuple[str, ...]],
|
||||
) -> FactorPerformance:
|
||||
performance = FactorPerformance("portfolio")
|
||||
if not factor_names or not weights:
|
||||
return performance
|
||||
|
||||
weight_array = np.array(weights, dtype=float)
|
||||
start_str = start_date.strftime("%Y%m%d")
|
||||
end_str = end_date.strftime("%Y%m%d")
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
trade_dates = _list_factor_dates(conn, start_str, end_str, universe)
|
||||
|
||||
combined_series: List[Tuple[str, Dict[str, float]]] = []
|
||||
|
||||
for trade_date_str in trade_dates:
|
||||
with db_session(read_only=True) as conn:
|
||||
matrix = _fetch_factor_matrix(conn, factor_names, trade_date_str, universe)
|
||||
if not matrix:
|
||||
continue
|
||||
next_trade = _next_trade_date(conn, trade_date_str)
|
||||
if not next_trade:
|
||||
continue
|
||||
curr_close = _fetch_close_map(conn, trade_date_str, matrix.keys())
|
||||
next_close = _fetch_close_map(conn, next_trade, matrix.keys())
|
||||
|
||||
factor_values: List[float] = []
|
||||
returns: List[float] = []
|
||||
combined_map: Dict[str, float] = {}
|
||||
for ts_code, vector in matrix.items():
|
||||
curr = curr_close.get(ts_code)
|
||||
nxt = next_close.get(ts_code)
|
||||
if curr is None or nxt is None or curr <= 0:
|
||||
continue
|
||||
combined_value = float(np.dot(weight_array, np.array(vector, dtype=float)))
|
||||
factor_values.append(combined_value)
|
||||
returns.append((nxt - curr) / curr)
|
||||
combined_map[ts_code] = combined_value
|
||||
|
||||
if len(factor_values) < 20:
|
||||
continue
|
||||
|
||||
values_array = np.array(factor_values, dtype=float)
|
||||
returns_array = np.array(returns, dtype=float)
|
||||
if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
|
||||
continue
|
||||
|
||||
try:
|
||||
ic, _ = stats.pearsonr(values_array, returns_array)
|
||||
rank_ic, _ = stats.spearmanr(values_array, returns_array)
|
||||
except Exception: # noqa: BLE001
|
||||
continue
|
||||
|
||||
if not (np.isfinite(ic) and np.isfinite(rank_ic)):
|
||||
continue
|
||||
|
||||
performance.ic_series.append(float(ic))
|
||||
performance.rank_ic_series.append(float(rank_ic))
|
||||
combined_series.append((trade_date_str, combined_map))
|
||||
|
||||
sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
|
||||
quantile = len(sorted_pairs) // 5
|
||||
if quantile > 0:
|
||||
top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
|
||||
bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
|
||||
performance.return_spreads.append(float(np.mean(top_returns) - np.mean(bottom_returns)))
|
||||
|
||||
if performance.return_spreads:
|
||||
returns_mean = float(np.mean(performance.return_spreads))
|
||||
returns_std = float(np.std(performance.return_spreads))
|
||||
if returns_std > 0:
|
||||
performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
|
||||
|
||||
performance.sample_size = len(performance.ic_series)
|
||||
performance.turnover_rate = _estimate_turnover_from_maps(combined_series)
|
||||
return performance
|
||||
|
||||
232
app/features/factor_audit.py
Normal file
232
app/features/factor_audit.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""Utilities for auditing persisted factor values against live formulas."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from typing import Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from app.features.factors import (
|
||||
DEFAULT_FACTORS,
|
||||
FactorResult,
|
||||
FactorSpec,
|
||||
compute_factors,
|
||||
lookup_factor_spec,
|
||||
)
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "factor_audit"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorAuditIssue:
|
||||
"""Details for a single factor mismatch discovered during auditing."""
|
||||
|
||||
ts_code: str
|
||||
factor: str
|
||||
stored: Optional[float]
|
||||
recomputed: Optional[float]
|
||||
difference: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorAuditSummary:
|
||||
"""Aggregated results for a factor audit run."""
|
||||
|
||||
trade_date: date
|
||||
tolerance: float
|
||||
factor_names: List[str]
|
||||
total_persisted: int
|
||||
total_recomputed: int
|
||||
evaluated: int
|
||||
mismatched: int
|
||||
missing_persisted: int
|
||||
missing_recomputed: int
|
||||
missing_columns: List[str] = field(default_factory=list)
|
||||
issues: List[FactorAuditIssue] = field(default_factory=list)
|
||||
|
||||
|
||||
def audit_factors(
|
||||
trade_date: date,
|
||||
*,
|
||||
factors: Optional[Sequence[str | FactorSpec]] = None,
|
||||
tolerance: float = 1e-6,
|
||||
max_issues: int = 50,
|
||||
) -> FactorAuditSummary:
|
||||
"""Recompute factor values and compare them with persisted records.
|
||||
|
||||
Args:
|
||||
trade_date: 需要审计的交易日
|
||||
factors: 因子名称或 ``FactorSpec`` 序列,缺省为默认因子集合
|
||||
tolerance: 比较阈值,超出视为不一致
|
||||
max_issues: 限制返回的详细问题数量
|
||||
|
||||
Returns:
|
||||
FactorAuditSummary: 审计结果摘要
|
||||
"""
|
||||
|
||||
specs = _resolve_factor_specs(factors)
|
||||
factor_names = [spec.name for spec in specs]
|
||||
trade_date_str = trade_date.strftime("%Y%m%d")
|
||||
|
||||
persisted_map, missing_columns = _load_persisted_factors(trade_date_str, factor_names)
|
||||
recomputed_results = compute_factors(
|
||||
trade_date,
|
||||
specs,
|
||||
persist=False,
|
||||
)
|
||||
recomputed_map = {result.ts_code: result.values for result in recomputed_results}
|
||||
|
||||
mismatched = 0
|
||||
evaluated = 0
|
||||
missing_persisted = 0
|
||||
issues: List[FactorAuditIssue] = []
|
||||
|
||||
for ts_code, values in recomputed_map.items():
|
||||
stored = persisted_map.get(ts_code)
|
||||
if not stored:
|
||||
missing_persisted += 1
|
||||
LOGGER.debug(
|
||||
"审计未找到持久化记录 ts_code=%s trade_date=%s",
|
||||
ts_code,
|
||||
trade_date_str,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
evaluated += 1
|
||||
for factor in factor_names:
|
||||
recomputed_value = values.get(factor)
|
||||
stored_value = stored.get(factor)
|
||||
numeric_recomputed = _coerce_float(recomputed_value)
|
||||
numeric_stored = _coerce_float(stored_value)
|
||||
if numeric_recomputed is None and numeric_stored is None:
|
||||
continue
|
||||
if numeric_recomputed is None or numeric_stored is None:
|
||||
mismatched += 1
|
||||
if len(issues) < max_issues:
|
||||
issues.append(
|
||||
FactorAuditIssue(
|
||||
ts_code=ts_code,
|
||||
factor=factor,
|
||||
stored=stored_value,
|
||||
recomputed=recomputed_value,
|
||||
difference=None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
diff = abs(numeric_recomputed - numeric_stored)
|
||||
if math.isnan(diff) or diff > tolerance:
|
||||
mismatched += 1
|
||||
if len(issues) < max_issues:
|
||||
issues.append(
|
||||
FactorAuditIssue(
|
||||
ts_code=ts_code,
|
||||
factor=factor,
|
||||
stored=numeric_stored,
|
||||
recomputed=numeric_recomputed,
|
||||
difference=diff if not math.isnan(diff) else None,
|
||||
)
|
||||
)
|
||||
|
||||
missing_recomputed = len(
|
||||
{code for code in persisted_map.keys() if code not in recomputed_map}
|
||||
)
|
||||
|
||||
summary = FactorAuditSummary(
|
||||
trade_date=trade_date,
|
||||
tolerance=tolerance,
|
||||
factor_names=factor_names,
|
||||
total_persisted=len(persisted_map),
|
||||
total_recomputed=len(recomputed_map),
|
||||
evaluated=evaluated,
|
||||
mismatched=mismatched,
|
||||
missing_persisted=missing_persisted,
|
||||
missing_recomputed=missing_recomputed,
|
||||
missing_columns=missing_columns,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
"因子审计完成 trade_date=%s evaluated=%s mismatched=%s missing=%s/%s",
|
||||
trade_date_str,
|
||||
evaluated,
|
||||
mismatched,
|
||||
missing_persisted,
|
||||
missing_recomputed,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if missing_columns:
|
||||
LOGGER.warning(
|
||||
"因子审计缺少字段 columns=%s trade_date=%s",
|
||||
missing_columns,
|
||||
trade_date_str,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def _resolve_factor_specs(
|
||||
factors: Optional[Sequence[str | FactorSpec]],
|
||||
) -> List[FactorSpec]:
|
||||
if not factors:
|
||||
return list(DEFAULT_FACTORS)
|
||||
resolved: Dict[str, FactorSpec] = {}
|
||||
for item in factors:
|
||||
if isinstance(item, FactorSpec):
|
||||
resolved[item.name] = FactorSpec(name=item.name, window=item.window)
|
||||
continue
|
||||
spec = lookup_factor_spec(str(item))
|
||||
if spec is None:
|
||||
LOGGER.debug("忽略未知因子,无法审计 factor=%s", item, extra=LOG_EXTRA)
|
||||
continue
|
||||
resolved[spec.name] = spec
|
||||
return list(resolved.values()) or list(DEFAULT_FACTORS)
|
||||
|
||||
|
||||
def _load_persisted_factors(
|
||||
trade_date: str,
|
||||
factor_names: Sequence[str],
|
||||
) -> tuple[Dict[str, Dict[str, Optional[float]]], List[str]]:
|
||||
if not factor_names:
|
||||
return {}, []
|
||||
with db_session(read_only=True) as conn:
|
||||
table_info = conn.execute("PRAGMA table_info(factors)").fetchall()
|
||||
available_columns: set[str] = set()
|
||||
for row in table_info:
|
||||
if isinstance(row, Mapping):
|
||||
available_columns.add(str(row.get("name")))
|
||||
else:
|
||||
available_columns.add(str(row[1]))
|
||||
selected = [name for name in factor_names if name in available_columns]
|
||||
missing_columns = [name for name in factor_names if name not in available_columns]
|
||||
if not selected:
|
||||
return {}, missing_columns
|
||||
column_clause = ", ".join(["ts_code", *selected])
|
||||
query = f"SELECT {column_clause} FROM factors WHERE trade_date = ?"
|
||||
rows = conn.execute(query, (trade_date,)).fetchall()
|
||||
persisted: Dict[str, Dict[str, Optional[float]]] = {}
|
||||
for row in rows:
|
||||
ts_code = row["ts_code"]
|
||||
persisted[ts_code] = {name: row[name] for name in selected}
|
||||
return persisted, missing_columns
|
||||
|
||||
|
||||
def _coerce_float(value: object) -> Optional[float]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if math.isnan(numeric) or not math.isfinite(numeric):
|
||||
return None
|
||||
return numeric
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FactorAuditIssue",
|
||||
"FactorAuditSummary",
|
||||
"audit_factors",
|
||||
]
|
||||
@ -126,6 +126,7 @@ def compute_factors(
|
||||
ts_codes: Optional[Sequence[str]] = None,
|
||||
skip_existing: bool = False,
|
||||
batch_size: int = 100,
|
||||
persist: bool = True,
|
||||
) -> List[FactorResult]:
|
||||
"""Calculate and persist factor values for the requested date.
|
||||
|
||||
@ -139,6 +140,7 @@ def compute_factors(
|
||||
ts_codes: 可选,限制计算的证券代码列表
|
||||
skip_existing: 是否跳过已存在的因子值
|
||||
batch_size: 批处理大小,用于优化性能
|
||||
persist: 是否写入数据库(False 时仅计算返回结果)
|
||||
|
||||
Returns:
|
||||
因子计算结果列表
|
||||
@ -238,8 +240,15 @@ def compute_factors(
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if rows_to_persist:
|
||||
if persist and rows_to_persist:
|
||||
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
|
||||
elif not persist:
|
||||
LOGGER.debug(
|
||||
"因子干跑完成,未写入数据库 trade_date=%s universe=%s",
|
||||
trade_date_str,
|
||||
len(universe),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
# 更新UI进度状态为完成
|
||||
if progress:
|
||||
@ -279,8 +288,18 @@ def compute_factor_range(
|
||||
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
|
||||
ts_codes: Optional[Sequence[str]] = None,
|
||||
skip_existing: bool = True,
|
||||
persist: bool = True,
|
||||
) -> List[FactorResult]:
|
||||
"""Compute factors for all trading days within ``[start, end]`` inclusive."""
|
||||
"""Compute factors for all trading days within ``[start, end]`` inclusive.
|
||||
|
||||
Args:
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
factors: 参与计算的因子列表
|
||||
ts_codes: 限定的股票池
|
||||
skip_existing: 是否跳过已有记录
|
||||
persist: 是否写入数据库(False 表示仅返回计算结果)
|
||||
"""
|
||||
|
||||
if end < start:
|
||||
raise ValueError("end date must not precede start date")
|
||||
@ -305,6 +324,7 @@ def compute_factor_range(
|
||||
factors,
|
||||
ts_codes=allowed,
|
||||
skip_existing=skip_existing,
|
||||
persist=persist,
|
||||
)
|
||||
)
|
||||
return aggregated
|
||||
@ -316,6 +336,7 @@ def compute_factors_incremental(
|
||||
ts_codes: Optional[Sequence[str]] = None,
|
||||
skip_existing: bool = True,
|
||||
max_trading_days: Optional[int] = 5,
|
||||
persist: bool = True,
|
||||
) -> Dict[str, object]:
|
||||
"""增量计算因子(从最新一条因子记录之后开始)。
|
||||
|
||||
@ -324,6 +345,7 @@ def compute_factors_incremental(
|
||||
ts_codes: 限定计算的证券池。
|
||||
skip_existing: 是否跳过已存在数据。
|
||||
max_trading_days: 限制本次计算的交易日数量(按交易日计数)。
|
||||
persist: 是否写入数据库。False 表示仅计算返回结果
|
||||
|
||||
Returns:
|
||||
包含起止日期、参与交易日及计算结果的字典。
|
||||
@ -360,6 +382,7 @@ def compute_factors_incremental(
|
||||
factors,
|
||||
ts_codes=codes_tuple,
|
||||
skip_existing=skip_existing,
|
||||
persist=persist,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1645,7 +1645,7 @@ class DataBroker:
|
||||
|
||||
def get_data_coverage(self, start_date: str, end_date: str) -> Dict:
|
||||
"""获取指定日期范围内的数据覆盖情况。
|
||||
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(格式:YYYYMMDD)
|
||||
end_date: 结束日期(格式:YYYYMMDD)
|
||||
@ -1674,6 +1674,18 @@ class DataBroker:
|
||||
LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA)
|
||||
return {}
|
||||
|
||||
def evaluate_data_quality(
|
||||
self,
|
||||
*,
|
||||
window_days: int = 7,
|
||||
top_issues: int = 5,
|
||||
) -> "DataQualitySummary":
|
||||
"""Run data-quality checks and return a scored summary."""
|
||||
|
||||
from app.utils.data_quality import evaluate_data_quality as _evaluate
|
||||
|
||||
return _evaluate(window_days=window_days, top_issues=top_issues)
|
||||
|
||||
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||
columns = self._get_table_columns(table)
|
||||
if columns is None:
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Utility helpers for performing lightweight data quality checks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
@ -22,6 +22,30 @@ class DataQualityResult:
|
||||
extras: Optional[Dict[str, object]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQualitySummary:
|
||||
window_days: int
|
||||
score: float
|
||||
total_checks: int
|
||||
severity_counts: Dict[Severity, int] = field(default_factory=dict)
|
||||
blocking: List[DataQualityResult] = field(default_factory=list)
|
||||
warnings: List[DataQualityResult] = field(default_factory=list)
|
||||
informational: List[DataQualityResult] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_blockers(self) -> bool:
|
||||
return bool(self.blocking)
|
||||
|
||||
def as_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
"window_days": self.window_days,
|
||||
"score": self.score,
|
||||
"total_checks": self.total_checks,
|
||||
"severity_counts": dict(self.severity_counts),
|
||||
"has_blockers": self.has_blockers,
|
||||
}
|
||||
|
||||
|
||||
def _parse_date(value: object) -> Optional[date]:
|
||||
"""Best-effort parse for trade_date columns stored as str/int."""
|
||||
if value is None:
|
||||
@ -366,3 +390,60 @@ def run_data_quality_checks(*, window_days: int = 7) -> List[DataQualityResult]:
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def summarize_data_quality(
|
||||
results: Sequence[DataQualityResult],
|
||||
*,
|
||||
window_days: int,
|
||||
top_issues: int = 5,
|
||||
) -> DataQualitySummary:
|
||||
"""Aggregate quality checks into a normalized score and severity summary."""
|
||||
|
||||
severity_buckets: Dict[str, List[DataQualityResult]] = {}
|
||||
for result in results:
|
||||
severity = (result.severity or "INFO").upper()
|
||||
severity_buckets.setdefault(severity, []).append(result)
|
||||
|
||||
counts = {severity: len(items) for severity, items in severity_buckets.items()}
|
||||
if not results:
|
||||
return DataQualitySummary(
|
||||
window_days=window_days,
|
||||
score=100.0,
|
||||
total_checks=0,
|
||||
severity_counts=counts,
|
||||
)
|
||||
|
||||
weights = {"ERROR": 5.0, "WARN": 2.0, "INFO": 0.0}
|
||||
penalty = 0.0
|
||||
for result in results:
|
||||
severity = (result.severity or "INFO").upper()
|
||||
penalty += weights.get(severity, 2.0)
|
||||
max_weight = max(weights.values(), default=1.0)
|
||||
max_penalty = max(1.0, len(results) * max_weight)
|
||||
score = max(0.0, 100.0 - (penalty / max_penalty) * 100.0)
|
||||
|
||||
return DataQualitySummary(
|
||||
window_days=window_days,
|
||||
score=round(score, 2),
|
||||
total_checks=len(results),
|
||||
severity_counts=counts,
|
||||
blocking=severity_buckets.get("ERROR", [])[:top_issues],
|
||||
warnings=severity_buckets.get("WARN", [])[:top_issues],
|
||||
informational=severity_buckets.get("INFO", [])[:top_issues],
|
||||
)
|
||||
|
||||
|
||||
def evaluate_data_quality(
|
||||
*,
|
||||
window_days: int = 7,
|
||||
top_issues: int = 5,
|
||||
) -> DataQualitySummary:
|
||||
"""Run quality checks and return a scored summary."""
|
||||
|
||||
results = run_data_quality_checks(window_days=window_days)
|
||||
return summarize_data_quality(
|
||||
results,
|
||||
window_days=window_days,
|
||||
top_issues=top_issues,
|
||||
)
|
||||
|
||||
18
docs/TODO.md
18
docs/TODO.md
@ -6,13 +6,13 @@
|
||||
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 因子计算流水线 | 🔄 | `compute_factors()` 及持久化流程已可用,仍需支持增量模式与公式复核。 |
|
||||
| DataBroker 弹性 | 🔄 | 自动重试、健康监控已接入;数据质量评分体系待设计。 |
|
||||
| 因子库扩展 | 🔄 | 动量/估值/流动性/情绪因子已上线;权重优化与组合评估待补。 |
|
||||
| 新闻数据接入 | 🔄 | RSS 解析与情感分析可用;实体识别与时效评分仍缺。 |
|
||||
| 数据完整性体系 | ⏳ | 需建立巡检脚本、异常告警与补数流程。 |
|
||||
| 选股使用预计算因子 | ⏳ | 调整选股流程以直接消费持久化因子,避免重复计算。 |
|
||||
| 因子公式复核 | ⏳ | 梳理现有公式、补充可视化验证与文档沉淀。 |
|
||||
| 因子计算流水线 | ✅ | 新增 `scripts/run_factor_pipeline.py` 支撑增量/干跑模式,并提供 `factor_audit` 审计报表。 |
|
||||
| DataBroker 弹性 | ✅ | 集成 `evaluate_data_quality()` 评分与分级汇总,可直接输出阻塞项。 |
|
||||
| 因子库扩展 | ✅ | 补齐因子权重优化与组合回测评估,支持组合度量与组件表现。 |
|
||||
| 新闻数据接入 | ✅ | 实体识别与时效热度打分落地,并配套单元测试验证字段持久化。 |
|
||||
| 数据完整性体系 | ✅ | 新增 `scripts/run_data_integrity.py` 自动巡检,异常触发告警并可联动补数。 |
|
||||
| 选股使用预计算因子 | ✅ | 回测引擎统一加载 `factors.*` 快照,默认不再重复计算核心因子。 |
|
||||
| 因子公式复核 | ✅ | 提供 `factor_audit` 工具与文档,支持公式复核与漂移检测。 |
|
||||
|
||||
## 决策优化与强化学习
|
||||
|
||||
@ -49,8 +49,8 @@
|
||||
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 风险代理决策闭环 | ✅ | `risk_round` 可调整决策并写入 `risk_assessment`。 |
|
||||
| 风险事件持久化 | 🔄 | 风险建议已写入决策结构,待落库至 `risk_events` 并完善 UI 呈现。 |
|
||||
| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 |
|
||||
| 风险事件持久化 | 🔄 | 已形成 `risk_round` → `risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 |
|
||||
| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 |
|
||||
| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 |
|
||||
|
||||
|
||||
43
docs/features/factor_formula_audit.md
Normal file
43
docs/features/factor_formula_audit.md
Normal file
@ -0,0 +1,43 @@
|
||||
# 因子公式复核指引
|
||||
|
||||
本文档总结因子公式复核流程,并说明如何使用新引入的工具快速检查数据库中的持久化因子。
|
||||
|
||||
## 1. 快速开始
|
||||
|
||||
1. **增量/干跑计算**:
|
||||
```bash
|
||||
python scripts/run_factor_pipeline.py --mode incremental --max-days 5
|
||||
```
|
||||
- 默认会写入数据库,若只想验证公式可加 `--no-persist`。
|
||||
|
||||
2. **执行公式审计**:
|
||||
```bash
|
||||
python scripts/run_factor_pipeline.py --mode single --trade-date 20250210 --audit
|
||||
```
|
||||
- 同等功能也可通过 Python 调用:
|
||||
```python
|
||||
from datetime import date
|
||||
from app.features.factor_audit import audit_factors
|
||||
|
||||
summary = audit_factors(date(2025, 2, 10))
|
||||
print(summary.to_dict())
|
||||
```
|
||||
|
||||
3. **查看得分**:`summary.mismatched` 为不一致条目数;若为 0 表示通过。
|
||||
|
||||
## 2. 常见使用场景
|
||||
|
||||
- **版本升级后复核**:指定 trade_date 运行审计,确认公式变更未引入漂移。
|
||||
- **数据库回滚/恢复**:使用 `--no-persist --audit` 快速检查备份数据的完整性。
|
||||
- **问题定位**:`summary.issues` 提供具体股票代码、因子名与差值,便于对账。
|
||||
|
||||
## 3. 结果解读
|
||||
|
||||
| 字段 | 说明 |
|
||||
| --- | --- |
|
||||
| `score` | 数据质量得分(0-100)。 |
|
||||
| `blocking` | 会导致任务失败的错误;需优先处理。 |
|
||||
| `warnings` | 风险提示,可安排在巡检时处理。 |
|
||||
|
||||
> 注:公式审计依赖 `factors` 表的历史数据;若发现字段缺失,请先运行 `scripts/run_factor_pipeline.py` 补齐。
|
||||
|
||||
78
scripts/run_data_integrity.py
Normal file
78
scripts/run_data_integrity.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Command-line entrypoint for data integrity checks and remediation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import date, timedelta
|
||||
|
||||
from app.utils import alerts
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.data_quality import evaluate_data_quality
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _build_parser().parse_args()
|
||||
summary = evaluate_data_quality(window_days=args.window, top_issues=args.top)
|
||||
|
||||
_print_summary(summary)
|
||||
|
||||
if summary.has_blockers:
|
||||
alerts.add_warning(
|
||||
"data_quality",
|
||||
f"检测到 {len(summary.blocking)} 项阻塞数据质量问题,得分 {summary.score:.1f}",
|
||||
detail=str(summary.as_dict()),
|
||||
)
|
||||
if args.auto_fill:
|
||||
_trigger_auto_fill(args.window)
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Run data integrity checks and optional remediation.")
|
||||
parser.add_argument(
|
||||
"--window",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Number of trailing days to inspect (default: 7).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum issues per severity to display (default: 5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-fill",
|
||||
action="store_true",
|
||||
help="Trigger DataBroker coverage runner when blocking issues are detected.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _print_summary(summary) -> None:
|
||||
print(f"窗口: {summary.window_days} 天 | 质量得分: {summary.score:.1f} | 总检查数: {summary.total_checks}")
|
||||
if summary.severity_counts:
|
||||
print("严重度统计:")
|
||||
for severity, count in summary.severity_counts.items():
|
||||
print(f" - {severity}: {count}")
|
||||
if summary.blocking:
|
||||
print("阻塞问题:")
|
||||
for issue in summary.blocking:
|
||||
print(f" [ERROR] {issue.check}: {issue.detail}")
|
||||
if summary.warnings:
|
||||
print("警告:")
|
||||
for issue in summary.warnings:
|
||||
print(f" [WARN] {issue.check}: {issue.detail}")
|
||||
|
||||
|
||||
def _trigger_auto_fill(window_days: int) -> None:
|
||||
broker = DataBroker()
|
||||
end = date.today()
|
||||
start = end - timedelta(days=window_days)
|
||||
try:
|
||||
broker.coverage_runner(start, end)
|
||||
print(f"已触发补数流程: {start} -> {end}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
alerts.add_warning("data_quality", "自动补数失败", detail=str(exc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
243
scripts/run_factor_pipeline.py
Normal file
243
scripts/run_factor_pipeline.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""Command-line helper for running the factor computation pipeline."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import date, datetime
|
||||
from typing import Iterable, List, Optional, Sequence
|
||||
|
||||
from app.features.factors import (
|
||||
DEFAULT_FACTORS,
|
||||
FactorResult,
|
||||
FactorSpec,
|
||||
compute_factor_range,
|
||||
compute_factors,
|
||||
compute_factors_incremental,
|
||||
lookup_factor_spec,
|
||||
)
|
||||
from app.features.factor_audit import audit_factors
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _build_parser().parse_args()
|
||||
persist = not args.no_persist
|
||||
factor_specs = _resolve_factor_specs(args.factors)
|
||||
ts_codes = _normalize_codes(args.ts_codes)
|
||||
batch_size = args.batch_size or 100
|
||||
|
||||
if args.mode == "single":
|
||||
if not args.trade_date:
|
||||
raise SystemExit("--trade-date is required in single mode")
|
||||
trade_day = _parse_date(args.trade_date)
|
||||
results = compute_factors(
|
||||
trade_day,
|
||||
factor_specs,
|
||||
ts_codes=ts_codes,
|
||||
skip_existing=args.skip_existing,
|
||||
batch_size=batch_size,
|
||||
persist=persist,
|
||||
)
|
||||
_print_summary_single(trade_day, results, persist)
|
||||
audit_dates = [trade_day] if args.audit else []
|
||||
elif args.mode == "range":
|
||||
if not args.start or not args.end:
|
||||
raise SystemExit("--start and --end are required in range mode")
|
||||
start = _parse_date(args.start)
|
||||
end = _parse_date(args.end)
|
||||
results = compute_factor_range(
|
||||
start,
|
||||
end,
|
||||
factors=factor_specs,
|
||||
ts_codes=ts_codes,
|
||||
skip_existing=args.skip_existing,
|
||||
persist=persist,
|
||||
)
|
||||
_print_summary_range(start, end, results, persist)
|
||||
audit_dates = sorted({result.trade_date for result in results}) if args.audit else []
|
||||
else:
|
||||
summary = compute_factors_incremental(
|
||||
factors=factor_specs,
|
||||
ts_codes=ts_codes,
|
||||
skip_existing=args.skip_existing,
|
||||
max_trading_days=args.max_days,
|
||||
persist=persist,
|
||||
)
|
||||
_print_summary_incremental(summary, persist)
|
||||
audit_dates = summary.get("trade_dates", []) if args.audit else []
|
||||
|
||||
if args.audit and audit_dates:
|
||||
for audit_date in audit_dates:
|
||||
summary = audit_factors(
|
||||
audit_date,
|
||||
factors=factor_specs,
|
||||
tolerance=args.audit_tolerance,
|
||||
max_issues=args.max_audit_issues,
|
||||
)
|
||||
_print_audit_summary(summary)
|
||||
elif args.audit:
|
||||
LOGGER.info("无可审计的日期,跳过因子审计步骤")
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Run factor computation pipeline.")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=("single", "range", "incremental"),
|
||||
default="single",
|
||||
help="Pipeline mode (default: single).",
|
||||
)
|
||||
parser.add_argument("--trade-date", help="Trade date (YYYYMMDD) for single mode.")
|
||||
parser.add_argument("--start", help="Start date (YYYYMMDD) for range mode.")
|
||||
parser.add_argument("--end", help="End date (YYYYMMDD) for range mode.")
|
||||
parser.add_argument(
|
||||
"--max-days",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Limit of trading days for incremental mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ts-code",
|
||||
dest="ts_codes",
|
||||
action="append",
|
||||
help="Limit computation to specific ts_code. Can be provided multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--factor",
|
||||
dest="factors",
|
||||
action="append",
|
||||
help="Factor name to include. Defaults to the built-in set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-existing",
|
||||
action="store_true",
|
||||
help="Skip securities that already have persisted values for the target date(s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-persist",
|
||||
action="store_true",
|
||||
help="Dry-run mode; compute factors without writing to the database.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Override default batch size when computing factors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audit",
|
||||
action="store_true",
|
||||
help="Run formula audit after computation completes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audit-tolerance",
|
||||
type=float,
|
||||
default=1e-6,
|
||||
help="Allowed absolute difference when auditing factors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-audit-issues",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of detailed audit issues to print.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _resolve_factor_specs(names: Optional[Sequence[str]]) -> List[FactorSpec]:
|
||||
if not names:
|
||||
return list(DEFAULT_FACTORS)
|
||||
resolved: List[FactorSpec] = []
|
||||
seen: set[str] = set()
|
||||
for name in names:
|
||||
spec = lookup_factor_spec(name)
|
||||
if spec is None:
|
||||
LOGGER.warning("未知因子,忽略: %s", name)
|
||||
continue
|
||||
if spec.name in seen:
|
||||
continue
|
||||
resolved.append(spec)
|
||||
seen.add(spec.name)
|
||||
return resolved or list(DEFAULT_FACTORS)
|
||||
|
||||
|
||||
def _normalize_codes(codes: Optional[Iterable[str]]) -> List[str] | None:
|
||||
if not codes:
|
||||
return None
|
||||
normalized = []
|
||||
for code in codes:
|
||||
text = (code or "").strip().upper()
|
||||
if text:
|
||||
normalized.append(text)
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _parse_date(value: str) -> date:
|
||||
value = value.strip()
|
||||
for fmt in ("%Y%m%d", "%Y-%m-%d"):
|
||||
try:
|
||||
return datetime.strptime(value, fmt).date()
|
||||
except ValueError:
|
||||
continue
|
||||
raise SystemExit(f"Invalid date: {value}")
|
||||
|
||||
|
||||
def _print_summary_single(trade_day: date, results: Sequence[FactorResult], persist: bool) -> None:
|
||||
LOGGER.info(
|
||||
"单日因子计算完成 trade_date=%s rows=%s persist=%s",
|
||||
trade_day.isoformat(),
|
||||
len(results),
|
||||
bool(persist),
|
||||
)
|
||||
|
||||
|
||||
def _print_summary_range(start: date, end: date, results: Sequence[FactorResult], persist: bool) -> None:
|
||||
trade_dates = sorted({result.trade_date for result in results})
|
||||
LOGGER.info(
|
||||
"区间因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
|
||||
start.isoformat(),
|
||||
end.isoformat(),
|
||||
len(trade_dates),
|
||||
len(results),
|
||||
bool(persist),
|
||||
)
|
||||
|
||||
|
||||
def _print_summary_incremental(summary: dict, persist: bool) -> None:
|
||||
trade_dates = summary.get("trade_dates") or []
|
||||
start = trade_dates[0].isoformat() if trade_dates else None
|
||||
end = trade_dates[-1].isoformat() if trade_dates else None
|
||||
LOGGER.info(
|
||||
"增量因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
|
||||
start,
|
||||
end,
|
||||
len(trade_dates),
|
||||
summary.get("count", 0),
|
||||
bool(persist),
|
||||
)
|
||||
|
||||
|
||||
def _print_audit_summary(summary) -> None:
|
||||
LOGGER.info(
|
||||
"因子审计 trade_date=%s mismatched=%s evaluated=%s missing_persisted=%s missing_recomputed=%s issues=%s",
|
||||
summary.trade_date.isoformat(),
|
||||
summary.mismatched,
|
||||
summary.evaluated,
|
||||
summary.missing_persisted,
|
||||
summary.missing_recomputed,
|
||||
len(summary.issues),
|
||||
)
|
||||
for issue in summary.issues:
|
||||
LOGGER.warning(
|
||||
"审计异常 ts_code=%s factor=%s stored=%s recomputed=%s diff=%s",
|
||||
issue.ts_code,
|
||||
issue.factor,
|
||||
issue.stored,
|
||||
issue.recomputed,
|
||||
issue.difference,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,9 +1,27 @@
|
||||
"""Pytest configuration shared across test modules."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.config import DataPaths, get_config
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def isolated_db(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
tmp_root = tmp_path / "data"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
cfg.data_paths = DataPaths(root=tmp_root)
|
||||
try:
|
||||
initialize_database()
|
||||
yield
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
56
tests/factor_utils.py
Normal file
56
tests/factor_utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
def populate_sample_data(ts_code: str, as_of: date, days: int = 60) -> None:
|
||||
"""Populate ``daily`` 和 ``daily_basic`` 表用于测试。"""
|
||||
initialize_database()
|
||||
with db_session() as conn:
|
||||
for offset in range(days):
|
||||
current_day = as_of - timedelta(days=offset)
|
||||
trade_date = current_day.strftime("%Y%m%d")
|
||||
close = 100 + (days - 1 - offset)
|
||||
turnover = 5 + 0.1 * (days - 1 - offset)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO daily
|
||||
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
ts_code,
|
||||
trade_date,
|
||||
close,
|
||||
close,
|
||||
close,
|
||||
close,
|
||||
0.0,
|
||||
1_000.0,
|
||||
1_000_000.0,
|
||||
),
|
||||
)
|
||||
pe = 10.0 + (offset % 5)
|
||||
pb = 1.5 + (offset % 3) * 0.1
|
||||
ps = 2.0 + (offset % 4) * 0.1
|
||||
volume_ratio = 0.5 + (offset % 4) * 0.5
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO daily_basic
|
||||
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
ts_code,
|
||||
trade_date,
|
||||
turnover,
|
||||
turnover,
|
||||
volume_ratio,
|
||||
pe,
|
||||
pb,
|
||||
ps,
|
||||
),
|
||||
)
|
||||
31
tests/test_backtest_engine.py
Normal file
31
tests/test_backtest_engine.py
Normal file
@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from app.backtest.engine import BacktestEngine, BtConfig
|
||||
|
||||
|
||||
def test_required_fields_include_precomputed_factors(isolated_db):
|
||||
cfg = BtConfig(
|
||||
id="bt-test",
|
||||
name="bt-test",
|
||||
start_date=date(2025, 1, 1),
|
||||
end_date=date(2025, 1, 2),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
engine = BacktestEngine(cfg)
|
||||
required = set(engine.required_fields)
|
||||
expected_fields = {
|
||||
"factors.mom_5",
|
||||
"factors.turn_5",
|
||||
"factors.val_pe_score",
|
||||
"factors.val_pb_score",
|
||||
"factors.volume_ratio_score",
|
||||
"factors.val_multiscore",
|
||||
"factors.risk_penalty",
|
||||
"factors.sent_momentum",
|
||||
"factors.sent_market",
|
||||
"factors.sent_divergence",
|
||||
}
|
||||
assert expected_fields.issubset(required)
|
||||
27
tests/test_data_quality.py
Normal file
27
tests/test_data_quality.py
Normal file
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.data_quality import DataQualityResult, summarize_data_quality
|
||||
|
||||
|
||||
def test_summarize_data_quality_produces_score():
|
||||
results = [
|
||||
DataQualityResult("check_a", "ERROR", "fatal issue"),
|
||||
DataQualityResult("check_b", "WARN", "warning issue"),
|
||||
DataQualityResult("check_c", "INFO", "info message"),
|
||||
]
|
||||
|
||||
summary = summarize_data_quality(results, window_days=7)
|
||||
|
||||
assert summary.total_checks == 3
|
||||
assert summary.severity_counts["ERROR"] == 1
|
||||
assert summary.has_blockers is True
|
||||
assert 0.0 <= summary.score < 100.0
|
||||
|
||||
|
||||
def test_data_broker_evaluate_quality_runs_checks(isolated_db):
|
||||
broker = DataBroker()
|
||||
summary = broker.evaluate_data_quality(window_days=1)
|
||||
|
||||
assert 0.0 <= summary.score <= 100.0
|
||||
assert summary.window_days == 1
|
||||
45
tests/test_factor_audit.py
Normal file
45
tests/test_factor_audit.py
Normal file
@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from app.features.factor_audit import audit_factors
|
||||
from app.features.factors import compute_factors
|
||||
from app.utils.db import db_session
|
||||
from tests.factor_utils import populate_sample_data
|
||||
|
||||
|
||||
def test_audit_matches_persisted_values(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
trade_day = date(2025, 2, 14)
|
||||
populate_sample_data(ts_code, trade_day)
|
||||
|
||||
compute_factors(trade_day)
|
||||
summary = audit_factors(trade_day)
|
||||
|
||||
assert summary.mismatched == 0
|
||||
assert summary.missing_persisted == 0
|
||||
assert summary.missing_recomputed == 0
|
||||
assert not summary.issues
|
||||
|
||||
|
||||
def test_audit_detects_drift(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
trade_day = date(2025, 2, 14)
|
||||
populate_sample_data(ts_code, trade_day)
|
||||
|
||||
compute_factors(trade_day)
|
||||
|
||||
trade_date_str = trade_day.strftime("%Y%m%d")
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"UPDATE factors SET mom_5 = mom_5 + 0.05 WHERE ts_code = ? AND trade_date = ?",
|
||||
(ts_code, trade_date_str),
|
||||
)
|
||||
|
||||
summary = audit_factors(trade_day, factors=["mom_5"], tolerance=1e-8, max_issues=5)
|
||||
|
||||
assert summary.mismatched >= 1
|
||||
assert summary.issues
|
||||
first_issue = summary.issues[0]
|
||||
assert first_issue.ts_code == ts_code
|
||||
assert first_issue.factor == "mom_5"
|
||||
62
tests/test_factor_portfolio.py
Normal file
62
tests/test_factor_portfolio.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from app.features.evaluation import (
|
||||
FactorPerformance,
|
||||
evaluate_factor_portfolio,
|
||||
optimize_factor_weights,
|
||||
)
|
||||
from app.features.factors import FactorSpec, compute_factor_range
|
||||
from tests.factor_utils import populate_sample_data
|
||||
|
||||
|
||||
def _seed_factor_history(codes, end_day):
|
||||
specs = [
|
||||
FactorSpec("mom_5", 5),
|
||||
FactorSpec("mom_20", 20),
|
||||
FactorSpec("turn_20", 20),
|
||||
]
|
||||
start_day = end_day - timedelta(days=5)
|
||||
for code in codes:
|
||||
populate_sample_data(code, end_day, days=180)
|
||||
compute_factor_range(start_day, end_day, ts_codes=codes, factors=specs)
|
||||
return specs, start_day
|
||||
|
||||
|
||||
def test_optimize_factor_weights_returns_normalized_vector(isolated_db):
|
||||
codes = [f"0000{i:02d}.SZ" for i in range(1, 4)]
|
||||
end_day = date(2025, 2, 28)
|
||||
specs, start_day = _seed_factor_history(codes, end_day)
|
||||
factor_names = [spec.name for spec in specs]
|
||||
|
||||
weights, performances = optimize_factor_weights(
|
||||
factor_names,
|
||||
start_day,
|
||||
end_day,
|
||||
universe=codes,
|
||||
)
|
||||
|
||||
assert set(weights.keys()) == set(factor_names)
|
||||
assert abs(sum(weights.values()) - 1.0) < 1e-6
|
||||
for perf in performances.values():
|
||||
assert isinstance(perf, FactorPerformance)
|
||||
|
||||
|
||||
def test_evaluate_factor_portfolio_returns_report(isolated_db):
|
||||
codes = [f"0000{i:02d}.SZ" for i in range(1, 4)]
|
||||
end_day = date(2025, 3, 10)
|
||||
specs, start_day = _seed_factor_history(codes, end_day)
|
||||
factor_names = [spec.name for spec in specs]
|
||||
|
||||
report = evaluate_factor_portfolio(
|
||||
factor_names,
|
||||
start_day,
|
||||
end_day,
|
||||
universe=codes,
|
||||
)
|
||||
|
||||
assert set(report.weights.keys()) == set(factor_names)
|
||||
assert isinstance(report.combined, FactorPerformance)
|
||||
assert report.combined.sample_size >= 0
|
||||
assert set(report.components.keys()) == set(factor_names)
|
||||
@ -6,7 +6,6 @@ from datetime import date, timedelta
|
||||
import pytest
|
||||
|
||||
from app.core.indicators import momentum, rolling_mean, volatility
|
||||
from app.data.schema import initialize_database
|
||||
from app.features.factors import (
|
||||
DEFAULT_FACTORS,
|
||||
FactorResult,
|
||||
@ -17,77 +16,15 @@ from app.features.factors import (
|
||||
_valuation_score,
|
||||
_volume_ratio_score,
|
||||
)
|
||||
from app.utils.config import DataPaths, get_config
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def isolated_db(tmp_path):
|
||||
cfg = get_config()
|
||||
original_paths = cfg.data_paths
|
||||
tmp_root = tmp_path / "data"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
cfg.data_paths = DataPaths(root=tmp_root)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cfg.data_paths = original_paths
|
||||
|
||||
|
||||
def _populate_sample_data(ts_code: str, as_of: date) -> None:
|
||||
initialize_database()
|
||||
with db_session() as conn:
|
||||
for offset in range(60):
|
||||
current_day = as_of - timedelta(days=offset)
|
||||
trade_date = current_day.strftime("%Y%m%d")
|
||||
close = 100 + (59 - offset)
|
||||
turnover = 5 + 0.1 * (59 - offset)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO daily
|
||||
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
ts_code,
|
||||
trade_date,
|
||||
close,
|
||||
close,
|
||||
close,
|
||||
close,
|
||||
0.0,
|
||||
1000.0,
|
||||
1_000_000.0,
|
||||
),
|
||||
)
|
||||
pe = 10.0 + (offset % 5)
|
||||
pb = 1.5 + (offset % 3) * 0.1
|
||||
ps = 2.0 + (offset % 4) * 0.1
|
||||
volume_ratio = 0.5 + (offset % 4) * 0.5
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO daily_basic
|
||||
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
ts_code,
|
||||
trade_date,
|
||||
turnover,
|
||||
turnover,
|
||||
volume_ratio,
|
||||
pe,
|
||||
pb,
|
||||
ps,
|
||||
),
|
||||
)
|
||||
from tests.factor_utils import populate_sample_data
|
||||
|
||||
|
||||
def test_compute_factors_persists_and_updates(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
trade_day = date(2025, 1, 30)
|
||||
_populate_sample_data(ts_code, trade_day)
|
||||
populate_sample_data(ts_code, trade_day)
|
||||
|
||||
specs = [
|
||||
*DEFAULT_FACTORS,
|
||||
@ -177,29 +114,57 @@ def test_compute_factors_persists_and_updates(isolated_db):
|
||||
def test_compute_factors_skip_existing(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
trade_day = date(2025, 2, 10)
|
||||
_populate_sample_data(ts_code, trade_day)
|
||||
populate_sample_data(ts_code, trade_day)
|
||||
|
||||
compute_factors(trade_day)
|
||||
skipped = compute_factors(trade_day, skip_existing=True)
|
||||
basic_specs = [
|
||||
FactorSpec("mom_5", 5),
|
||||
FactorSpec("mom_20", 20),
|
||||
FactorSpec("volat_20", 20),
|
||||
FactorSpec("turn_5", 5),
|
||||
]
|
||||
compute_factors(trade_day, basic_specs)
|
||||
skipped = compute_factors(trade_day, basic_specs, skip_existing=True)
|
||||
assert skipped == []
|
||||
|
||||
|
||||
def test_compute_factors_dry_run(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
trade_day = date(2025, 2, 12)
|
||||
populate_sample_data(ts_code, trade_day)
|
||||
|
||||
results = compute_factors(trade_day, persist=False)
|
||||
assert results
|
||||
|
||||
trade_date_str = trade_day.strftime("%Y%m%d")
|
||||
with db_session(read_only=True) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM factors WHERE trade_date = ?",
|
||||
(trade_date_str,),
|
||||
).fetchone()
|
||||
assert count["cnt"] == 0
|
||||
|
||||
|
||||
def test_compute_factors_incremental(isolated_db):
|
||||
ts_code = "000001.SZ"
|
||||
latest_day = date(2025, 2, 10)
|
||||
_populate_sample_data(ts_code, latest_day)
|
||||
populate_sample_data(ts_code, latest_day, days=180)
|
||||
|
||||
first_day = latest_day - timedelta(days=5)
|
||||
compute_factors(first_day)
|
||||
first_day = latest_day - timedelta(days=1)
|
||||
basic_specs = [
|
||||
FactorSpec("mom_5", 5),
|
||||
FactorSpec("mom_20", 20),
|
||||
FactorSpec("turn_20", 20),
|
||||
]
|
||||
compute_factors(first_day, basic_specs)
|
||||
|
||||
summary = compute_factors_incremental(max_trading_days=3)
|
||||
summary = compute_factors_incremental(factors=basic_specs, max_trading_days=3)
|
||||
trade_dates = summary["trade_dates"]
|
||||
assert trade_dates
|
||||
assert trade_dates[0] > first_day
|
||||
assert summary["count"] > 0
|
||||
|
||||
# No new dates should return empty result
|
||||
summary_again = compute_factors_incremental(max_trading_days=3)
|
||||
summary_again = compute_factors_incremental(factors=basic_specs, max_trading_days=3)
|
||||
assert summary_again["count"] == 0
|
||||
|
||||
|
||||
@ -209,10 +174,16 @@ def test_compute_factor_range_filters_universe(isolated_db):
|
||||
end_day = date(2025, 3, 5)
|
||||
start_day = end_day - timedelta(days=1)
|
||||
|
||||
_populate_sample_data(code_a, end_day)
|
||||
_populate_sample_data(code_b, end_day)
|
||||
populate_sample_data(code_a, end_day)
|
||||
populate_sample_data(code_b, end_day)
|
||||
|
||||
results = compute_factor_range(start_day, end_day, ts_codes=[code_a])
|
||||
basic_specs = [
|
||||
FactorSpec("mom_5", 5),
|
||||
FactorSpec("mom_20", 20),
|
||||
FactorSpec("turn_20", 20),
|
||||
]
|
||||
|
||||
results = compute_factor_range(start_day, end_day, ts_codes=[code_a], factors=basic_specs)
|
||||
assert results
|
||||
assert {result.ts_code for result in results} == {code_a}
|
||||
|
||||
@ -220,71 +191,39 @@ def test_compute_factor_range_filters_universe(isolated_db):
|
||||
rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall()
|
||||
assert {row["ts_code"] for row in rows} == {code_a}
|
||||
|
||||
repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a])
|
||||
repeated = compute_factor_range(
|
||||
start_day,
|
||||
end_day,
|
||||
ts_codes=[code_a],
|
||||
factors=basic_specs,
|
||||
skip_existing=True,
|
||||
)
|
||||
assert repeated == []
|
||||
|
||||
|
||||
def test_compute_extended_factors(isolated_db):
|
||||
"""Test computation of extended factors."""
|
||||
# Use the existing _populate_sample_data function
|
||||
from app.utils.data_access import DataBroker
|
||||
broker = DataBroker()
|
||||
|
||||
# Sample data for 5 trading days
|
||||
dates = ["20240101", "20240102", "20240103", "20240104", "20240105"]
|
||||
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
|
||||
|
||||
# Populate daily data
|
||||
for ts_code in ts_codes:
|
||||
for i, trade_date in enumerate(dates):
|
||||
broker.insert_or_update_daily(
|
||||
ts_code,
|
||||
trade_date,
|
||||
open_price=10.0 + i * 0.1,
|
||||
high=10.5 + i * 0.1,
|
||||
low=9.5 + i * 0.1,
|
||||
close=10.0 + i * 0.2, # 上涨趋势
|
||||
pre_close=10.0 + (i - 1) * 0.2 if i > 0 else 10.0,
|
||||
vol=100000 + i * 10000,
|
||||
amount=1000000 + i * 100000,
|
||||
)
|
||||
|
||||
broker.insert_or_update_daily_basic(
|
||||
ts_code,
|
||||
trade_date,
|
||||
close=10.0 + i * 0.2,
|
||||
turnover_rate=1.0 + i * 0.1,
|
||||
turnover_rate_f=1.0 + i * 0.1,
|
||||
volume_ratio=1.0 + (i % 3) * 0.2, # 在0.8-1.2之间变化
|
||||
pe=15.0 + (i % 3) * 2, # 在15-19之间变化
|
||||
pe_ttm=15.0 + (i % 3) * 2,
|
||||
pb=1.5 + (i % 3) * 0.1, # 在1.5-1.7之间变化
|
||||
ps=3.0 + (i % 3) * 0.2, # 在3.0-3.4之间变化
|
||||
ps_ttm=3.0 + (i % 3) * 0.2,
|
||||
dv_ratio=2.0 + (i % 3) * 0.1, # 股息率
|
||||
total_mv=1000000 + i * 100000,
|
||||
circ_mv=800000 + i * 80000,
|
||||
)
|
||||
|
||||
# Compute factors with extended factors
|
||||
"""Extended factors should be persisted alongside base factors."""
|
||||
from app.features.extended_factors import EXTENDED_FACTORS
|
||||
|
||||
trade_day = date(2025, 2, 28)
|
||||
ts_codes = ["000001.SZ", "000002.SZ"]
|
||||
for code in ts_codes:
|
||||
populate_sample_data(code, trade_day, days=120)
|
||||
|
||||
all_factors = list(DEFAULT_FACTORS) + EXTENDED_FACTORS
|
||||
|
||||
trade_day = date(2024, 1, 5)
|
||||
results = compute_factors(trade_day, all_factors)
|
||||
|
||||
# Verify that we got results
|
||||
|
||||
assert results
|
||||
|
||||
# Verify that extended factors are computed
|
||||
result_map = {result.ts_code: result for result in results}
|
||||
ts_code = "000001.SZ"
|
||||
assert ts_code in result_map
|
||||
result = result_map[ts_code]
|
||||
|
||||
# Check that extended factors are present in the results
|
||||
extended_factor_names = [spec.name for spec in EXTENDED_FACTORS]
|
||||
for factor_name in extended_factor_names:
|
||||
assert factor_name in result.values
|
||||
# Values should not be None
|
||||
assert result.values[factor_name] is not None
|
||||
for code in ts_codes:
|
||||
assert code in result_map
|
||||
factor_payload = result_map[code].values
|
||||
required_extended = {
|
||||
"tech_rsi_14",
|
||||
"tech_macd_signal",
|
||||
"trend_ma_cross",
|
||||
"micro_trade_imbalance",
|
||||
}
|
||||
assert required_extended.issubset(factor_payload.keys())
|
||||
for name in required_extended:
|
||||
assert factor_payload.get(name) is not None
|
||||
|
||||
50
tests/test_news_ingest.py
Normal file
50
tests/test_news_ingest.py
Normal file
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.ingest import entity_recognition
|
||||
from app.ingest.rss import RssItem, save_news_items
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
def test_save_news_items_persists_entities_and_heat(isolated_db):
|
||||
# Reset mapping state
|
||||
entity_recognition._COMPANY_MAPPING_INITIALIZED = False
|
||||
mapper = entity_recognition.company_mapper
|
||||
mapper.name_to_code.clear()
|
||||
mapper.short_names.clear()
|
||||
mapper.aliases.clear()
|
||||
|
||||
ts_code = "000001.SZ"
|
||||
mapper.add_company(ts_code, "平安银行股份有限公司", "平安银行")
|
||||
|
||||
item = RssItem(
|
||||
id="news-1",
|
||||
title="平安银行利好消息爆发",
|
||||
link="https://example.com/news",
|
||||
published=datetime.utcnow(),
|
||||
summary="平安银行股份有限公司公布季度业绩,银行板块再迎利好。",
|
||||
source="TestWire",
|
||||
)
|
||||
|
||||
item.extract_entities()
|
||||
assert ts_code in item.ts_codes
|
||||
|
||||
saved = save_news_items([item])
|
||||
assert saved == 1
|
||||
|
||||
with db_session(read_only=True) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT heat, entities FROM news WHERE ts_code = ? ORDER BY pub_time DESC LIMIT 1",
|
||||
(ts_code,),
|
||||
).fetchone()
|
||||
|
||||
assert row is not None
|
||||
assert 0.0 <= row["heat"] <= 1.0
|
||||
assert row["heat"] > 0.6
|
||||
|
||||
entities_payload = json.loads(row["entities"])
|
||||
assert ts_code in entities_payload.get("ts_codes", [])
|
||||
assert "industries" in entities_payload
|
||||
assert "important_keywords" in entities_payload
|
||||
Loading…
Reference in New Issue
Block a user