update
This commit is contained in:
parent
229bb60e74
commit
cfff1aefe5
@ -452,7 +452,7 @@ def render_today_plan() -> None:
|
|||||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
||||||
|
|
||||||
# 一键重评估所有标的按钮
|
# 一键重评估所有标的按钮
|
||||||
if st.button("一键重评估所有标的", type="primary", use_container_width=True):
|
if st.button("一键重评估所有标的", type="primary", width='stretch'):
|
||||||
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
||||||
try:
|
try:
|
||||||
# 解析交易日
|
# 解析交易日
|
||||||
@ -831,6 +831,9 @@ def render_today_plan() -> None:
|
|||||||
|
|
||||||
# 显示新闻表格
|
# 显示新闻表格
|
||||||
news_df = pd.DataFrame(news_data)
|
news_df = pd.DataFrame(news_data)
|
||||||
|
# 确保所有列都是字符串类型,避免PyArrow序列化错误
|
||||||
|
for col in news_df.columns:
|
||||||
|
news_df[col] = news_df[col].astype(str)
|
||||||
st.dataframe(news_df, width='stretch', hide_index=True)
|
st.dataframe(news_df, width='stretch', hide_index=True)
|
||||||
|
|
||||||
# 添加新闻详情展开视图
|
# 添加新闻详情展开视图
|
||||||
@ -1139,8 +1142,11 @@ def render_log_viewer() -> None:
|
|||||||
# 将sqlite3.Row对象转换为字典列表
|
# 将sqlite3.Row对象转换为字典列表
|
||||||
rows_dict = [{key: row[key] for key in row.keys()} for row in rows]
|
rows_dict = [{key: row[key] for key in row.keys()} for row in rows]
|
||||||
log_df = pd.DataFrame(rows_dict)
|
log_df = pd.DataFrame(rows_dict)
|
||||||
# 格式化时间戳
|
# 格式化时间戳并确保数据类型一致
|
||||||
log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S")
|
log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
# 确保所有列都是字符串类型,避免PyArrow序列化错误
|
||||||
|
for col in log_df.columns:
|
||||||
|
log_df[col] = log_df[col].astype(str)
|
||||||
else:
|
else:
|
||||||
log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"])
|
log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"])
|
||||||
|
|
||||||
@ -1154,8 +1160,7 @@ def render_log_viewer() -> None:
|
|||||||
"stage": st.column_config.TextColumn("执行阶段"),
|
"stage": st.column_config.TextColumn("执行阶段"),
|
||||||
"level": st.column_config.TextColumn("日志级别"),
|
"level": st.column_config.TextColumn("日志级别"),
|
||||||
"msg": st.column_config.TextColumn("日志消息", width="large")
|
"msg": st.column_config.TextColumn("日志消息", width="large")
|
||||||
},
|
}
|
||||||
use_container_width=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 下载功能
|
# 下载功能
|
||||||
@ -1218,6 +1223,12 @@ def render_log_viewer() -> None:
|
|||||||
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
df2 = pd.DataFrame(logs2, columns=["level", "count"])
|
||||||
df2["date"] = compare_date2.strftime("%Y-%m-%d")
|
df2["date"] = compare_date2.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# 确保所有列的数据类型一致,避免PyArrow序列化错误
|
||||||
|
for df in [df1, df2]:
|
||||||
|
for col in df.columns:
|
||||||
|
if col != "level": # level列保持字符串类型
|
||||||
|
df[col] = df[col].astype(object)
|
||||||
|
|
||||||
compare_df = pd.concat([df1, df2])
|
compare_df = pd.concat([df1, df2])
|
||||||
|
|
||||||
# 绘制对比图表
|
# 绘制对比图表
|
||||||
@ -1229,7 +1240,7 @@ def render_log_viewer() -> None:
|
|||||||
barmode="group",
|
barmode="group",
|
||||||
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})"
|
title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})"
|
||||||
)
|
)
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
st.plotly_chart(fig, width='stretch')
|
||||||
|
|
||||||
# 显示详细对比表格
|
# 显示详细对比表格
|
||||||
st.write("日志统计对比:")
|
st.write("日志统计对比:")
|
||||||
@ -1898,7 +1909,7 @@ def render_log_viewer() -> None:
|
|||||||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||||
if use_log_y:
|
if use_log_y:
|
||||||
fig.update_yaxes(type="log")
|
fig.update_yaxes(type="log")
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
st.plotly_chart(fig, width='stretch')
|
||||||
# ADD: export pivot
|
# ADD: export pivot
|
||||||
try:
|
try:
|
||||||
csv_buf = pivot.reset_index()
|
csv_buf = pivot.reset_index()
|
||||||
@ -2602,6 +2613,15 @@ def render_tests() -> None:
|
|||||||
})
|
})
|
||||||
df_reset["成交额(千元)"] = df_reset["成交额(千元)"] / 1000
|
df_reset["成交额(千元)"] = df_reset["成交额(千元)"] / 1000
|
||||||
|
|
||||||
|
# 确保所有列的数据类型正确,避免PyArrow序列化错误
|
||||||
|
numeric_columns = ["开盘价", "最高价", "最低价", "收盘价", "成交量(手)", "成交额(千元)"]
|
||||||
|
for col in numeric_columns:
|
||||||
|
if col in df_reset.columns:
|
||||||
|
df_reset[col] = pd.to_numeric(df_reset[col], errors='coerce')
|
||||||
|
|
||||||
|
# 确保日期列是datetime类型
|
||||||
|
df_reset["交易日"] = pd.to_datetime(df_reset["交易日"])
|
||||||
|
|
||||||
candle_fig = go.Figure(
|
candle_fig = go.Figure(
|
||||||
data=[
|
data=[
|
||||||
go.Candlestick(
|
go.Candlestick(
|
||||||
@ -2615,7 +2635,7 @@ def render_tests() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(candle_fig, use_container_width=True)
|
st.plotly_chart(candle_fig, width='stretch')
|
||||||
|
|
||||||
vol_fig = px.bar(
|
vol_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -2625,7 +2645,7 @@ def render_tests() -> None:
|
|||||||
title="成交量",
|
title="成交量",
|
||||||
)
|
)
|
||||||
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(vol_fig, use_container_width=True)
|
st.plotly_chart(vol_fig, width='stretch')
|
||||||
|
|
||||||
amt_fig = px.bar(
|
amt_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -2635,9 +2655,11 @@ def render_tests() -> None:
|
|||||||
title="成交额",
|
title="成交额",
|
||||||
)
|
)
|
||||||
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(amt_fig, use_container_width=True)
|
st.plotly_chart(amt_fig, width='stretch')
|
||||||
|
|
||||||
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
||||||
|
# 确保收盘价列是数值类型
|
||||||
|
df_reset["收盘价"] = pd.to_numeric(df_reset["收盘价"], errors='coerce')
|
||||||
box_fig = px.box(
|
box_fig = px.box(
|
||||||
df_reset,
|
df_reset,
|
||||||
x="月份",
|
x="月份",
|
||||||
@ -2646,7 +2668,7 @@ def render_tests() -> None:
|
|||||||
title="月度收盘价分布",
|
title="月度收盘价分布",
|
||||||
)
|
)
|
||||||
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(box_fig, use_container_width=True)
|
st.plotly_chart(box_fig, width='stretch')
|
||||||
|
|
||||||
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
||||||
st.dataframe(df_reset.tail(20), width='stretch')
|
st.dataframe(df_reset.tail(20), width='stretch')
|
||||||
|
|||||||
@ -3,16 +3,39 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple
|
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
from app.core.indicators import momentum, normalize, rolling_mean, volatility
|
||||||
|
|
||||||
|
# 延迟导入,避免循环依赖
|
||||||
|
collect_data_coverage = None
|
||||||
|
ensure_data_coverage = None
|
||||||
|
initialize_database = None
|
||||||
|
|
||||||
|
# 在模块加载时尝试导入
|
||||||
|
if collect_data_coverage is None or ensure_data_coverage is None:
|
||||||
|
try:
|
||||||
|
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
|
||||||
|
except ImportError:
|
||||||
|
# 导入失败时,在实际使用时会报错
|
||||||
|
pass
|
||||||
|
|
||||||
|
if initialize_database is None:
|
||||||
|
try:
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
except ImportError:
|
||||||
|
# 导入失败时,提供一个空实现
|
||||||
|
def initialize_database():
|
||||||
|
pass
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "data_broker"}
|
LOG_EXTRA = {"stage": "data_broker"}
|
||||||
|
|
||||||
@ -66,7 +89,7 @@ def _end_of_day(dt: datetime) -> str:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataBroker:
|
class DataBroker:
|
||||||
"""Lightweight data access helper for agent/LLM consumption."""
|
"""Lightweight data access helper with automated data fetching capabilities."""
|
||||||
|
|
||||||
FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = {
|
FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = {
|
||||||
"daily": {
|
"daily": {
|
||||||
@ -92,24 +115,50 @@ class DataBroker:
|
|||||||
}
|
}
|
||||||
MAX_WINDOW: ClassVar[int] = 120
|
MAX_WINDOW: ClassVar[int] = 120
|
||||||
BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
|
BENCHMARK_INDEX: ClassVar[str] = "000300.SH"
|
||||||
|
# 自动补数配置
|
||||||
|
AUTO_REFRESH_WINDOW: ClassVar[int] = 7 # 自动补数的时间窗口
|
||||||
|
REFRESH_RETRY_INTERVAL: ClassVar[int] = 5 # 补数重试间隔(秒)
|
||||||
|
MAX_REFRESH_WAIT: ClassVar[int] = 60 # 最大等待补数完成时间(秒)
|
||||||
|
|
||||||
enable_cache: bool = True
|
enable_cache: bool = True
|
||||||
latest_cache_size: int = 256
|
latest_cache_size: int = 256
|
||||||
series_cache_size: int = 512
|
series_cache_size: int = 512
|
||||||
_latest_cache: OrderedDict = field(init=False, repr=False)
|
_latest_cache: OrderedDict = field(init=False, repr=False)
|
||||||
_series_cache: OrderedDict = field(init=False, repr=False)
|
_series_cache: OrderedDict = field(init=False, repr=False)
|
||||||
|
# 补数相关状态管理
|
||||||
|
_refresh_lock: threading.RLock = field(init=False, repr=False)
|
||||||
|
_refresh_in_progress: Dict[str, bool] = field(init=False, repr=False)
|
||||||
|
_refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False)
|
||||||
|
_coverage_cache: Dict[str, Dict] = field(init=False, repr=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self._latest_cache = OrderedDict()
|
self._latest_cache = OrderedDict()
|
||||||
self._series_cache = OrderedDict()
|
self._series_cache = OrderedDict()
|
||||||
|
# 初始化补数相关状态
|
||||||
|
self._refresh_lock = threading.RLock()
|
||||||
|
self._refresh_in_progress = {}
|
||||||
|
self._refresh_callbacks = {}
|
||||||
|
self._coverage_cache = {}
|
||||||
|
if initialize_database is not None:
|
||||||
|
initialize_database() # 确保数据库已初始化
|
||||||
|
else:
|
||||||
|
LOGGER.warning("initialize_database 函数不可用,数据库可能未初始化", extra=LOG_EXTRA)
|
||||||
|
|
||||||
def fetch_latest(
|
def fetch_latest(
|
||||||
self,
|
self,
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
trade_date: str,
|
trade_date: str,
|
||||||
fields: Iterable[str],
|
fields: Iterable[str],
|
||||||
|
auto_refresh: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Fetch the latest value (<= trade_date) for each requested field."""
|
"""Fetch the latest value (<= trade_date) for each requested field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: 证券代码
|
||||||
|
trade_date: 交易日
|
||||||
|
fields: 要查询的字段列表
|
||||||
|
auto_refresh: 是否在数据不足时自动触发补数
|
||||||
|
"""
|
||||||
field_list = [str(item) for item in fields if item]
|
field_list = [str(item) for item in fields if item]
|
||||||
cache_key: Optional[Tuple[Any, ...]] = None
|
cache_key: Optional[Tuple[Any, ...]] = None
|
||||||
if self.enable_cache and field_list:
|
if self.enable_cache and field_list:
|
||||||
@ -118,6 +167,25 @@ class DataBroker:
|
|||||||
if cached is not None:
|
if cached is not None:
|
||||||
return deepcopy(cached)
|
return deepcopy(cached)
|
||||||
|
|
||||||
|
# 检查是否需要自动补数
|
||||||
|
if auto_refresh:
|
||||||
|
# 解析交易日以确定是否需要补数
|
||||||
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
|
if parsed_date:
|
||||||
|
# 检查最近交易日的数据是否存在
|
||||||
|
recent_trade_date = parsed_date.strftime('%Y%m%d')
|
||||||
|
# 对涉及的表进行数据可用性检查
|
||||||
|
tables = set()
|
||||||
|
for field_name in field_list:
|
||||||
|
resolved = self.resolve_field(field_name)
|
||||||
|
if resolved:
|
||||||
|
table, _ = resolved
|
||||||
|
tables.add(table)
|
||||||
|
|
||||||
|
if tables and self.check_data_availability(recent_trade_date, tables):
|
||||||
|
# 数据不足,触发后台补数
|
||||||
|
self._trigger_background_refresh(recent_trade_date)
|
||||||
|
|
||||||
grouped: Dict[str, List[str]] = {}
|
grouped: Dict[str, List[str]] = {}
|
||||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||||
derived_cache: Dict[str, Any] = {}
|
derived_cache: Dict[str, Any] = {}
|
||||||
@ -209,8 +277,18 @@ class DataBroker:
|
|||||||
ts_code: str,
|
ts_code: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
window: int,
|
window: int,
|
||||||
|
auto_refresh: bool = True,
|
||||||
) -> List[Tuple[str, float]]:
|
) -> List[Tuple[str, float]]:
|
||||||
"""Return descending time series tuples within the specified window."""
|
"""Return descending time series tuples within the specified window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: 表名
|
||||||
|
column: 列名
|
||||||
|
ts_code: 证券代码
|
||||||
|
end_date: 结束日期
|
||||||
|
window: 时间窗口大小
|
||||||
|
auto_refresh: 是否在数据不足时自动触发补数
|
||||||
|
"""
|
||||||
|
|
||||||
if window <= 0:
|
if window <= 0:
|
||||||
return []
|
return []
|
||||||
@ -226,6 +304,12 @@ class DataBroker:
|
|||||||
return []
|
return []
|
||||||
table, resolved = resolved_field
|
table, resolved = resolved_field
|
||||||
|
|
||||||
|
# 检查是否需要自动补数
|
||||||
|
if auto_refresh:
|
||||||
|
parsed_date = _parse_trade_date(end_date)
|
||||||
|
if parsed_date and self.check_data_availability(end_date, {table}):
|
||||||
|
self._trigger_background_refresh(end_date)
|
||||||
|
|
||||||
cache_key: Optional[Tuple[Any, ...]] = None
|
cache_key: Optional[Tuple[Any, ...]] = None
|
||||||
if self.enable_cache:
|
if self.enable_cache:
|
||||||
cache_key = (table, resolved, ts_code, end_date, window)
|
cache_key = (table, resolved, ts_code, end_date, window)
|
||||||
@ -335,6 +419,16 @@ class DataBroker:
|
|||||||
if window <= 0:
|
if window <= 0:
|
||||||
return []
|
return []
|
||||||
window = min(window, self.MAX_WINDOW)
|
window = min(window, self.MAX_WINDOW)
|
||||||
|
|
||||||
|
# 检查是否需要自动补数
|
||||||
|
if auto_refresh:
|
||||||
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
|
if parsed_date and self.check_data_availability(trade_date, {table}):
|
||||||
|
self._trigger_background_refresh(trade_date)
|
||||||
|
# 短暂等待以获取最新数据
|
||||||
|
if hasattr(time, 'sleep'):
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
columns = self._get_table_columns(table)
|
columns = self._get_table_columns(table)
|
||||||
if not columns:
|
if not columns:
|
||||||
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
|
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
|
||||||
@ -698,6 +792,335 @@ class DataBroker:
|
|||||||
while len(cache) > limit:
|
while len(cache) > limit:
|
||||||
cache.popitem(last=False)
|
cache.popitem(last=False)
|
||||||
|
|
||||||
|
def check_data_availability(
|
||||||
|
self,
|
||||||
|
trade_date: str,
|
||||||
|
tables: Set[str] = None,
|
||||||
|
threshold: float = 0.8,
|
||||||
|
) -> bool:
|
||||||
|
"""检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_date: 要检查的交易日
|
||||||
|
tables: 要检查的表集合,默认检查主要行情表
|
||||||
|
threshold: 数据覆盖率阈值,低于此值需要补数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示数据不足,需要补数
|
||||||
|
"""
|
||||||
|
# 如果配置了强制刷新,则始终返回需要补数
|
||||||
|
if get_config().force_refresh:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 如果未启用自动更新,则不进行补数
|
||||||
|
if not get_config().auto_update_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 默认检查的表
|
||||||
|
if tables is None:
|
||||||
|
tables = {"daily", "daily_basic", "stock_basic", "trade_cal"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析交易日
|
||||||
|
parsed_date = _parse_trade_date(trade_date)
|
||||||
|
if not parsed_date:
|
||||||
|
LOGGER.debug("无法解析交易日: %s", trade_date, extra=LOG_EXTRA)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 计算检查窗口
|
||||||
|
end_date = parsed_date.strftime('%Y%m%d')
|
||||||
|
start_date = (parsed_date - timedelta(days=self.AUTO_REFRESH_WINDOW)).strftime('%Y%m%d')
|
||||||
|
|
||||||
|
# 构建缓存键
|
||||||
|
cache_key = f"{start_date}_{end_date}_{'_'.join(sorted(tables))}"
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if cache_key in self._coverage_cache:
|
||||||
|
coverage = self._coverage_cache[cache_key]
|
||||||
|
current_time = time.time() if hasattr(time, 'time') else 0
|
||||||
|
if coverage.get('timestamp', 0) > current_time - 300: # 5分钟内有效
|
||||||
|
# 检查是否需要补数
|
||||||
|
for table in tables:
|
||||||
|
table_coverage = coverage.get(table, {})
|
||||||
|
if table_coverage.get('coverage', 0) < threshold:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 收集数据覆盖情况
|
||||||
|
if collect_data_coverage is None:
|
||||||
|
LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA)
|
||||||
|
return False
|
||||||
|
|
||||||
|
coverage = collect_data_coverage(
|
||||||
|
date.fromisoformat(start_date[:4] + '-' + start_date[4:6] + '-' + start_date[6:8]),
|
||||||
|
date.fromisoformat(end_date[:4] + '-' + end_date[4:6] + '-' + end_date[6:8])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到缓存
|
||||||
|
coverage['timestamp'] = time.time() if hasattr(time, 'time') else 0
|
||||||
|
self._coverage_cache[cache_key] = coverage
|
||||||
|
|
||||||
|
# 检查是否需要补数
|
||||||
|
for table in tables:
|
||||||
|
table_coverage = coverage.get(table, {})
|
||||||
|
if table_coverage.get('coverage', 0) < threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("检查数据可用性失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
# 出错时保守处理,不触发补数
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _trigger_background_refresh(self, target_date: str) -> None:
|
||||||
|
"""在后台线程触发数据补数。"""
|
||||||
|
parsed_date = _parse_trade_date(target_date)
|
||||||
|
if not parsed_date:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 构建补数日期范围
|
||||||
|
end_date = parsed_date.date()
|
||||||
|
start_date = end_date - timedelta(days=self.AUTO_REFRESH_WINDOW)
|
||||||
|
refresh_key = f"{start_date}_{end_date}"
|
||||||
|
|
||||||
|
# 检查是否已经在补数中
|
||||||
|
with self._refresh_lock:
|
||||||
|
if self._refresh_in_progress.get(refresh_key, False):
|
||||||
|
LOGGER.debug("数据补数已经在进行中: %s", refresh_key, extra=LOG_EXTRA)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._refresh_in_progress[refresh_key] = True
|
||||||
|
self._refresh_callbacks.setdefault(refresh_key, [])
|
||||||
|
|
||||||
|
def refresh_task():
|
||||||
|
try:
|
||||||
|
LOGGER.info("开始后台数据补数: %s 至 %s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
# 执行补数
|
||||||
|
if ensure_data_coverage is None:
|
||||||
|
LOGGER.error("ensure_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA)
|
||||||
|
with self._refresh_lock:
|
||||||
|
self._refresh_in_progress[refresh_key] = False
|
||||||
|
return
|
||||||
|
|
||||||
|
ensure_data_coverage(
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
force=False,
|
||||||
|
progress_hook=None
|
||||||
|
)
|
||||||
|
|
||||||
|
LOGGER.info("后台数据补数完成: %s 至 %s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
# 清除缓存,强制重新加载数据
|
||||||
|
self._latest_cache.clear()
|
||||||
|
self._series_cache.clear()
|
||||||
|
self._coverage_cache.clear()
|
||||||
|
|
||||||
|
# 执行回调
|
||||||
|
with self._refresh_lock:
|
||||||
|
callbacks = self._refresh_callbacks.pop(refresh_key, [])
|
||||||
|
self._refresh_in_progress[refresh_key] = False
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
try:
|
||||||
|
callback()
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
with self._refresh_lock:
|
||||||
|
self._refresh_in_progress[refresh_key] = False
|
||||||
|
|
||||||
|
# 启动后台线程
|
||||||
|
thread = threading.Thread(target=refresh_task, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def is_refreshing(self, start_date: str = None, end_date: str = None) -> bool:
|
||||||
|
"""检查指定日期范围是否正在补数中。"""
|
||||||
|
with self._refresh_lock:
|
||||||
|
if not start_date and not end_date:
|
||||||
|
# 检查是否有任何补数正在进行
|
||||||
|
return any(self._refresh_in_progress.values())
|
||||||
|
|
||||||
|
# 检查指定日期范围
|
||||||
|
for key, in_progress in self._refresh_in_progress.items():
|
||||||
|
if in_progress and key.startswith(start_date or '') and key.endswith(end_date or ''):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def wait_for_refresh_complete(
|
||||||
|
self,
|
||||||
|
timeout: float = None,
|
||||||
|
start_date: str = None,
|
||||||
|
end_date: str = None
|
||||||
|
) -> bool:
|
||||||
|
"""等待数据补数完成。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 超时时间(秒),默认为MAX_REFRESH_WAIT
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示补数已完成,False表示超时
|
||||||
|
"""
|
||||||
|
if timeout is None:
|
||||||
|
timeout = self.MAX_REFRESH_WAIT
|
||||||
|
|
||||||
|
start_time = time.time() if hasattr(time, 'time') else 0
|
||||||
|
current_time_func = time.time if hasattr(time, 'time') else lambda: 0
|
||||||
|
while current_time_func() - start_time < timeout:
|
||||||
|
if not self.is_refreshing(start_date, end_date):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 短暂休眠后再次检查
|
||||||
|
if hasattr(time, 'sleep'):
|
||||||
|
time.sleep(min(self.REFRESH_RETRY_INTERVAL, timeout / 10))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_data_refresh(
|
||||||
|
self,
|
||||||
|
callback: Callable,
|
||||||
|
start_date: str = None,
|
||||||
|
end_date: str = None
|
||||||
|
) -> None:
|
||||||
|
"""注册数据补数完成的回调函数。"""
|
||||||
|
if start_date and end_date:
|
||||||
|
refresh_key = f"{start_date}_{end_date}"
|
||||||
|
with self._refresh_lock:
|
||||||
|
self._refresh_callbacks.setdefault(refresh_key, []).append(callback)
|
||||||
|
# 如果当前没有补数在进行,则直接调用回调
|
||||||
|
if not self._refresh_in_progress.get(refresh_key, False):
|
||||||
|
try:
|
||||||
|
callback()
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def set_auto_refresh_window(self, days: int) -> None:
|
||||||
|
"""设置自动补数的时间窗口。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 自动补数的天数窗口
|
||||||
|
"""
|
||||||
|
if days > 0:
|
||||||
|
self.AUTO_REFRESH_WINDOW = days
|
||||||
|
LOGGER.info("自动补数窗口已设置为 %d 天", days, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def set_refresh_retry_interval(self, seconds: int) -> None:
|
||||||
|
"""设置补数检查的重试间隔。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seconds: 重试间隔(秒)
|
||||||
|
"""
|
||||||
|
if seconds > 0:
|
||||||
|
self.REFRESH_RETRY_INTERVAL = seconds
|
||||||
|
LOGGER.info("补数重试间隔已设置为 %d 秒", seconds, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def set_max_refresh_wait(self, seconds: int) -> None:
|
||||||
|
"""设置最大等待补数完成时间。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seconds: 最大等待时间(秒)
|
||||||
|
"""
|
||||||
|
if seconds > 0:
|
||||||
|
self.MAX_REFRESH_WAIT = seconds
|
||||||
|
LOGGER.info("最大补数等待时间已设置为 %d 秒", seconds, extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def force_refresh_data(self, start_date: str, end_date: str) -> bool:
|
||||||
|
"""强制刷新指定日期范围内的数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 开始日期(格式:YYYYMMDD)
|
||||||
|
end_date: 结束日期(格式:YYYYMMDD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功触发刷新
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析日期
|
||||||
|
start = _parse_trade_date(start_date)
|
||||||
|
end = _parse_trade_date(end_date)
|
||||||
|
if not start or not end:
|
||||||
|
LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 触发刷新
|
||||||
|
self._trigger_background_refresh(end_date)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_refresh_status(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""获取当前所有补数任务的状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含所有补数任务状态的字典
|
||||||
|
"""
|
||||||
|
with self._refresh_lock:
|
||||||
|
status = {}
|
||||||
|
for key, in_progress in self._refresh_in_progress.items():
|
||||||
|
start, end = key.split('_')[:2] if '_' in key else (key, key)
|
||||||
|
status[key] = {
|
||||||
|
'start_date': start,
|
||||||
|
'end_date': end,
|
||||||
|
'in_progress': in_progress,
|
||||||
|
'callback_count': len(self._refresh_callbacks.get(key, []))
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
|
||||||
|
def cancel_all_refresh_tasks(self) -> None:
|
||||||
|
"""取消所有正在等待的补数任务回调。
|
||||||
|
注意:已经开始执行的补数任务无法取消,但它们的结果将被忽略。
|
||||||
|
"""
|
||||||
|
with self._refresh_lock:
|
||||||
|
self._refresh_callbacks.clear()
|
||||||
|
# 保留刷新状态以避免立即重新触发
|
||||||
|
LOGGER.info("所有补数任务回调已取消", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def clear_coverage_cache(self) -> None:
|
||||||
|
"""清除数据覆盖情况的缓存。"""
|
||||||
|
self._coverage_cache.clear()
|
||||||
|
LOGGER.info("数据覆盖缓存已清除", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
def get_data_coverage(self, start_date: str, end_date: str) -> Dict:
|
||||||
|
"""获取指定日期范围内的数据覆盖情况。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 开始日期(格式:YYYYMMDD)
|
||||||
|
end_date: 结束日期(格式:YYYYMMDD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 数据覆盖情况的详细信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 解析日期
|
||||||
|
start = _parse_trade_date(start_date)
|
||||||
|
end = _parse_trade_date(end_date)
|
||||||
|
if not start or not end:
|
||||||
|
LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 转换日期格式
|
||||||
|
start_d = date.fromisoformat(start.strftime('%Y-%m-%d'))
|
||||||
|
end_d = date.fromisoformat(end.strftime('%Y-%m-%d'))
|
||||||
|
|
||||||
|
# 收集数据覆盖情况
|
||||||
|
if collect_data_coverage is None:
|
||||||
|
LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
coverage = collect_data_coverage(start_d, end_d)
|
||||||
|
return coverage
|
||||||
|
except Exception as exc:
|
||||||
|
LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA)
|
||||||
|
return {}
|
||||||
|
|
||||||
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||||
columns = self._get_table_columns(table)
|
columns = self._get_table_columns(table)
|
||||||
if columns is None:
|
if columns is None:
|
||||||
@ -712,3 +1135,17 @@ class DataBroker:
|
|||||||
if name.lower() == lowered:
|
if name.lower() == lowered:
|
||||||
return name
|
return name
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 确保time模块可用
|
||||||
|
import sys
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
except ImportError:
|
||||||
|
# 创建一个简单的替代实现
|
||||||
|
class TimeStub:
|
||||||
|
def time(self):
|
||||||
|
return 0
|
||||||
|
def sleep(self, seconds):
|
||||||
|
pass
|
||||||
|
time = TimeStub()
|
||||||
|
LOGGER.warning("无法导入time模块,使用替代实现", extra=LOG_EXTRA)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user