enhance validation ranges and improve news processing with GDELT support

This commit is contained in:
sam 2025-10-19 21:08:07 +08:00
parent ead4d0d28e
commit 2779d21d97
5 changed files with 127 additions and 67 deletions

View File

@ -24,8 +24,8 @@ FACTOR_LIMITS = {
# 估值评分类因子:标准化评分,限制在 -3到3 (Z-score标准化范围) # 估值评分类因子:标准化评分,限制在 -3到3 (Z-score标准化范围)
"val_": (-3.0, 3.0), "val_": (-3.0, 3.0),
# 量价类因子:成交量比率,限制在 0-10倍 # 量价类因子:成交量比率,允许 -1 到 10考虑相关性类因子
"volume_": (0, 10.0), "volume_": (-1.0, 10.0),
"volume_ratio": (0, 10.0), "volume_ratio": (0, 10.0),
# 市场状态类因子:标准化状态指标,限制在 -3到3 # 市场状态类因子:标准化状态指标,限制在 -3到3
@ -52,8 +52,8 @@ FACTOR_LIMITS = {
# 风险类因子:风险惩罚因子 # 风险类因子:风险惩罚因子
"risk_": (0, 1.0), "risk_": (0, 1.0),
# 价格比率类因子:价格与均线比率,限制在 0.5-2.0 (50%-200%) # 价格比率类因子:价格与均线比率,扩大到 0.2-5.0 以涵盖极端波动
"price_ma_": (0.5, 2.0), "price_ma_": (0.2, 5.0),
# 成交量比率类因子:成交量与均线比率,限制在 0.1-10.0 # 成交量比率类因子:成交量与均线比率,限制在 0.1-10.0
"volume_ma_": (0.1, 10.0), "volume_ma_": (0.1, 10.0),
@ -97,7 +97,7 @@ def validate_factor_value(
exact_matches = { exact_matches = {
# 技术指标精确范围 # 技术指标精确范围
"tech_rsi_14": (0, 100.0), # RSI指标范围 0-100 "tech_rsi_14": (0, 100.0), # RSI指标范围 0-100
"tech_macd_signal": (-5, 5), # MACD信号范围 "tech_macd_signal": (-20, 20), # MACD信号范围扩大以适配极端行情
"tech_bb_position": (-3.0, 3.0), # 布林带位置,标准差倍数 "tech_bb_position": (-3.0, 3.0), # 布林带位置,标准差倍数
"tech_obv_momentum": (-10.0, 10.0), # OBV动量标准化 "tech_obv_momentum": (-10.0, 10.0), # OBV动量标准化
"tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性 "tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性
@ -108,8 +108,8 @@ def validate_factor_value(
"trend_price_channel": (-1.0, 1.0), # 价格通道位置 "trend_price_channel": (-1.0, 1.0), # 价格通道位置
# 波动率指标精确范围 # 波动率指标精确范围
"vol_garch": (0, 50), # GARCH波动率预测限制在50%以内 "vol_garch": (0, 400), # GARCH波动率预测限制在更宽范围
"vol_range_pred": (0, 20), # 波动率范围预测限制在20%以内 "vol_range_pred": (0, 100), # 波动率范围预测
"vol_regime": (0, 1.0), # 波动率状态0-1之间 "vol_regime": (0, 1.0), # 波动率状态0-1之间
# 微观结构精确范围 # 微观结构精确范围
@ -120,6 +120,7 @@ def validate_factor_value(
"sent_impact": (0, 1.0), # 情绪影响度 "sent_impact": (0, 1.0), # 情绪影响度
"sent_divergence": (-1.0, 1.0), # 情绪分歧度 "sent_divergence": (-1.0, 1.0), # 情绪分歧度
"volume_price_diverge": (-1.0, 1.0), # 量价背离度 "volume_price_diverge": (-1.0, 1.0), # 量价背离度
"volume_price_corr": (-1.0, 1.0), # 量价相关性
} }
# 检查精确匹配 # 检查精确匹配
@ -127,26 +128,26 @@ def validate_factor_value(
min_val, max_val = exact_matches[name] min_val, max_val = exact_matches[name]
if min_val <= value <= max_val: if min_val <= value <= max_val:
return value return value
else: clipped = max(min(value, max_val), min_val)
LOGGER.warning( LOGGER.warning(
"因子值超出精确范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s", "因子值超出精确范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
name, value, min_val, max_val, ts_code, trade_date, name, value, min_val, max_val, ts_code, trade_date, clipped,
extra=LOG_EXTRA extra=LOG_EXTRA,
) )
return None return clipped
# 检查前缀模式匹配 # 检查前缀模式匹配
for prefix, (min_val, max_val) in FACTOR_LIMITS.items(): for prefix, (min_val, max_val) in FACTOR_LIMITS.items():
if name.startswith(prefix): if name.startswith(prefix):
if min_val <= value <= max_val: if min_val <= value <= max_val:
return value return value
else: clipped = max(min(value, max_val), min_val)
LOGGER.warning( LOGGER.warning(
"因子值超出前缀范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s", "因子值超出前缀范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
name, value, min_val, max_val, ts_code, trade_date, name, value, min_val, max_val, ts_code, trade_date, clipped,
extra=LOG_EXTRA extra=LOG_EXTRA,
) )
return None return clipped
# 如果没有匹配,使用更严格的默认范围 # 如果没有匹配,使用更严格的默认范围
default_min, default_max = -5.0, 5.0 default_min, default_max = -5.0, 5.0
@ -157,13 +158,14 @@ def validate_factor_value(
extra=LOG_EXTRA extra=LOG_EXTRA
) )
return value return value
else:
LOGGER.warning( clipped = max(min(value, default_max), default_min)
"因子值超出默认范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s", LOGGER.warning(
name, value, default_min, default_max, ts_code, trade_date, "因子值超出默认范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
extra=LOG_EXTRA name, value, default_min, default_max, ts_code, trade_date, clipped,
) extra=LOG_EXTRA,
return None )
return clipped
def detect_outliers( def detect_outliers(
values: Dict[str, float], values: Dict[str, float],

View File

@ -209,6 +209,17 @@ def ensure_data_coverage(
progress_hook(message, progress) progress_hook(message, progress)
LOGGER.info(message, extra=LOG_EXTRA) LOGGER.info(message, extra=LOG_EXTRA)
if news_enabled:
advance("拉取 GDELT 新闻数据")
try:
ingest_configured_gdelt(
start=start,
end=end,
incremental=not force,
)
except Exception as exc: # noqa: BLE001
LOGGER.warning("GDELT 新闻拉取失败:%s", exc, extra=LOG_EXTRA)
advance("准备股票基础信息与交易日历") advance("准备股票基础信息与交易日历")
ensure_stock_basic() ensure_stock_basic()
ensure_trade_calendar(start, end) ensure_trade_calendar(start, end)
@ -363,17 +374,6 @@ def ensure_data_coverage(
_save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES) _save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES)
_save_with_codes("us_daily", fetch_us_daily, targets=US_CODES) _save_with_codes("us_daily", fetch_us_daily, targets=US_CODES)
if news_enabled:
advance("拉取 GDELT 新闻数据")
try:
ingest_configured_gdelt(
start=start,
end=end,
incremental=not force,
)
except Exception as exc: # noqa: BLE001
LOGGER.warning("GDELT 新闻拉取失败:%s", exc, extra=LOG_EXTRA)
if progress_hook: if progress_hook:
progress_hook("数据覆盖检查完成", 1.0) progress_hook("数据覆盖检查完成", 1.0)

View File

@ -214,7 +214,7 @@ def _build_rss_item(record: Dict[str, object], config: GdeltSourceConfig) -> Opt
source = config.label or "GDELT" source = config.label or "GDELT"
source = source.strip() source = source.strip()
fingerprint = f"{url}|{published.isoformat()}" fingerprint = f"{url}|{published.isoformat()}|{config.key}"
article_id = hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest() article_id = hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest()
return rss_ingest.RssItem( return rss_ingest.RssItem(
@ -227,6 +227,7 @@ def _build_rss_item(record: Dict[str, object], config: GdeltSourceConfig) -> Opt
metadata={ metadata={
"source_key": config.key, "source_key": config.key,
"source_label": config.label, "source_label": config.label,
"source_type": "gdelt",
}, },
) )
@ -418,7 +419,10 @@ def ingest_configured_gdelt(
fetched = 0 fetched = 0
for config in sources: for config in sources:
source_start = start_dt source_start = start_dt
if incremental: effective_incremental = incremental
if start_dt is not None or end_dt is not None:
effective_incremental = False
elif incremental:
last_seen = _load_last_published(config.key) last_seen = _load_last_published(config.key)
if last_seen: if last_seen:
candidate = last_seen + timedelta(seconds=1) candidate = last_seen + timedelta(seconds=1)
@ -429,10 +433,21 @@ def ingest_configured_gdelt(
config.label, config.label,
source_start.isoformat() if source_start else None, source_start.isoformat() if source_start else None,
end_dt.isoformat() if end_dt else None, end_dt.isoformat() if end_dt else None,
incremental, effective_incremental,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
items = fetch_gdelt_articles(config, start=source_start, end=end_dt)
items: List[rss_ingest.RssItem] = []
if source_start and end_dt and source_start <= end_dt:
chunk_start = source_start
while chunk_start <= end_dt:
chunk_end = min(chunk_start + timedelta(days=1) - timedelta(seconds=1), end_dt)
chunk_items = fetch_gdelt_articles(config, start=chunk_start, end=chunk_end)
if chunk_items:
items.extend(chunk_items)
chunk_start = chunk_end + timedelta(seconds=1)
else:
items = fetch_gdelt_articles(config, start=source_start, end=end_dt)
if not items: if not items:
continue continue
aggregated.extend(items) aggregated.extend(items)

View File

@ -7,7 +7,7 @@ import sqlite3
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple
from urllib.parse import urlparse, urljoin from urllib.parse import urlparse, urljoin
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
@ -472,31 +472,62 @@ def _fetch_feed_items(
return items return items
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]: def _canonical_link(item: RssItem) -> str:
"""Drop duplicate stories by link/id fingerprint and process entities.""" link = (item.link or "").strip().lower()
if link:
return link
if item.id:
return item.id
fingerprint = f"{item.title}|{item.published.isoformat() if item.published else ''}"
return hashlib.sha1(fingerprint.encode("utf-8")).hexdigest()
def _is_gdelt_item(item: RssItem) -> bool:
metadata = item.metadata or {}
return metadata.get("source_type") == "gdelt" or bool(metadata.get("source_key"))
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
"""Drop duplicate stories by canonical link while preferring GDELT sources."""
selected: Dict[str, RssItem] = {}
order: List[str] = []
seen = set()
unique: List[RssItem] = []
for item in items: for item in items:
key = item.id or item.link
if key in seen:
continue
seen.add(key)
preassigned_codes = list(item.ts_codes or []) preassigned_codes = list(item.ts_codes or [])
# 提取实体和相关信息 # 提取实体和相关信息
item.extract_entities() item.extract_entities()
# 如果找到了相关股票,则保留这条新闻
if item.stock_mentions:
unique.append(item)
continue
# 否则如果配置了预设股票代码,则保留这些代码 keep = False
if preassigned_codes: if _is_gdelt_item(item):
keep = True
elif item.stock_mentions:
keep = True
elif preassigned_codes:
if not item.ts_codes: if not item.ts_codes:
item.ts_codes = preassigned_codes item.ts_codes = preassigned_codes
unique.append(item) keep = True
return unique
if not keep:
continue
key = _canonical_link(item)
existing = selected.get(key)
if existing is None:
selected[key] = item
order.append(key)
continue
if _is_gdelt_item(item) and not _is_gdelt_item(existing):
selected[key] = item
elif _is_gdelt_item(item) == _is_gdelt_item(existing):
if item.published and existing.published:
if item.published > existing.published:
selected[key] = item
else:
selected[key] = item
return [selected[key] for key in order if key in selected]
def save_news_items(items: Iterable[RssItem]) -> int: def save_news_items(items: Iterable[RssItem]) -> int:
@ -507,6 +538,7 @@ def save_news_items(items: Iterable[RssItem]) -> int:
rows: List[Tuple[object, ...]] = [] rows: List[Tuple[object, ...]] = []
processed = 0 processed = 0
gdelt_urls: Set[str] = set()
for item in items: for item in items:
text_payload = f"{item.title}\n{item.summary}" text_payload = f"{item.title}\n{item.summary}"
sentiment = _estimate_sentiment(text_payload) sentiment = _estimate_sentiment(text_payload)
@ -530,6 +562,8 @@ def save_news_items(items: Iterable[RssItem]) -> int:
} }
if item.metadata: if item.metadata:
entity_payload["metadata"] = dict(item.metadata) entity_payload["metadata"] = dict(item.metadata)
if _is_gdelt_item(item) and item.link:
gdelt_urls.add(item.link.strip())
entities = json.dumps(entity_payload, ensure_ascii=False) entities = json.dumps(entity_payload, ensure_ascii=False)
resolved_codes = base_codes or (None,) resolved_codes = base_codes or (None,)
for ts_code in resolved_codes: for ts_code in resolved_codes:
@ -556,9 +590,19 @@ def save_news_items(items: Iterable[RssItem]) -> int:
inserted = 0 inserted = 0
try: try:
with db_session() as conn: with db_session() as conn:
if gdelt_urls:
conn.executemany(
"""
DELETE FROM news
WHERE url = ?
AND (json_extract(entities, '$.metadata.source_type') IS NULL
OR json_extract(entities, '$.metadata.source_type') != 'gdelt')
""",
[(url,) for url in gdelt_urls],
)
conn.executemany( conn.executemany(
""" """
INSERT OR IGNORE INTO news INSERT OR REPLACE INTO news
(id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat) (id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from datetime import date from datetime import date
from typing import Callable, Iterable, List, Optional, Sequence from typing import Callable, Iterable, List, Optional, Sequence
from app.features.factors import compute_factor_range from app.features.factors import compute_factors_incremental
from app.utils import alerts from app.utils import alerts
from app.utils.logging import get_logger from app.utils.logging import get_logger
@ -44,11 +44,10 @@ def _default_post_tasks(job: FetchJob) -> List[PostTask]:
def _run_factor_backfill(job: FetchJob) -> None: def _run_factor_backfill(job: FetchJob) -> None:
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA) LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
compute_factor_range( compute_factors_incremental(
job.start,
job.end,
ts_codes=job.ts_codes, ts_codes=job.ts_codes,
skip_existing=False, skip_existing=True,
persist=True,
) )
alerts.clear_warnings("Factors") alerts.clear_warnings("Factors")