refactor news data fetching and industry sentiment calculation
This commit is contained in:
parent
90fb2a9df6
commit
5b2033f52b
@ -11,14 +11,13 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
import types
|
import types
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||||
from app.utils.db_query import BrokerQueryEngine
|
from app.utils.db_query import BrokerQueryEngine
|
||||||
|
from app.utils import alerts
|
||||||
|
|
||||||
# 延迟导入,避免循环依赖
|
# 延迟导入,避免循环依赖
|
||||||
collect_data_coverage = None
|
collect_data_coverage = None
|
||||||
@ -143,6 +142,14 @@ def _end_of_day(dt: datetime) -> str:
|
|||||||
return dt.strftime("%Y-%m-%d 23:59:59")
|
return dt.strftime("%Y-%m-%d 23:59:59")
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_start_of_day(dt: datetime) -> str:
|
||||||
|
return dt.strftime("%Y-%m-%dT00:00:00+00:00")
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_end_of_day(dt: datetime) -> str:
|
||||||
|
return dt.strftime("%Y-%m-%dT23:59:59+00:00")
|
||||||
|
|
||||||
|
|
||||||
def _coerce_date(value: object) -> Optional[date]:
|
def _coerce_date(value: object) -> Optional[date]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
@ -210,6 +217,7 @@ class DataBroker:
|
|||||||
self._coverage_cache = {}
|
self._coverage_cache = {}
|
||||||
self._refresh = _RefreshCoordinator(self)
|
self._refresh = _RefreshCoordinator(self)
|
||||||
self._query_engine = BrokerQueryEngine(db_session)
|
self._query_engine = BrokerQueryEngine(db_session)
|
||||||
|
self._auto_update_warning_emitted = False
|
||||||
if initialize_database is not None:
|
if initialize_database is not None:
|
||||||
initialize_database() # 确保数据库已初始化
|
initialize_database() # 确保数据库已初始化
|
||||||
else:
|
else:
|
||||||
@ -566,76 +574,167 @@ class DataBroker:
|
|||||||
self,
|
self,
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
trade_date: str,
|
trade_date: str,
|
||||||
limit: int = 30
|
limit: int = 30,
|
||||||
|
lookback_days: int = 3,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取新闻数据(简化实现)
|
"""获取新闻数据切片。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ts_code: 股票代码
|
ts_code: 股票代码
|
||||||
trade_date: 交易日期
|
trade_date: 交易日期(YYYYMMDD/ISO)
|
||||||
limit: 返回的新闻条数限制
|
limit: 返回的新闻条数限制
|
||||||
|
lookback_days: 回溯天数,用于拉取近几日新闻
|
||||||
Returns:
|
|
||||||
新闻数据列表,包含sentiment、heat、entities等字段
|
|
||||||
"""
|
|
||||||
# TODO: 使用真实新闻数据库替换随机生成的占位数据
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"sentiment": np.random.uniform(-1, 1),
|
|
||||||
"heat": np.random.uniform(0, 1),
|
|
||||||
"entities": "股票,市场,投资"
|
|
||||||
}
|
|
||||||
for _ in range(min(limit, 5))
|
|
||||||
]
|
|
||||||
|
|
||||||
def _lookup_industry(self, ts_code: str) -> Optional[str]:
|
|
||||||
"""查找股票所属行业
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ts_code: 股票代码
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
行业代码或名称,找不到时返回None
|
新闻数据列表,包含 sentiment、heat、entities 等字段
|
||||||
"""
|
"""
|
||||||
# TODO: 替换为真实行业映射逻辑(当前仅为占位数据)
|
if not ts_code or limit <= 0:
|
||||||
industry_mapping = {
|
return []
|
||||||
"000001.SZ": "银行",
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
"000002.SZ": "房地产",
|
if not parsed_date:
|
||||||
"000858.SZ": "食品饮料",
|
LOGGER.debug(
|
||||||
"000962.SZ": "医药生物",
|
"新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s",
|
||||||
}
|
ts_code,
|
||||||
return industry_mapping.get(ts_code, "其他")
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
window_days = max(1, lookback_days)
|
||||||
|
end_day = parsed_date.date()
|
||||||
|
start_day = end_day - timedelta(days=max(window_days - 1, 0))
|
||||||
|
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
|
||||||
|
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
|
||||||
|
query = (
|
||||||
|
"SELECT sentiment, heat, sentiment_index, heat_score, entities, "
|
||||||
|
"title, summary, source, url, pub_time "
|
||||||
|
"FROM news "
|
||||||
|
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
|
||||||
|
"ORDER BY pub_time DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
query,
|
||||||
|
(ts_code, start_bound, end_bound, int(limit)),
|
||||||
|
).fetchall()
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"新闻数据查询失败 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"新闻数据读取异常 ts_code=%s err=%s",
|
||||||
|
ts_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
|
def _derived_industry_sentiment(
|
||||||
"""计算行业情绪得分
|
self,
|
||||||
|
industry: str,
|
||||||
Args:
|
trade_date: str,
|
||||||
industry: 行业代码或名称
|
*,
|
||||||
trade_date: 交易日期
|
lookback_days: int = 5,
|
||||||
|
) -> Optional[float]:
|
||||||
Returns:
|
"""根据近几日新闻情绪推导行业层面的情绪指标。"""
|
||||||
行业情绪得分,找不到时返回None
|
if not industry:
|
||||||
"""
|
return None
|
||||||
# TODO: 接入行业情绪数据源,当前随机值仅用于占位显示
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
return np.random.uniform(-1, 1)
|
if not parsed_date:
|
||||||
|
return None
|
||||||
|
stocks = self.get_industry_stocks(industry)
|
||||||
|
if not stocks:
|
||||||
|
return None
|
||||||
|
peers: List[str] = list(dict.fromkeys(stocks))
|
||||||
|
if not peers:
|
||||||
|
return None
|
||||||
|
window_days = max(1, lookback_days)
|
||||||
|
end_day = parsed_date.date()
|
||||||
|
start_day = end_day - timedelta(days=max(window_days - 1, 0))
|
||||||
|
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
|
||||||
|
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
|
||||||
|
placeholders = ",".join("?" for _ in peers[:200])
|
||||||
|
if not placeholders:
|
||||||
|
return None
|
||||||
|
query = (
|
||||||
|
f"SELECT sentiment FROM news "
|
||||||
|
f"WHERE ts_code IN ({placeholders}) "
|
||||||
|
"AND pub_time BETWEEN ? AND ? "
|
||||||
|
"AND sentiment IS NOT NULL"
|
||||||
|
)
|
||||||
|
params: List[object] = list(peers[:200])
|
||||||
|
params.extend([start_bound, end_bound])
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
except sqlite3.OperationalError as exc:
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业情绪查询失败 industry=%s err=%s",
|
||||||
|
industry,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业情绪读取异常 industry=%s err=%s",
|
||||||
|
industry,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
sentiments: List[float] = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
sentiments.append(float(row["sentiment"]))
|
||||||
|
except (TypeError, ValueError, KeyError):
|
||||||
|
continue
|
||||||
|
if not sentiments:
|
||||||
|
return None
|
||||||
|
avg = sum(sentiments) / len(sentiments)
|
||||||
|
return max(-1.0, min(1.0, avg))
|
||||||
|
|
||||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||||
"""获取同行业股票列表
|
"""获取同行业股票列表。"""
|
||||||
|
if not industry:
|
||||||
Args:
|
return []
|
||||||
industry: 行业代码或名称
|
cache = getattr(self, "_industry_members_cache", None)
|
||||||
|
if cache is None:
|
||||||
Returns:
|
cache = {}
|
||||||
同行业股票代码列表
|
self._industry_members_cache = cache
|
||||||
"""
|
if industry in cache:
|
||||||
# TODO: 使用实际行业成分数据替换占位列表
|
return cache[industry]
|
||||||
industry_stocks = {
|
query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code"
|
||||||
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
try:
|
||||||
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
with db_session(read_only=True) as conn:
|
||||||
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
|
rows = conn.execute(query, (industry,)).fetchall()
|
||||||
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
|
except sqlite3.OperationalError as exc:
|
||||||
}
|
LOGGER.debug(
|
||||||
return industry_stocks.get(industry, [])
|
"行业成分查询失败 industry=%s err=%s",
|
||||||
|
industry,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
cache[industry] = []
|
||||||
|
return []
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"行业成分读取异常 industry=%s err=%s",
|
||||||
|
industry,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
cache[industry] = []
|
||||||
|
return []
|
||||||
|
members = [row["ts_code"] for row in rows if row and row["ts_code"]]
|
||||||
|
cache[industry] = members
|
||||||
|
return members
|
||||||
|
|
||||||
def fetch_flags(
|
def fetch_flags(
|
||||||
self,
|
self,
|
||||||
@ -1070,7 +1169,7 @@ class DataBroker:
|
|||||||
def check_data_availability(
|
def check_data_availability(
|
||||||
self,
|
self,
|
||||||
trade_date: str,
|
trade_date: str,
|
||||||
tables: Set[str] = None,
|
tables: Optional[Set[str]] = None,
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。
|
"""检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。
|
||||||
@ -1083,12 +1182,25 @@ class DataBroker:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True表示数据不足,需要补数
|
bool: True表示数据不足,需要补数
|
||||||
"""
|
"""
|
||||||
|
cfg = get_config()
|
||||||
# 如果配置了强制刷新,则始终返回需要补数
|
# 如果配置了强制刷新,则始终返回需要补数
|
||||||
if get_config().force_refresh:
|
if cfg.force_refresh:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 如果未启用自动更新,则不进行补数
|
# 如果未启用自动更新,则不进行补数
|
||||||
if not get_config().auto_update_data:
|
if not cfg.auto_update_data:
|
||||||
|
if not getattr(self, "_auto_update_warning_emitted", False):
|
||||||
|
message = "自动补数已关闭,系统将跳过缺口检测。"
|
||||||
|
LOGGER.warning(message, extra=LOG_EXTRA)
|
||||||
|
try:
|
||||||
|
alerts.add_warning(
|
||||||
|
"data_broker",
|
||||||
|
"自动补数已关闭",
|
||||||
|
"当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。",
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA)
|
||||||
|
self._auto_update_warning_emitted = True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 默认检查的表
|
# 默认检查的表
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user