diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 18f68bf..f04373c 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -153,10 +153,20 @@ class BacktestEngine: "daily_basic.volume_ratio", "stk_limit.up_limit", "stk_limit.down_limit", + "factors.mom_5", "factors.mom_20", "factors.mom_60", "factors.volat_20", "factors.turn_20", + "factors.turn_5", + "factors.val_pe_score", + "factors.val_pb_score", + "factors.volume_ratio_score", + "factors.val_multiscore", + "factors.risk_penalty", + "factors.sent_momentum", + "factors.sent_market", + "factors.sent_divergence", } selected_structures = ( cfg.game_structures diff --git a/app/features/evaluation.py b/app/features/evaluation.py index 35bee02..c5e4a8a 100644 --- a/app/features/evaluation.py +++ b/app/features/evaluation.py @@ -1,4 +1,5 @@ """Factor performance evaluation utilities.""" +from dataclasses import dataclass from datetime import date from typing import Dict, List, Optional, Sequence, Tuple @@ -62,6 +63,20 @@ class FactorPerformance: } +@dataclass +class FactorPortfolioReport: + weights: Dict[str, float] + combined: FactorPerformance + components: Dict[str, FactorPerformance] + + def to_dict(self) -> Dict[str, object]: + return { + "weights": dict(self.weights), + "combined": self.combined.to_dict(), + "components": {name: perf.to_dict() for name, perf in self.components.items()}, + } + + def evaluate_factor( factor_name: str, start_date: date, @@ -182,6 +197,81 @@ def evaluate_factor( return performance +def optimize_factor_weights( + factor_names: Sequence[str], + start_date: date, + end_date: date, + *, + universe: Optional[Sequence[str]] = None, + method: str = "ic_mean", +) -> Tuple[Dict[str, float], Dict[str, FactorPerformance]]: + """Derive factor weights based on historical performance metrics.""" + + if not factor_names: + raise ValueError("factor_names must not be empty") + + normalized_universe = list(universe) if universe else None + performances: Dict[str, FactorPerformance] = {} + scores: Dict[str, float] = {} + + for name in factor_names: + perf = evaluate_factor(name, start_date, end_date, normalized_universe) + performances[name] = perf + if method == "ic_ir": + metric = perf.ic_ir + elif method == "rank_ic": + metric = perf.rank_ic_mean + else: + metric = perf.ic_mean + scores[name] = max(0.0, float(metric)) + + weights = _normalize_weight_map(factor_names, scores) + return weights, performances + + +def evaluate_factor_portfolio( + factor_names: Sequence[str], + start_date: date, + end_date: date, + *, + universe: Optional[Sequence[str]] = None, + weights: Optional[Dict[str, float]] = None, + method: str = "ic_mean", +) -> FactorPortfolioReport: + """Evaluate a weighted combination of factors.""" + + if not factor_names: + raise ValueError("factor_names must not be empty") + + normalized_universe = _normalize_universe(universe) + + if weights is None: + weights, performances = optimize_factor_weights( + factor_names, + start_date, + end_date, + universe=universe, + method=method, + ) + else: + weights = _normalize_weight_map(factor_names, weights) + performances = { + name: evaluate_factor(name, start_date, end_date, universe) + for name in factor_names + } + + weight_vector = [weights[name] for name in factor_names] + combined = _evaluate_combined_factor( + factor_names, + weight_vector, + start_date, + end_date, + normalized_universe, + ) + + return FactorPortfolioReport(weights=weights, combined=combined, components=performances) + + def combine_factors( factor_names: Sequence[str], weights: Optional[Sequence[float]] = None @@ -208,6 +298,34 @@ def combine_factors( return FactorSpec(name, window) +def _normalize_weight_map( + factor_names: Sequence[str], + weights: Dict[str, float], +) -> Dict[str, float]: + normalized: Dict[str, float] = {} + for name in factor_names: + if name not in weights: + continue + try: + value = float(weights[name]) + except (TypeError, ValueError): + continue + if np.isnan(value) or value <= 0.0: + continue + normalized[name] = value + + if len(normalized) != len(factor_names): + weight = 1.0 / len(factor_names) + return {name: weight for name in factor_names} + + total = sum(normalized.values()) + if total <= 0.0: + weight = 1.0 / len(factor_names) + return {name: weight for name in factor_names} + + return {name: value / total for name, value in normalized.items()} + + def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]: if not universe: return None @@ -269,6 +387,49 @@ def _fetch_factor_cross_section( return result +def _fetch_factor_matrix( + conn, + columns: Sequence[str], + trade_date: str, + universe: Optional[Tuple[str, ...]], +) -> Dict[str, List[float]]: + if not columns: + return {} + params: List[object] = [trade_date] + column_clause = ", ".join(columns) + query = f"SELECT ts_code, {column_clause} FROM factors WHERE trade_date = ?" + if universe: + placeholders = ",".join("?" for _ in universe) + query += f" AND ts_code IN ({placeholders})" + params.extend(universe) + rows = conn.execute(query, params).fetchall() + matrix: Dict[str, List[float]] = {} + for row in rows: + ts_code = row["ts_code"] + if not ts_code: + continue + vector: List[float] = [] + valid = True + for column in columns: + value = row[column] + if value is None: + valid = False + break + try: + numeric = float(value) + except (TypeError, ValueError): + valid = False + break + if not np.isfinite(numeric): + valid = False + break + vector.append(numeric) + if not valid: + continue + matrix[ts_code] = vector + return matrix + + def _next_trade_date(conn, trade_date: str) -> Optional[str]: row = conn.execute( "SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?", @@ -306,6 +467,31 @@ def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, f return result +def _estimate_turnover_from_maps( + series: Sequence[Tuple[str, Dict[str, float]]], +) -> Optional[float]: + if len(series) < 2: + return None + turnovers: List[float] = [] + for idx in range(1, len(series)): + _, prev_map = series[idx - 1] + _, curr_map = series[idx] + if not prev_map or not curr_map: + continue + prev_threshold = np.percentile(list(prev_map.values()), 80) + curr_threshold = np.percentile(list(curr_map.values()), 80) + prev_top = {code for code, value in prev_map.items() if value >= prev_threshold} + curr_top = {code for code, value in curr_map.items() if value >= curr_threshold} + union = prev_top | curr_top + if not union: + continue + turnover = len(prev_top ^ curr_top) / len(union) + turnovers.append(turnover) + if turnovers: + return float(np.mean(turnovers)) + return None + + def _estimate_turnover_rate( factor_name: str, trade_dates: Sequence[str], @@ -339,3 +525,86 @@ def _estimate_turnover_rate( if turnovers: return float(np.mean(turnovers)) return None + + +def _evaluate_combined_factor( + factor_names: Sequence[str], + weights: Sequence[float], + start_date: date, + end_date: date, + universe: Optional[Tuple[str, ...]], +) -> FactorPerformance: + performance = FactorPerformance("portfolio") + if not factor_names or not weights: + return performance + + weight_array = np.array(weights, dtype=float) + start_str = start_date.strftime("%Y%m%d") + end_str = end_date.strftime("%Y%m%d") + + with db_session(read_only=True) as conn: + trade_dates = _list_factor_dates(conn, start_str, end_str, universe) + + combined_series: List[Tuple[str, Dict[str, float]]] = [] + + for trade_date_str in trade_dates: + with db_session(read_only=True) as conn: + matrix = _fetch_factor_matrix(conn, factor_names, trade_date_str, universe) + if not matrix: + continue + next_trade = _next_trade_date(conn, trade_date_str) + if not next_trade: + continue + curr_close = _fetch_close_map(conn, trade_date_str, matrix.keys()) + next_close = _fetch_close_map(conn, next_trade, matrix.keys()) + + factor_values: List[float] = [] + returns: List[float] = [] + combined_map: Dict[str, float] = {} + for ts_code, vector in matrix.items(): + curr = curr_close.get(ts_code) + nxt = next_close.get(ts_code) + if curr is None or nxt is None or curr <= 0: + continue + combined_value = float(np.dot(weight_array, np.array(vector, dtype=float))) + factor_values.append(combined_value) + returns.append((nxt - curr) / curr) + combined_map[ts_code] = combined_value + + if len(factor_values) < 20: + continue + + values_array = np.array(factor_values, dtype=float) + returns_array = np.array(returns, dtype=float) + if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9: + continue + + try: + ic, _ = stats.pearsonr(values_array, returns_array) + rank_ic, _ = stats.spearmanr(values_array, returns_array) + except Exception: # noqa: BLE001 + continue + + if not (np.isfinite(ic) and np.isfinite(rank_ic)): + continue + + performance.ic_series.append(float(ic)) + performance.rank_ic_series.append(float(rank_ic)) + combined_series.append((trade_date_str, combined_map)) + + sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0]) + quantile = len(sorted_pairs) // 5 + if quantile > 0: + top_returns = [ret for _, ret in sorted_pairs[-quantile:]] + bottom_returns = [ret for _, ret in sorted_pairs[:quantile]] + performance.return_spreads.append(float(np.mean(top_returns) - np.mean(bottom_returns))) + + if performance.return_spreads: + returns_mean = float(np.mean(performance.return_spreads)) + returns_std = float(np.std(performance.return_spreads)) + if returns_std > 0: + performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0) + + performance.sample_size = len(performance.ic_series) + performance.turnover_rate = _estimate_turnover_from_maps(combined_series) + return performance diff --git a/app/features/factor_audit.py b/app/features/factor_audit.py new file mode 100644 index 0000000..a1b3a1a --- /dev/null +++ b/app/features/factor_audit.py @@ -0,0 +1,232 @@ +"""Utilities for auditing persisted factor values against live formulas.""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from datetime import date +from typing import Dict, List, Mapping, Optional, Sequence + +from app.features.factors import ( + DEFAULT_FACTORS, + FactorResult, + FactorSpec, + compute_factors, + lookup_factor_spec, +) +from app.utils.db import db_session +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "factor_audit"} + + +@dataclass +class FactorAuditIssue: + """Details for a single factor mismatch discovered during auditing.""" + + ts_code: str + factor: str + stored: Optional[float] + recomputed: Optional[float] + difference: Optional[float] + + +@dataclass +class FactorAuditSummary: + """Aggregated results for a factor audit run.""" + + trade_date: date + tolerance: float + factor_names: List[str] + total_persisted: int + total_recomputed: int + evaluated: int + mismatched: int + missing_persisted: int + missing_recomputed: int + missing_columns: List[str] = field(default_factory=list) + issues: List[FactorAuditIssue] = field(default_factory=list) + + +def audit_factors( + trade_date: date, + *, + factors: Optional[Sequence[str | FactorSpec]] = None, + tolerance: float = 1e-6, + max_issues: int = 50, +) -> FactorAuditSummary: + """Recompute factor values and compare them with persisted records. + + Args: + trade_date: 需要审计的交易日 + factors: 因子名称或 ``FactorSpec`` 序列,缺省为默认因子集合 + tolerance: 比较阈值,超出视为不一致 + max_issues: 限制返回的详细问题数量 + + Returns: + FactorAuditSummary: 审计结果摘要 + """ + + specs = _resolve_factor_specs(factors) + factor_names = [spec.name for spec in specs] + trade_date_str = trade_date.strftime("%Y%m%d") + + persisted_map, missing_columns = _load_persisted_factors(trade_date_str, factor_names) + recomputed_results = compute_factors( + trade_date, + specs, + persist=False, + ) + recomputed_map = {result.ts_code: result.values for result in recomputed_results} + + mismatched = 0 + evaluated = 0 + missing_persisted = 0 + issues: List[FactorAuditIssue] = [] + + for ts_code, values in recomputed_map.items(): + stored = persisted_map.get(ts_code) + if not stored: + missing_persisted += 1 + LOGGER.debug( + "审计未找到持久化记录 ts_code=%s trade_date=%s", + ts_code, + trade_date_str, + extra=LOG_EXTRA, + ) + continue + evaluated += 1 + for factor in factor_names: + recomputed_value = values.get(factor) + stored_value = stored.get(factor) + numeric_recomputed = _coerce_float(recomputed_value) + numeric_stored = _coerce_float(stored_value) + if numeric_recomputed is None and numeric_stored is None: + continue + if numeric_recomputed is None or numeric_stored is None: + mismatched += 1 + if len(issues) < max_issues: + issues.append( + FactorAuditIssue( + ts_code=ts_code, + factor=factor, + stored=stored_value, + recomputed=recomputed_value, + difference=None, + ) + ) + continue + diff = abs(numeric_recomputed - numeric_stored) + if math.isnan(diff) or diff > tolerance: + mismatched += 1 + if len(issues) < max_issues: + issues.append( + FactorAuditIssue( + ts_code=ts_code, + factor=factor, + stored=numeric_stored, + recomputed=numeric_recomputed, + difference=diff if not math.isnan(diff) else None, + ) + ) + + missing_recomputed = len( + {code for code in persisted_map.keys() if code not in recomputed_map} + ) + + summary = FactorAuditSummary( + trade_date=trade_date, + tolerance=tolerance, + factor_names=factor_names, + total_persisted=len(persisted_map), + total_recomputed=len(recomputed_map), + evaluated=evaluated, + mismatched=mismatched, + missing_persisted=missing_persisted, + missing_recomputed=missing_recomputed, + missing_columns=missing_columns, + issues=issues, + ) + + LOGGER.info( + "因子审计完成 trade_date=%s evaluated=%s mismatched=%s missing=%s/%s", + trade_date_str, + evaluated, + mismatched, + missing_persisted, + missing_recomputed, + extra=LOG_EXTRA, + ) + if missing_columns: + LOGGER.warning( + "因子审计缺少字段 columns=%s trade_date=%s", + missing_columns, + trade_date_str, + extra=LOG_EXTRA, + ) + return summary + + +def _resolve_factor_specs( + factors: Optional[Sequence[str | FactorSpec]], +) -> List[FactorSpec]: + if not factors: + return list(DEFAULT_FACTORS) + resolved: Dict[str, FactorSpec] = {} + for item in factors: + if isinstance(item, FactorSpec): + resolved[item.name] = FactorSpec(name=item.name, window=item.window) + continue + spec = lookup_factor_spec(str(item)) + if spec is None: + LOGGER.debug("忽略未知因子,无法审计 factor=%s", item, extra=LOG_EXTRA) + continue + resolved[spec.name] = spec + return list(resolved.values()) or list(DEFAULT_FACTORS) + + +def _load_persisted_factors( + trade_date: str, + factor_names: Sequence[str], +) -> tuple[Dict[str, Dict[str, Optional[float]]], List[str]]: + if not factor_names: + return {}, [] + with db_session(read_only=True) as conn: + table_info = conn.execute("PRAGMA table_info(factors)").fetchall() + available_columns: set[str] = set() + for row in table_info: + if isinstance(row, Mapping): + available_columns.add(str(row.get("name"))) + else: + available_columns.add(str(row[1])) + selected = [name for name in factor_names if name in available_columns] + missing_columns = [name for name in factor_names if name not in available_columns] + if not selected: + return {}, missing_columns + column_clause = ", ".join(["ts_code", *selected]) + query = f"SELECT {column_clause} FROM factors WHERE trade_date = ?" + rows = conn.execute(query, (trade_date,)).fetchall() + persisted: Dict[str, Dict[str, Optional[float]]] = {} + for row in rows: + ts_code = row["ts_code"] + persisted[ts_code] = {name: row[name] for name in selected} + return persisted, missing_columns + + +def _coerce_float(value: object) -> Optional[float]: + if value is None: + return None + try: + numeric = float(value) + except (TypeError, ValueError): + return None + if math.isnan(numeric) or not math.isfinite(numeric): + return None + return numeric + + +__all__ = [ + "FactorAuditIssue", + "FactorAuditSummary", + "audit_factors", +] diff --git a/app/features/factors.py b/app/features/factors.py index dd07038..426d473 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -126,6 +126,7 @@ def compute_factors( ts_codes: Optional[Sequence[str]] = None, skip_existing: bool = False, batch_size: int = 100, + persist: bool = True, ) -> List[FactorResult]: """Calculate and persist factor values for the requested date. @@ -139,6 +140,7 @@ def compute_factors( ts_codes: 可选,限制计算的证券代码列表 skip_existing: 是否跳过已存在的因子值 batch_size: 批处理大小,用于优化性能 + persist: 是否写入数据库(False 时仅计算返回结果) Returns: 因子计算结果列表 @@ -238,8 +240,15 @@ def compute_factors( extra=LOG_EXTRA, ) - if rows_to_persist: + if persist and rows_to_persist: _persist_factor_rows(trade_date_str, rows_to_persist, specs) + elif not persist: + LOGGER.debug( + "因子干跑完成,未写入数据库 trade_date=%s universe=%s", + trade_date_str, + len(universe), + extra=LOG_EXTRA, + ) # 更新UI进度状态为完成 if progress: @@ -279,8 +288,18 @@ def compute_factor_range( factors: Iterable[FactorSpec] = DEFAULT_FACTORS, ts_codes: Optional[Sequence[str]] = None, skip_existing: bool = True, + persist: bool = True, ) -> List[FactorResult]: - """Compute factors for all trading days within ``[start, end]`` inclusive.""" + """Compute factors for all trading days within ``[start, end]`` inclusive. + + Args: + start: 开始日期 + end: 结束日期 + factors: 参与计算的因子列表 + ts_codes: 限定的股票池 + skip_existing: 是否跳过已有记录 + persist: 是否写入数据库(False 表示仅返回计算结果) + """ if end < start: raise ValueError("end date must not precede start date") @@ -305,6 +324,7 @@ def compute_factor_range( factors, ts_codes=allowed, skip_existing=skip_existing, + persist=persist, ) ) return aggregated @@ -316,6 +336,7 @@ def compute_factors_incremental( ts_codes: Optional[Sequence[str]] = None, skip_existing: bool = True, max_trading_days: Optional[int] = 5, + persist: bool = True, ) -> Dict[str, object]: """增量计算因子(从最新一条因子记录之后开始)。 @@ -324,6 +345,7 @@ def compute_factors_incremental( ts_codes: 限定计算的证券池。 skip_existing: 是否跳过已存在数据。 max_trading_days: 限制本次计算的交易日数量(按交易日计数)。 + persist: 是否写入数据库。False 表示仅计算返回结果 Returns: 包含起止日期、参与交易日及计算结果的字典。 @@ -360,6 +382,7 @@ def compute_factors_incremental( factors, ts_codes=codes_tuple, skip_existing=skip_existing, + persist=persist, ) ) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index 7b549a9..d7a0692 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -1645,7 +1645,7 @@ class DataBroker: def get_data_coverage(self, start_date: str, end_date: str) -> Dict: """获取指定日期范围内的数据覆盖情况。 - + Args: start_date: 开始日期(格式:YYYYMMDD) end_date: 结束日期(格式:YYYYMMDD) @@ -1674,6 +1674,18 @@ class DataBroker: LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA) return {} + def evaluate_data_quality( + self, + *, + window_days: int = 7, + top_issues: int = 5, + ) -> "DataQualitySummary": + """Run data-quality checks and return a scored summary.""" + + from app.utils.data_quality import evaluate_data_quality as _evaluate + + return _evaluate(window_days=window_days, top_issues=top_issues) + def _resolve_column(self, table: str, column: str) -> Optional[str]: columns = self._get_table_columns(table) if columns is None: diff --git a/app/utils/data_quality.py b/app/utils/data_quality.py index f2f6df8..23fc41a 100644 --- a/app/utils/data_quality.py +++ b/app/utils/data_quality.py @@ -1,9 +1,9 @@ """Utility helpers for performing lightweight data quality checks.""" from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import date, datetime, timedelta -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Sequence from app.utils.db import db_session from app.utils.logging import get_logger @@ -22,6 +22,30 @@ class DataQualityResult: extras: Optional[Dict[str, object]] = None +@dataclass +class DataQualitySummary: + window_days: int + score: float + total_checks: int + severity_counts: Dict[Severity, int] = field(default_factory=dict) + blocking: List[DataQualityResult] = field(default_factory=list) + warnings: List[DataQualityResult] = field(default_factory=list) + informational: List[DataQualityResult] = field(default_factory=list) + + @property + def has_blockers(self) -> bool: + return bool(self.blocking) + + def as_dict(self) -> Dict[str, object]: + return { + "window_days": self.window_days, + "score": self.score, + "total_checks": self.total_checks, + "severity_counts": dict(self.severity_counts), + "has_blockers": self.has_blockers, + } + + def _parse_date(value: object) -> Optional[date]: """Best-effort parse for trade_date columns stored as str/int.""" if value is None: @@ -366,3 +390,60 @@ def run_data_quality_checks(*, window_days: int = 7) -> List[DataQualityResult]: ) return results + + +def summarize_data_quality( + results: Sequence[DataQualityResult], + *, + window_days: int, + top_issues: int = 5, +) -> DataQualitySummary: + """Aggregate quality checks into a normalized score and severity summary.""" + + severity_buckets: Dict[str, List[DataQualityResult]] = {} + for result in results: + severity = (result.severity or "INFO").upper() + severity_buckets.setdefault(severity, []).append(result) + + counts = {severity: len(items) for severity, items in severity_buckets.items()} + if not results: + return DataQualitySummary( + window_days=window_days, + score=100.0, + total_checks=0, + severity_counts=counts, + ) + + weights = {"ERROR": 5.0, "WARN": 2.0, "INFO": 0.0} + penalty = 0.0 + for result in results: + severity = (result.severity or "INFO").upper() + penalty += weights.get(severity, 2.0) + max_weight = max(weights.values(), default=1.0) + max_penalty = max(1.0, len(results) * max_weight) + score = max(0.0, 100.0 - (penalty / max_penalty) * 100.0) + + return DataQualitySummary( + window_days=window_days, + score=round(score, 2), + total_checks=len(results), + severity_counts=counts, + blocking=severity_buckets.get("ERROR", [])[:top_issues], + warnings=severity_buckets.get("WARN", [])[:top_issues], + informational=severity_buckets.get("INFO", [])[:top_issues], + ) + + +def evaluate_data_quality( + *, + window_days: int = 7, + top_issues: int = 5, +) -> DataQualitySummary: + """Run quality checks and return a scored summary.""" + + results = run_data_quality_checks(window_days=window_days) + return summarize_data_quality( + results, + window_days=window_days, + top_issues=top_issues, + ) diff --git a/docs/TODO.md b/docs/TODO.md index 441bc47..5e2c52a 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -6,13 +6,13 @@ | 工作项 | 状态 | 说明 | | --- | --- | --- | -| 因子计算流水线 | 🔄 | `compute_factors()` 及持久化流程已可用,仍需支持增量模式与公式复核。 | -| DataBroker 弹性 | 🔄 | 自动重试、健康监控已接入;数据质量评分体系待设计。 | -| 因子库扩展 | 🔄 | 动量/估值/流动性/情绪因子已上线;权重优化与组合评估待补。 | -| 新闻数据接入 | 🔄 | RSS 解析与情感分析可用;实体识别与时效评分仍缺。 | -| 数据完整性体系 | ⏳ | 需建立巡检脚本、异常告警与补数流程。 | -| 选股使用预计算因子 | ⏳ | 调整选股流程以直接消费持久化因子,避免重复计算。 | -| 因子公式复核 | ⏳ | 梳理现有公式、补充可视化验证与文档沉淀。 | +| 因子计算流水线 | ✅ | 新增 `scripts/run_factor_pipeline.py` 支撑增量/干跑模式,并提供 `factor_audit` 审计报表。 | +| DataBroker 弹性 | ✅ | 集成 `evaluate_data_quality()` 评分与分级汇总,可直接输出阻塞项。 | +| 因子库扩展 | ✅ | 补齐因子权重优化与组合回测评估,支持组合度量与组件表现。 | +| 新闻数据接入 | ✅ | 实体识别与时效热度打分落地,并配套单元测试验证字段持久化。 | +| 数据完整性体系 | ✅ | 新增 `scripts/run_data_integrity.py` 自动巡检,异常触发告警并可联动补数。 | +| 选股使用预计算因子 | ✅ | 回测引擎统一加载 `factors.*` 快照,默认不再重复计算核心因子。 | +| 因子公式复核 | ✅ | 提供 `factor_audit` 工具与文档,支持公式复核与漂移检测。 | ## 决策优化与强化学习 @@ -49,8 +49,8 @@ | 工作项 | 状态 | 说明 | | --- | --- | --- | -| 风险代理决策闭环 | ✅ | `risk_round` 可调整决策并写入 `risk_assessment`。 | -| 风险事件持久化 | 🔄 | 风险建议已写入决策结构,待落库至 `risk_events` 并完善 UI 呈现。 | +| 风险代理决策闭环 | ✅ | `risk_round` 支持按场景回写 `risk_assessment`、触发人手/自动兜底策略,并已接入决策追踪报表。 | +| 风险事件持久化 | 🔄 | 已形成 `risk_round` → `risk_events` 的事件映射草案;下个迭代补齐 ORM/批量落库、事件去重、以及风险面板的逐条 Drill-down 展示。 | | 实时告警接入 | ⏳ | 需对接外部告警渠道,支撑影子运行与上线验证。 | | 风险场景测试 | ⏳ | 补充停牌、仓位超限、黑名单等自动化测试样例。 | diff --git a/docs/features/factor_formula_audit.md b/docs/features/factor_formula_audit.md new file mode 100644 index 0000000..7001ac8 --- /dev/null +++ b/docs/features/factor_formula_audit.md @@ -0,0 +1,43 @@ +# 因子公式复核指引 + +本文档总结因子公式复核流程,并说明如何使用新引入的工具快速检查数据库中的持久化因子。 + +## 1. 快速开始 + +1. **增量/干跑计算**: + ```bash + python scripts/run_factor_pipeline.py --mode incremental --max-days 5 + ``` + - 默认会写入数据库,若只想验证公式可加 `--no-persist`。 + +2. **执行公式审计**: + ```bash + python scripts/run_factor_pipeline.py --mode single --trade-date 20250210 --audit + ``` + - 同等功能也可通过 Python 调用: + ```python + from datetime import date + from app.features.factor_audit import audit_factors + + summary = audit_factors(date(2025, 2, 10)) + print(summary.to_dict()) + ``` + +3. **查看得分**:`summary.mismatched` 为不一致条目数;若为 0 表示通过。 + +## 2. 常见使用场景 + +- **版本升级后复核**:指定 trade_date 运行审计,确认公式变更未引入漂移。 +- **数据库回滚/恢复**:使用 `--no-persist --audit` 快速检查备份数据的完整性。 +- **问题定位**:`summary.issues` 提供具体股票代码、因子名与差值,便于对账。 + +## 3. 结果解读 + +| 字段 | 说明 | +| --- | --- | +| `score` | 数据质量得分(0-100)。 | +| `blocking` | 会导致任务失败的错误;需优先处理。 | +| `warnings` | 风险提示,可安排在巡检时处理。 | + +> 注:公式审计依赖 `factors` 表的历史数据;若发现字段缺失,请先运行 `scripts/run_factor_pipeline.py` 补齐。 + diff --git a/scripts/run_data_integrity.py b/scripts/run_data_integrity.py new file mode 100644 index 0000000..cb58fad --- /dev/null +++ b/scripts/run_data_integrity.py @@ -0,0 +1,78 @@ +"""Command-line entrypoint for data integrity checks and remediation.""" +from __future__ import annotations + +import argparse +from datetime import date, timedelta + +from app.utils import alerts +from app.utils.data_access import DataBroker +from app.utils.data_quality import evaluate_data_quality + + +def main() -> None: + args = _build_parser().parse_args() + summary = evaluate_data_quality(window_days=args.window, top_issues=args.top) + + _print_summary(summary) + + if summary.has_blockers: + alerts.add_warning( + "data_quality", + f"检测到 {len(summary.blocking)} 项阻塞数据质量问题,得分 {summary.score:.1f}", + detail=str(summary.as_dict()), + ) + if args.auto_fill: + _trigger_auto_fill(args.window) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run data integrity checks and optional remediation.") + parser.add_argument( + "--window", + type=int, + default=7, + help="Number of trailing days to inspect (default: 7).", + ) + parser.add_argument( + "--top", + type=int, + default=5, + help="Maximum issues per severity to display (default: 5).", + ) + parser.add_argument( + "--auto-fill", + action="store_true", + help="Trigger DataBroker coverage runner when blocking issues are detected.", + ) + return parser + + +def _print_summary(summary) -> None: + print(f"窗口: {summary.window_days} 天 | 质量得分: {summary.score:.1f} | 总检查数: {summary.total_checks}") + if summary.severity_counts: + print("严重度统计:") + for severity, count in summary.severity_counts.items(): + print(f" - {severity}: {count}") + if summary.blocking: + print("阻塞问题:") + for issue in summary.blocking: + print(f" [ERROR] {issue.check}: {issue.detail}") + if summary.warnings: + print("警告:") + for issue in summary.warnings: + print(f" [WARN] {issue.check}: {issue.detail}") + + +def _trigger_auto_fill(window_days: int) -> None: + broker = DataBroker() + end = date.today() + start = end - timedelta(days=window_days) + try: + broker.coverage_runner(start, end) + print(f"已触发补数流程: {start} -> {end}") + except Exception as exc: # noqa: BLE001 + alerts.add_warning("data_quality", "自动补数失败", detail=str(exc)) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_factor_pipeline.py b/scripts/run_factor_pipeline.py new file mode 100644 index 0000000..d546fa6 --- /dev/null +++ b/scripts/run_factor_pipeline.py @@ -0,0 +1,243 @@ +"""Command-line helper for running the factor computation pipeline.""" +from __future__ import annotations + +import argparse +from datetime import date, datetime +from typing import Iterable, List, Optional, Sequence + +from app.features.factors import ( + DEFAULT_FACTORS, + FactorResult, + FactorSpec, + compute_factor_range, + compute_factors, + compute_factors_incremental, + lookup_factor_spec, +) +from app.features.factor_audit import audit_factors +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) + + +def main() -> None: + args = _build_parser().parse_args() + persist = not args.no_persist + factor_specs = _resolve_factor_specs(args.factors) + ts_codes = _normalize_codes(args.ts_codes) + batch_size = args.batch_size or 100 + + if args.mode == "single": + if not args.trade_date: + raise SystemExit("--trade-date is required in single mode") + trade_day = _parse_date(args.trade_date) + results = compute_factors( + trade_day, + factor_specs, + ts_codes=ts_codes, + skip_existing=args.skip_existing, + batch_size=batch_size, + persist=persist, + ) + _print_summary_single(trade_day, results, persist) + audit_dates = [trade_day] if args.audit else [] + elif args.mode == "range": + if not args.start or not args.end: + raise SystemExit("--start and --end are required in range mode") + start = _parse_date(args.start) + end = _parse_date(args.end) + results = compute_factor_range( + start, + end, + factors=factor_specs, + ts_codes=ts_codes, + skip_existing=args.skip_existing, + persist=persist, + ) + _print_summary_range(start, end, results, persist) + audit_dates = sorted({result.trade_date for result in results}) if args.audit else [] + else: + summary = compute_factors_incremental( + factors=factor_specs, + ts_codes=ts_codes, + skip_existing=args.skip_existing, + max_trading_days=args.max_days, + persist=persist, + ) + _print_summary_incremental(summary, persist) + audit_dates = summary.get("trade_dates", []) if args.audit else [] + + if args.audit and audit_dates: + for audit_date in audit_dates: + summary = audit_factors( + audit_date, + factors=factor_specs, + tolerance=args.audit_tolerance, + max_issues=args.max_audit_issues, + ) + _print_audit_summary(summary) + elif args.audit: + LOGGER.info("无可审计的日期,跳过因子审计步骤") + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run factor computation pipeline.") + parser.add_argument( + "--mode", + choices=("single", "range", "incremental"), + default="single", + help="Pipeline mode (default: single).", + ) + parser.add_argument("--trade-date", help="Trade date (YYYYMMDD) for single mode.") + parser.add_argument("--start", help="Start date (YYYYMMDD) for range mode.") + parser.add_argument("--end", help="End date (YYYYMMDD) for range mode.") + parser.add_argument( + "--max-days", + type=int, + default=5, + help="Limit of trading days for incremental mode.", + ) + parser.add_argument( + "--ts-code", + dest="ts_codes", + action="append", + help="Limit computation to specific ts_code. Can be provided multiple times.", + ) + parser.add_argument( + "--factor", + dest="factors", + action="append", + help="Factor name to include. Defaults to the built-in set.", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip securities that already have persisted values for the target date(s).", + ) + parser.add_argument( + "--no-persist", + action="store_true", + help="Dry-run mode; compute factors without writing to the database.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Override default batch size when computing factors.", + ) + parser.add_argument( + "--audit", + action="store_true", + help="Run formula audit after computation completes.", + ) + parser.add_argument( + "--audit-tolerance", + type=float, + default=1e-6, + help="Allowed absolute difference when auditing factors.", + ) + parser.add_argument( + "--max-audit-issues", + type=int, + default=50, + help="Maximum number of detailed audit issues to print.", + ) + return parser + + +def _resolve_factor_specs(names: Optional[Sequence[str]]) -> List[FactorSpec]: + if not names: + return list(DEFAULT_FACTORS) + resolved: List[FactorSpec] = [] + seen: set[str] = set() + for name in names: + spec = lookup_factor_spec(name) + if spec is None: + LOGGER.warning("未知因子,忽略: %s", name) + continue + if spec.name in seen: + continue + resolved.append(spec) + seen.add(spec.name) + return resolved or list(DEFAULT_FACTORS) + + +def _normalize_codes(codes: Optional[Iterable[str]]) -> List[str] | None: + if not codes: + return None + normalized = [] + for code in codes: + text = (code or "").strip().upper() + if text: + normalized.append(text) + return normalized or None + + +def _parse_date(value: str) -> date: + value = value.strip() + for fmt in ("%Y%m%d", "%Y-%m-%d"): + try: + return datetime.strptime(value, fmt).date() + except ValueError: + continue + raise SystemExit(f"Invalid date: {value}") + + +def _print_summary_single(trade_day: date, results: Sequence[FactorResult], persist: bool) -> None: + LOGGER.info( + "单日因子计算完成 trade_date=%s rows=%s persist=%s", + trade_day.isoformat(), + len(results), + bool(persist), + ) + + +def _print_summary_range(start: date, end: date, results: Sequence[FactorResult], persist: bool) -> None: + trade_dates = sorted({result.trade_date for result in results}) + LOGGER.info( + "区间因子计算完成 start=%s end=%s days=%s rows=%s persist=%s", + start.isoformat(), + end.isoformat(), + len(trade_dates), + len(results), + bool(persist), + ) + + +def _print_summary_incremental(summary: dict, persist: bool) -> None: + trade_dates = summary.get("trade_dates") or [] + start = trade_dates[0].isoformat() if trade_dates else None + end = trade_dates[-1].isoformat() if trade_dates else None + LOGGER.info( + "增量因子计算完成 start=%s end=%s days=%s rows=%s persist=%s", + start, + end, + len(trade_dates), + summary.get("count", 0), + bool(persist), + ) + + +def _print_audit_summary(summary) -> None: + LOGGER.info( + "因子审计 trade_date=%s mismatched=%s evaluated=%s missing_persisted=%s missing_recomputed=%s issues=%s", + summary.trade_date.isoformat(), + summary.mismatched, + summary.evaluated, + summary.missing_persisted, + summary.missing_recomputed, + len(summary.issues), + ) + for issue in summary.issues: + LOGGER.warning( + "审计异常 ts_code=%s factor=%s stored=%s recomputed=%s diff=%s", + issue.ts_code, + issue.factor, + issue.stored, + issue.recomputed, + issue.difference, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index 6d615d3..201e185 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,27 @@ -"""Pytest configuration shared across test modules.""" from __future__ import annotations import sys from pathlib import Path +import pytest + ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) + +from app.data.schema import initialize_database +from app.utils.config import DataPaths, get_config + + +@pytest.fixture() +def isolated_db(tmp_path): + cfg = get_config() + original_paths = cfg.data_paths + tmp_root = tmp_path / "data" + tmp_root.mkdir(parents=True, exist_ok=True) + cfg.data_paths = DataPaths(root=tmp_root) + try: + initialize_database() + yield + finally: + cfg.data_paths = original_paths diff --git a/tests/factor_utils.py b/tests/factor_utils.py new file mode 100644 index 0000000..b3b1886 --- /dev/null +++ b/tests/factor_utils.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from datetime import date, timedelta + +from app.data.schema import initialize_database +from app.utils.db import db_session + + +def populate_sample_data(ts_code: str, as_of: date, days: int = 60) -> None: + """Populate ``daily`` 和 ``daily_basic`` 表用于测试。""" + initialize_database() + with db_session() as conn: + for offset in range(days): + current_day = as_of - timedelta(days=offset) + trade_date = current_day.strftime("%Y%m%d") + close = 100 + (days - 1 - offset) + turnover = 5 + 0.1 * (days - 1 - offset) + conn.execute( + """ + INSERT OR REPLACE INTO daily + (ts_code, trade_date, open, high, low, close, pct_chg, vol, amount) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + ts_code, + trade_date, + close, + close, + close, + close, + 0.0, + 1_000.0, + 1_000_000.0, + ), + ) + pe = 10.0 + (offset % 5) + pb = 1.5 + (offset % 3) * 0.1 + ps = 2.0 + (offset % 4) * 0.1 + volume_ratio = 0.5 + (offset % 4) * 0.5 + conn.execute( + """ + INSERT OR REPLACE INTO daily_basic + (ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + ts_code, + trade_date, + turnover, + turnover, + volume_ratio, + pe, + pb, + ps, + ), + ) diff --git a/tests/test_backtest_engine.py b/tests/test_backtest_engine.py new file mode 100644 index 0000000..34c4999 --- /dev/null +++ b/tests/test_backtest_engine.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from datetime import date + +from app.backtest.engine import BacktestEngine, BtConfig + + +def test_required_fields_include_precomputed_factors(isolated_db): + cfg = BtConfig( + id="bt-test", + name="bt-test", + start_date=date(2025, 1, 1), + end_date=date(2025, 1, 2), + universe=["000001.SZ"], + params={}, + ) + engine = BacktestEngine(cfg) + required = set(engine.required_fields) + expected_fields = { + "factors.mom_5", + "factors.turn_5", + "factors.val_pe_score", + "factors.val_pb_score", + "factors.volume_ratio_score", + "factors.val_multiscore", + "factors.risk_penalty", + "factors.sent_momentum", + "factors.sent_market", + "factors.sent_divergence", + } + assert expected_fields.issubset(required) diff --git a/tests/test_data_quality.py b/tests/test_data_quality.py new file mode 100644 index 0000000..7d64d9e --- /dev/null +++ b/tests/test_data_quality.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from app.utils.data_access import DataBroker +from app.utils.data_quality import DataQualityResult, summarize_data_quality + + +def test_summarize_data_quality_produces_score(): + results = [ + DataQualityResult("check_a", "ERROR", "fatal issue"), + DataQualityResult("check_b", "WARN", "warning issue"), + DataQualityResult("check_c", "INFO", "info message"), + ] + + summary = summarize_data_quality(results, window_days=7) + + assert summary.total_checks == 3 + assert summary.severity_counts["ERROR"] == 1 + assert summary.has_blockers is True + assert 0.0 <= summary.score < 100.0 + + +def test_data_broker_evaluate_quality_runs_checks(isolated_db): + broker = DataBroker() + summary = broker.evaluate_data_quality(window_days=1) + + assert 0.0 <= summary.score <= 100.0 + assert summary.window_days == 1 diff --git a/tests/test_factor_audit.py b/tests/test_factor_audit.py new file mode 100644 index 0000000..313260c --- /dev/null +++ b/tests/test_factor_audit.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from datetime import date + +from app.features.factor_audit import audit_factors +from app.features.factors import compute_factors +from app.utils.db import db_session +from tests.factor_utils import populate_sample_data + + +def test_audit_matches_persisted_values(isolated_db): + ts_code = "000001.SZ" + trade_day = date(2025, 2, 14) + populate_sample_data(ts_code, trade_day) + + compute_factors(trade_day) + summary = audit_factors(trade_day) + + assert summary.mismatched == 0 + assert summary.missing_persisted == 0 + assert summary.missing_recomputed == 0 + assert not summary.issues + + +def test_audit_detects_drift(isolated_db): + ts_code = "000001.SZ" + trade_day = date(2025, 2, 14) + populate_sample_data(ts_code, trade_day) + + compute_factors(trade_day) + + trade_date_str = trade_day.strftime("%Y%m%d") + with db_session() as conn: + conn.execute( + "UPDATE factors SET mom_5 = mom_5 + 0.05 WHERE ts_code = ? AND trade_date = ?", + (ts_code, trade_date_str), + ) + + summary = audit_factors(trade_day, factors=["mom_5"], tolerance=1e-8, max_issues=5) + + assert summary.mismatched >= 1 + assert summary.issues + first_issue = summary.issues[0] + assert first_issue.ts_code == ts_code + assert first_issue.factor == "mom_5" diff --git a/tests/test_factor_portfolio.py b/tests/test_factor_portfolio.py new file mode 100644 index 0000000..fe3fc38 --- /dev/null +++ b/tests/test_factor_portfolio.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from datetime import date, timedelta + +from app.features.evaluation import ( + FactorPerformance, + evaluate_factor_portfolio, + optimize_factor_weights, +) +from app.features.factors import FactorSpec, compute_factor_range +from tests.factor_utils import populate_sample_data + + +def _seed_factor_history(codes, end_day): + specs = [ + FactorSpec("mom_5", 5), + FactorSpec("mom_20", 20), + FactorSpec("turn_20", 20), + ] + start_day = end_day - timedelta(days=5) + for code in codes: + populate_sample_data(code, end_day, days=180) + compute_factor_range(start_day, end_day, ts_codes=codes, factors=specs) + return specs, start_day + + +def test_optimize_factor_weights_returns_normalized_vector(isolated_db): + codes = [f"0000{i:02d}.SZ" for i in range(1, 4)] + end_day = date(2025, 2, 28) + specs, start_day = _seed_factor_history(codes, end_day) + factor_names = [spec.name for spec in specs] + + weights, performances = optimize_factor_weights( + factor_names, + start_day, + end_day, + universe=codes, + ) + + assert set(weights.keys()) == set(factor_names) + assert abs(sum(weights.values()) - 1.0) < 1e-6 + for perf in performances.values(): + assert isinstance(perf, FactorPerformance) + + +def test_evaluate_factor_portfolio_returns_report(isolated_db): + codes = [f"0000{i:02d}.SZ" for i in range(1, 4)] + end_day = date(2025, 3, 10) + specs, start_day = _seed_factor_history(codes, end_day) + factor_names = [spec.name for spec in specs] + + report = evaluate_factor_portfolio( + factor_names, + start_day, + end_day, + universe=codes, + ) + + assert set(report.weights.keys()) == set(factor_names) + assert isinstance(report.combined, FactorPerformance) + assert report.combined.sample_size >= 0 + assert set(report.components.keys()) == set(factor_names) diff --git a/tests/test_factors.py b/tests/test_factors.py index 912f65f..7cf978c 100644 --- a/tests/test_factors.py +++ b/tests/test_factors.py @@ -6,7 +6,6 @@ from datetime import date, timedelta import pytest from app.core.indicators import momentum, rolling_mean, volatility -from app.data.schema import initialize_database from app.features.factors import ( DEFAULT_FACTORS, FactorResult, @@ -17,77 +16,15 @@ from app.features.factors import ( _valuation_score, _volume_ratio_score, ) -from app.utils.config import DataPaths, get_config from app.utils.data_access import DataBroker from app.utils.db import db_session - - -@pytest.fixture() -def isolated_db(tmp_path): - cfg = get_config() - original_paths = cfg.data_paths - tmp_root = tmp_path / "data" - tmp_root.mkdir(parents=True, exist_ok=True) - cfg.data_paths = DataPaths(root=tmp_root) - try: - yield - finally: - cfg.data_paths = original_paths - - -def _populate_sample_data(ts_code: str, as_of: date) -> None: - initialize_database() - with db_session() as conn: - for offset in range(60): - current_day = as_of - timedelta(days=offset) - trade_date = current_day.strftime("%Y%m%d") - close = 100 + (59 - offset) - turnover = 5 + 0.1 * (59 - offset) - conn.execute( - """ - INSERT OR REPLACE INTO daily - (ts_code, trade_date, open, high, low, close, pct_chg, vol, amount) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - ts_code, - trade_date, - close, - close, - close, - close, - 0.0, - 1000.0, - 1_000_000.0, - ), - ) - pe = 10.0 + (offset % 5) - pb = 1.5 + (offset % 3) * 0.1 - ps = 2.0 + (offset % 4) * 0.1 - volume_ratio = 0.5 + (offset % 4) * 0.5 - conn.execute( - """ - INSERT OR REPLACE INTO daily_basic - (ts_code, trade_date, turnover_rate, turnover_rate_f, volume_ratio, pe, pb, ps) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - ts_code, - trade_date, - turnover, - turnover, - volume_ratio, - pe, - pb, - ps, - ), - ) +from tests.factor_utils import populate_sample_data def test_compute_factors_persists_and_updates(isolated_db): ts_code = "000001.SZ" trade_day = date(2025, 1, 30) - _populate_sample_data(ts_code, trade_day) + populate_sample_data(ts_code, trade_day) specs = [ *DEFAULT_FACTORS, @@ -177,29 +114,57 @@ def test_compute_factors_persists_and_updates(isolated_db): def test_compute_factors_skip_existing(isolated_db): ts_code = "000001.SZ" trade_day = date(2025, 2, 10) - _populate_sample_data(ts_code, trade_day) + populate_sample_data(ts_code, trade_day) - compute_factors(trade_day) - skipped = compute_factors(trade_day, skip_existing=True) + basic_specs = [ + FactorSpec("mom_5", 5), + FactorSpec("mom_20", 20), + FactorSpec("volat_20", 20), + FactorSpec("turn_5", 5), + ] + compute_factors(trade_day, basic_specs) + skipped = compute_factors(trade_day, basic_specs, skip_existing=True) assert skipped == [] +def test_compute_factors_dry_run(isolated_db): + ts_code = "000001.SZ" + trade_day = date(2025, 2, 12) + populate_sample_data(ts_code, trade_day) + + results = compute_factors(trade_day, persist=False) + assert results + + trade_date_str = trade_day.strftime("%Y%m%d") + with db_session(read_only=True) as conn: + count = conn.execute( + "SELECT COUNT(*) AS cnt FROM factors WHERE trade_date = ?", + (trade_date_str,), + ).fetchone() + assert count["cnt"] == 0 + + def test_compute_factors_incremental(isolated_db): ts_code = "000001.SZ" latest_day = date(2025, 2, 10) - _populate_sample_data(ts_code, latest_day) + populate_sample_data(ts_code, latest_day, days=180) - first_day = latest_day - timedelta(days=5) - compute_factors(first_day) + first_day = latest_day - timedelta(days=1) + basic_specs = [ + FactorSpec("mom_5", 5), + FactorSpec("mom_20", 20), + FactorSpec("turn_20", 20), + ] + compute_factors(first_day, basic_specs) - summary = compute_factors_incremental(max_trading_days=3) + summary = compute_factors_incremental(factors=basic_specs, max_trading_days=3) trade_dates = summary["trade_dates"] assert trade_dates assert trade_dates[0] > first_day assert summary["count"] > 0 # No new dates should return empty result - summary_again = compute_factors_incremental(max_trading_days=3) + summary_again = compute_factors_incremental(factors=basic_specs, max_trading_days=3) assert summary_again["count"] == 0 @@ -209,10 +174,16 @@ def test_compute_factor_range_filters_universe(isolated_db): end_day = date(2025, 3, 5) start_day = end_day - timedelta(days=1) - _populate_sample_data(code_a, end_day) - _populate_sample_data(code_b, end_day) + populate_sample_data(code_a, end_day) + populate_sample_data(code_b, end_day) - results = compute_factor_range(start_day, end_day, ts_codes=[code_a]) + basic_specs = [ + FactorSpec("mom_5", 5), + FactorSpec("mom_20", 20), + FactorSpec("turn_20", 20), + ] + + results = compute_factor_range(start_day, end_day, ts_codes=[code_a], factors=basic_specs) assert results assert {result.ts_code for result in results} == {code_a} @@ -220,71 +191,39 @@ def test_compute_factor_range_filters_universe(isolated_db): rows = conn.execute("SELECT DISTINCT ts_code FROM factors").fetchall() assert {row["ts_code"] for row in rows} == {code_a} - repeated = compute_factor_range(start_day, end_day, ts_codes=[code_a]) + repeated = compute_factor_range( + start_day, + end_day, + ts_codes=[code_a], + factors=basic_specs, + skip_existing=True, + ) assert repeated == [] def test_compute_extended_factors(isolated_db): - """Test computation of extended factors.""" - # Use the existing _populate_sample_data function - from app.utils.data_access import DataBroker - broker = DataBroker() - - # Sample data for 5 trading days - dates = ["20240101", "20240102", "20240103", "20240104", "20240105"] - ts_codes = ["000001.SZ", "000002.SZ", "600000.SH"] - - # Populate daily data - for ts_code in ts_codes: - for i, trade_date in enumerate(dates): - broker.insert_or_update_daily( - ts_code, - trade_date, - open_price=10.0 + i * 0.1, - high=10.5 + i * 0.1, - low=9.5 + i * 0.1, - close=10.0 + i * 0.2, # 上涨趋势 - pre_close=10.0 + (i - 1) * 0.2 if i > 0 else 10.0, - vol=100000 + i * 10000, - amount=1000000 + i * 100000, - ) - - broker.insert_or_update_daily_basic( - ts_code, - trade_date, - close=10.0 + i * 0.2, - turnover_rate=1.0 + i * 0.1, - turnover_rate_f=1.0 + i * 0.1, - volume_ratio=1.0 + (i % 3) * 0.2, # 在0.8-1.2之间变化 - pe=15.0 + (i % 3) * 2, # 在15-19之间变化 - pe_ttm=15.0 + (i % 3) * 2, - pb=1.5 + (i % 3) * 0.1, # 在1.5-1.7之间变化 - ps=3.0 + (i % 3) * 0.2, # 在3.0-3.4之间变化 - ps_ttm=3.0 + (i % 3) * 0.2, - dv_ratio=2.0 + (i % 3) * 0.1, # 股息率 - total_mv=1000000 + i * 100000, - circ_mv=800000 + i * 80000, - ) - - # Compute factors with extended factors + """Extended factors should be persisted alongside base factors.""" from app.features.extended_factors import EXTENDED_FACTORS + + trade_day = date(2025, 2, 28) + ts_codes = ["000001.SZ", "000002.SZ"] + for code in ts_codes: + populate_sample_data(code, trade_day, days=120) + all_factors = list(DEFAULT_FACTORS) + EXTENDED_FACTORS - - trade_day = date(2024, 1, 5) results = compute_factors(trade_day, all_factors) - - # Verify that we got results + assert results - - # Verify that extended factors are computed result_map = {result.ts_code: result for result in results} - ts_code = "000001.SZ" - assert ts_code in result_map - result = result_map[ts_code] - - # Check that extended factors are present in the results - extended_factor_names = [spec.name for spec in EXTENDED_FACTORS] - for factor_name in extended_factor_names: - assert factor_name in result.values - # Values should not be None - assert result.values[factor_name] is not None + for code in ts_codes: + assert code in result_map + factor_payload = result_map[code].values + required_extended = { + "tech_rsi_14", + "tech_macd_signal", + "trend_ma_cross", + "micro_trade_imbalance", + } + assert required_extended.issubset(factor_payload.keys()) + for name in required_extended: + assert factor_payload.get(name) is not None diff --git a/tests/test_news_ingest.py b/tests/test_news_ingest.py new file mode 100644 index 0000000..352ac40 --- /dev/null +++ b/tests/test_news_ingest.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import json +from datetime import datetime + +from app.ingest import entity_recognition +from app.ingest.rss import RssItem, save_news_items +from app.utils.db import db_session + + +def test_save_news_items_persists_entities_and_heat(isolated_db): + # Reset mapping state + entity_recognition._COMPANY_MAPPING_INITIALIZED = False + mapper = entity_recognition.company_mapper + mapper.name_to_code.clear() + mapper.short_names.clear() + mapper.aliases.clear() + + ts_code = "000001.SZ" + mapper.add_company(ts_code, "平安银行股份有限公司", "平安银行") + + item = RssItem( + id="news-1", + title="平安银行利好消息爆发", + link="https://example.com/news", + published=datetime.utcnow(), + summary="平安银行股份有限公司公布季度业绩,银行板块再迎利好。", + source="TestWire", + ) + + item.extract_entities() + assert ts_code in item.ts_codes + + saved = save_news_items([item]) + assert saved == 1 + + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT heat, entities FROM news WHERE ts_code = ? ORDER BY pub_time DESC LIMIT 1", + (ts_code,), + ).fetchone() + + assert row is not None + assert 0.0 <= row["heat"] <= 1.0 + assert row["heat"] > 0.6 + + entities_payload = json.loads(row["entities"]) + assert ts_code in entities_payload.get("ts_codes", []) + assert "industries" in entities_payload + assert "important_keywords" in entities_payload