enhance factors calculation with new technical and market microstructure features
This commit is contained in:
parent
27b7f024c0
commit
34f0758135
@ -80,8 +80,8 @@ def volatility_regime(prices: Sequence[float],
|
|||||||
|
|
||||||
# 结合价格波动和成交量波动判断状态
|
# 结合价格波动和成交量波动判断状态
|
||||||
if vol > 0 and vol_of_vol > 0:
|
if vol > 0 and vol_of_vol > 0:
|
||||||
regime = 0.5 * (vol / np.mean(abs(returns)) - 1) + \
|
regime = 0.5 * (vol / np.mean(np.abs(returns)) - 1) + \
|
||||||
0.5 * (vol_of_vol / np.mean(abs(vol_changes)) - 1)
|
0.5 * (vol_of_vol / np.mean(np.abs(vol_changes)) - 1)
|
||||||
return np.clip(regime, -1, 1)
|
return np.clip(regime, -1, 1)
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|||||||
@ -185,13 +185,10 @@ class ExtendedFactors:
|
|||||||
return rsi(close_series, 14)
|
return rsi(close_series, 14)
|
||||||
|
|
||||||
elif factor_name == "tech_macd_signal":
|
elif factor_name == "tech_macd_signal":
|
||||||
_, signal = macd(close_series)
|
return macd(close_series, 12, 26, 9)
|
||||||
return signal
|
|
||||||
|
|
||||||
elif factor_name == "tech_bb_position":
|
elif factor_name == "tech_bb_position":
|
||||||
upper, lower = bollinger_bands(close_series, 20)
|
return bollinger_bands(close_series, 20)
|
||||||
pos = (close_series[0] - lower) / (upper - lower + 1e-8)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
elif factor_name == "tech_obv_momentum":
|
elif factor_name == "tech_obv_momentum":
|
||||||
return obv_momentum(close_series, volume_series, 20)
|
return obv_momentum(close_series, volume_series, 20)
|
||||||
@ -204,19 +201,120 @@ class ExtendedFactors:
|
|||||||
ma_5 = rolling_mean(close_series, 5)
|
ma_5 = rolling_mean(close_series, 5)
|
||||||
ma_20 = rolling_mean(close_series, 20)
|
ma_20 = rolling_mean(close_series, 20)
|
||||||
return ma_5 - ma_20
|
return ma_5 - ma_20
|
||||||
|
|
||||||
|
elif factor_name == "trend_price_channel":
|
||||||
|
# 价格通道突破因子:当前价格相对于通道的位置
|
||||||
|
window = 20
|
||||||
|
high_channel = max(close_series[:window])
|
||||||
|
low_channel = min(close_series[:window])
|
||||||
|
if high_channel != low_channel:
|
||||||
|
return (close_series[0] - low_channel) / (high_channel - low_channel)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
elif factor_name == "trend_adx":
|
||||||
|
# 简化的ADX计算:基于价格变动方向
|
||||||
|
window = 14
|
||||||
|
if len(close_series) < window + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算价格变动
|
||||||
|
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||||
|
|
||||||
|
# 计算正向和负向变动
|
||||||
|
pos_moves = sum(max(0, change) for change in price_changes)
|
||||||
|
neg_moves = sum(max(0, -change) for change in price_changes)
|
||||||
|
|
||||||
|
# 简化的ADX计算
|
||||||
|
if pos_moves + neg_moves > 0:
|
||||||
|
return (pos_moves - neg_moves) / (pos_moves + neg_moves)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 市场微观结构因子
|
||||||
|
elif factor_name == "micro_tick_direction":
|
||||||
|
# 简化的逐笔方向:基于最近价格变动
|
||||||
|
window = 5
|
||||||
|
if len(close_series) < window + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算价格变动方向
|
||||||
|
directions = [1 if close_series[i] > close_series[i+1] else -1 for i in range(window)]
|
||||||
|
return sum(directions) / window
|
||||||
|
|
||||||
|
elif factor_name == "micro_trade_imbalance":
|
||||||
|
# 交易失衡:基于价格和成交量的联合分析
|
||||||
|
window = 10
|
||||||
|
if len(close_series) < window + 1 or len(volume_series) < window + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算价格变动和成交量变动
|
||||||
|
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||||
|
volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)]
|
||||||
|
|
||||||
|
# 计算交易失衡指标
|
||||||
|
imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window))
|
||||||
|
return imbalance / (window * np.mean(volume_series[:window]) + 1e-8)
|
||||||
|
|
||||||
# 波动率预测因子
|
# 波动率预测因子
|
||||||
elif factor_name == "vol_garch":
|
elif factor_name == "vol_garch":
|
||||||
return garch_volatility(close_series, 20)
|
return garch_volatility(close_series, 20)
|
||||||
|
|
||||||
|
elif factor_name == "vol_range_pred":
|
||||||
|
# 波动区间预测:基于历史价格区间
|
||||||
|
window = 10
|
||||||
|
if len(close_series) < window + 5:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算历史价格区间
|
||||||
|
ranges = []
|
||||||
|
for i in range(5): # 使用最近5个窗口
|
||||||
|
if i + window < len(close_series):
|
||||||
|
price_range = max(close_series[i:i+window]) - min(close_series[i:i+window])
|
||||||
|
ranges.append(price_range / close_series[i])
|
||||||
|
|
||||||
|
if ranges:
|
||||||
|
# 使用历史区间的75分位数作为预测
|
||||||
|
return np.percentile(ranges, 75)
|
||||||
|
return None
|
||||||
|
|
||||||
elif factor_name == "vol_regime":
|
elif factor_name == "vol_regime":
|
||||||
regime, _ = volatility_regime(close_series, volume_series, 20)
|
return volatility_regime(close_series, volume_series, 20)
|
||||||
return regime
|
|
||||||
|
|
||||||
# 量价联合因子
|
# 量价联合因子
|
||||||
elif factor_name == "volume_price_corr":
|
elif factor_name == "volume_price_corr":
|
||||||
return volume_price_correlation(close_series, volume_series, 20)
|
return volume_price_correlation(close_series, volume_series, 20)
|
||||||
|
|
||||||
|
elif factor_name == "volume_price_diverge":
|
||||||
|
# 量价背离:价格和成交量趋势的背离程度
|
||||||
|
window = 10
|
||||||
|
if len(close_series) < window or len(volume_series) < window:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算价格和成交量趋势
|
||||||
|
price_trend = sum(1 if close_series[i] > close_series[i+1] else -1 for i in range(window-1))
|
||||||
|
volume_trend = sum(1 if volume_series[i] > volume_series[i+1] else -1 for i in range(window-1))
|
||||||
|
|
||||||
|
# 计算背离程度
|
||||||
|
divergence = price_trend * volume_trend * -1 # 反向为背离
|
||||||
|
return np.clip(divergence / (window - 1), -1, 1)
|
||||||
|
|
||||||
|
elif factor_name == "volume_intensity":
|
||||||
|
# 成交强度:基于成交量和价格变动的加权指标
|
||||||
|
window = 5
|
||||||
|
if len(close_series) < window + 1 or len(volume_series) < window + 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算价格变动
|
||||||
|
price_changes = [abs(close_series[i] - close_series[i+1]) for i in range(window)]
|
||||||
|
|
||||||
|
# 计算成交量加权的价格变动
|
||||||
|
weighted_changes = sum(price_changes[i] * volume_series[i] for i in range(window))
|
||||||
|
total_volume = sum(volume_series[:window])
|
||||||
|
|
||||||
|
if total_volume > 0:
|
||||||
|
intensity = weighted_changes / (total_volume * np.mean(close_series[:window]) + 1e-8)
|
||||||
|
return np.clip(intensity * 100, -100, 100) # 归一化到合理范围
|
||||||
|
return None
|
||||||
|
|
||||||
# 增强动量因子
|
# 增强动量因子
|
||||||
elif factor_name == "momentum_adaptive":
|
elif factor_name == "momentum_adaptive":
|
||||||
return adaptive_momentum(close_series, volume_series, 20)
|
return adaptive_momentum(close_series, volume_series, 20)
|
||||||
|
|||||||
@ -294,10 +294,13 @@ def _compute_batch_factors(
|
|||||||
"""批量计算多个证券的因子值,提高计算效率"""
|
"""批量计算多个证券的因子值,提高计算效率"""
|
||||||
batch_results = []
|
batch_results = []
|
||||||
|
|
||||||
|
# 批次化数据可用性检查
|
||||||
|
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||||
|
|
||||||
for ts_code in ts_codes:
|
for ts_code in ts_codes:
|
||||||
try:
|
try:
|
||||||
# 先检查数据可用性
|
# 检查数据可用性(使用批次化结果)
|
||||||
if not _check_data_availability(broker, ts_code, trade_date, specs):
|
if ts_code not in available_codes:
|
||||||
validation_stats["data_missing"] += 1
|
validation_stats["data_missing"] += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -366,6 +369,78 @@ def _check_data_availability(
|
|||||||
return True # 所有检查都通过
|
return True # 所有检查都通过
|
||||||
|
|
||||||
|
|
||||||
|
def _check_batch_data_availability(
|
||||||
|
broker: DataBroker,
|
||||||
|
ts_codes: List[str],
|
||||||
|
trade_date: str,
|
||||||
|
specs: Sequence[FactorSpec],
|
||||||
|
) -> Set[str]:
|
||||||
|
"""批次化检查多个证券的数据可用性,使用DataBroker的批次查询方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
broker: 数据代理
|
||||||
|
ts_codes: 证券代码列表
|
||||||
|
trade_date: 交易日期
|
||||||
|
specs: 因子规格列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据可用的证券代码集合
|
||||||
|
"""
|
||||||
|
if not ts_codes:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
available_codes = set()
|
||||||
|
|
||||||
|
# 使用DataBroker的批次化检查数据充分性
|
||||||
|
sufficient_codes = broker.check_batch_data_sufficiency(ts_codes, trade_date)
|
||||||
|
|
||||||
|
if not sufficient_codes:
|
||||||
|
return available_codes
|
||||||
|
|
||||||
|
# 使用DataBroker的批次化获取最新字段数据
|
||||||
|
required_fields = ["daily.close", "daily_basic.turnover_rate"]
|
||||||
|
batch_fields_data = broker.fetch_batch_latest(sufficient_codes, trade_date, required_fields)
|
||||||
|
|
||||||
|
# 检查每个证券的必需字段
|
||||||
|
for ts_code in sufficient_codes:
|
||||||
|
fields_data = batch_fields_data.get(ts_code, {})
|
||||||
|
|
||||||
|
# 检查必需字段是否存在
|
||||||
|
has_all_required = True
|
||||||
|
for field in required_fields:
|
||||||
|
if fields_data.get(field) is None:
|
||||||
|
LOGGER.debug(
|
||||||
|
"批次化检查缺少字段 field=%s ts_code=%s date=%s",
|
||||||
|
field, ts_code, trade_date,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
has_all_required = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_all_required:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查收盘价有效性
|
||||||
|
close_price = fields_data.get("daily.close")
|
||||||
|
if close_price is None or float(close_price) <= 0:
|
||||||
|
LOGGER.debug(
|
||||||
|
"批次化检查收盘价无效 ts_code=%s date=%s price=%s",
|
||||||
|
ts_code, trade_date, close_price,
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
available_codes.add(ts_code)
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
"批次化数据可用性检查完成 总证券数=%s 可用证券数=%s",
|
||||||
|
len(ts_codes), len(available_codes),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
|
return available_codes
|
||||||
|
|
||||||
|
|
||||||
def _detect_and_handle_outliers(
|
def _detect_and_handle_outliers(
|
||||||
values: Dict[str, float | None],
|
values: Dict[str, float | None],
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
|
|||||||
@ -108,21 +108,18 @@ class SentimentFactors:
|
|||||||
|
|
||||||
# 3. 计算市场情绪指数
|
# 3. 计算市场情绪指数
|
||||||
# 获取成交量数据
|
# 获取成交量数据
|
||||||
volume_data = broker.get_stock_data(
|
volume_data = broker.fetch_latest(
|
||||||
ts_code,
|
ts_code,
|
||||||
trade_date,
|
trade_date,
|
||||||
fields=["daily_basic.volume_ratio"],
|
fields=["daily_basic.volume_ratio"]
|
||||||
limit=self.factor_specs["sent_market"]
|
|
||||||
)
|
)
|
||||||
if volume_data:
|
if volume_data and "daily_basic.volume_ratio" in volume_data:
|
||||||
volume_ratios = [
|
volume_ratio = volume_data["daily_basic.volume_ratio"]
|
||||||
row.get("daily_basic.volume_ratio", 1.0)
|
# 使用单个成交量比率值
|
||||||
for row in volume_data
|
|
||||||
]
|
|
||||||
results["sent_market"] = market_sentiment_index(
|
results["sent_market"] = market_sentiment_index(
|
||||||
sentiment_series,
|
sentiment_series,
|
||||||
heat_series,
|
heat_series,
|
||||||
volume_ratios,
|
[volume_ratio], # 转换为列表
|
||||||
window=self.factor_specs["sent_market"]
|
window=self.factor_specs["sent_market"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -10,18 +10,26 @@ LOG_EXTRA = {"stage": "factor_validation"}
|
|||||||
|
|
||||||
# 因子值范围限制配置
|
# 因子值范围限制配置
|
||||||
FACTOR_LIMITS = {
|
FACTOR_LIMITS = {
|
||||||
# 动量类因子限制在 ±50%
|
# 动量类因子限制在 ±100%
|
||||||
"mom_": (-0.5, 0.5),
|
"mom_": (-1.0, 1.0),
|
||||||
# 波动率类因子限制在 0-30%
|
# 波动率类因子限制在 0-50%
|
||||||
"volat_": (0, 0.3),
|
"volat_": (0, 0.5),
|
||||||
# 换手率类因子限制在 0-100%
|
# 换手率类因子限制在 0-500% (实际换手率可能超过100%)
|
||||||
"turn_": (0, 1.0),
|
"turn_": (0, 5.0),
|
||||||
# 估值评分类因子限制在 0-1
|
# 估值评分类因子限制在 -1到1
|
||||||
"val_": (0, 1.0),
|
"val_": (-1.0, 1.0),
|
||||||
# 量价类因子
|
# 量价类因子
|
||||||
"volume_": (0, 5.0),
|
"volume_": (0, 10.0),
|
||||||
# 市场状态类因子
|
# 市场状态类因子
|
||||||
"market_": (-1.0, 1.0),
|
"market_": (-1.0, 1.0),
|
||||||
|
# 技术指标类因子
|
||||||
|
"tech_": (-1.0, 1.0),
|
||||||
|
# 趋势类因子
|
||||||
|
"trend_": (-1.0, 1.0),
|
||||||
|
# 微观结构类因子
|
||||||
|
"micro_": (-1.0, 1.0),
|
||||||
|
# 情绪类因子
|
||||||
|
"sent_": (-1.0, 1.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate_factor_value(
|
def validate_factor_value(
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
import types
|
import types
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
@ -350,18 +352,13 @@ class DataBroker:
|
|||||||
if cached is not None:
|
if cached is not None:
|
||||||
return [tuple(item) for item in cached]
|
return [tuple(item) for item in cached]
|
||||||
|
|
||||||
query = (
|
|
||||||
f"SELECT trade_date, {resolved} FROM {table} "
|
|
||||||
"WHERE ts_code = ? AND trade_date <= ? "
|
|
||||||
"ORDER BY trade_date DESC LIMIT ?"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
|
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"时间序列查询失败 table=%s column=%s err=%s",
|
"时间序列查询失败 table=%s column=%s err=%s",
|
||||||
table,
|
table,
|
||||||
column,
|
resolved,
|
||||||
exc,
|
exc,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
@ -396,6 +393,148 @@ class DataBroker:
|
|||||||
)
|
)
|
||||||
return series
|
return series
|
||||||
|
|
||||||
|
def fetch_batch_latest(
|
||||||
|
self,
|
||||||
|
ts_codes: List[str],
|
||||||
|
trade_date: str,
|
||||||
|
fields: Iterable[str],
|
||||||
|
auto_refresh: bool = True,
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""批次化获取多个证券的最新字段数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_codes: 证券代码列表
|
||||||
|
trade_date: 交易日
|
||||||
|
fields: 要查询的字段列表
|
||||||
|
auto_refresh: 是否在数据不足时自动触发补数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
证券代码到字段数据的映射
|
||||||
|
"""
|
||||||
|
if not ts_codes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
field_list = [str(item) for item in fields if item]
|
||||||
|
if not field_list:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 检查是否需要自动补数
|
||||||
|
if auto_refresh:
|
||||||
|
self._refresh.ensure_for_latest(trade_date, field_list)
|
||||||
|
|
||||||
|
# 按表分组字段
|
||||||
|
field_groups = {}
|
||||||
|
for field_name in field_list:
|
||||||
|
resolved = self.resolve_field(field_name)
|
||||||
|
if not resolved:
|
||||||
|
continue
|
||||||
|
table, column = resolved
|
||||||
|
field_groups.setdefault(table, set()).add(column)
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
|
|
||||||
|
# 对每个表进行批量查询
|
||||||
|
for table, columns in field_groups.items():
|
||||||
|
if not ts_codes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建批量查询SQL
|
||||||
|
placeholders = ','.join(['?'] * len(ts_codes))
|
||||||
|
columns_str = ', '.join(['ts_code', 'trade_date'] + list(columns))
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT {columns_str}
|
||||||
|
FROM (
|
||||||
|
SELECT {columns_str},
|
||||||
|
ROW_NUMBER() OVER (PARTITION BY ts_code ORDER BY trade_date DESC) as rn
|
||||||
|
FROM {table}
|
||||||
|
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
|
||||||
|
) WHERE rn = 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, (*ts_codes, trade_date)).fetchall()
|
||||||
|
for row in rows:
|
||||||
|
ts_code = row['ts_code']
|
||||||
|
batch_data.setdefault(ts_code, {})
|
||||||
|
|
||||||
|
for column in columns:
|
||||||
|
field_name = f"{table}.{column}"
|
||||||
|
try:
|
||||||
|
batch_data[ts_code][field_name] = float(row[column])
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
batch_data[ts_code][field_name] = row[column]
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(
|
||||||
|
"批次化字段查询失败 table=%s err=%s",
|
||||||
|
table, str(e),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
# 失败时回退到单条查询
|
||||||
|
for ts_code in ts_codes:
|
||||||
|
try:
|
||||||
|
latest_fields = self.fetch_latest(ts_code, trade_date, [f"{table}.{col}" for col in columns])
|
||||||
|
batch_data.setdefault(ts_code, {}).update(latest_fields)
|
||||||
|
except Exception as inner_e:
|
||||||
|
LOGGER.debug(
|
||||||
|
"单条字段查询失败 ts_code=%s table=%s err=%s",
|
||||||
|
ts_code, table, str(inner_e),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
def check_batch_data_sufficiency(
|
||||||
|
self,
|
||||||
|
ts_codes: List[str],
|
||||||
|
trade_date: str,
|
||||||
|
min_data_count: int = 60,
|
||||||
|
) -> Set[str]:
|
||||||
|
"""批次化检查多个证券的数据充分性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_codes: 证券代码列表
|
||||||
|
trade_date: 交易日
|
||||||
|
min_data_count: 最小数据条数要求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据充分的证券代码集合
|
||||||
|
"""
|
||||||
|
if not ts_codes:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
sufficient_codes = set()
|
||||||
|
|
||||||
|
# 使用IN查询批量检查数据充分性
|
||||||
|
placeholders = ','.join(['?'] * len(ts_codes))
|
||||||
|
query = f"""
|
||||||
|
SELECT ts_code, COUNT(*) as data_count
|
||||||
|
FROM daily
|
||||||
|
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
|
||||||
|
GROUP BY ts_code
|
||||||
|
HAVING COUNT(*) >= ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, (*ts_codes, trade_date, min_data_count)).fetchall()
|
||||||
|
for row in rows:
|
||||||
|
ts_code = row['ts_code']
|
||||||
|
sufficient_codes.add(ts_code)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(
|
||||||
|
"批次化数据充分性检查失败 err=%s",
|
||||||
|
str(e),
|
||||||
|
extra=LOG_EXTRA
|
||||||
|
)
|
||||||
|
# 失败时回退到单条检查
|
||||||
|
for ts_code in ts_codes:
|
||||||
|
if check_data_sufficiency(ts_code, trade_date):
|
||||||
|
sufficient_codes.add(ts_code)
|
||||||
|
|
||||||
|
return sufficient_codes
|
||||||
|
|
||||||
def register_refresh_callback(
|
def register_refresh_callback(
|
||||||
self,
|
self,
|
||||||
start: date | str,
|
start: date | str,
|
||||||
@ -422,6 +561,85 @@ class DataBroker:
|
|||||||
if callback not in bucket:
|
if callback not in bucket:
|
||||||
bucket.append(callback)
|
bucket.append(callback)
|
||||||
|
|
||||||
|
def get_news_data(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
limit: int = 30
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""获取新闻数据(简化实现)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: 股票代码
|
||||||
|
trade_date: 交易日期
|
||||||
|
limit: 返回的新闻条数限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新闻数据列表,包含sentiment、heat、entities等字段
|
||||||
|
"""
|
||||||
|
# 简化实现:返回模拟数据
|
||||||
|
# 在实际应用中,这里应该查询新闻数据库
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"sentiment": np.random.uniform(-1, 1),
|
||||||
|
"heat": np.random.uniform(0, 1),
|
||||||
|
"entities": "股票,市场,投资"
|
||||||
|
}
|
||||||
|
for _ in range(min(limit, 5))
|
||||||
|
]
|
||||||
|
|
||||||
|
def _lookup_industry(self, ts_code: str) -> Optional[str]:
|
||||||
|
"""查找股票所属行业
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: 股票代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
行业代码或名称,找不到时返回None
|
||||||
|
"""
|
||||||
|
# 简化实现:返回模拟行业
|
||||||
|
# 在实际应用中,这里应该查询股票行业信息
|
||||||
|
industry_mapping = {
|
||||||
|
"000001.SZ": "银行",
|
||||||
|
"000002.SZ": "房地产",
|
||||||
|
"000858.SZ": "食品饮料",
|
||||||
|
"000962.SZ": "医药生物",
|
||||||
|
}
|
||||||
|
return industry_mapping.get(ts_code, "其他")
|
||||||
|
|
||||||
|
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
|
||||||
|
"""计算行业情绪得分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
industry: 行业代码或名称
|
||||||
|
trade_date: 交易日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
行业情绪得分,找不到时返回None
|
||||||
|
"""
|
||||||
|
# 简化实现:返回模拟情绪得分
|
||||||
|
# 在实际应用中,这里应该基于行业新闻计算情绪
|
||||||
|
return np.random.uniform(-1, 1)
|
||||||
|
|
||||||
|
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||||
|
"""获取同行业股票列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
industry: 行业代码或名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
同行业股票代码列表
|
||||||
|
"""
|
||||||
|
# 简化实现:返回模拟股票列表
|
||||||
|
# 在实际应用中,这里应该查询行业股票列表
|
||||||
|
industry_stocks = {
|
||||||
|
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
||||||
|
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
||||||
|
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
|
||||||
|
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
|
||||||
|
}
|
||||||
|
return industry_stocks.get(industry, [])
|
||||||
|
|
||||||
def fetch_flags(
|
def fetch_flags(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user