refactor news data fetching and industry sentiment calculation
This commit is contained in:
parent
90fb2a9df6
commit
5b2033f52b
@ -11,14 +11,13 @@ from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import get_config
|
||||
import types
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||
from app.utils.db_query import BrokerQueryEngine
|
||||
from app.utils import alerts
|
||||
|
||||
# 延迟导入,避免循环依赖
|
||||
collect_data_coverage = None
|
||||
@ -143,6 +142,14 @@ def _end_of_day(dt: datetime) -> str:
|
||||
return dt.strftime("%Y-%m-%d 23:59:59")
|
||||
|
||||
|
||||
def _iso_start_of_day(dt: datetime) -> str:
|
||||
return dt.strftime("%Y-%m-%dT00:00:00+00:00")
|
||||
|
||||
|
||||
def _iso_end_of_day(dt: datetime) -> str:
|
||||
return dt.strftime("%Y-%m-%dT23:59:59+00:00")
|
||||
|
||||
|
||||
def _coerce_date(value: object) -> Optional[date]:
|
||||
if value is None:
|
||||
return None
|
||||
@ -210,6 +217,7 @@ class DataBroker:
|
||||
self._coverage_cache = {}
|
||||
self._refresh = _RefreshCoordinator(self)
|
||||
self._query_engine = BrokerQueryEngine(db_session)
|
||||
self._auto_update_warning_emitted = False
|
||||
if initialize_database is not None:
|
||||
initialize_database() # 确保数据库已初始化
|
||||
else:
|
||||
@ -566,76 +574,167 @@ class DataBroker:
|
||||
self,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
limit: int = 30
|
||||
limit: int = 30,
|
||||
lookback_days: int = 3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取新闻数据(简化实现)
|
||||
"""获取新闻数据切片。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
trade_date: 交易日期
|
||||
trade_date: 交易日期(YYYYMMDD/ISO)
|
||||
limit: 返回的新闻条数限制
|
||||
lookback_days: 回溯天数,用于拉取近几日新闻
|
||||
|
||||
Returns:
|
||||
新闻数据列表,包含 sentiment、heat、entities 等字段
|
||||
"""
|
||||
# TODO: 使用真实新闻数据库替换随机生成的占位数据
|
||||
return [
|
||||
{
|
||||
"sentiment": np.random.uniform(-1, 1),
|
||||
"heat": np.random.uniform(0, 1),
|
||||
"entities": "股票,市场,投资"
|
||||
}
|
||||
for _ in range(min(limit, 5))
|
||||
]
|
||||
if not ts_code or limit <= 0:
|
||||
return []
|
||||
parsed_date = _parse_trade_date(trade_date)
|
||||
if not parsed_date:
|
||||
LOGGER.debug(
|
||||
"新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s",
|
||||
ts_code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
window_days = max(1, lookback_days)
|
||||
end_day = parsed_date.date()
|
||||
start_day = end_day - timedelta(days=max(window_days - 1, 0))
|
||||
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
|
||||
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
|
||||
query = (
|
||||
"SELECT sentiment, heat, sentiment_index, heat_score, entities, "
|
||||
"title, summary, source, url, pub_time "
|
||||
"FROM news "
|
||||
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
|
||||
"ORDER BY pub_time DESC LIMIT ?"
|
||||
)
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
query,
|
||||
(ts_code, start_bound, end_bound, int(limit)),
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError as exc:
|
||||
LOGGER.debug(
|
||||
"新闻数据查询失败 ts_code=%s err=%s",
|
||||
ts_code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"新闻数据读取异常 ts_code=%s err=%s",
|
||||
ts_code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def _lookup_industry(self, ts_code: str) -> Optional[str]:
|
||||
"""查找股票所属行业
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
|
||||
Returns:
|
||||
行业代码或名称,找不到时返回None
|
||||
"""
|
||||
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
|
||||
industry_mapping = {
|
||||
"000001.SZ": "银行",
|
||||
"000002.SZ": "房地产",
|
||||
"000858.SZ": "食品饮料",
|
||||
"000962.SZ": "医药生物",
|
||||
}
|
||||
return industry_mapping.get(ts_code, "其他")
|
||||
|
||||
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
|
||||
"""计算行业情绪得分
|
||||
|
||||
Args:
|
||||
industry: 行业代码或名称
|
||||
trade_date: 交易日期
|
||||
|
||||
Returns:
|
||||
行业情绪得分,找不到时返回None
|
||||
"""
|
||||
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
|
||||
return np.random.uniform(-1, 1)
|
||||
def _derived_industry_sentiment(
|
||||
self,
|
||||
industry: str,
|
||||
trade_date: str,
|
||||
*,
|
||||
lookback_days: int = 5,
|
||||
) -> Optional[float]:
|
||||
"""根据近几日新闻情绪推导行业层面的情绪指标。"""
|
||||
if not industry:
|
||||
return None
|
||||
parsed_date = _parse_trade_date(trade_date)
|
||||
if not parsed_date:
|
||||
return None
|
||||
stocks = self.get_industry_stocks(industry)
|
||||
if not stocks:
|
||||
return None
|
||||
peers: List[str] = list(dict.fromkeys(stocks))
|
||||
if not peers:
|
||||
return None
|
||||
window_days = max(1, lookback_days)
|
||||
end_day = parsed_date.date()
|
||||
start_day = end_day - timedelta(days=max(window_days - 1, 0))
|
||||
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
|
||||
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
|
||||
placeholders = ",".join("?" for _ in peers[:200])
|
||||
if not placeholders:
|
||||
return None
|
||||
query = (
|
||||
f"SELECT sentiment FROM news "
|
||||
f"WHERE ts_code IN ({placeholders}) "
|
||||
"AND pub_time BETWEEN ? AND ? "
|
||||
"AND sentiment IS NOT NULL"
|
||||
)
|
||||
params: List[object] = list(peers[:200])
|
||||
params.extend([start_bound, end_bound])
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
except sqlite3.OperationalError as exc:
|
||||
LOGGER.debug(
|
||||
"行业情绪查询失败 industry=%s err=%s",
|
||||
industry,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"行业情绪读取异常 industry=%s err=%s",
|
||||
industry,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return None
|
||||
sentiments: List[float] = []
|
||||
for row in rows:
|
||||
try:
|
||||
sentiments.append(float(row["sentiment"]))
|
||||
except (TypeError, ValueError, KeyError):
|
||||
continue
|
||||
if not sentiments:
|
||||
return None
|
||||
avg = sum(sentiments) / len(sentiments)
|
||||
return max(-1.0, min(1.0, avg))
|
||||
|
||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||
"""获取同行业股票列表
|
||||
|
||||
Args:
|
||||
industry: 行业代码或名称
|
||||
|
||||
Returns:
|
||||
同行业股票代码列表
|
||||
"""
|
||||
# TODO: 使用实际行业成分数据替换占位列表
|
||||
industry_stocks = {
|
||||
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
||||
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
||||
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
|
||||
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
|
||||
}
|
||||
return industry_stocks.get(industry, [])
|
||||
"""获取同行业股票列表。"""
|
||||
if not industry:
|
||||
return []
|
||||
cache = getattr(self, "_industry_members_cache", None)
|
||||
if cache is None:
|
||||
cache = {}
|
||||
self._industry_members_cache = cache
|
||||
if industry in cache:
|
||||
return cache[industry]
|
||||
query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code"
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(query, (industry,)).fetchall()
|
||||
except sqlite3.OperationalError as exc:
|
||||
LOGGER.debug(
|
||||
"行业成分查询失败 industry=%s err=%s",
|
||||
industry,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
cache[industry] = []
|
||||
return []
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"行业成分读取异常 industry=%s err=%s",
|
||||
industry,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
cache[industry] = []
|
||||
return []
|
||||
members = [row["ts_code"] for row in rows if row and row["ts_code"]]
|
||||
cache[industry] = members
|
||||
return members
|
||||
|
||||
def fetch_flags(
|
||||
self,
|
||||
@ -1070,7 +1169,7 @@ class DataBroker:
|
||||
def check_data_availability(
|
||||
self,
|
||||
trade_date: str,
|
||||
tables: Set[str] = None,
|
||||
tables: Optional[Set[str]] = None,
|
||||
threshold: float = 0.8,
|
||||
) -> bool:
|
||||
"""检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。
|
||||
@ -1083,12 +1182,25 @@ class DataBroker:
|
||||
Returns:
|
||||
bool: True表示数据不足,需要补数
|
||||
"""
|
||||
cfg = get_config()
|
||||
# 如果配置了强制刷新,则始终返回需要补数
|
||||
if get_config().force_refresh:
|
||||
if cfg.force_refresh:
|
||||
return True
|
||||
|
||||
# 如果未启用自动更新,则不进行补数
|
||||
if not get_config().auto_update_data:
|
||||
if not cfg.auto_update_data:
|
||||
if not getattr(self, "_auto_update_warning_emitted", False):
|
||||
message = "自动补数已关闭,系统将跳过缺口检测。"
|
||||
LOGGER.warning(message, extra=LOG_EXTRA)
|
||||
try:
|
||||
alerts.add_warning(
|
||||
"data_broker",
|
||||
"自动补数已关闭",
|
||||
"当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。",
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA)
|
||||
self._auto_update_warning_emitted = True
|
||||
return False
|
||||
|
||||
# 默认检查的表
|
||||
|
||||
Loading…
Reference in New Issue
Block a user