add factor optimization and portfolio evaluation features

This commit is contained in:
sam 2025-10-17 08:59:18 +08:00
parent 74d98bf4e0
commit 59ffd86f82
18 changed files with 1369 additions and 150 deletions

View File

@ -153,10 +153,20 @@ class BacktestEngine:
"daily_basic.volume_ratio",
"stk_limit.up_limit",
"stk_limit.down_limit",
"factors.mom_5",
"factors.mom_20",
"factors.mom_60",
"factors.volat_20",
"factors.turn_20",
"factors.turn_5",
"factors.val_pe_score",
"factors.val_pb_score",
"factors.volume_ratio_score",
"factors.val_multiscore",
"factors.risk_penalty",
"factors.sent_momentum",
"factors.sent_market",
"factors.sent_divergence",
}
selected_structures = (
cfg.game_structures

View File

@ -1,4 +1,5 @@
"""Factor performance evaluation utilities."""
from dataclasses import dataclass
from datetime import date
from typing import Dict, List, Optional, Sequence, Tuple
@ -62,6 +63,20 @@ class FactorPerformance:
}
@dataclass
class FactorPortfolioReport:
weights: Dict[str, float]
combined: FactorPerformance
components: Dict[str, FactorPerformance]
def to_dict(self) -> Dict[str, object]:
return {
"weights": dict(self.weights),
"combined": self.combined.to_dict(),
"components": {name: perf.to_dict() for name, perf in self.components.items()},
}
def evaluate_factor(
factor_name: str,
start_date: date,
@ -182,6 +197,81 @@ def evaluate_factor(
return performance
def optimize_factor_weights(
factor_names: Sequence[str],
start_date: date,
end_date: date,
*,
universe: Optional[Sequence[str]] = None,
method: str = "ic_mean",
) -> Tuple[Dict[str, float], Dict[str, FactorPerformance]]:
"""Derive factor weights based on historical performance metrics."""
if not factor_names:
raise ValueError("factor_names must not be empty")
normalized_universe = list(universe) if universe else None
performances: Dict[str, FactorPerformance] = {}
scores: Dict[str, float] = {}
for name in factor_names:
perf = evaluate_factor(name, start_date, end_date, normalized_universe)
performances[name] = perf
if method == "ic_ir":
metric = perf.ic_ir
elif method == "rank_ic":
metric = perf.rank_ic_mean
else:
metric = perf.ic_mean
scores[name] = max(0.0, float(metric))
weights = _normalize_weight_map(factor_names, scores)
return weights, performances
def evaluate_factor_portfolio(
factor_names: Sequence[str],
start_date: date,
end_date: date,
*,
universe: Optional[Sequence[str]] = None,
weights: Optional[Dict[str, float]] = None,
method: str = "ic_mean",
) -> FactorPortfolioReport:
"""Evaluate a weighted combination of factors."""
if not factor_names:
raise ValueError("factor_names must not be empty")
normalized_universe = _normalize_universe(universe)
if weights is None:
weights, performances = optimize_factor_weights(
factor_names,
start_date,
end_date,
universe=universe,
method=method,
)
else:
weights = _normalize_weight_map(factor_names, weights)
performances = {
name: evaluate_factor(name, start_date, end_date, universe)
for name in factor_names
}
weight_vector = [weights[name] for name in factor_names]
combined = _evaluate_combined_factor(
factor_names,
weight_vector,
start_date,
end_date,
normalized_universe,
)
return FactorPortfolioReport(weights=weights, combined=combined, components=performances)
def combine_factors(
factor_names: Sequence[str],
weights: Optional[Sequence[float]] = None
@ -208,6 +298,34 @@ def combine_factors(
return FactorSpec(name, window)
def _normalize_weight_map(
factor_names: Sequence[str],
weights: Dict[str, float],
) -> Dict[str, float]:
normalized: Dict[str, float] = {}
for name in factor_names:
if name not in weights:
continue
try:
value = float(weights[name])
except (TypeError, ValueError):
continue
if np.isnan(value) or value <= 0.0:
continue
normalized[name] = value
if len(normalized) != len(factor_names):
weight = 1.0 / len(factor_names)
return {name: weight for name in factor_names}
total = sum(normalized.values())
if total <= 0.0:
weight = 1.0 / len(factor_names)
return {name: weight for name in factor_names}
return {name: value / total for name, value in normalized.items()}
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
if not universe:
return None
@ -269,6 +387,49 @@ def _fetch_factor_cross_section(
return result
def _fetch_factor_matrix(
conn,
columns: Sequence[str],
trade_date: str,
universe: Optional[Tuple[str, ...]],
) -> Dict[str, List[float]]:
if not columns:
return {}
params: List[object] = [trade_date]
column_clause = ", ".join(columns)
query = f"SELECT ts_code, {column_clause} FROM factors WHERE trade_date = ?"
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
rows = conn.execute(query, params).fetchall()
matrix: Dict[str, List[float]] = {}
for row in rows:
ts_code = row["ts_code"]
if not ts_code:
continue
vector: List[float] = []
valid = True
for column in columns:
value = row[column]
if value is None:
valid = False
break
try:
numeric = float(value)
except (TypeError, ValueError):
valid = False
break
if not np.isfinite(numeric):
valid = False
break
vector.append(numeric)
if not valid:
continue
matrix[ts_code] = vector
return matrix
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
row = conn.execute(
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
@ -306,6 +467,31 @@ def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, f
return result
def _estimate_turnover_from_maps(
series: Sequence[Tuple[str, Dict[str, float]]],
) -> Optional[float]:
if len(series) < 2:
return None
turnovers: List[float] = []
for idx in range(1, len(series)):
_, prev_map = series[idx - 1]
_, curr_map = series[idx]
if not prev_map or not curr_map:
continue
prev_threshold = np.percentile(list(prev_map.values()), 80)
curr_threshold = np.percentile(list(curr_map.values()), 80)
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
union = prev_top | curr_top
if not union:
continue
turnover = len(prev_top ^ curr_top) / len(union)
turnovers.append(turnover)
if turnovers:
return float(np.mean(turnovers))
return None
def _estimate_turnover_rate(
factor_name: str,
trade_dates: Sequence[str],
@ -339,3 +525,86 @@ def _estimate_turnover_rate(
if turnovers:
return float(np.mean(turnovers))
return None
def _evaluate_combined_factor(
factor_names: Sequence[str],
weights: Sequence[float],
start_date: date,
end_date: date,
universe: Optional[Tuple[str, ...]],
) -> FactorPerformance:
performance = FactorPerformance("portfolio")
if not factor_names or not weights:
return performance
weight_array = np.array(weights, dtype=float)
start_str = start_date.strftime("%Y%m%d")
end_str = end_date.strftime("%Y%m%d")
with db_session(read_only=True) as conn:
trade_dates = _list_factor_dates(conn, start_str, end_str, universe)
combined_series: List[Tuple[str, Dict[str, float]]] = []
for trade_date_str in trade_dates:
with db_session(read_only=True) as conn:
matrix = _fetch_factor_matrix(conn, factor_names, trade_date_str, universe)
if not matrix:
continue
next_trade = _next_trade_date(conn, trade_date_str)
if not next_trade:
continue
curr_close = _fetch_close_map(conn, trade_date_str, matrix.keys())
next_close = _fetch_close_map(conn, next_trade, matrix.keys())
factor_values: List[float] = []
returns: List[float] = []
combined_map: Dict[str, float] = {}
for ts_code, vector in matrix.items():
curr = curr_close.get(ts_code)
nxt = next_close.get(ts_code)
if curr is None or nxt is None or curr <= 0:
continue
combined_value = float(np.dot(weight_array, np.array(vector, dtype=float)))
factor_values.append(combined_value)
returns.append((nxt - curr) / curr)
combined_map[ts_code] = combined_value
if len(factor_values) < 20:
continue
values_array = np.array(factor_values, dtype=float)
returns_array = np.array(returns, dtype=float)
if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
continue
try:
ic, _ = stats.pearsonr(values_array, returns_array)
rank_ic, _ = stats.spearmanr(values_array, returns_array)
except Exception: # noqa: BLE001
continue
if not (np.isfinite(ic) and np.isfinite(rank_ic)):
continue
performance.ic_series.append(float(ic))
performance.rank_ic_series.append(float(rank_ic))
combined_series.append((trade_date_str, combined_map))
sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
quantile = len(sorted_pairs) // 5
if quantile > 0:
top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
performance.return_spreads.append(float(np.mean(top_returns) - np.mean(bottom_returns)))
if performance.return_spreads:
returns_mean = float(np.mean(performance.return_spreads))
returns_std = float(np.std(performance.return_spreads))
if returns_std > 0:
performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
performance.sample_size = len(performance.ic_series)
performance.turnover_rate = _estimate_turnover_from_maps(combined_series)
return performance

View File

@ -0,0 +1,232 @@
"""Utilities for auditing persisted factor values against live formulas."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from datetime import date
from typing import Dict, List, Mapping, Optional, Sequence
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factors,
lookup_factor_spec,
)
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_audit"}
@dataclass
class FactorAuditIssue:
"""Details for a single factor mismatch discovered during auditing."""
ts_code: str
factor: str
stored: Optional[float]
recomputed: Optional[float]
difference: Optional[float]
@dataclass
class FactorAuditSummary:
"""Aggregated results for a factor audit run."""
trade_date: date
tolerance: float
factor_names: List[str]
total_persisted: int
total_recomputed: int
evaluated: int
mismatched: int
missing_persisted: int
missing_recomputed: int
missing_columns: List[str] = field(default_factory=list)
issues: List[FactorAuditIssue] = field(default_factory=list)
def audit_factors(
trade_date: date,
*,
factors: Optional[Sequence[str | FactorSpec]] = None,
tolerance: float = 1e-6,
max_issues: int = 50,
) -> FactorAuditSummary:
"""Recompute factor values and compare them with persisted records.
Args:
trade_date: 需要审计的交易日
factors: 因子名称或 ``FactorSpec`` 序列缺省为默认因子集合
tolerance: 比较阈值超出视为不一致
max_issues: 限制返回的详细问题数量
Returns:
FactorAuditSummary: 审计结果摘要
"""
specs = _resolve_factor_specs(factors)
factor_names = [spec.name for spec in specs]
trade_date_str = trade_date.strftime("%Y%m%d")
persisted_map, missing_columns = _load_persisted_factors(trade_date_str, factor_names)
recomputed_results = compute_factors(
trade_date,
specs,
persist=False,
)
recomputed_map = {result.ts_code: result.values for result in recomputed_results}
mismatched = 0
evaluated = 0
missing_persisted = 0
issues: List[FactorAuditIssue] = []
for ts_code, values in recomputed_map.items():
stored = persisted_map.get(ts_code)
if not stored:
missing_persisted += 1
LOGGER.debug(
"审计未找到持久化记录 ts_code=%s trade_date=%s",
ts_code,
trade_date_str,
extra=LOG_EXTRA,
)
continue
evaluated += 1
for factor in factor_names:
recomputed_value = values.get(factor)
stored_value = stored.get(factor)
numeric_recomputed = _coerce_float(recomputed_value)
numeric_stored = _coerce_float(stored_value)
if numeric_recomputed is None and numeric_stored is None:
continue
if numeric_recomputed is None or numeric_stored is None:
mismatched += 1
if len(issues) < max_issues:
issues.append(
FactorAuditIssue(
ts_code=ts_code,
factor=factor,
stored=stored_value,
recomputed=recomputed_value,
difference=None,
)
)
continue
diff = abs(numeric_recomputed - numeric_stored)
if math.isnan(diff) or diff > tolerance:
mismatched += 1
if len(issues) < max_issues:
issues.append(
FactorAuditIssue(
ts_code=ts_code,
factor=factor,
stored=numeric_stored,
recomputed=numeric_recomputed,
difference=diff if not math.isnan(diff) else None,
)
)
missing_recomputed = len(
{code for code in persisted_map.keys() if code not in recomputed_map}
)
summary = FactorAuditSummary(
trade_date=trade_date,
tolerance=tolerance,
factor_names=factor_names,
total_persisted=len(persisted_map),
total_recomputed=len(recomputed_map),
evaluated=evaluated,
mismatched=mismatched,
missing_persisted=missing_persisted,
missing_recomputed=missing_recomputed,
missing_columns=missing_columns,
issues=issues,
)
LOGGER.info(
"因子审计完成 trade_date=%s evaluated=%s mismatched=%s missing=%s/%s",
trade_date_str,
evaluated,
mismatched,
missing_persisted,
missing_recomputed,
extra=LOG_EXTRA,
)
if missing_columns:
LOGGER.warning(
"因子审计缺少字段 columns=%s trade_date=%s",
missing_columns,
trade_date_str,
extra=LOG_EXTRA,
)
return summary
def _resolve_factor_specs(
factors: Optional[Sequence[str | FactorSpec]],
) -> List[FactorSpec]:
if not factors:
return list(DEFAULT_FACTORS)
resolved: Dict[str, FactorSpec] = {}
for item in factors:
if isinstance(item, FactorSpec):
resolved[item.name] = FactorSpec(name=item.name, window=item.window)
continue
spec = lookup_factor_spec(str(item))
if spec is None:
LOGGER.debug("忽略未知因子,无法审计 factor=%s", item, extra=LOG_EXTRA)
continue
resolved[spec.name] = spec
return list(resolved.values()) or list(DEFAULT_FACTORS)
def _load_persisted_factors(
trade_date: str,
factor_names: Sequence[str],
) -> tuple[Dict[str, Dict[str, Optional[float]]], List[str]]:
if not factor_names:
return {}, []
with db_session(read_only=True) as conn:
table_info = conn.execute("PRAGMA table_info(factors)").fetchall()
available_columns: set[str] = set()
for row in table_info:
if isinstance(row, Mapping):
available_columns.add(str(row.get("name")))
else:
available_columns.add(str(row[1]))
selected = [name for name in factor_names if name in available_columns]
missing_columns = [name for name in factor_names if name not in available_columns]
if not selected:
return {}, missing_columns
column_clause = ", ".join(["ts_code", *selected])
query = f"SELECT {column_clause} FROM factors WHERE trade_date = ?"
rows = conn.execute(query, (trade_date,)).fetchall()
persisted: Dict[str, Dict[str, Optional[float]]] = {}
for row in rows:
ts_code = row["ts_code"]
persisted[ts_code] = {name: row[name] for name in selected}
return persisted, missing_columns
def _coerce_float(value: object) -> Optional[float]:
if value is None:
return None
try:
numeric = float(value)
except (TypeError, ValueError):
return None
if math.isnan(numeric) or not math.isfinite(numeric):
return None
return numeric
__all__ = [
"FactorAuditIssue",
"FactorAuditSummary",
"audit_factors",
]

View File

@ -126,6 +126,7 @@ def compute_factors(
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = False,
batch_size: int = 100,
persist: bool = True,
) -> List[FactorResult]:
"""Calculate and persist factor values for the requested date.
@ -139,6 +140,7 @@ def compute_factors(
ts_codes: 可选限制计算的证券代码列表
skip_existing: 是否跳过已存在的因子值
batch_size: 批处理大小用于优化性能
persist: 是否写入数据库False 时仅计算返回结果
Returns:
因子计算结果列表
@ -238,8 +240,15 @@ def compute_factors(
extra=LOG_EXTRA,
)
if rows_to_persist:
if persist and rows_to_persist:
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
elif not persist:
LOGGER.debug(
"因子干跑完成,未写入数据库 trade_date=%s universe=%s",
trade_date_str,
len(universe),
extra=LOG_EXTRA,
)
# 更新UI进度状态为完成
if progress:
@ -279,8 +288,18 @@ def compute_factor_range(
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = True,
persist: bool = True,
) -> List[FactorResult]:
"""Compute factors for all trading days within ``[start, end]`` inclusive."""
"""Compute factors for all trading days within ``[start, end]`` inclusive.
Args:
start: 开始日期
end: 结束日期
factors: 参与计算的因子列表
ts_codes: 限定的股票池
skip_existing: 是否跳过已有记录
persist: 是否写入数据库False 表示仅返回计算结果
"""
if end < start:
raise ValueError("end date must not precede start date")
@ -305,6 +324,7 @@ def compute_factor_range(
factors,
ts_codes=allowed,
skip_existing=skip_existing,
persist=persist,
)
)
return aggregated
@ -316,6 +336,7 @@ def compute_factors_incremental(
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = True,
max_trading_days: Optional[int] = 5,
persist: bool = True,
) -> Dict[str, object]:
"""增量计算因子(从最新一条因子记录之后开始)。
@ -324,6 +345,7 @@ def compute_factors_incremental(
ts_codes: 限定计算的证券池
skip_existing: 是否跳过已存在数据
max_trading_days: 限制本次计算的交易日数量按交易日计数
persist: 是否写入数据库False 表示仅计算返回结果
Returns:
包含起止日期参与交易日及计算结果的字典
@ -360,6 +382,7 @@ def compute_factors_incremental(
factors,
ts_codes=codes_tuple,
skip_existing=skip_existing,
persist=persist,
)
)

View File

@ -1645,7 +1645,7 @@ class DataBroker:
def get_data_coverage(self, start_date: str, end_date: str) -> Dict:
"""获取指定日期范围内的数据覆盖情况。
Args:
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
@ -1674,6 +1674,18 @@ class DataBroker:
LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA)
return {}
def evaluate_data_quality(
self,
*,
window_days: int = 7,
top_issues: int = 5,
) -> "DataQualitySummary":
"""Run data-quality checks and return a scored summary."""
from app.utils.data_quality import evaluate_data_quality as _evaluate
return _evaluate(window_days=window_days, top_issues=top_issues)
def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table)
if columns is None:

View File

@ -1,9 +1,9 @@
"""Utility helpers for performing lightweight data quality checks."""
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Sequence
from app.utils.db import db_session
from app.utils.logging import get_logger
@ -22,6 +22,30 @@ class DataQualityResult:
extras: Optional[Dict[str, object]] = None
@dataclass
class DataQualitySummary:
window_days: int
score: float
total_checks: int
severity_counts: Dict[Severity, int] = field(default_factory=dict)
blocking: List[DataQualityResult] = field(default_factory=list)
warnings: List[DataQualityResult] = field(default_factory=list)
informational: List[DataQualityResult] = field(default_factory=list)
@property
def has_blockers(self) -> bool:
return bool(self.blocking)
def as_dict(self) -> Dict[str, object]:
return {
"window_days": self.window_days,
"score": self.score,
"total_checks": self.total_checks,
"severity_counts": dict(self.severity_counts),
"has_blockers": self.has_blockers,
}
def _parse_date(value: object) -> Optional[date]:
"""Best-effort parse for trade_date columns stored as str/int."""
if value is None:
@ -366,3 +390,60 @@ def run_data_quality_checks(*, window_days: int = 7) -> List[DataQualityResult]:
)
return results
def summarize_data_quality(
results: Sequence[DataQualityResult],
*,
window_days: int,
top_issues: int = 5,
) -> DataQualitySummary:
"""Aggregate quality checks into a normalized score and severity summary."""
severity_buckets: Dict[str, List[DataQualityResult]] = {}
for result in results:
severity = (result.severity or "INFO").upper()
severity_buckets.setdefault(severity, []).append(result)
counts = {severity: len(items) for severity, items in severity_buckets.items()}
if not results:
return DataQualitySummary(
window_days=window_days,
score=100.0,
total_checks=0,
severity_counts=counts,
)
weights = {"ERROR": 5.0, "WARN": 2.0, "INFO": 0.0}
penalty = 0.0
for result in results:
severity = (result.severity or "INFO").upper()
penalty += weights.get(severity, 2.0)
max_weight = max(weights.values(), default=1.0)
max_penalty = max(1.0, len(results) * max_weight)
score = max(0.0, 100.0 - (penalty / max_penalty) * 100.0)
return DataQualitySummary(
window_days=window_days,
score=round(score, 2),
total_checks=len(results),
severity_counts=counts,
blocking=severity_buckets.get("ERROR", [])[:top_issues],
warnings=severity_buckets.get("WARN", [])[:top_issues],
informational=severity_buckets.get("INFO", [])[:top_issues],
)
def evaluate_data_quality(
*,
window_days: int = 7,
top_issues: int = 5,
) -> DataQualitySummary:
"""Run quality checks and return a scored summary."""
results = run_data_quality_checks(window_days=window_days)
return summarize_data_quality(
results,
window_days=window_days,
top_issues=top_issues,
)

View File

@ -6,13 +6,13 @@
| 工作项 | 状态 | 说明 |
| --- | --- | --- |
| 因子计算流水线 | 🔄 | `compute_factors()` 及持久化流程已可用,仍需支持增量模式与公式复核。 |
| DataBroker 弹性 | 🔄 | 自动重试、健康监控已接入;数据质量评分体系待设计。 |
| 因子库扩展 | 🔄 | 动量/估值/流动性/情绪因子已上线;权重优化与组合评估待补。 |
| 新闻数据接入 | 🔄 | RSS 解析与情感分析可用;实体识别与时效评分仍缺。 |
| 数据完整性体系 | ⏳ | 需建立巡检脚本、异常告警与补数流程。 |
| 选股使用预计算因子 | ⏳ | 调整选股流程以直接消费持久化因子,避免重复计算。 |
| 因子公式复核 | ⏳ | 梳理现有公式、补充可视化验证与文档沉淀。 |
| 因子计算流水线 | ✅ | 新增 `scripts/run_factor_pipeline.py` 支撑增量/干跑模式,并提供 `factor_audit` 审计报表。 |
| DataBroker 弹性 | ✅ | 集成 `evaluate_data_quality()` 评分与分级汇总,可直接输出阻塞项。 |
| 因子库扩展 | ✅ | 补齐因子权重优化与组合回测评估,支持组合度量与组件表现。 |
| 新闻数据接入 | ✅ | 实体识别与时效热度打分落地,并配套单元测试验证字段持久化。 |
| 数据完整性体系 | ✅ | 新增 `scripts/run_data_integrity.py` 自动巡检,异常触发告警并可联动补数。 |
| 选股使用预计算因子 | ✅ | 回测引擎统一加载 `factors.*` 快照,默认不再重复计算核心因子。 |
| 因子公式复核 | ✅ | 提供 `factor_audit` 工具与文档,支持公式复核与漂移检测。 |
## 决策优化与强化学习
@ -49,8 +49,8 @@
| 工作项 | 状态 | 说明 |
| --- | --- | --- |
| 风险代理决策闭环 | ✅ | `risk_round` 可调整决策并写入 `risk_assessment`。 |
| 风险事件持久化 | 🔄 | 风险建议已写入决策结构,待落库至 `risk_events` 并完善 UI 呈现。 |
| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 |
| 风险事件持久化 | 🔄 | 已形成 `risk_round``risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 |
| 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 |
| 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 |

View File

@ -0,0 +1,43 @@
# 因子公式复核指引
本文档总结因子公式复核流程,并说明如何使用新引入的工具快速检查数据库中的持久化因子。
## 1. 快速开始
1. **增量/干跑计算**
```bash
python scripts/run_factor_pipeline.py --mode incremental --max-days 5
```
- 默认会写入数据库,若只想验证公式可加 `--no-persist`
2. **执行公式审计**
```bash
python scripts/run_factor_pipeline.py --mode single --trade-date 20250210 --audit
```
- 同等功能也可通过 Python 调用:
```python
from datetime import date
from app.features.factor_audit import audit_factors
summary = audit_factors(date(2025, 2, 10))
print(summary.to_dict())
```
3. **查看得分**`summary.mismatched` 为不一致条目数;若为 0 表示通过。
## 2. 常见使用场景
- **版本升级后复核**:指定 trade_date 运行审计,确认公式变更未引入漂移。
- **数据库回滚/恢复**:使用 `--no-persist --audit` 快速检查备份数据的完整性。
- **问题定位**`summary.issues` 提供具体股票代码、因子名与差值,便于对账。
## 3. 结果解读
| 字段 | 说明 |
| --- | --- |
| `score` | 数据质量得分0-100。 |
| `blocking` | 会导致任务失败的错误;需优先处理。 |
| `warnings` | 风险提示,可安排在巡检时处理。 |
> 注:公式审计依赖 `factors` 表的历史数据;若发现字段缺失,请先运行 `scripts/run_factor_pipeline.py` 补齐。

View File

@ -0,0 +1,78 @@
"""Command-line entrypoint for data integrity checks and remediation."""
from __future__ import annotations
import argparse
from datetime import date, timedelta
from app.utils import alerts
from app.utils.data_access import DataBroker
from app.utils.data_quality import evaluate_data_quality
def main() -> None:
args = _build_parser().parse_args()
summary = evaluate_data_quality(window_days=args.window, top_issues=args.top)
_print_summary(summary)
if summary.has_blockers:
alerts.add_warning(
"data_quality",
f"检测到 {len(summary.blocking)} 项阻塞数据质量问题,得分 {summary.score:.1f}",
detail=str(summary.as_dict()),
)
if args.auto_fill:
_trigger_auto_fill(args.window)
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run data integrity checks and optional remediation.")
parser.add_argument(
"--window",
type=int,
default=7,
help="Number of trailing days to inspect (default: 7).",
)
parser.add_argument(
"--top",
type=int,
default=5,
help="Maximum issues per severity to display (default: 5).",
)
parser.add_argument(
"--auto-fill",
action="store_true",
help="Trigger DataBroker coverage runner when blocking issues are detected.",
)
return parser
def _print_summary(summary) -> None:
print(f"窗口: {summary.window_days} 天 | 质量得分: {summary.score:.1f} | 总检查数: {summary.total_checks}")
if summary.severity_counts:
print("严重度统计:")
for severity, count in summary.severity_counts.items():
print(f" - {severity}: {count}")
if summary.blocking:
print("阻塞问题:")
for issue in summary.blocking:
print(f" [ERROR] {issue.check}: {issue.detail}")
if summary.warnings:
print("警告:")
for issue in summary.warnings:
print(f" [WARN] {issue.check}: {issue.detail}")
def _trigger_auto_fill(window_days: int) -> None:
broker = DataBroker()
end = date.today()
start = end - timedelta(days=window_days)
try:
broker.coverage_runner(start, end)
print(f"已触发补数流程: {start} -> {end}")
except Exception as exc: # noqa: BLE001
alerts.add_warning("data_quality", "自动补数失败", detail=str(exc))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,243 @@
"""Command-line helper for running the factor computation pipeline."""
from __future__ import annotations
import argparse
from datetime import date, datetime
from typing import Iterable, List, Optional, Sequence
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factor_range,
compute_factors,
compute_factors_incremental,
lookup_factor_spec,
)
from app.features.factor_audit import audit_factors
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
def main() -> None:
args = _build_parser().parse_args()
persist = not args.no_persist
factor_specs = _resolve_factor_specs(args.factors)
ts_codes = _normalize_codes(args.ts_codes)
batch_size = args.batch_size or 100
if args.mode == "single":
if not args.trade_date:
raise SystemExit("--trade-date is required in single mode")
trade_day = _parse_date(args.trade_date)
results = compute_factors(
trade_day,
factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
batch_size=batch_size,
persist=persist,
)
_print_summary_single(trade_day, results, persist)
audit_dates = [trade_day] if args.audit else []
elif args.mode == "range":
if not args.start or not args.end:
raise SystemExit("--start and --end are required in range mode")
start = _parse_date(args.start)
end = _parse_date(args.end)
results = compute_factor_range(
start,
end,
factors=factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
persist=persist,
)
_print_summary_range(start, end, results, persist)
audit_dates = sorted({result.trade_date for result in results}) if args.audit else []
else:
summary = compute_factors_incremental(
factors=factor_specs,
ts_codes=ts_codes,
skip_existing=args.skip_existing,
max_trading_days=args.max_days,
persist=persist,
)
_print_summary_incremental(summary, persist)
audit_dates = summary.get("trade_dates", []) if args.audit else []
if args.audit and audit_dates:
for audit_date in audit_dates:
summary = audit_factors(
audit_date,
factors=factor_specs,
tolerance=args.audit_tolerance,
max_issues=args.max_audit_issues,
)
_print_audit_summary(summary)
elif args.audit:
LOGGER.info("无可审计的日期,跳过因子审计步骤")
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run factor computation pipeline.")
parser.add_argument(
"--mode",
choices=("single", "range", "incremental"),
default="single",
help="Pipeline mode (default: single).",
)
parser.add_argument("--trade-date", help="Trade date (YYYYMMDD) for single mode.")
parser.add_argument("--start", help="Start date (YYYYMMDD) for range mode.")
parser.add_argument("--end", help="End date (YYYYMMDD) for range mode.")
parser.add_argument(
"--max-days",
type=int,
default=5,
help="Limit of trading days for incremental mode.",
)
parser.add_argument(
"--ts-code",
dest="ts_codes",
action="append",
help="Limit computation to specific ts_code. Can be provided multiple times.",
)
parser.add_argument(
"--factor",
dest="factors",
action="append",
help="Factor name to include. Defaults to the built-in set.",
)
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip securities that already have persisted values for the target date(s).",
)
parser.add_argument(
"--no-persist",
action="store_true",
help="Dry-run mode; compute factors without writing to the database.",
)
parser.add_argument(
"--batch-size",
type=int,
default=100,
help="Override default batch size when computing factors.",
)
parser.add_argument(
"--audit",
action="store_true",
help="Run formula audit after computation completes.",
)
parser.add_argument(
"--audit-tolerance",
type=float,
default=1e-6,
help="Allowed absolute difference when auditing factors.",
)
parser.add_argument(
"--max-audit-issues",
type=int,
default=50,
help="Maximum number of detailed audit issues to print.",
)
return parser
def _resolve_factor_specs(names: Optional[Sequence[str]]) -> List[FactorSpec]:
if not names:
return list(DEFAULT_FACTORS)
resolved: List[FactorSpec] = []
seen: set[str] = set()
for name in names:
spec = lookup_factor_spec(name)
if spec is None:
LOGGER.warning("未知因子,忽略: %s", name)
continue
if spec.name in seen:
continue
resolved.append(spec)
seen.add(spec.name)
return resolved or list(DEFAULT_FACTORS)
def _normalize_codes(codes: Optional[Iterable[str]]) -> List[str] | None:
if not codes:
return None
normalized = []
for code in codes:
text = (code or "").strip().upper()
if text:
normalized.append(text)
return normalized or None
def _parse_date(value: str) -> date:
value = value.strip()
for fmt in ("%Y%m%d", "%Y-%m-%d"):
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
raise SystemExit(f"Invalid date: {value}")
def _print_summary_single(trade_day: date, results: Sequence[FactorResult], persist: bool) -> None:
LOGGER.info(
"单日因子计算完成 trade_date=%s rows=%s persist=%s",
trade_day.isoformat(),
len(results),
bool(persist),
)
def _print_summary_range(start: date, end: date, results: Sequence[FactorResult], persist: bool) -> None:
trade_dates = sorted({result.trade_date for result in results})
LOGGER.info(
"区间因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
start.isoformat(),
end.isoformat(),
len(trade_dates),
len(results),
bool(persist),
)
def _print_summary_incremental(summary: dict, persist: bool) -> None:
trade_dates = summary.get("trade_dates") or []
start = trade_dates[0].isoformat() if trade_dates else None
end = trade_dates[-1].isoformat() if trade_dates else None
LOGGER.info(
"增量因子计算完成 start=%s end=%s days=%s rows=%s persist=%s",
start,
end,
len(trade_dates),
summary.get("count", 0),
bool(persist),
)
def _print_audit_summary(summary) -> None:
LOGGER.info(
"因子审计 trade_date=%s mismatched=%s evaluated=%s missing_persisted=%s missing_recomputed=%s issues=%s",
summary.trade_date.isoformat(),
summary.mismatched,
summary.evaluated,
summary.missing_persisted,
summary.missing_recomputed,
len(summary.issues),
)
for issue in summary.issues:
LOGGER.warning(
"审计异常 ts_code=%s factor=%s stored=%s recomputed=%s diff=%s",
issue.ts_code,
issue.factor,
issue.stored,
issue.recomputed,
issue.difference,
)
if __name__ == "__main__":
main()

View File

@ -1,9 +1,27 @@
"""Pytest configuration shared across test modules."""
from __future__ import annotations
import sys
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.data.schema import initialize_database
from app.utils.config import DataPaths, get_config
@pytest.fixture()
def isolated_db(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
try:
initialize_database()
yield
finally:
cfg.data_paths = original_paths

56
tests/factor_utils.py Normal file
View File

@ -0,0 +1,56 @@
from __future__ import annotations
from datetime import date, timedelta
from app.data.schema import initialize_database
from app.utils.db import db_session
def populate_sample_data(ts_code: str, as_of: date, days: int = 60) -> None:
"""Populate ``daily`` 和 ``daily_basic`` 表用于测试。"""
initialize_database()
with db_session() as conn:
for offset in range(days):
current_day = as_of - timedelta(days=offset)
trade_date = current_day.strftime("%Y%m%d")
close = 100 + (days - 1 - offset)
turnover = 5 + 0.1 * (days - 1 - offset)
conn.execute(
"""
INSERT OR REPLACE INTO daily
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
close,
close,
close,
close,
0.0,
1_000.0,
1_000_000.0,
),
)
pe = 10.0 + (offset % 5)
pb = 1.5 + (offset % 3) * 0.1
ps = 2.0 + (offset % 4) * 0.1
volume_ratio = 0.5 + (offset % 4) * 0.5
conn.execute(
"""
INSERT OR REPLACE INTO daily_basic
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
turnover,
turnover,
volume_ratio,
pe,
pb,
ps,
),
)

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from datetime import date
from app.backtest.engine import BacktestEngine, BtConfig
def test_required_fields_include_precomputed_factors(isolated_db):
cfg = BtConfig(
id="bt-test",
name="bt-test",
start_date=date(2025, 1, 1),
end_date=date(2025, 1, 2),
universe=["000001.SZ"],
params={},
)
engine = BacktestEngine(cfg)
required = set(engine.required_fields)
expected_fields = {
"factors.mom_5",
"factors.turn_5",
"factors.val_pe_score",
"factors.val_pb_score",
"factors.volume_ratio_score",
"factors.val_multiscore",
"factors.risk_penalty",
"factors.sent_momentum",
"factors.sent_market",
"factors.sent_divergence",
}
assert expected_fields.issubset(required)

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from app.utils.data_access import DataBroker
from app.utils.data_quality import DataQualityResult, summarize_data_quality
def test_summarize_data_quality_produces_score():
results = [
DataQualityResult("check_a", "ERROR", "fatal issue"),
DataQualityResult("check_b", "WARN", "warning issue"),
DataQualityResult("check_c", "INFO", "info message"),
]
summary = summarize_data_quality(results, window_days=7)
assert summary.total_checks == 3
assert summary.severity_counts["ERROR"] == 1
assert summary.has_blockers is True
assert 0.0 <= summary.score < 100.0
def test_data_broker_evaluate_quality_runs_checks(isolated_db):
broker = DataBroker()
summary = broker.evaluate_data_quality(window_days=1)
assert 0.0 <= summary.score <= 100.0
assert summary.window_days == 1

View File

@ -0,0 +1,45 @@
from __future__ import annotations
from datetime import date
from app.features.factor_audit import audit_factors
from app.features.factors import compute_factors
from app.utils.db import db_session
from tests.factor_utils import populate_sample_data
def test_audit_matches_persisted_values(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 2, 14)
populate_sample_data(ts_code, trade_day)
compute_factors(trade_day)
summary = audit_factors(trade_day)
assert summary.mismatched == 0
assert summary.missing_persisted == 0
assert summary.missing_recomputed == 0
assert not summary.issues
def test_audit_detects_drift(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 2, 14)
populate_sample_data(ts_code, trade_day)
compute_factors(trade_day)
trade_date_str = trade_day.strftime("%Y%m%d")
with db_session() as conn:
conn.execute(
"UPDATE factors SET mom_5 = mom_5 + 0.05 WHERE ts_code = ? AND trade_date = ?",
(ts_code, trade_date_str),
)
summary = audit_factors(trade_day, factors=["mom_5"], tolerance=1e-8, max_issues=5)
assert summary.mismatched >= 1
assert summary.issues
first_issue = summary.issues[0]
assert first_issue.ts_code == ts_code
assert first_issue.factor == "mom_5"

View File

@ -0,0 +1,62 @@
from __future__ import annotations
from datetime import date, timedelta
from app.features.evaluation import (
FactorPerformance,
evaluate_factor_portfolio,
optimize_factor_weights,
)
from app.features.factors import FactorSpec, compute_factor_range
from tests.factor_utils import populate_sample_data
def _seed_factor_history(codes, end_day):
specs = [
FactorSpec("mom_5", 5),
FactorSpec("mom_20", 20),
FactorSpec("turn_20", 20),
]
start_day = end_day - timedelta(days=5)
for code in codes:
populate_sample_data(code, end_day, days=180)
compute_factor_range(start_day, end_day, ts_codes=codes, factors=specs)
return specs, start_day
def test_optimize_factor_weights_returns_normalized_vector(isolated_db):
codes = [f"0000{i:02d}.SZ" for i in range(1, 4)]
end_day = date(2025, 2, 28)
specs, start_day = _seed_factor_history(codes, end_day)
factor_names = [spec.name for spec in specs]
weights, performances = optimize_factor_weights(
factor_names,
start_day,
end_day,
universe=codes,
)
assert set(weights.keys()) == set(factor_names)
assert abs(sum(weights.values()) - 1.0) < 1e-6
for perf in performances.values():
assert isinstance(perf, FactorPerformance)
def test_evaluate_factor_portfolio_returns_report(isolated_db):
codes = [f"0000{i:02d}.SZ" for i in range(1, 4)]
end_day = date(2025, 3, 10)
specs, start_day = _seed_factor_history(codes, end_day)
factor_names = [spec.name for spec in specs]
report = evaluate_factor_portfolio(
factor_names,
start_day,
end_day,
universe=codes,
)
assert set(report.weights.keys()) == set(factor_names)
assert isinstance(report.combined, FactorPerformance)
assert report.combined.sample_size >= 0
assert set(report.components.keys()) == set(factor_names)

View File

@ -6,7 +6,6 @@ from datetime import date, timedelta
import pytest
from app.core.indicators import momentum, rolling_mean, volatility
from app.data.schema import initialize_database
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
@ -17,77 +16,15 @@ from app.features.factors import (
_valuation_score,
_volume_ratio_score,
)
from app.utils.config import DataPaths, get_config
from app.utils.data_access import DataBroker
from app.utils.db import db_session
@pytest.fixture()
def isolated_db(tmp_path):
cfg = get_config()
original_paths = cfg.data_paths
tmp_root = tmp_path / "data"
tmp_root.mkdir(parents=True, exist_ok=True)
cfg.data_paths = DataPaths(root=tmp_root)
try:
yield
finally:
cfg.data_paths = original_paths
def _populate_sample_data(ts_code: str, as_of: date) -> None:
initialize_database()
with db_session() as conn:
for offset in range(60):
current_day = as_of - timedelta(days=offset)
trade_date = current_day.strftime("%Y%m%d")
close = 100 + (59 - offset)
turnover = 5 + 0.1 * (59 - offset)
conn.execute(
"""
INSERT OR REPLACE INTO daily
(ts_code, trade_date, open, high, low, close, pct_chg, vol, amount)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
close,
close,
close,
close,
0.0,
1000.0,
1_000_000.0,
),
)
pe = 10.0 + (offset % 5)
pb = 1.5 + (offset % 3) * 0.1
ps = 2.0 + (offset % 4) * 0.1
volume_ratio = 0.5 + (offset % 4) * 0.5
conn.execute(
"""
INSERT OR REPLACE INTO daily_basic
(ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
trade_date,
turnover,
turnover,
volume_ratio,
pe,
pb,
ps,
),
)
from tests.factor_utils import populate_sample_data
def test_compute_factors_persists_and_updates(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 1, 30)
_populate_sample_data(ts_code, trade_day)
populate_sample_data(ts_code, trade_day)
specs = [
*DEFAULT_FACTORS,
@ -177,29 +114,57 @@ def test_compute_factors_persists_and_updates(isolated_db):
def test_compute_factors_skip_existing(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 2, 10)
_populate_sample_data(ts_code, trade_day)
populate_sample_data(ts_code, trade_day)
compute_factors(trade_day)
skipped = compute_factors(trade_day, skip_existing=True)
basic_specs = [
FactorSpec("mom_5", 5),
FactorSpec("mom_20", 20),
FactorSpec("volat_20", 20),
FactorSpec("turn_5", 5),
]
compute_factors(trade_day, basic_specs)
skipped = compute_factors(trade_day, basic_specs, skip_existing=True)
assert skipped == []
def test_compute_factors_dry_run(isolated_db):
ts_code = "000001.SZ"
trade_day = date(2025, 2, 12)
populate_sample_data(ts_code, trade_day)
results = compute_factors(trade_day, persist=False)
assert results
trade_date_str = trade_day.strftime("%Y%m%d")
with db_session(read_only=True) as conn:
count = conn.execute(
"SELECT COUNT(*) AS cnt FROM factors WHERE trade_date = ?",
(trade_date_str,),
).fetchone()
assert count["cnt"] == 0
def test_compute_factors_incremental(isolated_db):
ts_code = "000001.SZ"
latest_day = date(2025, 2, 10)
_populate_sample_data(ts_code, latest_day)
populate_sample_data(ts_code, latest_day, days=180)
first_day = latest_day - timedelta(days=5)
compute_factors(first_day)
first_day = latest_day - timedelta(days=1)
basic_specs = [
FactorSpec("mom_5", 5),
FactorSpec("mom_20", 20),
FactorSpec("turn_20", 20),
]
compute_factors(first_day, basic_specs)
summary = compute_factors_incremental(max_trading_days=3)
summary = compute_factors_incremental(factors=basic_specs, max_trading_days=3)
trade_dates = summary["trade_dates"]
assert trade_dates
assert trade_dates[0] > first_day
assert summary["count"] > 0
# No new dates should return empty result
summary_again = compute_factors_incremental(max_trading_days=3)
summary_again = compute_factors_incremental(factors=basic_specs, max_trading_days=3)
assert summary_again["count"] == 0
@ -209,10 +174,16 @@ def test_compute_factor_range_filters_universe(isolated_db):
end_day = date(2025, 3, 5)
start_day = end_day - timedelta(days=1)
_populate_sample_data(code_a, end_day)
_populate_sample_data(code_b, end_day)
populate_sample_data(code_a, end_day)
populate_sample_data(code_b, end_day)
results = compute_factor_range(start_day, end_day, ts_codes=[code_a])
basic_specs = [
FactorSpec("mom_5", 5),
FactorSpec("mom_20", 20),
FactorSpec("turn_20", 20),
]
results = compute_factor_range(start_day, end_day, ts_codes=[code_a], factors=basic_specs)
assert results
assert {result.ts_code for result in results} == {code_a}
@ -220,71 +191,39 @@ def test_compute_factor_range_filters_universe(isolated_db):
rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall()
assert {row["ts_code"] for row in rows} == {code_a}
repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a])
repeated = compute_factor_range(
start_day,
end_day,
ts_codes=[code_a],
factors=basic_specs,
skip_existing=True,
)
assert repeated == []
def test_compute_extended_factors(isolated_db):
"""Test computation of extended factors."""
# Use the existing _populate_sample_data function
from app.utils.data_access import DataBroker
broker = DataBroker()
# Sample data for 5 trading days
dates = ["20240101", "20240102", "20240103", "20240104", "20240105"]
ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"]
# Populate daily data
for ts_code in ts_codes:
for i, trade_date in enumerate(dates):
broker.insert_or_update_daily(
ts_code,
trade_date,
open_price=10.0 + i * 0.1,
high=10.5 + i * 0.1,
low=9.5 + i * 0.1,
close=10.0 + i * 0.2, # 上涨趋势
pre_close=10.0 + (i - 1) * 0.2 if i > 0 else 10.0,
vol=100000 + i * 10000,
amount=1000000 + i * 100000,
)
broker.insert_or_update_daily_basic(
ts_code,
trade_date,
close=10.0 + i * 0.2,
turnover_rate=1.0 + i * 0.1,
turnover_rate_f=1.0 + i * 0.1,
volume_ratio=1.0 + (i % 3) * 0.2, # 在0.8-1.2之间变化
pe=15.0 + (i % 3) * 2, # 在15-19之间变化
pe_ttm=15.0 + (i % 3) * 2,
pb=1.5 + (i % 3) * 0.1, # 在1.5-1.7之间变化
ps=3.0 + (i % 3) * 0.2, # 在3.0-3.4之间变化
ps_ttm=3.0 + (i % 3) * 0.2,
dv_ratio=2.0 + (i % 3) * 0.1, # 股息率
total_mv=1000000 + i * 100000,
circ_mv=800000 + i * 80000,
)
# Compute factors with extended factors
"""Extended factors should be persisted alongside base factors."""
from app.features.extended_factors import EXTENDED_FACTORS
trade_day = date(2025, 2, 28)
ts_codes = ["000001.SZ", "000002.SZ"]
for code in ts_codes:
populate_sample_data(code, trade_day, days=120)
all_factors = list(DEFAULT_FACTORS) + EXTENDED_FACTORS
trade_day = date(2024, 1, 5)
results = compute_factors(trade_day, all_factors)
# Verify that we got results
assert results
# Verify that extended factors are computed
result_map = {result.ts_code: result for result in results}
ts_code = "000001.SZ"
assert ts_code in result_map
result = result_map[ts_code]
# Check that extended factors are present in the results
extended_factor_names = [spec.name for spec in EXTENDED_FACTORS]
for factor_name in extended_factor_names:
assert factor_name in result.values
# Values should not be None
assert result.values[factor_name] is not None
for code in ts_codes:
assert code in result_map
factor_payload = result_map[code].values
required_extended = {
"tech_rsi_14",
"tech_macd_signal",
"trend_ma_cross",
"micro_trade_imbalance",
}
assert required_extended.issubset(factor_payload.keys())
for name in required_extended:
assert factor_payload.get(name) is not None

50
tests/test_news_ingest.py Normal file
View File

@ -0,0 +1,50 @@
from __future__ import annotations
import json
from datetime import datetime
from app.ingest import entity_recognition
from app.ingest.rss import RssItem, save_news_items
from app.utils.db import db_session
def test_save_news_items_persists_entities_and_heat(isolated_db):
# Reset mapping state
entity_recognition._COMPANY_MAPPING_INITIALIZED = False
mapper = entity_recognition.company_mapper
mapper.name_to_code.clear()
mapper.short_names.clear()
mapper.aliases.clear()
ts_code = "000001.SZ"
mapper.add_company(ts_code, "平安银行股份有限公司", "平安银行")
item = RssItem(
id="news-1",
title="平安银行利好消息爆发",
link="https://example.com/news",
published=datetime.utcnow(),
summary="平安银行股份有限公司公布季度业绩,银行板块再迎利好。",
source="TestWire",
)
item.extract_entities()
assert ts_code in item.ts_codes
saved = save_news_items([item])
assert saved == 1
with db_session(read_only=True) as conn:
row = conn.execute(
"SELECT heat, entities FROM news WHERE ts_code = ? ORDER BY pub_time DESC LIMIT 1",
(ts_code,),
).fetchone()
assert row is not None
assert 0.0 <= row["heat"] <= 1.0
assert row["heat"] > 0.6
entities_payload = json.loads(row["entities"])
assert ts_code in entities_payload.get("ts_codes", [])
assert "industries" in entities_payload
assert "important_keywords" in entities_payload