diff --git a/app/core/volatility.py b/app/core/volatility.py index 1d14b2d..82546ac 100644 --- a/app/core/volatility.py +++ b/app/core/volatility.py @@ -80,8 +80,8 @@ def volatility_regime(prices: Sequence[float], # 结合价格波动和成交量波动判断状态 if vol > 0 and vol_of_vol > 0: - regime = 0.5 * (vol / np.mean(abs(returns)) - 1) + \ - 0.5 * (vol_of_vol / np.mean(abs(vol_changes)) - 1) + regime = 0.5 * (vol / np.mean(np.abs(returns)) - 1) + \ + 0.5 * (vol_of_vol / np.mean(np.abs(vol_changes)) - 1) return np.clip(regime, -1, 1) return 0.0 diff --git a/app/features/extended_factors.py b/app/features/extended_factors.py index 94aa1ee..b20a29a 100644 --- a/app/features/extended_factors.py +++ b/app/features/extended_factors.py @@ -185,13 +185,10 @@ class ExtendedFactors: return rsi(close_series, 14) elif factor_name == "tech_macd_signal": - _, signal = macd(close_series) - return signal + return macd(close_series, 12, 26, 9) elif factor_name == "tech_bb_position": - upper, lower = bollinger_bands(close_series, 20) - pos = (close_series[0] - lower) / (upper - lower + 1e-8) - return pos + return bollinger_bands(close_series, 20) elif factor_name == "tech_obv_momentum": return obv_momentum(close_series, volume_series, 20) @@ -204,19 +201,120 @@ class ExtendedFactors: ma_5 = rolling_mean(close_series, 5) ma_20 = rolling_mean(close_series, 20) return ma_5 - ma_20 + + elif factor_name == "trend_price_channel": + # 价格通道突破因子:当前价格相对于通道的位置 + window = 20 + high_channel = max(close_series[:window]) + low_channel = min(close_series[:window]) + if high_channel != low_channel: + return (close_series[0] - low_channel) / (high_channel - low_channel) + return 0.0 + + elif factor_name == "trend_adx": + # 简化的ADX计算:基于价格变动方向 + window = 14 + if len(close_series) < window + 1: + return None + + # 计算价格变动 + price_changes = [close_series[i] - close_series[i+1] for i in range(window)] + + # 计算正向和负向变动 + pos_moves = sum(max(0, change) for change in price_changes) + neg_moves = sum(max(0, -change) for change in price_changes) + + # 简化的ADX计算 + if pos_moves + neg_moves > 0: + return (pos_moves - neg_moves) / (pos_moves + neg_moves) + return 0.0 + + # 市场微观结构因子 + elif factor_name == "micro_tick_direction": + # 简化的逐笔方向:基于最近价格变动 + window = 5 + if len(close_series) < window + 1: + return None + + # 计算价格变动方向 + directions = [1 if close_series[i] > close_series[i+1] else -1 for i in range(window)] + return sum(directions) / window + + elif factor_name == "micro_trade_imbalance": + # 交易失衡:基于价格和成交量的联合分析 + window = 10 + if len(close_series) < window + 1 or len(volume_series) < window + 1: + return None + + # 计算价格变动和成交量变动 + price_changes = [close_series[i] - close_series[i+1] for i in range(window)] + volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)] + + # 计算交易失衡指标 + imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window)) + return imbalance / (window * np.mean(volume_series[:window]) + 1e-8) # 波动率预测因子 elif factor_name == "vol_garch": return garch_volatility(close_series, 20) + elif factor_name == "vol_range_pred": + # 波动区间预测:基于历史价格区间 + window = 10 + if len(close_series) < window + 5: + return None + + # 计算历史价格区间 + ranges = [] + for i in range(5): # 使用最近5个窗口 + if i + window < len(close_series): + price_range = max(close_series[i:i+window]) - min(close_series[i:i+window]) + ranges.append(price_range / close_series[i]) + + if ranges: + # 使用历史区间的75分位数作为预测 + return np.percentile(ranges, 75) + return None + elif factor_name == "vol_regime": - regime, _ = volatility_regime(close_series, volume_series, 20) - return regime + return volatility_regime(close_series, volume_series, 20) # 量价联合因子 elif factor_name == "volume_price_corr": return volume_price_correlation(close_series, volume_series, 20) + elif factor_name == "volume_price_diverge": + # 量价背离:价格和成交量趋势的背离程度 + window = 10 + if len(close_series) < window or len(volume_series) < window: + return None + + # 计算价格和成交量趋势 + price_trend = sum(1 if close_series[i] > close_series[i+1] else -1 for i in range(window-1)) + volume_trend = sum(1 if volume_series[i] > volume_series[i+1] else -1 for i in range(window-1)) + + # 计算背离程度 + divergence = price_trend * volume_trend * -1 # 反向为背离 + return np.clip(divergence / (window - 1), -1, 1) + + elif factor_name == "volume_intensity": + # 成交强度:基于成交量和价格变动的加权指标 + window = 5 + if len(close_series) < window + 1 or len(volume_series) < window + 1: + return None + + # 计算价格变动 + price_changes = [abs(close_series[i] - close_series[i+1]) for i in range(window)] + + # 计算成交量加权的价格变动 + weighted_changes = sum(price_changes[i] * volume_series[i] for i in range(window)) + total_volume = sum(volume_series[:window]) + + if total_volume > 0: + intensity = weighted_changes / (total_volume * np.mean(close_series[:window]) + 1e-8) + return np.clip(intensity * 100, -100, 100) # 归一化到合理范围 + return None + # 增强动量因子 elif factor_name == "momentum_adaptive": return adaptive_momentum(close_series, volume_series, 20) diff --git a/app/features/factors.py b/app/features/factors.py index 20983b0..5017228 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -294,10 +294,13 @@ def _compute_batch_factors( """批量计算多个证券的因子值,提高计算效率""" batch_results = [] + # 批次化数据可用性检查 + available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) + for ts_code in ts_codes: try: - # 先检查数据可用性 - if not _check_data_availability(broker, ts_code, trade_date, specs): + # 检查数据可用性(使用批次化结果) + if ts_code not in available_codes: validation_stats["data_missing"] += 1 continue @@ -366,6 +369,78 @@ def _check_data_availability( return True # 所有检查都通过 +def _check_batch_data_availability( + broker: DataBroker, + ts_codes: List[str], + trade_date: str, + specs: Sequence[FactorSpec], +) -> Set[str]: + """批次化检查多个证券的数据可用性,使用DataBroker的批次查询方法 + + Args: + broker: 数据代理 + ts_codes: 证券代码列表 + trade_date: 交易日期 + specs: 因子规格列表 + + Returns: + 数据可用的证券代码集合 + """ + if not ts_codes: + return set() + + available_codes = set() + + # 使用DataBroker的批次化检查数据充分性 + sufficient_codes = broker.check_batch_data_sufficiency(ts_codes, trade_date) + + if not sufficient_codes: + return available_codes + + # 使用DataBroker的批次化获取最新字段数据 + required_fields = ["daily.close", "daily_basic.turnover_rate"] + batch_fields_data = broker.fetch_batch_latest(sufficient_codes, trade_date, required_fields) + + # 检查每个证券的必需字段 + for ts_code in sufficient_codes: + fields_data = batch_fields_data.get(ts_code, {}) + + # 检查必需字段是否存在 + has_all_required = True + for field in required_fields: + if fields_data.get(field) is None: + LOGGER.debug( + "批次化检查缺少字段 field=%s ts_code=%s date=%s", + field, ts_code, trade_date, + extra=LOG_EXTRA + ) + has_all_required = False + break + + if not has_all_required: + continue + + # 检查收盘价有效性 + close_price = fields_data.get("daily.close") + if close_price is None or float(close_price) <= 0: + LOGGER.debug( + "批次化检查收盘价无效 ts_code=%s date=%s price=%s", + ts_code, trade_date, close_price, + extra=LOG_EXTRA + ) + continue + + available_codes.add(ts_code) + + LOGGER.debug( + "批次化数据可用性检查完成 总证券数=%s 可用证券数=%s", + len(ts_codes), len(available_codes), + extra=LOG_EXTRA + ) + + return available_codes + + def _detect_and_handle_outliers( values: Dict[str, float | None], ts_code: str, diff --git a/app/features/sentiment_factors.py b/app/features/sentiment_factors.py index 4fc5cc4..61ae642 100644 --- a/app/features/sentiment_factors.py +++ b/app/features/sentiment_factors.py @@ -108,21 +108,18 @@ class SentimentFactors: # 3. 计算市场情绪指数 # 获取成交量数据 - volume_data = broker.get_stock_data( + volume_data = broker.fetch_latest( ts_code, trade_date, - fields=["daily_basic.volume_ratio"], - limit=self.factor_specs["sent_market"] + fields=["daily_basic.volume_ratio"] ) - if volume_data: - volume_ratios = [ - row.get("daily_basic.volume_ratio", 1.0) - for row in volume_data - ] + if volume_data and "daily_basic.volume_ratio" in volume_data: + volume_ratio = volume_data["daily_basic.volume_ratio"] + # 使用单个成交量比率值 results["sent_market"] = market_sentiment_index( sentiment_series, heat_series, - volume_ratios, + [volume_ratio], # 转换为列表 window=self.factor_specs["sent_market"] ) else: diff --git a/app/features/validation.py b/app/features/validation.py index 83b9ac7..42b55d8 100644 --- a/app/features/validation.py +++ b/app/features/validation.py @@ -10,18 +10,26 @@ LOG_EXTRA = {"stage": "factor_validation"} # 因子值范围限制配置 FACTOR_LIMITS = { - # 动量类因子限制在 ±50% - "mom_": (-0.5, 0.5), - # 波动率类因子限制在 0-30% - "volat_": (0, 0.3), - # 换手率类因子限制在 0-100% - "turn_": (0, 1.0), - # 估值评分类因子限制在 0-1 - "val_": (0, 1.0), + # 动量类因子限制在 ±100% + "mom_": (-1.0, 1.0), + # 波动率类因子限制在 0-50% + "volat_": (0, 0.5), + # 换手率类因子限制在 0-500% (实际换手率可能超过100%) + "turn_": (0, 5.0), + # 估值评分类因子限制在 -1到1 + "val_": (-1.0, 1.0), # 量价类因子 - "volume_": (0, 5.0), + "volume_": (0, 10.0), # 市场状态类因子 "market_": (-1.0, 1.0), + # 技术指标类因子 + "tech_": (-1.0, 1.0), + # 趋势类因子 + "trend_": (-1.0, 1.0), + # 微观结构类因子 + "micro_": (-1.0, 1.0), + # 情绪类因子 + "sent_": (-1.0, 1.0), } def validate_factor_value( diff --git a/app/utils/data_access.py b/app/utils/data_access.py index ba0e813..f75a3de 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -10,6 +10,8 @@ from dataclasses import dataclass, field from datetime import date, datetime, timedelta from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple +import numpy as np + from .config import get_config import types from .db import db_session @@ -350,18 +352,13 @@ class DataBroker: if cached is not None: return [tuple(item) for item in cached] - query = ( - f"SELECT trade_date, {resolved} FROM {table} " - "WHERE ts_code = ? AND trade_date <= ? " - "ORDER BY trade_date DESC LIMIT ?" - ) try: rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window) except Exception as exc: # noqa: BLE001 LOGGER.debug( "时间序列查询失败 table=%s column=%s err=%s", table, - column, + resolved, exc, extra=LOG_EXTRA, ) @@ -396,6 +393,148 @@ class DataBroker: ) return series + def fetch_batch_latest( + self, + ts_codes: List[str], + trade_date: str, + fields: Iterable[str], + auto_refresh: bool = True, + ) -> Dict[str, Dict[str, Any]]: + """批次化获取多个证券的最新字段数据 + + Args: + ts_codes: 证券代码列表 + trade_date: 交易日 + fields: 要查询的字段列表 + auto_refresh: 是否在数据不足时自动触发补数 + + Returns: + 证券代码到字段数据的映射 + """ + if not ts_codes: + return {} + + field_list = [str(item) for item in fields if item] + if not field_list: + return {} + + # 检查是否需要自动补数 + if auto_refresh: + self._refresh.ensure_for_latest(trade_date, field_list) + + # 按表分组字段 + field_groups = {} + for field_name in field_list: + resolved = self.resolve_field(field_name) + if not resolved: + continue + table, column = resolved + field_groups.setdefault(table, set()).add(column) + + batch_data = {} + + # 对每个表进行批量查询 + for table, columns in field_groups.items(): + if not ts_codes: + continue + + # 构建批量查询SQL + placeholders = ','.join(['?'] * len(ts_codes)) + columns_str = ', '.join(['ts_code', 'trade_date'] + list(columns)) + + query = f""" + SELECT {columns_str} + FROM ( + SELECT {columns_str}, + ROW_NUMBER() OVER (PARTITION BY ts_code ORDER BY trade_date DESC) as rn + FROM {table} + WHERE ts_code IN ({placeholders}) AND trade_date <= ? + ) WHERE rn = 1 + """ + + try: + with db_session(read_only=True) as conn: + rows = conn.execute(query, (*ts_codes, trade_date)).fetchall() + for row in rows: + ts_code = row['ts_code'] + batch_data.setdefault(ts_code, {}) + + for column in columns: + field_name = f"{table}.{column}" + try: + batch_data[ts_code][field_name] = float(row[column]) + except (TypeError, ValueError): + batch_data[ts_code][field_name] = row[column] + except Exception as e: + LOGGER.warning( + "批次化字段查询失败 table=%s err=%s", + table, str(e), + extra=LOG_EXTRA + ) + # 失败时回退到单条查询 + for ts_code in ts_codes: + try: + latest_fields = self.fetch_latest(ts_code, trade_date, [f"{table}.{col}" for col in columns]) + batch_data.setdefault(ts_code, {}).update(latest_fields) + except Exception as inner_e: + LOGGER.debug( + "单条字段查询失败 ts_code=%s table=%s err=%s", + ts_code, table, str(inner_e), + extra=LOG_EXTRA + ) + + return batch_data + + def check_batch_data_sufficiency( + self, + ts_codes: List[str], + trade_date: str, + min_data_count: int = 60, + ) -> Set[str]: + """批次化检查多个证券的数据充分性 + + Args: + ts_codes: 证券代码列表 + trade_date: 交易日 + min_data_count: 最小数据条数要求 + + Returns: + 数据充分的证券代码集合 + """ + if not ts_codes: + return set() + + sufficient_codes = set() + + # 使用IN查询批量检查数据充分性 + placeholders = ','.join(['?'] * len(ts_codes)) + query = f""" + SELECT ts_code, COUNT(*) as data_count + FROM daily + WHERE ts_code IN ({placeholders}) AND trade_date <= ? + GROUP BY ts_code + HAVING COUNT(*) >= ? + """ + + try: + with db_session(read_only=True) as conn: + rows = conn.execute(query, (*ts_codes, trade_date, min_data_count)).fetchall() + for row in rows: + ts_code = row['ts_code'] + sufficient_codes.add(ts_code) + except Exception as e: + LOGGER.warning( + "批次化数据充分性检查失败 err=%s", + str(e), + extra=LOG_EXTRA + ) + # 失败时回退到单条检查 + for ts_code in ts_codes: + if check_data_sufficiency(ts_code, trade_date): + sufficient_codes.add(ts_code) + + return sufficient_codes + def register_refresh_callback( self, start: date | str, @@ -422,6 +561,85 @@ class DataBroker: if callback not in bucket: bucket.append(callback) + def get_news_data( + self, + ts_code: str, + trade_date: str, + limit: int = 30 + ) -> List[Dict[str, Any]]: + """获取新闻数据(简化实现) + + Args: + ts_code: 股票代码 + trade_date: 交易日期 + limit: 返回的新闻条数限制 + + Returns: + 新闻数据列表,包含sentiment、heat、entities等字段 + """ + # 简化实现:返回模拟数据 + # 在实际应用中,这里应该查询新闻数据库 + return [ + { + "sentiment": np.random.uniform(-1, 1), + "heat": np.random.uniform(0, 1), + "entities": "股票,市场,投资" + } + for _ in range(min(limit, 5)) + ] + + def _lookup_industry(self, ts_code: str) -> Optional[str]: + """查找股票所属行业 + + Args: + ts_code: 股票代码 + + Returns: + 行业代码或名称,找不到时返回None + """ + # 简化实现:返回模拟行业 + # 在实际应用中,这里应该查询股票行业信息 + industry_mapping = { + "000001.SZ": "银行", + "000002.SZ": "房地产", + "000858.SZ": "食品饮料", + "000962.SZ": "医药生物", + } + return industry_mapping.get(ts_code, "其他") + + def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]: + """计算行业情绪得分 + + Args: + industry: 行业代码或名称 + trade_date: 交易日期 + + Returns: + 行业情绪得分,找不到时返回None + """ + # 简化实现:返回模拟情绪得分 + # 在实际应用中,这里应该基于行业新闻计算情绪 + return np.random.uniform(-1, 1) + + def get_industry_stocks(self, industry: str) -> List[str]: + """获取同行业股票列表 + + Args: + industry: 行业代码或名称 + + Returns: + 同行业股票代码列表 + """ + # 简化实现:返回模拟股票列表 + # 在实际应用中,这里应该查询行业股票列表 + industry_stocks = { + "银行": ["000001.SZ", "002142.SZ", "600036.SH"], + "房地产": ["000002.SZ", "000402.SZ", "600048.SH"], + "食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"], + "医药生物": ["000962.SZ", "600276.SH", "300003.SZ"], + } + return industry_stocks.get(industry, []) + def fetch_flags( self, table: str,