llm-quant/app/utils/data_access.py
2025-09-28 18:57:05 +08:00

166 lines
5.2 KiB
Python

"""Utility helpers to retrieve structured data slices for agents and departments."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple
from .db import db_session
from .logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_broker"}
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _is_safe_identifier(name: str) -> bool:
return bool(_IDENTIFIER_RE.match(name))
def _safe_split(path: str) -> Tuple[str, str] | None:
if "." not in path:
return None
table, column = path.split(".", 1)
table = table.strip()
column = column.strip()
if not table or not column:
return None
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA)
return None
return table, column
def parse_field_path(path: str) -> Tuple[str, str] | None:
"""Validate and split a `table.column` field expression."""
return _safe_split(path)
@dataclass
class DataBroker:
"""Lightweight data access helper for agent/LLM consumption."""
def fetch_latest(
self,
ts_code: str,
trade_date: str,
fields: Iterable[str],
) -> Dict[str, float]:
"""Fetch the latest value (<= trade_date) for each requested field."""
grouped: Dict[str, List[str]] = {}
for item in fields:
if not item:
continue
normalized = _safe_split(str(item))
if not normalized:
continue
table, column = normalized
grouped.setdefault(table, [])
if column not in grouped[table]:
grouped[table].append(column)
if not grouped:
return {}
results: Dict[str, float] = {}
with db_session(read_only=True) as conn:
for table, columns in grouped.items():
joined_cols = ", ".join(columns)
query = (
f"SELECT trade_date, {joined_cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT 1"
)
try:
row = conn.execute(query, (ts_code, trade_date)).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"查询失败 table=%s fields=%s err=%s",
table,
columns,
exc,
extra=LOG_EXTRA,
)
continue
if not row:
continue
for column in columns:
value = row[column]
if value is None:
continue
key = f"{table}.{column}"
results[key] = float(value)
return results
def fetch_series(
self,
table: str,
column: str,
ts_code: str,
end_date: str,
window: int,
) -> List[Tuple[str, float]]:
"""Return descending time series tuples within the specified window."""
if window <= 0:
return []
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
return []
query = (
f"SELECT trade_date, {column} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
with db_session(read_only=True) as conn:
try:
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列查询失败 table=%s column=%s err=%s",
table,
column,
exc,
extra=LOG_EXTRA,
)
return []
series: List[Tuple[str, float]] = []
for row in rows:
value = row[column]
if value is None:
continue
series.append((row["trade_date"], float(value)))
return series
def fetch_flags(
self,
table: str,
ts_code: str,
trade_date: str,
where_clause: str,
params: Sequence[object],
) -> bool:
"""Generic helper to test if a record exists (used for limit/suspend lookups)."""
if not _is_safe_identifier(table):
return False
query = (
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
)
bind_params = (ts_code, *params)
with db_session(read_only=True) as conn:
try:
row = conn.execute(query, bind_params).fetchone()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"flag 查询失败 table=%s where=%s err=%s",
table,
where_clause,
exc,
extra=LOG_EXTRA,
)
return False
return row is not None