From 5b2033f52bda433168cb8bdc15dd441d6e6cebfd Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 11 Oct 2025 09:40:54 +0800 Subject: [PATCH] refactor news data fetching and industry sentiment calculation --- app/utils/data_access.py | 242 ++++++++++++++++++++++++++++----------- 1 file changed, 177 insertions(+), 65 deletions(-) diff --git a/app/utils/data_access.py b/app/utils/data_access.py index 1f750be..a5387b9 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -11,14 +11,13 @@ from dataclasses import dataclass, field from datetime import date, datetime, timedelta from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple -import numpy as np - from .config import get_config import types from .db import db_session from .logging import get_logger from app.core.indicators import momentum, normalize, rolling_mean, volatility from app.utils.db_query import BrokerQueryEngine +from app.utils import alerts # 延迟导入,避免循环依赖 collect_data_coverage = None @@ -143,6 +142,14 @@ def _end_of_day(dt: datetime) -> str: return dt.strftime("%Y-%m-%d 23:59:59") +def _iso_start_of_day(dt: datetime) -> str: + return dt.strftime("%Y-%m-%dT00:00:00+00:00") + + +def _iso_end_of_day(dt: datetime) -> str: + return dt.strftime("%Y-%m-%dT23:59:59+00:00") + + def _coerce_date(value: object) -> Optional[date]: if value is None: return None @@ -210,6 +217,7 @@ class DataBroker: self._coverage_cache = {} self._refresh = _RefreshCoordinator(self) self._query_engine = BrokerQueryEngine(db_session) + self._auto_update_warning_emitted = False if initialize_database is not None: initialize_database() # 确保数据库已初始化 else: @@ -566,76 +574,167 @@ class DataBroker: self, ts_code: str, trade_date: str, - limit: int = 30 + limit: int = 30, + lookback_days: int = 3, ) -> List[Dict[str, Any]]: - """获取新闻数据(简化实现) - + """获取新闻数据切片。 + Args: ts_code: 股票代码 - trade_date: 交易日期 + trade_date: 交易日期(YYYYMMDD/ISO) limit: 返回的新闻条数限制 - - Returns: - 新闻数据列表,包含sentiment、heat、entities等字段 - """ - # TODO: 使用真实新闻数据库替换随机生成的占位数据 - return [ - { - "sentiment": np.random.uniform(-1, 1), - "heat": np.random.uniform(0, 1), - "entities": "股票,市场,投资" - } - for _ in range(min(limit, 5)) - ] + lookback_days: 回溯天数,用于拉取近几日新闻 - def _lookup_industry(self, ts_code: str) -> Optional[str]: - """查找股票所属行业 - - Args: - ts_code: 股票代码 - Returns: - 行业代码或名称,找不到时返回None + 新闻数据列表,包含 sentiment、heat、entities 等字段 """ - # TODO: 替换为真实行业映射逻辑(当前仅为占位数据) - industry_mapping = { - "000001.SZ": "银行", - "000002.SZ": "房地产", - "000858.SZ": "食品饮料", - "000962.SZ": "医药生物", - } - return industry_mapping.get(ts_code, "其他") + if not ts_code or limit <= 0: + return [] + parsed_date = _parse_trade_date(trade_date) + if not parsed_date: + LOGGER.debug( + "新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s", + ts_code, + trade_date, + extra=LOG_EXTRA, + ) + return [] + window_days = max(1, lookback_days) + end_day = parsed_date.date() + start_day = end_day - timedelta(days=max(window_days - 1, 0)) + start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day)) + end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day)) + query = ( + "SELECT sentiment, heat, sentiment_index, heat_score, entities, " + "title, summary, source, url, pub_time " + "FROM news " + "WHERE ts_code = ? AND pub_time BETWEEN ? AND ? " + "ORDER BY pub_time DESC LIMIT ?" + ) + try: + with db_session(read_only=True) as conn: + rows = conn.execute( + query, + (ts_code, start_bound, end_bound, int(limit)), + ).fetchall() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "新闻数据查询失败 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return [] + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "新闻数据读取异常 ts_code=%s err=%s", + ts_code, + exc, + extra=LOG_EXTRA, + ) + return [] + return [dict(row) for row in rows] - def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]: - """计算行业情绪得分 - - Args: - industry: 行业代码或名称 - trade_date: 交易日期 - - Returns: - 行业情绪得分,找不到时返回None - """ - # TODO: 接入行业情绪数据源,当前随机值仅用于占位显示 - return np.random.uniform(-1, 1) + def _derived_industry_sentiment( + self, + industry: str, + trade_date: str, + *, + lookback_days: int = 5, + ) -> Optional[float]: + """根据近几日新闻情绪推导行业层面的情绪指标。""" + if not industry: + return None + parsed_date = _parse_trade_date(trade_date) + if not parsed_date: + return None + stocks = self.get_industry_stocks(industry) + if not stocks: + return None + peers: List[str] = list(dict.fromkeys(stocks)) + if not peers: + return None + window_days = max(1, lookback_days) + end_day = parsed_date.date() + start_day = end_day - timedelta(days=max(window_days - 1, 0)) + start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day)) + end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day)) + placeholders = ",".join("?" for _ in peers[:200]) + if not placeholders: + return None + query = ( + f"SELECT sentiment FROM news " + f"WHERE ts_code IN ({placeholders}) " + "AND pub_time BETWEEN ? AND ? " + "AND sentiment IS NOT NULL" + ) + params: List[object] = list(peers[:200]) + params.extend([start_bound, end_bound]) + try: + with db_session(read_only=True) as conn: + rows = conn.execute(query, params).fetchall() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "行业情绪查询失败 industry=%s err=%s", + industry, + exc, + extra=LOG_EXTRA, + ) + return None + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "行业情绪读取异常 industry=%s err=%s", + industry, + exc, + extra=LOG_EXTRA, + ) + return None + sentiments: List[float] = [] + for row in rows: + try: + sentiments.append(float(row["sentiment"])) + except (TypeError, ValueError, KeyError): + continue + if not sentiments: + return None + avg = sum(sentiments) / len(sentiments) + return max(-1.0, min(1.0, avg)) def get_industry_stocks(self, industry: str) -> List[str]: - """获取同行业股票列表 - - Args: - industry: 行业代码或名称 - - Returns: - 同行业股票代码列表 - """ - # TODO: 使用实际行业成分数据替换占位列表 - industry_stocks = { - "银行": ["000001.SZ", "002142.SZ", "600036.SH"], - "房地产": ["000002.SZ", "000402.SZ", "600048.SH"], - "食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"], - "医药生物": ["000962.SZ", "600276.SH", "300003.SZ"], - } - return industry_stocks.get(industry, []) + """获取同行业股票列表。""" + if not industry: + return [] + cache = getattr(self, "_industry_members_cache", None) + if cache is None: + cache = {} + self._industry_members_cache = cache + if industry in cache: + return cache[industry] + query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code" + try: + with db_session(read_only=True) as conn: + rows = conn.execute(query, (industry,)).fetchall() + except sqlite3.OperationalError as exc: + LOGGER.debug( + "行业成分查询失败 industry=%s err=%s", + industry, + exc, + extra=LOG_EXTRA, + ) + cache[industry] = [] + return [] + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "行业成分读取异常 industry=%s err=%s", + industry, + exc, + extra=LOG_EXTRA, + ) + cache[industry] = [] + return [] + members = [row["ts_code"] for row in rows if row and row["ts_code"]] + cache[industry] = members + return members def fetch_flags( self, @@ -1070,7 +1169,7 @@ class DataBroker: def check_data_availability( self, trade_date: str, - tables: Set[str] = None, + tables: Optional[Set[str]] = None, threshold: float = 0.8, ) -> bool: """检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。 @@ -1083,12 +1182,25 @@ class DataBroker: Returns: bool: True表示数据不足,需要补数 """ + cfg = get_config() # 如果配置了强制刷新,则始终返回需要补数 - if get_config().force_refresh: + if cfg.force_refresh: return True # 如果未启用自动更新,则不进行补数 - if not get_config().auto_update_data: + if not cfg.auto_update_data: + if not getattr(self, "_auto_update_warning_emitted", False): + message = "自动补数已关闭,系统将跳过缺口检测。" + LOGGER.warning(message, extra=LOG_EXTRA) + try: + alerts.add_warning( + "data_broker", + "自动补数已关闭", + "当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。", + ) + except Exception: # noqa: BLE001 + LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA) + self._auto_update_warning_emitted = True return False # 默认检查的表