111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
"""Shared read-only query helpers for database access."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Iterable, List, Mapping, Optional, Sequence
|
|
|
|
|
|
@dataclass
|
|
class BrokerQueryEngine:
|
|
"""Lightweight wrapper around standard query patterns."""
|
|
|
|
session_factory: Callable[..., object]
|
|
_date_cache: dict = None
|
|
|
|
def _find_date_column(self, conn, table: str) -> str | None:
|
|
"""Return the best date column for the table or None if none found."""
|
|
if self._date_cache is None:
|
|
self._date_cache = {}
|
|
if table in self._date_cache:
|
|
return self._date_cache[table]
|
|
try:
|
|
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
|
except Exception:
|
|
self._date_cache[table] = None
|
|
return None
|
|
cols = [row[1] if isinstance(row, tuple) else row["name"] for row in rows]
|
|
# Prefer canonical 'trade_date'
|
|
if "trade_date" in cols:
|
|
self._date_cache[table] = "trade_date"
|
|
return "trade_date"
|
|
# Prefer any column that ends with '_date'
|
|
for c in cols:
|
|
if isinstance(c, str) and c.endswith("_date"):
|
|
self._date_cache[table] = c
|
|
return c
|
|
# No date-like column
|
|
self._date_cache[table] = None
|
|
return None
|
|
|
|
def fetch_latest(
|
|
self,
|
|
table: str,
|
|
ts_code: str,
|
|
trade_date: str,
|
|
columns: Sequence[str],
|
|
) -> Optional[Mapping[str, object]]:
|
|
if not columns:
|
|
return None
|
|
joined_cols = ", ".join(columns)
|
|
with self.session_factory(read_only=True) as conn:
|
|
date_col = self._find_date_column(conn, table)
|
|
if table == "suspend" or date_col is None:
|
|
# For suspend table we prefer to query by ts_code only
|
|
query = f"SELECT {joined_cols} FROM {table} WHERE ts_code = ? ORDER BY rowid DESC LIMIT 1"
|
|
return conn.execute(query, (ts_code,)).fetchone()
|
|
query = (
|
|
f"SELECT {date_col}, {joined_cols} FROM {table} "
|
|
f"WHERE ts_code = ? AND {date_col} <= ? "
|
|
f"ORDER BY {date_col} DESC LIMIT 1"
|
|
)
|
|
return conn.execute(query, (ts_code, trade_date)).fetchone()
|
|
|
|
def fetch_series(
|
|
self,
|
|
table: str,
|
|
column: str,
|
|
ts_code: str,
|
|
end_date: str,
|
|
limit: int,
|
|
) -> List[Mapping[str, object]]:
|
|
with self.session_factory(read_only=True) as conn:
|
|
date_col = self._find_date_column(conn, table)
|
|
if date_col is None:
|
|
# No date column: return most recent rows by rowid
|
|
query = f"SELECT rowid AS trade_date, {column} FROM {table} WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
|
rows = conn.execute(query, (ts_code, limit)).fetchall()
|
|
else:
|
|
query = (
|
|
f"SELECT {date_col} AS trade_date, {column} FROM {table} "
|
|
f"WHERE ts_code = ? AND {date_col} <= ? "
|
|
f"ORDER BY {date_col} DESC LIMIT ?"
|
|
)
|
|
rows = conn.execute(query, (ts_code, end_date, limit)).fetchall()
|
|
return list(rows)
|
|
|
|
def fetch_table(
|
|
self,
|
|
table: str,
|
|
columns: Iterable[str],
|
|
ts_code: str,
|
|
trade_date: Optional[str],
|
|
limit: int,
|
|
) -> List[Mapping[str, object]]:
|
|
cols = ", ".join(columns)
|
|
if trade_date is None:
|
|
query = (
|
|
f"SELECT {cols} FROM {table} "
|
|
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
|
)
|
|
params: Sequence[object] = (ts_code, limit)
|
|
else:
|
|
query = (
|
|
f"SELECT {cols} FROM {table} "
|
|
"WHERE ts_code = ? AND trade_date <= ? "
|
|
"ORDER BY trade_date DESC LIMIT ?"
|
|
)
|
|
params = (ts_code, trade_date, limit)
|
|
with self.session_factory(read_only=True) as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
return list(rows)
|