llm-quant/app/utils/data_access.py

1719 lines
60 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Utility helpers to retrieve structured data slices for agents and departments."""
from __future__ import annotations
import re
import sqlite3
import threading
import time
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import types
from .config import get_config
from .db import db_session
from .logging import get_logger
from app.core.indicators import momentum, normalize, rolling_mean, volatility
from app.utils.db_query import BrokerQueryEngine
from app.utils import alerts
from app.ingest.coverage import collect_data_coverage as _collect_coverage, ensure_data_coverage as _ensure_coverage
try:
from app.data.schema import initialize_database
except ImportError:
def initialize_database():
"""Fallback stub used when the real initializer cannot be imported.
Return a lightweight object with the attributes callers expect
(executed, skipped, missing_tables) so code that calls
`initialize_database()` can safely inspect the result.
"""
return types.SimpleNamespace(executed=0, skipped=True, missing_tables=[])
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_broker"}
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _is_safe_identifier(name: str) -> bool:
return bool(_IDENTIFIER_RE.match(name))
def _default_coverage_runner(start: date, end: date) -> None:
if _ensure_coverage is None:
LOGGER.debug("默认补数函数不可用,跳过自动补数", extra=LOG_EXTRA)
return
_ensure_coverage(
start,
end,
include_limits=False,
include_extended=False,
force=False,
progress_hook=None,
)
def _default_coverage_collector(start: date, end: date) -> Dict[str, Dict[str, object]]:
if _collect_coverage is None:
LOGGER.debug("默认覆盖统计函数不可用,返回空结果", extra=LOG_EXTRA)
return {}
return _collect_coverage(start, end)
def _safe_split(path: str) -> Tuple[str, str] | None:
if "." not in path:
return None
table, column = path.split(".", 1)
table = table.strip()
column = column.strip()
if not table or not column:
return None
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA)
return None
return table, column
@dataclass
class _RefreshCoordinator:
"""Orchestrates background refresh requests for the broker."""
broker: "DataBroker"
def ensure_for_latest(self, trade_date: str, fields: Iterable[str]) -> None:
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
return
normalized = parsed_date.strftime("%Y%m%d")
tables = self._collect_tables(fields)
if tables and self.broker.check_data_availability(normalized, tables):
LOGGER.debug(
"触发近端数据刷新 trade_date=%s tables=%s",
normalized,
sorted(tables),
extra=LOG_EXTRA,
)
self.broker._trigger_background_refresh(normalized)
def ensure_for_series(self, end_date: str, table: str) -> None:
parsed_date = _parse_trade_date(end_date)
if not parsed_date:
return
normalized = parsed_date.strftime("%Y%m%d")
if self.broker.check_data_availability(normalized, {table}):
LOGGER.debug(
"触发序列刷新 trade_date=%s table=%s",
normalized,
table,
extra=LOG_EXTRA,
)
self.broker._trigger_background_refresh(normalized)
def _collect_tables(self, fields: Iterable[str]) -> Set[str]:
tables: Set[str] = set()
for field_name in fields:
resolved = self.broker.resolve_field(field_name)
if resolved:
table, _ = resolved
tables.add(table)
return tables
def parse_field_path(path: str) -> Tuple[str, str] | None:
"""Validate and split a `table.column` field expression."""
return _safe_split(path)
def _parse_trade_date(value: object) -> Optional[datetime]:
if value is None:
return None
text = str(value).strip()
if not text:
return None
text = text.replace("-", "")
try:
return datetime.strptime(text[:8], "%Y%m%d")
except ValueError:
return None
def _start_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 00:00:00")
def _end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%d 23:59:59")
def _iso_start_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT00:00:00+00:00")
def _iso_end_of_day(dt: datetime) -> str:
return dt.strftime("%Y-%m-%dT23:59:59+00:00")
def _coerce_date(value: object) -> Optional[date]:
if value is None:
return None
if isinstance(value, date):
return value
parsed = _parse_trade_date(value)
if parsed:
return parsed.date()
return None
@dataclass
class DataBroker:
"""Lightweight data access helper with automated data fetching capabilities."""
FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = {
"daily": {
"volume": "vol",
"vol": "vol",
"turnover": "amount",
},
"daily_basic": {
"turnover": "turnover_rate",
"turnover_rate": "turnover_rate",
"turnover_rate_f": "turnover_rate_f",
"volume_ratio": "volume_ratio",
"pe": "pe",
"pb": "pb",
"ps": "ps",
"ps_ttm": "ps_ttm",
"dividend_yield": "dv_ratio",
},
"stk_limit": {
"up": "up_limit",
"down": "down_limit",
},
}
MAX_WINDOW: ClassVar[int] = 120
BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
# 自动补数配置
AUTO_REFRESH_WINDOW: ClassVar[int] = 7 # 自动补数的时间窗口
REFRESH_RETRY_INTERVAL: ClassVar[int] = 5 # 补数重试间隔(秒)
MAX_REFRESH_WAIT: ClassVar[int] = 60 # 最大等待补数完成时间(秒)
enable_cache: bool = True
latest_cache_size: int = 256
series_cache_size: int = 512
coverage_runner: Callable[[date, date], None] = field(default=_default_coverage_runner)
coverage_collector: Callable[[date, date], Dict[str, Dict[str, object]]] = field(default=_default_coverage_collector)
_latest_cache: OrderedDict = field(init=False, repr=False)
_series_cache: OrderedDict = field(init=False, repr=False)
# 补数相关状态管理
_refresh_lock: threading.RLock = field(init=False, repr=False)
_refresh_in_progress: Dict[str, bool] = field(init=False, repr=False)
_refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False)
_coverage_cache: Dict[str, Dict] = field(init=False, repr=False)
_refresh: _RefreshCoordinator = field(init=False, repr=False)
_query_engine: BrokerQueryEngine = field(init=False, repr=False)
def __post_init__(self) -> None:
self._latest_cache = OrderedDict()
self._series_cache = OrderedDict()
# 初始化补数相关状态
self._refresh_lock = threading.RLock()
self._refresh_in_progress = {}
self._refresh_callbacks = {}
self._coverage_cache = {}
self._refresh = _RefreshCoordinator(self)
self._query_engine = BrokerQueryEngine(db_session)
self._auto_update_warning_emitted = False
if initialize_database is not None:
initialize_database() # 确保数据库已初始化
else:
LOGGER.warning("initialize_database 函数不可用,数据库可能未初始化", extra=LOG_EXTRA)
def fetch_latest(
self,
ts_code: str,
trade_date: str,
fields: Iterable[str],
auto_refresh: bool = True,
) -> Dict[str, Any]:
"""Fetch the latest value (<= trade_date) for each requested field.
Args:
ts_code: 证券代码
trade_date: 交易日
fields: 要查询的字段列表
auto_refresh: 是否在数据不足时自动触发补数
"""
field_list = [str(item) for item in fields if item]
cache_key: Optional[Tuple[Any, ...]] = None
if self.enable_cache and field_list:
cache_key = (ts_code, trade_date, tuple(sorted(field_list)))
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
return deepcopy(cached)
# 检查是否需要自动补数
if auto_refresh:
self._refresh.ensure_for_latest(trade_date, field_list)
grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
derived_cache: Dict[str, Any] = {}
results: Dict[str, Any] = {}
for field_name in field_list:
resolved = self.resolve_field(field_name)
if not resolved:
derived = self._resolve_derived_field(
ts_code,
trade_date,
field_name,
derived_cache,
)
if derived is not None:
results[field_name] = derived
continue
table, column = resolved
grouped.setdefault(table, [])
if column not in grouped[table]:
grouped[table].append(column)
field_map.setdefault((table, column), []).append(field_name)
if grouped:
for table, columns in grouped.items():
try:
row = self._query_engine.fetch_latest(table, ts_code, trade_date, columns)
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
)
continue
if not row:
continue
for column in columns:
value = row[column]
if value is None:
continue
for original in field_map.get((table, column), [f"{table}.{column}"]):
try:
results[original] = float(value)
except (TypeError, ValueError):
results[original] = value
if cache_key is not None and not results:
cached = self._cache_lookup(self._latest_cache, cache_key)
if cached is not None:
LOGGER.debug(
"使用缓存结果 ts_code=%s trade_date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA,
)
return deepcopy(cached)
if cache_key is not None and results:
self._cache_store(
self._latest_cache,
cache_key,
deepcopy(results),
self.latest_cache_size,
)
return results
def fetch_series(
self,
table: str,
column: str,
ts_code: str,
end_date: str,
window: int,
auto_refresh: bool = True,
) -> List[Tuple[str, float]]:
"""Return descending time series tuples within the specified window.
Args:
table: 表名
column: 列名
ts_code: 证券代码
end_date: 结束日期
window: 时间窗口大小
auto_refresh: 是否在数据不足时自动触发补数
"""
if window <= 0:
return []
window = min(window, self.MAX_WINDOW)
resolved_field = self.resolve_field(f"{table}.{column}")
if not resolved_field:
LOGGER.debug(
"时间序列字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return []
table, resolved = resolved_field
# 检查是否需要自动补数
if auto_refresh:
self._refresh.ensure_for_series(end_date, table)
cache_key: Optional[Tuple[Any, ...]] = None
if self.enable_cache:
cache_key = (table, resolved, ts_code, end_date, window)
cached = self._cache_lookup(self._series_cache, cache_key)
if cached is not None:
return [tuple(item) for item in cached]
try:
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列查询失败 table=%s column=%s err=%s",
table,
resolved,
exc,
extra=LOG_EXTRA,
)
if cache_key is not None:
cached = self._cache_lookup(self._series_cache, cache_key)
if cached is not None:
LOGGER.debug(
"使用缓存时间序列 table=%s column=%s ts_code=%s",
table,
resolved,
ts_code,
extra=LOG_EXTRA,
)
return [tuple(item) for item in cached]
return []
series: List[Tuple[str, float]] = []
for row in rows:
value = row[resolved]
trade_dt = row["trade_date"]
if value is None or trade_dt is None:
continue
try:
series.append((trade_dt, float(value)))
except (TypeError, ValueError):
continue
if cache_key is not None and series:
self._cache_store(
self._series_cache,
cache_key,
tuple(series),
self.series_cache_size,
)
return series
def fetch_batch_latest(
self,
ts_codes: List[str],
trade_date: str,
fields: Iterable[str],
auto_refresh: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""批次化获取多个证券的最新字段数据
Args:
ts_codes: 证券代码列表
trade_date: 交易日
fields: 要查询的字段列表
auto_refresh: 是否在数据不足时自动触发补数
Returns:
证券代码到字段数据的映射
"""
if not ts_codes:
return {}
field_list = [str(item) for item in fields if item]
if not field_list:
return {}
# 检查是否需要自动补数
if auto_refresh:
self._refresh.ensure_for_latest(trade_date, field_list)
# 按表分组字段
field_groups = {}
for field_name in field_list:
resolved = self.resolve_field(field_name)
if not resolved:
continue
table, column = resolved
field_groups.setdefault(table, set()).add(column)
batch_data = {}
# 对每个表进行批量查询
for table, columns in field_groups.items():
if not ts_codes:
continue
# 构建批量查询SQL
placeholders = ','.join(['?'] * len(ts_codes))
columns_str = ', '.join(['ts_code', 'trade_date'] + list(columns))
query = f"""
SELECT {columns_str}
FROM (
SELECT {columns_str},
ROW_NUMBER() OVER (PARTITION BY ts_code ORDER BY trade_date DESC) as rn
FROM {table}
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
) WHERE rn = 1
"""
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (*ts_codes, trade_date)).fetchall()
for row in rows:
ts_code = row['ts_code']
batch_data.setdefault(ts_code, {})
for column in columns:
field_name = f"{table}.{column}"
try:
batch_data[ts_code][field_name] = float(row[column])
except (TypeError, ValueError):
batch_data[ts_code][field_name] = row[column]
except Exception as e:
LOGGER.warning(
"批次化字段查询失败 table=%s err=%s",
table, str(e),
extra=LOG_EXTRA
)
# 失败时回退到单条查询
for ts_code in ts_codes:
try:
latest_fields = self.fetch_latest(ts_code, trade_date, [f"{table}.{col}" for col in columns])
batch_data.setdefault(ts_code, {}).update(latest_fields)
except Exception as inner_e:
LOGGER.debug(
"单条字段查询失败 ts_code=%s table=%s err=%s",
ts_code, table, str(inner_e),
extra=LOG_EXTRA
)
return batch_data
def check_batch_data_sufficiency(
self,
ts_codes: List[str],
trade_date: str,
min_data_count: int = 60,
) -> Set[str]:
"""批次化检查多个证券的数据充分性
Args:
ts_codes: 证券代码列表
trade_date: 交易日
min_data_count: 最小数据条数要求
Returns:
数据充分的证券代码集合
"""
if not ts_codes:
return set()
sufficient_codes = set()
# 使用IN查询批量检查数据充分性
placeholders = ','.join(['?'] * len(ts_codes))
query = f"""
SELECT ts_code, COUNT(*) as data_count
FROM daily
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
GROUP BY ts_code
HAVING COUNT(*) >= ?
"""
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (*ts_codes, trade_date, min_data_count)).fetchall()
for row in rows:
ts_code = row['ts_code']
sufficient_codes.add(ts_code)
except Exception as e:
LOGGER.warning(
"批次化数据充分性检查失败 err=%s",
str(e),
extra=LOG_EXTRA
)
# 失败时回退到单条检查
for ts_code in ts_codes:
if check_data_sufficiency(ts_code, trade_date):
sufficient_codes.add(ts_code)
return sufficient_codes
def register_refresh_callback(
self,
start: date | str,
end: date | str,
callback: Callable[[], None],
) -> None:
"""Register a hook invoked after background refresh completes for the window."""
if callback is None:
return
start_date = _coerce_date(start)
end_date = _coerce_date(end)
if not start_date or not end_date:
LOGGER.debug(
"忽略无效补数回调窗口 start=%s end=%s",
start,
end,
extra=LOG_EXTRA,
)
return
key = f"{start_date}_{end_date}"
with self._refresh_lock:
bucket = self._refresh_callbacks.setdefault(key, [])
if callback not in bucket:
bucket.append(callback)
def get_news_data(
self,
ts_code: str,
trade_date: str,
limit: int = 30,
lookback_days: int = 3,
) -> List[Dict[str, Any]]:
"""获取新闻数据切片。
Args:
ts_code: 股票代码
trade_date: 交易日期YYYYMMDD/ISO
limit: 返回的新闻条数限制
lookback_days: 回溯天数,用于拉取近几日新闻
Returns:
新闻数据列表,包含 sentiment、heat、entities 等字段
"""
if not ts_code or limit <= 0:
return []
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
LOGGER.debug(
"新闻数据查询失败,无法解析日期 ts_code=%s trade_date=%s",
ts_code,
trade_date,
extra=LOG_EXTRA,
)
return []
window_days = max(1, lookback_days)
end_day = parsed_date.date()
start_day = end_day - timedelta(days=max(window_days - 1, 0))
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
query = (
"SELECT sentiment, heat, sentiment_index, heat_score, entities, "
"title, summary, source, url, pub_time "
"FROM news "
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
"ORDER BY pub_time DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
rows = conn.execute(
query,
(ts_code, start_bound, end_bound, int(limit)),
).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"新闻数据查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"新闻数据读取异常 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
return [dict(row) for row in rows]
def _derived_industry_sentiment(
self,
industry: str,
trade_date: str,
*,
lookback_days: int = 5,
) -> Optional[float]:
"""根据近几日新闻情绪推导行业层面的情绪指标。"""
if not industry:
return None
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
return None
stocks = self.get_industry_stocks(industry)
if not stocks:
return None
peers: List[str] = list(dict.fromkeys(stocks))
if not peers:
return None
window_days = max(1, lookback_days)
end_day = parsed_date.date()
start_day = end_day - timedelta(days=max(window_days - 1, 0))
start_bound = _iso_start_of_day(datetime(start_day.year, start_day.month, start_day.day))
end_bound = _iso_end_of_day(datetime(end_day.year, end_day.month, end_day.day))
placeholders = ",".join("?" for _ in peers[:200])
if not placeholders:
return None
query = (
f"SELECT sentiment FROM news "
f"WHERE ts_code IN ({placeholders}) "
"AND pub_time BETWEEN ? AND ? "
"AND sentiment IS NOT NULL"
)
params: List[object] = list(peers[:200])
params.extend([start_bound, end_bound])
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业情绪查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业情绪读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
return None
sentiments: List[float] = []
for row in rows:
try:
sentiments.append(float(row["sentiment"]))
except (TypeError, ValueError, KeyError):
continue
if not sentiments:
return None
avg = sum(sentiments) / len(sentiments)
return max(-1.0, min(1.0, avg))
def get_industry_stocks(self, industry: str) -> List[str]:
"""获取同行业股票列表。"""
if not industry:
return []
cache = getattr(self, "_industry_members_cache", None)
if cache is None:
cache = {}
self._industry_members_cache = cache
if industry in cache:
return cache[industry]
query = "SELECT ts_code FROM stock_basic WHERE industry = ? ORDER BY ts_code"
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (industry,)).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业成分查询失败 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业成分读取异常 industry=%s err=%s",
industry,
exc,
extra=LOG_EXTRA,
)
cache[industry] = []
return []
members = [row["ts_code"] for row in rows if row and row["ts_code"]]
cache[industry] = members
return members
def fetch_flags(
self,
table: str,
ts_code: str,
trade_date: str,
where_clause: str,
params: Sequence[object],
auto_refresh: bool = True,
) -> bool:
"""Generic helper to test if a record exists (used for limit/suspend lookups)."""
if not _is_safe_identifier(table):
return False
parsed_date = _parse_trade_date(trade_date)
trade_key = parsed_date.strftime("%Y%m%d") if parsed_date else str(trade_date)
if auto_refresh and parsed_date and self.check_data_availability(trade_key, {table}):
self._trigger_background_refresh(trade_key)
if hasattr(time, "sleep"):
time.sleep(0.5)
if table == "suspend":
query = (
"SELECT 1 FROM suspend "
"WHERE ts_code = ? "
"AND suspend_date IS NOT NULL "
"AND suspend_date <> '' "
"AND suspend_date <= ? "
"AND (resume_date IS NULL OR resume_date = '' OR resume_date > ?) "
"LIMIT 1"
)
bind_params = (ts_code, trade_key, trade_key)
else:
clauses = ["ts_code = ?"]
bind_params_list: List[object] = [ts_code]
clause_text = (where_clause or "").strip()
if clause_text:
clauses.append(clause_text)
bind_params_list.extend(params)
query = f"SELECT 1 FROM {table} WHERE {' AND '.join(clauses)} LIMIT 1"
bind_params = tuple(bind_params_list)
try:
with db_session(read_only=True) as conn:
try:
row = conn.execute(query, bind_params).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"flag 查询失败 table=%s where=%s err=%s",
table,
where_clause,
exc,
extra=LOG_EXTRA,
)
return False
except sqlite3.OperationalError as exc:
LOGGER.debug(
"flag 查询连接失败 table=%s err=%s",
table,
exc,
extra=LOG_EXTRA,
)
return False
return row is not None
def fetch_table_rows(
self,
table: str,
ts_code: str,
trade_date: str,
window: int,
auto_refresh: bool = True,
) -> List[Dict[str, object]]:
if window <= 0:
return []
window = min(window, self.MAX_WINDOW)
# 检查是否需要自动补数
if auto_refresh:
parsed_date = _parse_trade_date(trade_date)
if parsed_date and self.check_data_availability(trade_date, {table}):
self._trigger_background_refresh(trade_date)
# 短暂等待以获取最新数据
if hasattr(time, 'sleep'):
time.sleep(0.5)
columns = self._get_table_columns(table)
if not columns:
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
return []
try:
rows = self._query_engine.fetch_table(
table,
columns,
ts_code,
trade_date if "trade_date" in columns else None,
window,
)
except Exception as exc: # noqa: BLE001
LOGGER.debug("表查询失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
return []
return [{col: row[col] for col in columns} for row in rows]
def _resolve_derived_field(
self,
ts_code: str,
trade_date: str,
field: str,
cache: Dict[str, Any],
) -> Optional[Any]:
if field in cache:
return cache[field]
value: Optional[Any] = None
if field == "factors.mom_20":
value = self._derived_price_momentum(ts_code, trade_date, 20)
elif field == "factors.mom_60":
value = self._derived_price_momentum(ts_code, trade_date, 60)
elif field == "factors.volat_20":
value = self._derived_price_volatility(ts_code, trade_date, 20)
elif field == "factors.turn_20":
value = self._derived_turnover_mean(ts_code, trade_date, 20)
elif field == "news.sentiment_index":
rows = cache.get("__news_rows__")
if rows is None:
rows = self._fetch_recent_news(ts_code, trade_date)
cache["__news_rows__"] = rows
value = self._news_sentiment_from_rows(rows)
elif field == "news.heat_score":
rows = cache.get("__news_rows__")
if rows is None:
rows = self._fetch_recent_news(ts_code, trade_date)
cache["__news_rows__"] = rows
value = self._news_heat_from_rows(rows)
elif field == "macro.industry_heat":
value = self._derived_industry_heat(ts_code, trade_date)
elif field in {"macro.relative_strength", "index.performance_peers"}:
value = self._derived_relative_strength(ts_code, trade_date, cache)
cache[field] = value
return value
def _derived_price_momentum(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
values = [value for _dt, value in series]
if not values:
return None
return momentum(values, window)
def _derived_price_volatility(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series("daily", "close", ts_code, trade_date, window)
values = [value for _dt, value in series]
if len(values) < 2:
return None
return volatility(values, window)
def _derived_turnover_mean(
self,
ts_code: str,
trade_date: str,
window: int,
) -> Optional[float]:
series = self.fetch_series(
"daily_basic",
"turnover_rate",
ts_code,
trade_date,
window,
)
values = [value for _dt, value in series]
if not values:
return None
return rolling_mean(values, window)
def _fetch_recent_news(
self,
ts_code: str,
trade_date: str,
days: int = 3,
limit: int = 120,
) -> List[Dict[str, Any]]:
baseline = _parse_trade_date(trade_date)
if baseline is None:
return []
start = _start_of_day(baseline - timedelta(days=days))
end = _end_of_day(baseline)
query = (
"SELECT sentiment, heat FROM news "
"WHERE ts_code = ? AND pub_time BETWEEN ? AND ? "
"ORDER BY pub_time DESC LIMIT ?"
)
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (ts_code, start, end, limit)).fetchall()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"新闻查询连接失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"新闻查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return []
return [dict(row) for row in rows]
@staticmethod
def _news_sentiment_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
sentiments: List[float] = []
for row in rows:
value = row.get("sentiment")
if value is None:
continue
try:
sentiments.append(float(value))
except (TypeError, ValueError):
continue
if not sentiments:
return None
avg = sum(sentiments) / len(sentiments)
return max(-1.0, min(1.0, avg))
@staticmethod
def _news_heat_from_rows(rows: List[Dict[str, Any]]) -> Optional[float]:
if not rows:
return None
total_heat = 0.0
for row in rows:
value = row.get("heat")
if value is None:
continue
try:
total_heat += max(float(value), 0.0)
except (TypeError, ValueError):
continue
if total_heat > 0:
return normalize(total_heat, factor=100.0)
return normalize(len(rows), factor=20.0)
def _derived_industry_heat(self, ts_code: str, trade_date: str) -> Optional[float]:
industry = self._lookup_industry(ts_code)
if not industry:
return None
query = (
"SELECT heat FROM heat_daily "
"WHERE scope = ? AND key = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
try:
with db_session(read_only=True) as conn:
row = conn.execute(query, ("industry", industry, trade_date)).fetchone()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业热度查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业热度读取异常 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
return None
if not row:
return None
heat_value = row["heat"]
if heat_value is None:
return None
return normalize(heat_value, factor=100.0)
def _lookup_industry(self, ts_code: str) -> Optional[str]:
cache = getattr(self, "_industry_cache", None)
if cache is None:
cache = {}
self._industry_cache = cache
if ts_code in cache:
return cache[ts_code]
query = "SELECT industry FROM stock_basic WHERE ts_code = ?"
try:
with db_session(read_only=True) as conn:
row = conn.execute(query, (ts_code,)).fetchone()
except sqlite3.OperationalError as exc:
LOGGER.debug(
"行业查询连接失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
cache[ts_code] = None
return None
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"行业查询失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA,
)
cache[ts_code] = None
return None
industry = None
if row:
industry = row["industry"]
cache[ts_code] = industry
return industry
def _derived_relative_strength(
self,
ts_code: str,
trade_date: str,
cache: Dict[str, Any],
) -> Optional[float]:
window = 20
series = self.fetch_series("daily", "close", ts_code, trade_date, max(window, 30))
values = [value for _dt, value in series]
if not values:
return None
stock_momentum = momentum(values, window)
bench_key = f"__benchmark_mom_{window}"
benchmark = cache.get(bench_key)
if benchmark is None:
benchmark = self._index_momentum(trade_date, window)
cache[bench_key] = benchmark
diff = stock_momentum if benchmark is None else stock_momentum - benchmark
diff = max(-0.2, min(0.2, diff))
return (diff + 0.2) / 0.4
def _index_momentum(self, trade_date: str, window: int) -> Optional[float]:
series = self.fetch_series(
"index_daily",
"close",
self.BENCHMARK_INDEX,
trade_date,
window,
)
values = [value for _dt, value in series]
if not values:
return None
return momentum(values, window)
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
normalized = _safe_split(field)
if not normalized:
return None
table, column = normalized
resolved = self._resolve_column(table, column)
if not resolved:
# Certain fields are derived at runtime and intentionally
# do not require physical columns. Suppress noisy debug logs
# for those known derived fields so startup isn't spammy.
derived_fields = {
"macro.industry_heat",
"macro.relative_strength",
"index.performance_peers",
"news.heat_score",
"news.sentiment_index",
}
if f"{table}.{column}" in derived_fields:
return None
LOGGER.debug(
"字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return None
return table, resolved
def _get_table_columns(self, table: str) -> Optional[List[str]]:
if not _is_safe_identifier(table):
return None
cache = getattr(self, "_column_cache", None)
if cache is None:
cache = {}
self._column_cache = cache
if table in cache:
return cache[table]
try:
with db_session(read_only=True) as conn:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
cache[table] = None
return None
if not rows:
cache[table] = None
return None
columns = [row["name"] for row in rows if row["name"]]
cache[table] = columns
return columns
def _cache_lookup(self, cache: OrderedDict, key: Tuple[Any, ...]) -> Optional[Any]:
if key in cache:
cache.move_to_end(key)
return cache[key]
return None
def _cache_store(
self,
cache: OrderedDict,
key: Tuple[Any, ...],
value: Any,
limit: int,
) -> None:
if not self.enable_cache or limit <= 0:
return
cache[key] = value
cache.move_to_end(key)
while len(cache) > limit:
cache.popitem(last=False)
def check_data_availability(
self,
trade_date: str,
tables: Optional[Set[str]] = None,
threshold: float = 0.8,
) -> bool:
"""检查指定交易日的数据是否可用如不可用则返回True需要补数
Args:
trade_date: 要检查的交易日
tables: 要检查的表集合,默认检查主要行情表
threshold: 数据覆盖率阈值,低于此值需要补数
Returns:
bool: True表示数据不足需要补数
"""
cfg = get_config()
# 如果配置了强制刷新,则始终返回需要补数
if cfg.force_refresh:
return True
# 如果未启用自动更新,则不进行补数
if not cfg.auto_update_data:
if not getattr(self, "_auto_update_warning_emitted", False):
message = "自动补数已关闭,系统将跳过缺口检测。"
LOGGER.warning(message, extra=LOG_EXTRA)
try:
alerts.add_warning(
"data_broker",
"自动补数已关闭",
"当前运行模式不会触发数据补齐,请在设置中开启自动更新或手动补数。",
)
except Exception: # noqa: BLE001
LOGGER.debug("自动补数告警发送失败", extra=LOG_EXTRA)
self._auto_update_warning_emitted = True
return False
# 默认检查的表
if tables is None:
tables = {"daily", "daily_basic", "stock_basic", "trade_cal"}
try:
# 解析交易日
parsed_date = _parse_trade_date(trade_date)
if not parsed_date:
LOGGER.debug("无法解析交易日: %s", trade_date, extra=LOG_EXTRA)
return False
# 计算检查窗口
end_date = parsed_date.strftime('%Y%m%d')
start_date = (parsed_date - timedelta(days=self.AUTO_REFRESH_WINDOW)).strftime('%Y%m%d')
# 构建缓存键
cache_key = f"{start_date}_{end_date}_{'_'.join(sorted(tables))}"
# 检查缓存
if cache_key in self._coverage_cache:
coverage = self._coverage_cache[cache_key]
current_time = time.time() if hasattr(time, 'time') else 0
if coverage.get('timestamp', 0) > current_time - 300: # 5分钟内有效
# 检查是否需要补数
for table in tables:
table_coverage = coverage.get(table, {})
if table_coverage.get('coverage', 0) < threshold:
return True
return False
# 收集数据覆盖情况
if self.coverage_collector is None:
LOGGER.debug("未配置覆盖统计函数,无法判断是否需要补数", extra=LOG_EXTRA)
return False
start_d = datetime.strptime(start_date, "%Y%m%d").date()
end_d = datetime.strptime(end_date, "%Y%m%d").date()
coverage = self.coverage_collector(start_d, end_d)
# 保存到缓存
coverage['timestamp'] = time.time() if hasattr(time, 'time') else 0
self._coverage_cache[cache_key] = coverage
# 检查是否需要补数
for table in tables:
table_coverage = coverage.get(table, {})
if table_coverage.get('coverage', 0) < threshold:
return True
except Exception as exc:
LOGGER.exception("检查数据可用性失败: %s", exc, extra=LOG_EXTRA)
# 出错时保守处理,不触发补数
return False
return False
def _trigger_background_refresh(self, target_date: str) -> None:
"""在后台线程触发数据补数。"""
parsed_date = _parse_trade_date(target_date)
if not parsed_date:
return
# 构建补数日期范围
end_date = parsed_date.date()
start_date = end_date - timedelta(days=self.AUTO_REFRESH_WINDOW)
refresh_key = f"{start_date}_{end_date}"
# 检查是否已经在补数中
with self._refresh_lock:
if self._refresh_in_progress.get(refresh_key, False):
LOGGER.debug("数据补数已经在进行中: %s", refresh_key, extra=LOG_EXTRA)
return
self._refresh_in_progress[refresh_key] = True
self._refresh_callbacks.setdefault(refresh_key, [])
def refresh_task():
try:
LOGGER.info("开始后台数据补数: %s%s", start_date, end_date, extra=LOG_EXTRA)
# 执行补数
if self.coverage_runner is None:
LOGGER.debug("未配置覆盖补数函数,跳过自动补数", extra=LOG_EXTRA)
with self._refresh_lock:
self._refresh_in_progress[refresh_key] = False
self._refresh_callbacks.pop(refresh_key, None)
return
self.coverage_runner(start_date, end_date)
LOGGER.info("后台数据补数完成: %s%s", start_date, end_date, extra=LOG_EXTRA)
# 清除缓存,强制重新加载数据
self._latest_cache.clear()
self._series_cache.clear()
self._coverage_cache.clear()
# 执行回调
with self._refresh_lock:
callbacks = self._refresh_callbacks.pop(refresh_key, [])
self._refresh_in_progress[refresh_key] = False
if callbacks:
LOGGER.info(
"执行补数回调 count=%s key=%s",
len(callbacks),
refresh_key,
extra=LOG_EXTRA,
)
for callback in callbacks:
try:
callback()
except Exception as exc:
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
except Exception as exc:
LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA)
with self._refresh_lock:
self._refresh_in_progress[refresh_key] = False
# 启动后台线程
thread = threading.Thread(target=refresh_task, daemon=True)
thread.start()
def is_refreshing(self, start_date: str = None, end_date: str = None) -> bool:
"""检查指定日期范围是否正在补数中。"""
with self._refresh_lock:
if not start_date and not end_date:
# 检查是否有任何补数正在进行
return any(self._refresh_in_progress.values())
# 检查指定日期范围
for key, in_progress in self._refresh_in_progress.items():
if in_progress and key.startswith(start_date or '') and key.endswith(end_date or ''):
return True
return False
def wait_for_refresh_complete(
self,
timeout: float = None,
start_date: str = None,
end_date: str = None
) -> bool:
"""等待数据补数完成。
Args:
timeout: 超时时间默认为MAX_REFRESH_WAIT
start_date: 开始日期
end_date: 结束日期
Returns:
bool: True表示补数已完成False表示超时
"""
if timeout is None:
timeout = self.MAX_REFRESH_WAIT
start_time = time.time() if hasattr(time, 'time') else 0
current_time_func = time.time if hasattr(time, 'time') else lambda: 0
while current_time_func() - start_time < timeout:
if not self.is_refreshing(start_date, end_date):
return True
# 短暂休眠后再次检查
if hasattr(time, 'sleep'):
time.sleep(min(self.REFRESH_RETRY_INTERVAL, timeout / 10))
return False
def on_data_refresh(
self,
callback: Callable,
start_date: str = None,
end_date: str = None
) -> None:
"""注册数据补数完成的回调函数。"""
if start_date and end_date:
refresh_key = f"{start_date}_{end_date}"
with self._refresh_lock:
self._refresh_callbacks.setdefault(refresh_key, []).append(callback)
# 如果当前没有补数在进行,则直接调用回调
if not self._refresh_in_progress.get(refresh_key, False):
try:
callback()
except Exception as exc:
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
def set_auto_refresh_window(self, days: int) -> None:
"""设置自动补数的时间窗口。
Args:
days: 自动补数的天数窗口
"""
if days > 0:
self.AUTO_REFRESH_WINDOW = days
LOGGER.info("自动补数窗口已设置为 %d", days, extra=LOG_EXTRA)
def set_refresh_retry_interval(self, seconds: int) -> None:
"""设置补数检查的重试间隔。
Args:
seconds: 重试间隔(秒)
"""
if seconds > 0:
self.REFRESH_RETRY_INTERVAL = seconds
LOGGER.info("补数重试间隔已设置为 %d", seconds, extra=LOG_EXTRA)
def set_max_refresh_wait(self, seconds: int) -> None:
"""设置最大等待补数完成时间。
Args:
seconds: 最大等待时间(秒)
"""
if seconds > 0:
self.MAX_REFRESH_WAIT = seconds
LOGGER.info("最大补数等待时间已设置为 %d", seconds, extra=LOG_EXTRA)
def force_refresh_data(self, start_date: str, end_date: str) -> bool:
"""强制刷新指定日期范围内的数据。
Args:
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
Returns:
bool: 是否成功触发刷新
"""
try:
# 解析日期
start = _parse_trade_date(start_date)
end = _parse_trade_date(end_date)
if not start or not end:
LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA)
return False
# 触发刷新
self._trigger_background_refresh(end_date)
return True
except Exception as exc:
LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA)
return False
def get_index_stocks(
self,
index_code: str,
trade_date: str,
min_weight: float = 0.0
) -> List[str]:
"""获取指数成分股列表。
Args:
index_code: 指数代码(如 000300.SH)
trade_date: 交易日期
min_weight: 最小权重筛选
Returns:
成分股代码列表
"""
try:
with db_session(read_only=True) as conn:
# 获取小于等于给定日期的最新一期成分股
rows = conn.execute(
"""
SELECT DISTINCT ts_code
FROM index_weight
WHERE index_code = ?
AND trade_date = (
SELECT MAX(trade_date)
FROM index_weight
WHERE index_code = ?
AND trade_date <= ?
)
AND weight >= ?
ORDER BY weight DESC
""",
(index_code, index_code, trade_date, min_weight)
).fetchall()
return [row["ts_code"] for row in rows if row and row["ts_code"]]
except Exception as exc:
LOGGER.exception(
"获取指数成分股失败 index=%s date=%s err=%s",
index_code,
trade_date,
exc,
extra=LOG_EXTRA
)
return []
def get_refresh_status(self) -> Dict[str, Dict[str, Any]]:
"""获取当前所有补数任务的状态。
Returns:
Dict: 包含所有补数任务状态的字典
"""
with self._refresh_lock:
status = {}
for key, in_progress in self._refresh_in_progress.items():
start, end = key.split('_')[:2] if '_' in key else (key, key)
status[key] = {
'start_date': start,
'end_date': end,
'in_progress': in_progress,
'callback_count': len(self._refresh_callbacks.get(key, []))
}
return status
def cancel_all_refresh_tasks(self) -> None:
"""取消所有正在等待的补数任务回调。
注意:已经开始执行的补数任务无法取消,但它们的结果将被忽略。
"""
with self._refresh_lock:
self._refresh_callbacks.clear()
# 保留刷新状态以避免立即重新触发
LOGGER.info("所有补数任务回调已取消", extra=LOG_EXTRA)
def clear_coverage_cache(self) -> None:
"""清除数据覆盖情况的缓存。"""
self._coverage_cache.clear()
LOGGER.info("数据覆盖缓存已清除", extra=LOG_EXTRA)
def get_stock_info(self, ts_code: str, trade_date: str = None) -> Optional[Dict[str, Any]]:
"""获取股票基本信息。
Args:
ts_code: 股票代码
trade_date: 交易日期,默认为最新日期
Returns:
Dict: 股票基本信息,包含名称、行业等
"""
if not trade_date:
# 如果没有提供交易日期,使用当前日期
trade_date = datetime.now().strftime("%Y%m%d")
try:
# 获取股票基本信息
raw_info = self.fetch_latest(
ts_code=ts_code,
trade_date=trade_date,
fields=["stock_basic.name", "stock_basic.industry"]
)
if not raw_info:
return None
info: Dict[str, Any] = {"ts_code": ts_code}
for key, value in raw_info.items():
if key == "ts_code":
continue
alias = key.split(".", 1)[-1] if isinstance(key, str) and "." in key else key
info[alias] = value
return info
except Exception as exc:
LOGGER.debug(
"获取股票信息失败 ts_code=%s err=%s",
ts_code,
exc,
extra=LOG_EXTRA
)
return None
def fetch_latest_factor(self, ts_code: str, factor: str, eval_date: date) -> Optional[float]:
"""获取指定股票的最新因子值。
Args:
ts_code: 股票代码
factor: 因子名称
eval_date: 评估日期
Returns:
float: 因子值如果获取失败则返回None
"""
trade_date = eval_date.strftime("%Y%m%d")
try:
# 构建因子字段名称
factor_field = f"factors.{factor}"
# 获取因子值
result = self.fetch_latest(
ts_code=ts_code,
trade_date=trade_date,
fields=[factor_field]
)
if not result or factor_field not in result:
return None
return result[factor_field]
except Exception as exc:
LOGGER.debug(
"获取因子值失败 ts_code=%s factor=%s err=%s",
ts_code,
factor,
exc,
extra=LOG_EXTRA
)
return None
def get_all_stocks(self, trade_date: str) -> List[str]:
"""获取所有股票代码列表。
Args:
trade_date: 交易日期
Returns:
List[str]: 股票代码列表
"""
try:
with db_session(read_only=True) as conn:
# 获取指定日期之前所有有交易的股票
rows = conn.execute(
"""
SELECT DISTINCT ts_code
FROM daily
WHERE trade_date <= ?
ORDER BY ts_code
""",
(trade_date,)
).fetchall()
return [row["ts_code"] for row in rows if row and row["ts_code"]]
except Exception as exc:
LOGGER.exception(
"获取所有股票失败 date=%s err=%s",
trade_date,
exc,
extra=LOG_EXTRA
)
return []
def get_data_coverage(self, start_date: str, end_date: str) -> Dict:
"""获取指定日期范围内的数据覆盖情况。
Args:
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
Returns:
Dict: 数据覆盖情况的详细信息
"""
try:
# 解析日期
start = _parse_trade_date(start_date)
end = _parse_trade_date(end_date)
if not start or not end:
LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA)
return {}
# 转换日期格式
start_d = date.fromisoformat(start.strftime('%Y-%m-%d'))
end_d = date.fromisoformat(end.strftime('%Y-%m-%d'))
if self.coverage_collector is None:
LOGGER.debug("未配置覆盖统计函数,返回空覆盖结果", extra=LOG_EXTRA)
return {}
return self.coverage_collector(start_d, end_d)
except Exception as exc:
LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA)
return {}
def evaluate_data_quality(
self,
*,
window_days: int = 7,
top_issues: int = 5,
) -> "DataQualitySummary":
"""Run data-quality checks and return a scored summary."""
from app.utils.data_quality import evaluate_data_quality as _evaluate
return _evaluate(window_days=window_days, top_issues=top_issues)
def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table)
if columns is None:
return None
alias_map = self.FIELD_ALIASES.get(table, {})
candidate = alias_map.get(column, column)
if candidate in columns:
return candidate
# Try lower-case or fallback alias normalization
lowered = candidate.lower()
for name in columns:
if name.lower() == lowered:
return name
return None
# 确保time模块可用
import sys
try:
import time
except ImportError:
# 创建一个简单的替代实现
class TimeStub:
def time(self):
return 0
def sleep(self, seconds):
pass
time = TimeStub()
LOGGER.warning("无法导入time模块使用替代实现", extra=LOG_EXTRA)