refactor news data fetching and industry sentiment calculation

This commit is contained in:
Your Name 2025-10-11 09:40:54 +08:00
parent 90fb2a9df6
commit 5b2033f52b

View File

@ -11,14 +11,13 @@ from dataclasses import dataclass, field
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import numpy as np
from .config import get_config from .config import get_config
import types import types
from .db import db_session from .db import db_session
from .logging import get_logger from .logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility from app.core.indicators import momentum, normalize, rolling_mean, volatility
from app.utils.db_query import BrokerQueryEngine from app.utils.db_query import BrokerQueryEngine
from app.utils import alerts
# 延迟导入,避免循环依赖 # 延迟导入,避免循环依赖
collect_data_coverage = None collect_data_coverage = None
@ -143,6 +142,14 @@ def _end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 23:59:59") return dt.strftime("%Y-%m-%d 23:59:59")
def _iso_start_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT00:00:00+00:00")
def _iso_end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT23:59:59+00:00")
def _coerce_date(value: object) -> Optional[date]: def _coerce_date(value: object) -> Optional[date]:
if value is None: if value is None:
return None return None
@ -210,6 +217,7 @@ class DataBroker:
self._coverage_cache = {} self._coverage_cache = {}
self._refresh = _RefreshCoordinator(self) self._refresh = _RefreshCoordinator(self)
self._query_engine = BrokerQueryEngine(db_session) self._query_engine = BrokerQueryEngine(db_session)
self._auto_update_warning_emitted = False
if initialize_database is not None: if initialize_database is not None:
initialize_database() # 确保数据库已初始化 initialize_database() # 确保数据库已初始化
else: else:
@ -566,76 +574,167 @@ class DataBroker:
self, self,
ts_code: str, ts_code: str,
trade_date: str, trade_date: str,
limit: int = 30 limit: int = 30,
lookback_days: int = 3,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取新闻数据(简化实现) """获取新闻数据切片。
Args: Args:
ts_code: 股票代码 ts_code: 股票代码
trade_date: 交易日期 trade_date: 交易日期YYYYMMDD/ISO
limit: 返回的新闻条数限制 limit: 返回的新闻条数限制
lookback_days: 回溯天数用于拉取近几日新闻
Returns: Returns:
新闻数据列表包含 sentimentheatentities 等字段 新闻数据列表包含 sentimentheatentities 等字段
""" """
# TODO: 使用真实新闻数据库替换随机生成的占位数据 if not ts_code or limit <= 0:
return [ return []
{ parsed_date = _parse_trade_date(trade_date)
"sentiment": np.random.uniform(-1, 1), if not parsed_date:
"heat": np.random.uniform(0, 1), LOGGER.debug(
"entities": "股票,市场,投资" "新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s",
} ts_code,
for _ in range(min(limit, 5)) trade_date,
] extra=LOG_EXTRA,
)
return []
window_days = max(1, lookback_days)
end_day = parsed_date.date()
start_day = end_day - timedelta(days=max(window_days - 1, 0))
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
query = (
"SELECT sentiment, heat, sentiment_index, heat_score, entities, "
"title, summary, source, url, pub_time "
"FROM news "
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
"ORDER BY pub_time DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
rows = conn.execute(
query,
(ts_code, start_bound, end_bound, int(limit)),
).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"新闻数据查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"新闻数据读取异常 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
return [dict(row) for row in rows]
def _lookup_industry(self, ts_code: str) -> Optional[str]: def _derived_industry_sentiment(
"""查找股票所属行业 self,
industry: str,
Args: trade_date: str,
ts_code: 股票代码 *,
lookback_days: int = 5,
Returns: ) -> Optional[float]:
行业代码或名称找不到时返回None """根据近几日新闻情绪推导行业层面的情绪指标。"""
""" if not industry:
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据) return None
industry_mapping = { parsed_date = _parse_trade_date(trade_date)
"000001.SZ": "银行", if not parsed_date:
"000002.SZ": "房地产", return None
"000858.SZ": "食品饮料", stocks = self.get_industry_stocks(industry)
"000962.SZ": "医药生物", if not stocks:
} return None
return industry_mapping.get(ts_code, "其他") peers: List[str] = list(dict.fromkeys(stocks))
if not peers:
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]: return None
"""计算行业情绪得分 window_days = max(1, lookback_days)
end_day = parsed_date.date()
Args: start_day = end_day - timedelta(days=max(window_days - 1, 0))
industry: 行业代码或名称 start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
trade_date: 交易日期 end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
placeholders = ",".join("?" for _ in peers[:200])
Returns: if not placeholders:
行业情绪得分找不到时返回None return None
""" query = (
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示 f"SELECT sentiment FROM news "
return np.random.uniform(-1, 1) f"WHERE ts_code IN ({placeholders}) "
"AND pub_time BETWEEN ? AND ? "
"AND sentiment IS NOT NULL"
)
params: List[object] = list(peers[:200])
params.extend([start_bound, end_bound])
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业情绪查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业情绪读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
sentiments: List[float] = []
for row in rows:
try:
sentiments.append(float(row["sentiment"]))
except (TypeError, ValueError, KeyError):
continue
if not sentiments:
return None
avg = sum(sentiments) / len(sentiments)
return max(-1.0, min(1.0, avg))
def get_industry_stocks(self, industry: str) -> List[str]: def get_industry_stocks(self, industry: str) -> List[str]:
"""获取同行业股票列表 """获取同行业股票列表。"""
if not industry:
Args: return []
industry: 行业代码或名称 cache = getattr(self, "_industry_members_cache", None)
if cache is None:
Returns: cache = {}
同行业股票代码列表 self._industry_members_cache = cache
""" if industry in cache:
# TODO: 使用实际行业成分数据替换占位列表 return cache[industry]
industry_stocks = { query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code"
"银行": ["000001.SZ", "002142.SZ", "600036.SH"], try:
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"], with db_session(read_only=True) as conn:
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"], rows = conn.execute(query, (industry,)).fetchall()
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"], except sqlite3.OperationalError as exc:
} LOGGER.debug(
return industry_stocks.get(industry, []) "行业成分查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业成分读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
members = [row["ts_code"] for row in rows if row and row["ts_code"]]
cache[industry] = members
return members
def fetch_flags( def fetch_flags(
self, self,
@ -1070,7 +1169,7 @@ class DataBroker:
def check_data_availability( def check_data_availability(
self, self,
trade_date: str, trade_date: str,
tables: Set[str] = None, tables: Optional[Set[str]] = None,
threshold: float = 0.8, threshold: float = 0.8,
) -> bool: ) -> bool:
"""检查指定交易日的数据是否可用如不可用则返回True需要补数 """检查指定交易日的数据是否可用如不可用则返回True需要补数
@ -1083,12 +1182,25 @@ class DataBroker:
Returns: Returns:
bool: True表示数据不足需要补数 bool: True表示数据不足需要补数
""" """
cfg = get_config()
# 如果配置了强制刷新,则始终返回需要补数 # 如果配置了强制刷新,则始终返回需要补数
if get_config().force_refresh: if cfg.force_refresh:
return True return True
# 如果未启用自动更新,则不进行补数 # 如果未启用自动更新,则不进行补数
if not get_config().auto_update_data: if not cfg.auto_update_data:
if not getattr(self, "_auto_update_warning_emitted", False):
message = "自动补数已关闭,系统将跳过缺口检测。"
LOGGER.warning(message, extra=LOG_EXTRA)
try:
alerts.add_warning(
"data_broker",
"自动补数已关闭",
"当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。",
)
except Exception: # noqa: BLE001
LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA)
self._auto_update_warning_emitted = True
return False return False
# 默认检查的表 # 默认检查的表