enhance validation ranges and improve news processing with GDELT support

This commit is contained in:
sam 2025-10-19 21:08:07 +08:00
parent ead4d0d28e
commit 2779d21d97
5 changed files with 127 additions and 67 deletions

View File

@ -24,8 +24,8 @@ FACTOR_LIMITS = {
# 估值评分类因子:标准化评分,限制在 -3到3 (Z-score标准化范围)
"val_": (-3.0, 3.0),
# 量价类因子:成交量比率,限制在 0-10倍
"volume_": (0, 10.0),
# 量价类因子:成交量比率,允许 -1 到 10考虑相关性类因子
"volume_": (-1.0, 10.0),
"volume_ratio": (0, 10.0),
# 市场状态类因子:标准化状态指标,限制在 -3到3
@ -52,8 +52,8 @@ FACTOR_LIMITS = {
# 风险类因子:风险惩罚因子
"risk_": (0, 1.0),
# 价格比率类因子:价格与均线比率,限制在 0.5-2.0 (50%-200%)
"price_ma_": (0.5, 2.0),
# 价格比率类因子:价格与均线比率,扩大到 0.2-5.0 以涵盖极端波动
"price_ma_": (0.2, 5.0),
# 成交量比率类因子:成交量与均线比率,限制在 0.1-10.0
"volume_ma_": (0.1, 10.0),
@ -97,7 +97,7 @@ def validate_factor_value(
exact_matches = {
# 技术指标精确范围
"tech_rsi_14": (0, 100.0), # RSI指标范围 0-100
"tech_macd_signal": (-5, 5), # MACD信号范围
"tech_macd_signal": (-20, 20), # MACD信号范围扩大以适配极端行情
"tech_bb_position": (-3.0, 3.0), # 布林带位置,标准差倍数
"tech_obv_momentum": (-10.0, 10.0), # OBV动量标准化
"tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性
@ -108,8 +108,8 @@ def validate_factor_value(
"trend_price_channel": (-1.0, 1.0), # 价格通道位置
# 波动率指标精确范围
"vol_garch": (0, 50), # GARCH波动率预测限制在50%以内
"vol_range_pred": (0, 20), # 波动率范围预测限制在20%以内
"vol_garch": (0, 400), # GARCH波动率预测限制在更宽范围
"vol_range_pred": (0, 100), # 波动率范围预测
"vol_regime": (0, 1.0), # 波动率状态0-1之间
# 微观结构精确范围
@ -120,6 +120,7 @@ def validate_factor_value(
"sent_impact": (0, 1.0), # 情绪影响度
"sent_divergence": (-1.0, 1.0), # 情绪分歧度
"volume_price_diverge": (-1.0, 1.0), # 量价背离度
"volume_price_corr": (-1.0, 1.0), # 量价相关性
}
# 检查精确匹配
@ -127,26 +128,26 @@ def validate_factor_value(
min_val, max_val = exact_matches[name]
if min_val <= value <= max_val:
return value
else:
LOGGER.warning(
"因子值超出精确范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s",
name, value, min_val, max_val, ts_code, trade_date,
extra=LOG_EXTRA
)
return None
clipped = max(min(value, max_val), min_val)
LOGGER.warning(
"因子值超出精确范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
name, value, min_val, max_val, ts_code, trade_date, clipped,
extra=LOG_EXTRA,
)
return clipped
# 检查前缀模式匹配
for prefix, (min_val, max_val) in FACTOR_LIMITS.items():
if name.startswith(prefix):
if min_val <= value <= max_val:
return value
else:
LOGGER.warning(
"因子值超出前缀范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s",
name, value, min_val, max_val, ts_code, trade_date,
extra=LOG_EXTRA
)
return None
clipped = max(min(value, max_val), min_val)
LOGGER.warning(
"因子值超出前缀范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
name, value, min_val, max_val, ts_code, trade_date, clipped,
extra=LOG_EXTRA,
)
return clipped
# 如果没有匹配,使用更严格的默认范围
default_min, default_max = -5.0, 5.0
@ -157,13 +158,14 @@ def validate_factor_value(
extra=LOG_EXTRA
)
return value
else:
LOGGER.warning(
"因子值超出默认范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s",
name, value, default_min, default_max, ts_code, trade_date,
extra=LOG_EXTRA
)
return None
clipped = max(min(value, default_max), default_min)
LOGGER.warning(
"因子值超出默认范围 factor=%s value=%f range=[%f,%f] ts_code=%s date=%s -> clipped=%f",
name, value, default_min, default_max, ts_code, trade_date, clipped,
extra=LOG_EXTRA,
)
return clipped
def detect_outliers(
values: Dict[str, float],

View File

@ -209,6 +209,17 @@ def ensure_data_coverage(
progress_hook(message, progress)
LOGGER.info(message, extra=LOG_EXTRA)
if news_enabled:
advance("拉取 GDELT 新闻数据")
try:
ingest_configured_gdelt(
start=start,
end=end,
incremental=not force,
)
except Exception as exc: # noqa: BLE001
LOGGER.warning("GDELT 新闻拉取失败:%s", exc, extra=LOG_EXTRA)
advance("准备股票基础信息与交易日历")
ensure_stock_basic()
ensure_trade_calendar(start, end)
@ -363,17 +374,6 @@ def ensure_data_coverage(
_save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES)
_save_with_codes("us_daily", fetch_us_daily, targets=US_CODES)
if news_enabled:
advance("拉取 GDELT 新闻数据")
try:
ingest_configured_gdelt(
start=start,
end=end,
incremental=not force,
)
except Exception as exc: # noqa: BLE001
LOGGER.warning("GDELT 新闻拉取失败:%s", exc, extra=LOG_EXTRA)
if progress_hook:
progress_hook("数据覆盖检查完成", 1.0)

View File

@ -214,7 +214,7 @@ def _build_rss_item(record: Dict[str, object], config: GdeltSourceConfig) -> Opt
source = config.label or "GDELT"
source = source.strip()
fingerprint = f"{url}|{published.isoformat()}"
fingerprint = f"{url}|{published.isoformat()}|{config.key}"
article_id = hashlib.blake2s(fingerprint.encode("utf-8"), digest_size=16).hexdigest()
return rss_ingest.RssItem(
@ -227,6 +227,7 @@ def _build_rss_item(record: Dict[str, object], config: GdeltSourceConfig) -> Opt
metadata={
"source_key": config.key,
"source_label": config.label,
"source_type": "gdelt",
},
)
@ -418,7 +419,10 @@ def ingest_configured_gdelt(
fetched = 0
for config in sources:
source_start = start_dt
if incremental:
effective_incremental = incremental
if start_dt is not None or end_dt is not None:
effective_incremental = False
elif incremental:
last_seen = _load_last_published(config.key)
if last_seen:
candidate = last_seen + timedelta(seconds=1)
@ -429,10 +433,21 @@ def ingest_configured_gdelt(
config.label,
source_start.isoformat() if source_start else None,
end_dt.isoformat() if end_dt else None,
incremental,
effective_incremental,
extra=LOG_EXTRA,
)
items = fetch_gdelt_articles(config, start=source_start, end=end_dt)
items: List[rss_ingest.RssItem] = []
if source_start and end_dt and source_start <= end_dt:
chunk_start = source_start
while chunk_start <= end_dt:
chunk_end = min(chunk_start + timedelta(days=1) - timedelta(seconds=1), end_dt)
chunk_items = fetch_gdelt_articles(config, start=chunk_start, end=chunk_end)
if chunk_items:
items.extend(chunk_items)
chunk_start = chunk_end + timedelta(seconds=1)
else:
items = fetch_gdelt_articles(config, start=source_start, end=end_dt)
if not items:
continue
aggregated.extend(items)

View File

@ -7,7 +7,7 @@ import sqlite3
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta, timezone
from email.utils import parsedate_to_datetime
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple
from urllib.parse import urlparse, urljoin
from xml.etree import ElementTree as ET
@ -472,31 +472,62 @@ def _fetch_feed_items(
return items
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
"""Drop duplicate stories by link/id fingerprint and process entities."""
def _canonical_link(item: RssItem) -> str:
link = (item.link or "").strip().lower()
if link:
return link
if item.id:
return item.id
fingerprint = f"{item.title}|{item.published.isoformat() if item.published else ''}"
return hashlib.sha1(fingerprint.encode("utf-8")).hexdigest()
def _is_gdelt_item(item: RssItem) -> bool:
metadata = item.metadata or {}
return metadata.get("source_type") == "gdelt" or bool(metadata.get("source_key"))
def deduplicate_items(items: Iterable[RssItem]) -> List[RssItem]:
"""Drop duplicate stories by canonical link while preferring GDELT sources."""
selected: Dict[str, RssItem] = {}
order: List[str] = []
seen = set()
unique: List[RssItem] = []
for item in items:
key = item.id or item.link
if key in seen:
continue
seen.add(key)
preassigned_codes = list(item.ts_codes or [])
# 提取实体和相关信息
item.extract_entities()
# 如果找到了相关股票,则保留这条新闻
if item.stock_mentions:
unique.append(item)
continue
# 否则如果配置了预设股票代码,则保留这些代码
if preassigned_codes:
keep = False
if _is_gdelt_item(item):
keep = True
elif item.stock_mentions:
keep = True
elif preassigned_codes:
if not item.ts_codes:
item.ts_codes = preassigned_codes
unique.append(item)
return unique
keep = True
if not keep:
continue
key = _canonical_link(item)
existing = selected.get(key)
if existing is None:
selected[key] = item
order.append(key)
continue
if _is_gdelt_item(item) and not _is_gdelt_item(existing):
selected[key] = item
elif _is_gdelt_item(item) == _is_gdelt_item(existing):
if item.published and existing.published:
if item.published > existing.published:
selected[key] = item
else:
selected[key] = item
return [selected[key] for key in order if key in selected]
def save_news_items(items: Iterable[RssItem]) -> int:
@ -507,6 +538,7 @@ def save_news_items(items: Iterable[RssItem]) -> int:
rows: List[Tuple[object, ...]] = []
processed = 0
gdelt_urls: Set[str] = set()
for item in items:
text_payload = f"{item.title}\n{item.summary}"
sentiment = _estimate_sentiment(text_payload)
@ -530,6 +562,8 @@ def save_news_items(items: Iterable[RssItem]) -> int:
}
if item.metadata:
entity_payload["metadata"] = dict(item.metadata)
if _is_gdelt_item(item) and item.link:
gdelt_urls.add(item.link.strip())
entities = json.dumps(entity_payload, ensure_ascii=False)
resolved_codes = base_codes or (None,)
for ts_code in resolved_codes:
@ -556,9 +590,19 @@ def save_news_items(items: Iterable[RssItem]) -> int:
inserted = 0
try:
with db_session() as conn:
if gdelt_urls:
conn.executemany(
"""
DELETE FROM news
WHERE url = ?
AND (json_extract(entities, '$.metadata.source_type') IS NULL
OR json_extract(entities, '$.metadata.source_type') != 'gdelt')
""",
[(url,) for url in gdelt_urls],
)
conn.executemany(
"""
INSERT OR IGNORE INTO news
INSERT OR REPLACE INTO news
(id, ts_code, pub_time, source, title, summary, url, entities, sentiment, heat)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from datetime import date
from typing import Callable, Iterable, List, Optional, Sequence
from app.features.factors import compute_factor_range
from app.features.factors import compute_factors_incremental
from app.utils import alerts
from app.utils.logging import get_logger
@ -44,11 +44,10 @@ def _default_post_tasks(job: FetchJob) -> List[PostTask]:
def _run_factor_backfill(job: FetchJob) -> None:
LOGGER.info("开始计算因子:%s", job.name, extra=LOG_EXTRA)
compute_factor_range(
job.start,
job.end,
compute_factors_incremental(
ts_codes=job.ts_codes,
skip_existing=False,
skip_existing=True,
persist=True,
)
alerts.clear_warnings("Factors")