166 lines
5.2 KiB
Python
166 lines
5.2 KiB
Python
"""Utility helpers to retrieve structured data slices for agents and departments."""
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Iterable, List, Sequence, Tuple
|
|
|
|
from .db import db_session
|
|
from .logging import get_logger
|
|
|
|
LOGGER = get_logger(__name__)
|
|
LOG_EXTRA = {"stage": "data_broker"}
|
|
|
|
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|
|
|
|
|
def _is_safe_identifier(name: str) -> bool:
|
|
return bool(_IDENTIFIER_RE.match(name))
|
|
|
|
|
|
def _safe_split(path: str) -> Tuple[str, str] | None:
|
|
if "." not in path:
|
|
return None
|
|
table, column = path.split(".", 1)
|
|
table = table.strip()
|
|
column = column.strip()
|
|
if not table or not column:
|
|
return None
|
|
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
|
LOGGER.debug("忽略非法字段:%s", path, extra=LOG_EXTRA)
|
|
return None
|
|
return table, column
|
|
|
|
|
|
def parse_field_path(path: str) -> Tuple[str, str] | None:
|
|
"""Validate and split a `table.column` field expression."""
|
|
|
|
return _safe_split(path)
|
|
|
|
|
|
@dataclass
|
|
class DataBroker:
|
|
"""Lightweight data access helper for agent/LLM consumption."""
|
|
|
|
def fetch_latest(
|
|
self,
|
|
ts_code: str,
|
|
trade_date: str,
|
|
fields: Iterable[str],
|
|
) -> Dict[str, float]:
|
|
"""Fetch the latest value (<= trade_date) for each requested field."""
|
|
|
|
grouped: Dict[str, List[str]] = {}
|
|
for item in fields:
|
|
if not item:
|
|
continue
|
|
normalized = _safe_split(str(item))
|
|
if not normalized:
|
|
continue
|
|
table, column = normalized
|
|
grouped.setdefault(table, [])
|
|
if column not in grouped[table]:
|
|
grouped[table].append(column)
|
|
|
|
if not grouped:
|
|
return {}
|
|
|
|
results: Dict[str, float] = {}
|
|
with db_session(read_only=True) as conn:
|
|
for table, columns in grouped.items():
|
|
joined_cols = ", ".join(columns)
|
|
query = (
|
|
f"SELECT trade_date, {joined_cols} FROM {table} "
|
|
"WHERE ts_code = ? AND trade_date <= ? "
|
|
"ORDER BY trade_date DESC LIMIT 1"
|
|
)
|
|
try:
|
|
row = conn.execute(query, (ts_code, trade_date)).fetchone()
|
|
except Exception as exc: # noqa: BLE001
|
|
LOGGER.debug(
|
|
"查询失败 table=%s fields=%s err=%s",
|
|
table,
|
|
columns,
|
|
exc,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
continue
|
|
if not row:
|
|
continue
|
|
for column in columns:
|
|
value = row[column]
|
|
if value is None:
|
|
continue
|
|
key = f"{table}.{column}"
|
|
results[key] = float(value)
|
|
return results
|
|
|
|
def fetch_series(
|
|
self,
|
|
table: str,
|
|
column: str,
|
|
ts_code: str,
|
|
end_date: str,
|
|
window: int,
|
|
) -> List[Tuple[str, float]]:
|
|
"""Return descending time series tuples within the specified window."""
|
|
|
|
if window <= 0:
|
|
return []
|
|
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
|
return []
|
|
query = (
|
|
f"SELECT trade_date, {column} FROM {table} "
|
|
"WHERE ts_code = ? AND trade_date <= ? "
|
|
"ORDER BY trade_date DESC LIMIT ?"
|
|
)
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
rows = conn.execute(query, (ts_code, end_date, window)).fetchall()
|
|
except Exception as exc: # noqa: BLE001
|
|
LOGGER.debug(
|
|
"时间序列查询失败 table=%s column=%s err=%s",
|
|
table,
|
|
column,
|
|
exc,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
return []
|
|
series: List[Tuple[str, float]] = []
|
|
for row in rows:
|
|
value = row[column]
|
|
if value is None:
|
|
continue
|
|
series.append((row["trade_date"], float(value)))
|
|
return series
|
|
|
|
def fetch_flags(
|
|
self,
|
|
table: str,
|
|
ts_code: str,
|
|
trade_date: str,
|
|
where_clause: str,
|
|
params: Sequence[object],
|
|
) -> bool:
|
|
"""Generic helper to test if a record exists (used for limit/suspend lookups)."""
|
|
|
|
if not _is_safe_identifier(table):
|
|
return False
|
|
query = (
|
|
f"SELECT 1 FROM {table} WHERE ts_code = ? AND {where_clause} LIMIT 1"
|
|
)
|
|
bind_params = (ts_code, *params)
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
row = conn.execute(query, bind_params).fetchone()
|
|
except Exception as exc: # noqa: BLE001
|
|
LOGGER.debug(
|
|
"flag 查询失败 table=%s where=%s err=%s",
|
|
table,
|
|
where_clause,
|
|
exc,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
return False
|
|
return row is not None
|