llm-quant/app/utils/db_query.py

111 lines
4.1 KiB
Python

"""Shared read-only query helpers for database access."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, List, Mapping, Optional, Sequence
@dataclass
class BrokerQueryEngine:
"""Lightweight wrapper around standard query patterns."""
session_factory: Callable[..., object]
_date_cache: dict = None
def _find_date_column(self, conn, table: str) -> str | None:
"""Return the best date column for the table or None if none found."""
if self._date_cache is None:
self._date_cache = {}
if table in self._date_cache:
return self._date_cache[table]
try:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception:
self._date_cache[table] = None
return None
cols = [row[1] if isinstance(row, tuple) else row["name"] for row in rows]
# Prefer canonical 'trade_date'
if "trade_date" in cols:
self._date_cache[table] = "trade_date"
return "trade_date"
# Prefer any column that ends with '_date'
for c in cols:
if isinstance(c, str) and c.endswith("_date"):
self._date_cache[table] = c
return c
# No date-like column
self._date_cache[table] = None
return None
def fetch_latest(
self,
table: str,
ts_code: str,
trade_date: str,
columns: Sequence[str],
) -> Optional[Mapping[str, object]]:
if not columns:
return None
joined_cols = ", ".join(columns)
with self.session_factory(read_only=True) as conn:
date_col = self._find_date_column(conn, table)
if table == "suspend" or date_col is None:
# For suspend table we prefer to query by ts_code only
query = f"SELECT {joined_cols} FROM {table} WHERE ts_code = ? ORDER BY rowid DESC LIMIT 1"
return conn.execute(query, (ts_code,)).fetchone()
query = (
f"SELECT {date_col}, {joined_cols} FROM {table} "
f"WHERE ts_code = ? AND {date_col} <= ? "
f"ORDER BY {date_col} DESC LIMIT 1"
)
return conn.execute(query, (ts_code, trade_date)).fetchone()
def fetch_series(
self,
table: str,
column: str,
ts_code: str,
end_date: str,
limit: int,
) -> List[Mapping[str, object]]:
with self.session_factory(read_only=True) as conn:
date_col = self._find_date_column(conn, table)
if date_col is None:
# No date column: return most recent rows by rowid
query = f"SELECT rowid AS trade_date, {column} FROM {table} WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
rows = conn.execute(query, (ts_code, limit)).fetchall()
else:
query = (
f"SELECT {date_col} AS trade_date, {column} FROM {table} "
f"WHERE ts_code = ? AND {date_col} <= ? "
f"ORDER BY {date_col} DESC LIMIT ?"
)
rows = conn.execute(query, (ts_code, end_date, limit)).fetchall()
return list(rows)
def fetch_table(
self,
table: str,
columns: Iterable[str],
ts_code: str,
trade_date: Optional[str],
limit: int,
) -> List[Mapping[str, object]]:
cols = ", ".join(columns)
if trade_date is None:
query = (
f"SELECT {cols} FROM {table} "
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
)
params: Sequence[object] = (ts_code, limit)
else:
query = (
f"SELECT {cols} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
params = (ts_code, trade_date, limit)
with self.session_factory(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
return list(rows)