diff --git a/app/features/factors.py b/app/features/factors.py index 426d473..ad0701e 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -17,6 +17,7 @@ from app.utils.logging import get_logger from app.features.extended_factors import ExtendedFactors from app.features.sentiment_factors import SentimentFactors from app.features.value_risk_factors import ValueRiskFactors +from app.ingest.news import prepare_news_for_factors # 导入因子验证功能 from app.features.validation import check_data_sufficiency, check_data_sufficiency_for_zero_window, detect_outliers # 导入UI进度状态管理 @@ -161,6 +162,10 @@ def compute_factors( LOGGER.info("无可用标的生成因子 trade_date=%s", trade_date_str, extra=LOG_EXTRA) return [] + if any(spec.name.startswith("sent_") for spec in specs): + # 情绪因子需要依赖最新的新闻情绪/热度评分,先确保新闻数据落库 + prepare_news_for_factors(trade_date, lookback_days=7) + if skip_existing: # 检查所有因子名称 factor_names = [spec.name for spec in specs] diff --git a/app/ingest/api_client.py b/app/ingest/api_client.py index 6b90d73..a00eb76 100644 --- a/app/ingest/api_client.py +++ b/app/ingest/api_client.py @@ -5,7 +5,7 @@ import os import sqlite3 import time from collections import defaultdict, deque -from datetime import date +from datetime import date, datetime, timedelta from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple import pandas as pd @@ -266,6 +266,112 @@ def _record_exists( return row is not None +def _parse_date_str(value: str) -> Optional[date]: + try: + return datetime.strptime(value, "%Y%m%d").date() + except (TypeError, ValueError): + return None + + +def _increment_date_str(value: str) -> str: + parsed = _parse_date_str(value) + if parsed is None: + return value + return _format_date(parsed + timedelta(days=1)) + + +def _infer_exchange(ts_code: Optional[str]) -> str: + if not ts_code: + return "SSE" + if ts_code.endswith(".SZ"): + return "SZSE" + if ts_code.endswith(".BJ"): + return "BSE" + return "SSE" + + +def _trading_dates_for_range(start_str: str, end_str: str, ts_code: Optional[str] = None) -> List[str]: + start_dt = _parse_date_str(start_str) + end_dt = _parse_date_str(end_str) + if start_dt is None or end_dt is None or start_dt > end_dt: + return [] + exchange = _infer_exchange(ts_code) + trade_dates = _load_trade_dates(start_dt, end_dt, exchange=exchange) + if not trade_dates and exchange != "SSE": + trade_dates = _load_trade_dates(start_dt, end_dt, exchange="SSE") + return trade_dates + + +def _distinct_dates( + table: str, + date_col: str, + start_str: str, + end_str: str, + ts_code: Optional[str] = None, +) -> Set[str]: + sql = f"SELECT DISTINCT {date_col} AS trade_date FROM {table} WHERE {date_col} BETWEEN ? AND ?" + params: List[object] = [start_str, end_str] + if ts_code: + sql += " AND ts_code = ?" + params.append(ts_code) + try: + with db_session(read_only=True) as conn: + rows = conn.execute(sql, tuple(params)).fetchall() + except sqlite3.OperationalError: + return set() + return {str(row["trade_date"]) for row in rows if row["trade_date"]} + + +def _first_missing_date( + table: str, + date_col: str, + start_str: str, + end_str: str, + ts_code: Optional[str] = None, +) -> Optional[str]: + trade_calendar = _trading_dates_for_range(start_str, end_str, ts_code=ts_code) + if not trade_calendar: + return None + existing = _distinct_dates(table, date_col, start_str, end_str, ts_code=ts_code) + for trade_date in trade_calendar: + if trade_date not in existing: + return trade_date + return None + + +def _latest_existing_date( + table: str, + date_col: str, + *, + start_str: Optional[str] = None, + end_str: Optional[str] = None, + ts_code: Optional[str] = None, +) -> Optional[str]: + sql = f"SELECT MAX({date_col}) AS max_date FROM {table}" + clauses: List[str] = [] + params: List[object] = [] + if start_str: + clauses.append(f"{date_col} >= ?") + params.append(start_str) + if end_str: + clauses.append(f"{date_col} <= ?") + params.append(end_str) + if ts_code: + clauses.append("ts_code = ?") + params.append(ts_code) + if clauses: + sql += " WHERE " + " AND ".join(clauses) + try: + with db_session(read_only=True) as conn: + row = conn.execute(sql, tuple(params)).fetchone() + except sqlite3.OperationalError: + return None + if not row: + return None + value = row["max_date"] + return str(value) if value else None + + def _existing_suspend_dates(start_str: str, end_str: str, ts_code: str | None = None) -> Set[str]: sql = "SELECT DISTINCT suspend_date FROM suspend WHERE suspend_date BETWEEN ? AND ?" params: List[object] = [start_str, end_str] @@ -1089,6 +1195,21 @@ def fetch_suspensions( } if ts_code: params["ts_code"] = ts_code + if skip_existing: + resume_start = start_str + latest = _latest_existing_date( + "suspend", + "suspend_date", + start_str=start_str, + end_str=end_str, + ts_code=ts_code, + ) + if latest and latest >= start_str: + resume_start = _increment_date_str(latest) + if resume_start > end_str: + LOGGER.info("停复牌信息已覆盖 %s-%s,跳过", start_str, end_str, extra=LOG_EXTRA) + return [] + params["start_date"] = resume_start df = _fetch_paginated("suspend_d", params) if df.empty: return [] @@ -1195,6 +1316,25 @@ def fetch_stk_limit( params: Dict[str, object] = {"start_date": start_str, "end_date": end_str} if ts_code: params["ts_code"] = ts_code + if skip_existing: + resume_start = start_str + latest = _latest_existing_date( + "stk_limit", + "trade_date", + start_str=start_str, + end_str=end_str, + ts_code=ts_code, + ) + if latest and latest >= start_str: + missing = _first_missing_date("stk_limit", "trade_date", start_str, latest, ts_code=ts_code) + if missing: + resume_start = missing + else: + resume_start = _increment_date_str(latest) + if resume_start > end_str: + LOGGER.info("涨跌停数据已覆盖 %s-%s,跳过", start_str, end_str, extra=LOG_EXTRA) + return [] + params["start_date"] = resume_start df = _fetch_paginated("stk_limit", params, limit=4000) if df.empty: return [] diff --git a/app/ingest/gdelt.py b/app/ingest/gdelt.py index f0aa6a8..3a0c063 100644 --- a/app/ingest/gdelt.py +++ b/app/ingest/gdelt.py @@ -2,6 +2,7 @@ from __future__ import annotations import hashlib +import re import sqlite3 from dataclasses import dataclass, field, replace from datetime import date, datetime, timedelta, timezone @@ -23,6 +24,19 @@ LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "gdelt_ingest"} DateLike = Union[date, datetime] +_LANGUAGE_CANONICAL: Dict[str, str] = { + "en": "en", + "eng": "en", + "english": "en", + "zh": "zh", + "zho": "zh", + "zh-cn": "zh", + "zh-hans": "zh", + "zh-hant": "zh", + "zh_tw": "zh", + "chinese": "zh", +} + @dataclass class GdeltSourceConfig: @@ -227,28 +241,132 @@ def fetch_gdelt_articles( LOGGER.warning("未安装 gdeltdoc,跳过 GDELT 拉取", extra=LOG_EXTRA) return [] - filters_kwargs = dict(config.filters) - filters_kwargs.setdefault("num_records", config.num_records) - if start: + base_filters = dict(config.filters) + base_filters.setdefault("num_records", config.num_records) + original_timespan = base_filters.get("timespan") + filters_kwargs = dict(base_filters) + + def _strip_quotes(token: str) -> str: + stripped = token.strip() + if (stripped.startswith('"') and stripped.endswith('"')) or (stripped.startswith("'") and stripped.endswith("'")): + return stripped[1:-1].strip() + return stripped + + def _normalize_keywords(value: object) -> object: + if isinstance(value, str): + parts = [part.strip() for part in re.split(r"\s+OR\s+", value) if part.strip()] + if len(parts) <= 1: + return _strip_quotes(value) + normalized = [_strip_quotes(part) for part in parts] + return normalized + if isinstance(value, (list, tuple, set)): + normalized = [_strip_quotes(str(item)) for item in value if str(item).strip()] + return normalized + return value + + def _sanitize(filters: Dict[str, object]) -> Dict[str, object]: + cleaned = dict(filters) + def _normalise_sequence_field(field: str, mapping: Optional[Dict[str, str]] = None) -> None: + value = cleaned.get(field) + if isinstance(value, (list, tuple, set)): + items: List[str] = [] + for token in value: + if not token: + continue + token_str = str(token).strip() + if not token_str: + continue + mapped = mapping.get(token_str.lower(), token_str) if mapping else token_str + if mapped not in items: + items.append(mapped) + if not items: + cleaned.pop(field, None) + elif len(items) == 1: + cleaned[field] = items[0] + else: + cleaned[field] = items + elif isinstance(value, str): + stripped = value.strip() + if not stripped: + cleaned.pop(field, None) + elif mapping: + cleaned[field] = mapping.get(stripped.lower(), stripped) + else: + cleaned[field] = stripped + elif value is None: + cleaned.pop(field, None) + + _normalise_sequence_field("language", _LANGUAGE_CANONICAL) + _normalise_sequence_field("country") + _normalise_sequence_field("domain") + _normalise_sequence_field("domain_exact") + + keyword_value = cleaned.get("keyword") + if keyword_value is not None: + normalized_keyword = _normalize_keywords(keyword_value) + if isinstance(normalized_keyword, list): + if not normalized_keyword: + cleaned.pop("keyword", None) + elif len(normalized_keyword) == 1: + cleaned["keyword"] = normalized_keyword[0] + else: + cleaned["keyword"] = normalized_keyword + elif isinstance(normalized_keyword, str): + cleaned["keyword"] = normalized_keyword + else: + cleaned.pop("keyword", None) + + return cleaned + + if start or end: filters_kwargs.pop("timespan", None) + if start: filters_kwargs["start_date"] = start if end: - filters_kwargs.pop("timespan", None) filters_kwargs["end_date"] = end - try: - filter_obj = Filters(**filters_kwargs) - except Exception as exc: # noqa: BLE001 - guard misconfigured filters - LOGGER.error("GDELT 过滤器解析失败 key=%s err=%s", config.key, exc, extra=LOG_EXTRA) - return [] + filters_kwargs = _sanitize(filters_kwargs) client = GdeltDoc() - try: - df = client.article_search(filter_obj) - except Exception as exc: # noqa: BLE001 - network/service issues - LOGGER.warning("GDELT 请求失败 key=%s err=%s", config.key, exc, extra=LOG_EXTRA) - return [] + def _run_query(kwargs: Dict[str, object]) -> Optional[pd.DataFrame]: + try: + filter_obj = Filters(**kwargs) + except Exception as exc: # noqa: BLE001 + LOGGER.error("GDELT 过滤器解析失败 key=%s err=%s", config.key, exc, extra=LOG_EXTRA) + return None + try: + return client.article_search(filter_obj) + except Exception as exc: # noqa: BLE001 + message = str(exc) + if "Invalid/Unsupported Language" in message and kwargs.get("language"): + LOGGER.warning( + "GDELT 语言过滤不被支持,移除后重试 key=%s languages=%s", + config.key, + kwargs.get("language"), + extra=LOG_EXTRA, + ) + retry_kwargs = dict(kwargs) + retry_kwargs.pop("language", None) + return _run_query(retry_kwargs) + LOGGER.warning("GDELT 请求失败 key=%s err=%s", config.key, exc, extra=LOG_EXTRA) + return None + + df = _run_query(filters_kwargs) + if df is None or df.empty: + if (start or end) and original_timespan: + fallback_kwargs = dict(base_filters) + fallback_kwargs["timespan"] = original_timespan + fallback_kwargs.pop("start_date", None) + fallback_kwargs.pop("end_date", None) + fallback_kwargs = _sanitize(fallback_kwargs) + LOGGER.info( + "GDELT 无匹配结果,尝试使用 timespan 回退 key=%s timespan=%s", + config.key, + original_timespan, + extra=LOG_EXTRA, + ) + df = _run_query(fallback_kwargs) if df is None or df.empty: LOGGER.info("GDELT 无匹配结果 key=%s", config.key, extra=LOG_EXTRA) return [] diff --git a/app/ingest/news.py b/app/ingest/news.py new file mode 100644 index 0000000..532b62f --- /dev/null +++ b/app/ingest/news.py @@ -0,0 +1,119 @@ +"""Unified news ingestion orchestration with GDELT as the primary source.""" +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import Set, Tuple + +from app.data.schema import initialize_database +from app.utils.logging import get_logger + +from .gdelt import ingest_configured_gdelt + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "news_ingest"} + +_PREPARED_WINDOWS: Set[Tuple[str, int]] = set() + + +def _normalize_date(value: date | datetime) -> datetime: + if isinstance(value, datetime): + return value + return datetime.combine(value, datetime.min.time()) + + +def ingest_latest_news( + *, + days_back: int = 1, + force: bool = False, +) -> int: + """Fetch latest news primarily via GDELT within a day-level window.""" + + initialize_database() + now = datetime.utcnow() + days = max(days_back, 1) + start_day = (now.date() - timedelta(days=days - 1)) + start_dt = datetime.combine(start_day, datetime.min.time()) + end_dt = datetime.combine(now.date(), datetime.max.time()) + LOGGER.info( + "触发 GDELT 新闻拉取 days=%s force=%s", + days, + force, + extra=LOG_EXTRA, + ) + inserted = ingest_configured_gdelt( + start=start_dt, + end=end_dt, + incremental=not force, + ) + LOGGER.info("新闻拉取完成 inserted=%s", inserted, extra=LOG_EXTRA) + return inserted + + +def ensure_news_range( + start: date | datetime, + end: date | datetime, + *, + force: bool = False, +) -> int: + """Ensure the news store covers the requested window.""" + + initialize_database() + start_dt = _normalize_date(start) + end_dt = _normalize_date(end) + if start_dt > end_dt: + start_dt, end_dt = end_dt, start_dt + start_dt = datetime.combine(start_dt.date(), datetime.min.time()) + end_dt = datetime.combine(end_dt.date(), datetime.max.time()) + LOGGER.info( + "同步 GDELT 新闻数据 start=%s end=%s force=%s", + start_dt.isoformat(), + end_dt.isoformat(), + force, + extra=LOG_EXTRA, + ) + inserted = ingest_configured_gdelt( + start=start_dt, + end=end_dt, + incremental=not force, + ) + LOGGER.info( + "新闻窗口同步完成 inserted=%s start=%s end=%s", + inserted, + start_dt.isoformat(), + end_dt.isoformat(), + extra=LOG_EXTRA, + ) + return inserted + + +def prepare_news_for_factors( + trade_date: date, + *, + lookback_days: int = 3, + force: bool = False, +) -> int: + """Prepare news data before sentiment factor computation.""" + + key = (trade_date.strftime("%Y%m%d"), max(lookback_days, 1)) + if not force and key in _PREPARED_WINDOWS: + LOGGER.debug( + "新闻窗口已准备完成 trade_date=%s lookback=%s", + key[0], + key[1], + extra=LOG_EXTRA, + ) + return 0 + + end_date = trade_date + start_date = trade_date - timedelta(days=max(lookback_days - 1, 0)) + inserted = ensure_news_range(start_date, end_date, force=force) + if not force: + _PREPARED_WINDOWS.add(key) + return inserted + + +__all__ = [ + "ensure_news_range", + "ingest_latest_news", + "prepare_news_for_factors", +] diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index aaa6347..86dd9b8 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -13,7 +13,7 @@ if str(ROOT) not in sys.path: from app.data.schema import initialize_database from app.ingest.checker import run_boot_check -from app.ingest.rss import ingest_configured_rss +from app.ingest.news import ingest_latest_news from app.ui.portfolio_config import render_portfolio_config from app.ui.progress_state import render_factor_progress from app.ui.shared import LOGGER, LOG_EXTRA, render_tuning_backtest_hints @@ -58,15 +58,15 @@ def main() -> None: progress_hook=progress_hook, force_refresh=False, ) - rss_count = ingest_configured_rss(hours_back=24, max_items_per_feed=50) + news_count = ingest_latest_news(days_back=1, force=False) LOGGER.info( - "自动数据更新完成:日线数据覆盖%s-%s,RSS新闻%s条", + "自动数据更新完成:日线数据覆盖%s-%s,GDELT新闻%s条", report.start, report.end, - rss_count, + news_count, extra=LOG_EXTRA, ) - st.success(f"✅ 自动数据更新完成:获取RSS新闻 {rss_count} 条") + st.success(f"✅ 自动数据更新完成:获取 GDELT 新闻 {news_count} 条") except Exception as exc: # noqa: BLE001 LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA) st.error(f"❌ 自动数据更新失败:{exc}") diff --git a/app/ui/views/tests.py b/app/ui/views/tests.py index 5930cac..1224341 100644 --- a/app/ui/views/tests.py +++ b/app/ui/views/tests.py @@ -71,61 +71,39 @@ def render_tests() -> None: st.divider() - st.subheader("RSS 数据测试") - st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。") - rss_url = st.text_input( - "测试 RSS 地址", - value="https://rsshub.app/cls/depth/1000", - help="留空则使用默认配置的全部 RSS 来源。", - ).strip() - rss_hours = int( + st.subheader("新闻数据测试(GDELT)") + st.write("用于验证 GDELT 配置是否能够正常抓取新闻并写入数据库。") + news_days = int( st.number_input( - "回溯窗口(小时)", + "回溯窗口(天)", min_value=1, - max_value=168, - value=24, - step=6, - help="仅抓取最近指定小时内的新闻。", + max_value=30, + value=1, + step=1, + help="按天抓取最近区间的新闻。", ) ) - rss_limit = int( - st.number_input( - "单源抓取条数", - min_value=1, - max_value=200, - value=50, - step=10, - ) + force_news = st.checkbox( + "强制重新抓取(忽略增量状态)", + value=False, ) - if st.button("运行 RSS 测试"): - from app.ingest import rss as rss_ingest + if st.button("运行 GDELT 新闻测试"): + from app.ingest.news import ingest_latest_news LOGGER.info( - "点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s", - rss_url, - rss_hours, - rss_limit, + "点击 GDELT 新闻测试按钮 days=%s force=%s", + news_days, + force_news, extra=LOG_EXTRA, ) - with st.spinner("正在抓取 RSS 新闻..."): + with st.spinner("正在抓取 GDELT 新闻..."): try: - if rss_url: - items = rss_ingest.fetch_rss_feed( - rss_url, - hours_back=rss_hours, - max_items=rss_limit, - ) - count = rss_ingest.save_news_items(items) - else: - count = rss_ingest.ingest_configured_rss( - hours_back=rss_hours, - max_items_per_feed=rss_limit, - ) - st.success(f"RSS 测试完成,新增 {count} 条新闻记录。") + count = ingest_latest_news(days_back=news_days, force=force_news) + st.success(f"GDELT 新闻测试完成,新增 {count} 条新闻记录。") except Exception as exc: # noqa: BLE001 - LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA) - st.error(f"RSS 测试失败:{exc}") - alerts.add_warning("RSS", "RSS 测试执行失败", str(exc)) + LOGGER.exception("GDELT 新闻测试失败", extra=LOG_EXTRA) + st.error(f"GDELT 新闻测试失败:{exc}") + alerts.add_warning("GDELT", "GDELT 新闻测试执行失败", str(exc)) update_dashboard_sidebar() st.divider()