refactor news data fetching and industry sentiment calculation

This commit is contained in:
Your Name 2025-10-11 09:40:54 +08:00
parent 90fb2a9df6
commit 5b2033f52b

View File

@ -11,14 +11,13 @@ from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import numpy as np
from .config import get_config
import types
from .db import db_session
from .logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility
from app.utils.db_query import BrokerQueryEngine
from app.utils import alerts
# 延迟导入,避免循环依赖
collect_data_coverage = None
@ -143,6 +142,14 @@ def _end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 23:59:59")
def _iso_start_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT00:00:00+00:00")
def _iso_end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT23:59:59+00:00")
def _coerce_date(value: object) -> Optional[date]:
if value is None:
return None
@ -210,6 +217,7 @@ class DataBroker:
self._coverage_cache = {}
self._refresh = _RefreshCoordinator(self)
self._query_engine = BrokerQueryEngine(db_session)
self._auto_update_warning_emitted = False
if initialize_database is not None:
initialize_database() # 确保数据库已初始化
else:
@ -566,76 +574,167 @@ class DataBroker:
self,
ts_code: str,
trade_date: str,
limit: int = 30
limit: int = 30,
lookback_days: int = 3,
) -> List[Dict[str, Any]]:
"""获取新闻数据(简化实现)
"""获取新闻数据切片。
Args:
ts_code: 股票代码
trade_date: 交易日期
trade_date: 交易日期YYYYMMDD/ISO
limit: 返回的新闻条数限制
lookback_days: 回溯天数用于拉取近几日新闻
Returns:
新闻数据列表包含sentimentheatentities等字段
新闻数据列表包含 sentimentheatentities 等字段
"""
# TODO: 使用真实新闻数据库替换随机生成的占位数据
return [
{
"sentiment": np.random.uniform(-1, 1),
"heat": np.random.uniform(0, 1),
"entities": "股票,市场,投资"
}
for _ in range(min(limit, 5))
]
if not ts_code or limit <= 0:
return []
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
LOGGER.debug(
"新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA,
)
return []
window_days = max(1, lookback_days)
end_day = parsed_date.date()
start_day = end_day - timedelta(days=max(window_days - 1, 0))
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
query = (
"SELECT sentiment, heat, sentiment_index, heat_score, entities, "
"title, summary, source, url, pub_time "
"FROM news "
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
"ORDER BY pub_time DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
rows = conn.execute(
query,
(ts_code, start_bound, end_bound, int(limit)),
).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"新闻数据查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"新闻数据读取异常 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
return [dict(row) for row in rows]
def _lookup_industry(self, ts_code: str) -> Optional[str]:
"""查找股票所属行业
Args:
ts_code: 股票代码
Returns:
行业代码或名称找不到时返回None
"""
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
industry_mapping = {
"000001.SZ": "银行",
"000002.SZ": "房地产",
"000858.SZ": "食品饮料",
"000962.SZ": "医药生物",
}
return industry_mapping.get(ts_code, "其他")
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
"""计算行业情绪得分
Args:
industry: 行业代码或名称
trade_date: 交易日期
Returns:
行业情绪得分找不到时返回None
"""
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
return np.random.uniform(-1, 1)
def _derived_industry_sentiment(
self,
industry: str,
trade_date: str,
*,
lookback_days: int = 5,
) -> Optional[float]:
"""根据近几日新闻情绪推导行业层面的情绪指标。"""
if not industry:
return None
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
return None
stocks = self.get_industry_stocks(industry)
if not stocks:
return None
peers: List[str] = list(dict.fromkeys(stocks))
if not peers:
return None
window_days = max(1, lookback_days)
end_day = parsed_date.date()
start_day = end_day - timedelta(days=max(window_days - 1, 0))
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
placeholders = ",".join("?" for _ in peers[:200])
if not placeholders:
return None
query = (
f"SELECT sentiment FROM news "
f"WHERE ts_code IN ({placeholders}) "
"AND pub_time BETWEEN ? AND ? "
"AND sentiment IS NOT NULL"
)
params: List[object] = list(peers[:200])
params.extend([start_bound, end_bound])
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业情绪查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业情绪读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
sentiments: List[float] = []
for row in rows:
try:
sentiments.append(float(row["sentiment"]))
except (TypeError, ValueError, KeyError):
continue
if not sentiments:
return None
avg = sum(sentiments) / len(sentiments)
return max(-1.0, min(1.0, avg))
def get_industry_stocks(self, industry: str) -> List[str]:
"""获取同行业股票列表
Args:
industry: 行业代码或名称
Returns:
同行业股票代码列表
"""
# TODO: 使用实际行业成分数据替换占位列表
industry_stocks = {
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
}
return industry_stocks.get(industry, [])
"""获取同行业股票列表。"""
if not industry:
return []
cache = getattr(self, "_industry_members_cache", None)
if cache is None:
cache = {}
self._industry_members_cache = cache
if industry in cache:
return cache[industry]
query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code"
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (industry,)).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业成分查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业成分读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
members = [row["ts_code"] for row in rows if row and row["ts_code"]]
cache[industry] = members
return members
def fetch_flags(
self,
@ -1070,7 +1169,7 @@ class DataBroker:
def check_data_availability(
self,
trade_date: str,
tables: Set[str] = None,
tables: Optional[Set[str]] = None,
threshold: float = 0.8,
) -> bool:
"""检查指定交易日的数据是否可用如不可用则返回True需要补数
@ -1083,12 +1182,25 @@ class DataBroker:
Returns:
bool: True表示数据不足需要补数
"""
cfg = get_config()
# 如果配置了强制刷新,则始终返回需要补数
if get_config().force_refresh:
if cfg.force_refresh:
return True
# 如果未启用自动更新,则不进行补数
if not get_config().auto_update_data:
if not cfg.auto_update_data:
if not getattr(self, "_auto_update_warning_emitted", False):
message = "自动补数已关闭,系统将跳过缺口检测。"
LOGGER.warning(message, extra=LOG_EXTRA)
try:
alerts.add_warning(
"data_broker",
"自动补数已关闭",
"当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。",
)
except Exception: # noqa: BLE001
LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA)
self._auto_update_warning_emitted = True
return False
# 默认检查的表