From cfff1aefe576f4fbc121f8f8fca32dcf45dfdecb Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 5 Oct 2025 08:39:40 +0800 Subject: [PATCH] update --- app/ui/streamlit_app.py | 42 +++- app/utils/data_access.py | 447 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 474 insertions(+), 15 deletions(-) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index f101931..34fc9bb 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -452,7 +452,7 @@ def render_today_plan() -> None: batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) # 一键重评估所有标的按钮 - if st.button("一键重评估所有标的", type="primary", use_container_width=True): + if st.button("一键重评估所有标的", type="primary", width='stretch'): with st.spinner("正在对所有标的进行重评估,请稍候..."): try: # 解析交易日 @@ -831,6 +831,9 @@ def render_today_plan() -> None: # 显示新闻表格 news_df = pd.DataFrame(news_data) + # 确保所有列都是字符串类型,避免PyArrow序列化错误 + for col in news_df.columns: + news_df[col] = news_df[col].astype(str) st.dataframe(news_df, width='stretch', hide_index=True) # 添加新闻详情展开视图 @@ -1139,8 +1142,11 @@ def render_log_viewer() -> None: # 将sqlite3.Row对象转换为字典列表 rows_dict = [{key: row[key] for key in row.keys()} for row in rows] log_df = pd.DataFrame(rows_dict) - # 格式化时间戳 + # 格式化时间戳并确保数据类型一致 log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") + # 确保所有列都是字符串类型,避免PyArrow序列化错误 + for col in log_df.columns: + log_df[col] = log_df[col].astype(str) else: log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"]) @@ -1154,8 +1160,7 @@ def render_log_viewer() -> None: "stage": st.column_config.TextColumn("执行阶段"), "level": st.column_config.TextColumn("日志级别"), "msg": st.column_config.TextColumn("日志消息", width="large") - }, - use_container_width=True + } ) # 下载功能 @@ -1218,6 +1223,12 @@ def render_log_viewer() -> None: df2 = pd.DataFrame(logs2, columns=["level", "count"]) df2["date"] = compare_date2.strftime("%Y-%m-%d") + # 确保所有列的数据类型一致,避免PyArrow序列化错误 + for df in [df1, df2]: + for col in df.columns: + if col != "level": # level列保持字符串类型 + df[col] = df[col].astype(object) + compare_df = pd.concat([df1, df2]) # 绘制对比图表 @@ -1229,7 +1240,7 @@ def render_log_viewer() -> None: barmode="group", title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})" ) - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width='stretch') # 显示详细对比表格 st.write("日志统计对比:") @@ -1898,7 +1909,7 @@ def render_log_viewer() -> None: fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10)) if use_log_y: fig.update_yaxes(type="log") - st.plotly_chart(fig, use_container_width=True) + st.plotly_chart(fig, width='stretch') # ADD: export pivot try: csv_buf = pivot.reset_index() @@ -2601,6 +2612,15 @@ def render_tests() -> None: "amount": "成交额(千元)", }) df_reset["成交额(千元)"] = df_reset["成交额(千元)"] / 1000 + + # 确保所有列的数据类型正确,避免PyArrow序列化错误 + numeric_columns = ["开盘价", "最高价", "最低价", "收盘价", "成交量(手)", "成交额(千元)"] + for col in numeric_columns: + if col in df_reset.columns: + df_reset[col] = pd.to_numeric(df_reset[col], errors='coerce') + + # 确保日期列是datetime类型 + df_reset["交易日"] = pd.to_datetime(df_reset["交易日"]) candle_fig = go.Figure( data=[ @@ -2615,7 +2635,7 @@ def render_tests() -> None: ] ) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(candle_fig, use_container_width=True) + st.plotly_chart(candle_fig, width='stretch') vol_fig = px.bar( df_reset, @@ -2625,7 +2645,7 @@ def render_tests() -> None: title="成交量", ) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(vol_fig, use_container_width=True) + st.plotly_chart(vol_fig, width='stretch') amt_fig = px.bar( df_reset, @@ -2635,9 +2655,11 @@ def render_tests() -> None: title="成交额", ) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(amt_fig, use_container_width=True) + st.plotly_chart(amt_fig, width='stretch') df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) + # 确保收盘价列是数值类型 + df_reset["收盘价"] = pd.to_numeric(df_reset["收盘价"], errors='coerce') box_fig = px.box( df_reset, x="月份", @@ -2646,7 +2668,7 @@ def render_tests() -> None: title="月度收盘价分布", ) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(box_fig, use_container_width=True) + st.plotly_chart(box_fig, width='stretch') st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.dataframe(df_reset.tail(20), width='stretch') diff --git a/app/utils/data_access.py b/app/utils/data_access.py index f52546f..cc4302d 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -3,16 +3,39 @@ from __future__ import annotations import re import sqlite3 +import threading from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass, field -from datetime import datetime, timedelta -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Sequence, Tuple +from datetime import date, datetime, timedelta +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple +from .config import get_config from .db import db_session from .logging import get_logger from app.core.indicators import momentum, normalize, rolling_mean, volatility +# 延迟导入,避免循环依赖 +collect_data_coverage = None +ensure_data_coverage = None +initialize_database = None + +# 在模块加载时尝试导入 +if collect_data_coverage is None or ensure_data_coverage is None: + try: + from app.ingest.tushare import collect_data_coverage, ensure_data_coverage + except ImportError: + # 导入失败时,在实际使用时会报错 + pass + +if initialize_database is None: + try: + from app.data.schema import initialize_database + except ImportError: + # 导入失败时,提供一个空实现 + def initialize_database(): + pass + LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "data_broker"} @@ -66,7 +89,7 @@ def _end_of_day(dt: datetime) -> str: @dataclass class DataBroker: - """Lightweight data access helper for agent/LLM consumption.""" + """Lightweight data access helper with automated data fetching capabilities.""" FIELD_ALIASES: ClassVar[Dict[str, Dict[str, str]]] = { "daily": { @@ -92,24 +115,50 @@ class DataBroker: } MAX_WINDOW: ClassVar[int] = 120 BENCHMARK_INDEX: ClassVar[str] = "000300.SH" + # 自动补数配置 + AUTO_REFRESH_WINDOW: ClassVar[int] = 7 # 自动补数的时间窗口 + REFRESH_RETRY_INTERVAL: ClassVar[int] = 5 # 补数重试间隔(秒) + MAX_REFRESH_WAIT: ClassVar[int] = 60 # 最大等待补数完成时间(秒) enable_cache: bool = True latest_cache_size: int = 256 series_cache_size: int = 512 _latest_cache: OrderedDict = field(init=False, repr=False) _series_cache: OrderedDict = field(init=False, repr=False) + # 补数相关状态管理 + _refresh_lock: threading.RLock = field(init=False, repr=False) + _refresh_in_progress: Dict[str, bool] = field(init=False, repr=False) + _refresh_callbacks: Dict[str, List[Callable]] = field(init=False, repr=False) + _coverage_cache: Dict[str, Dict] = field(init=False, repr=False) def __post_init__(self) -> None: self._latest_cache = OrderedDict() self._series_cache = OrderedDict() + # 初始化补数相关状态 + self._refresh_lock = threading.RLock() + self._refresh_in_progress = {} + self._refresh_callbacks = {} + self._coverage_cache = {} + if initialize_database is not None: + initialize_database() # 确保数据库已初始化 + else: + LOGGER.warning("initialize_database 函数不可用,数据库可能未初始化", extra=LOG_EXTRA) def fetch_latest( self, ts_code: str, trade_date: str, fields: Iterable[str], + auto_refresh: bool = True, ) -> Dict[str, Any]: - """Fetch the latest value (<= trade_date) for each requested field.""" + """Fetch the latest value (<= trade_date) for each requested field. + + Args: + ts_code: 证券代码 + trade_date: 交易日 + fields: 要查询的字段列表 + auto_refresh: 是否在数据不足时自动触发补数 + """ field_list = [str(item) for item in fields if item] cache_key: Optional[Tuple[Any, ...]] = None if self.enable_cache and field_list: @@ -118,6 +167,25 @@ class DataBroker: if cached is not None: return deepcopy(cached) + # 检查是否需要自动补数 + if auto_refresh: + # 解析交易日以确定是否需要补数 + parsed_date = _parse_trade_date(trade_date) + if parsed_date: + # 检查最近交易日的数据是否存在 + recent_trade_date = parsed_date.strftime('%Y%m%d') + # 对涉及的表进行数据可用性检查 + tables = set() + for field_name in field_list: + resolved = self.resolve_field(field_name) + if resolved: + table, _ = resolved + tables.add(table) + + if tables and self.check_data_availability(recent_trade_date, tables): + # 数据不足,触发后台补数 + self._trigger_background_refresh(recent_trade_date) + grouped: Dict[str, List[str]] = {} field_map: Dict[Tuple[str, str], List[str]] = {} derived_cache: Dict[str, Any] = {} @@ -209,8 +277,18 @@ class DataBroker: ts_code: str, end_date: str, window: int, + auto_refresh: bool = True, ) -> List[Tuple[str, float]]: - """Return descending time series tuples within the specified window.""" + """Return descending time series tuples within the specified window. + + Args: + table: 表名 + column: 列名 + ts_code: 证券代码 + end_date: 结束日期 + window: 时间窗口大小 + auto_refresh: 是否在数据不足时自动触发补数 + """ if window <= 0: return [] @@ -226,6 +304,12 @@ class DataBroker: return [] table, resolved = resolved_field + # 检查是否需要自动补数 + if auto_refresh: + parsed_date = _parse_trade_date(end_date) + if parsed_date and self.check_data_availability(end_date, {table}): + self._trigger_background_refresh(end_date) + cache_key: Optional[Tuple[Any, ...]] = None if self.enable_cache: cache_key = (table, resolved, ts_code, end_date, window) @@ -335,6 +419,16 @@ class DataBroker: if window <= 0: return [] window = min(window, self.MAX_WINDOW) + + # 检查是否需要自动补数 + if auto_refresh: + parsed_date = _parse_trade_date(trade_date) + if parsed_date and self.check_data_availability(trade_date, {table}): + self._trigger_background_refresh(trade_date) + # 短暂等待以获取最新数据 + if hasattr(time, 'sleep'): + time.sleep(0.5) + columns = self._get_table_columns(table) if not columns: LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA) @@ -697,6 +791,335 @@ class DataBroker: cache.move_to_end(key) while len(cache) > limit: cache.popitem(last=False) + + def check_data_availability( + self, + trade_date: str, + tables: Set[str] = None, + threshold: float = 0.8, + ) -> bool: + """检查指定交易日的数据是否可用,如不可用则返回True(需要补数)。 + + Args: + trade_date: 要检查的交易日 + tables: 要检查的表集合,默认检查主要行情表 + threshold: 数据覆盖率阈值,低于此值需要补数 + + Returns: + bool: True表示数据不足,需要补数 + """ + # 如果配置了强制刷新,则始终返回需要补数 + if get_config().force_refresh: + return True + + # 如果未启用自动更新,则不进行补数 + if not get_config().auto_update_data: + return False + + # 默认检查的表 + if tables is None: + tables = {"daily", "daily_basic", "stock_basic", "trade_cal"} + + try: + # 解析交易日 + parsed_date = _parse_trade_date(trade_date) + if not parsed_date: + LOGGER.debug("无法解析交易日: %s", trade_date, extra=LOG_EXTRA) + return False + + # 计算检查窗口 + end_date = parsed_date.strftime('%Y%m%d') + start_date = (parsed_date - timedelta(days=self.AUTO_REFRESH_WINDOW)).strftime('%Y%m%d') + + # 构建缓存键 + cache_key = f"{start_date}_{end_date}_{'_'.join(sorted(tables))}" + + # 检查缓存 + if cache_key in self._coverage_cache: + coverage = self._coverage_cache[cache_key] + current_time = time.time() if hasattr(time, 'time') else 0 + if coverage.get('timestamp', 0) > current_time - 300: # 5分钟内有效 + # 检查是否需要补数 + for table in tables: + table_coverage = coverage.get(table, {}) + if table_coverage.get('coverage', 0) < threshold: + return True + return False + + # 收集数据覆盖情况 + if collect_data_coverage is None: + LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + return False + + coverage = collect_data_coverage( + date.fromisoformat(start_date[:4] + '-' + start_date[4:6] + '-' + start_date[6:8]), + date.fromisoformat(end_date[:4] + '-' + end_date[4:6] + '-' + end_date[6:8]) + ) + + # 保存到缓存 + coverage['timestamp'] = time.time() if hasattr(time, 'time') else 0 + self._coverage_cache[cache_key] = coverage + + # 检查是否需要补数 + for table in tables: + table_coverage = coverage.get(table, {}) + if table_coverage.get('coverage', 0) < threshold: + return True + + except Exception as exc: + LOGGER.exception("检查数据可用性失败: %s", exc, extra=LOG_EXTRA) + # 出错时保守处理,不触发补数 + return False + + return False + + def _trigger_background_refresh(self, target_date: str) -> None: + """在后台线程触发数据补数。""" + parsed_date = _parse_trade_date(target_date) + if not parsed_date: + return + + # 构建补数日期范围 + end_date = parsed_date.date() + start_date = end_date - timedelta(days=self.AUTO_REFRESH_WINDOW) + refresh_key = f"{start_date}_{end_date}" + + # 检查是否已经在补数中 + with self._refresh_lock: + if self._refresh_in_progress.get(refresh_key, False): + LOGGER.debug("数据补数已经在进行中: %s", refresh_key, extra=LOG_EXTRA) + return + + self._refresh_in_progress[refresh_key] = True + self._refresh_callbacks.setdefault(refresh_key, []) + + def refresh_task(): + try: + LOGGER.info("开始后台数据补数: %s 至 %s", start_date, end_date, extra=LOG_EXTRA) + + # 执行补数 + if ensure_data_coverage is None: + LOGGER.error("ensure_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + with self._refresh_lock: + self._refresh_in_progress[refresh_key] = False + return + + ensure_data_coverage( + start_date, + end_date, + force=False, + progress_hook=None + ) + + LOGGER.info("后台数据补数完成: %s 至 %s", start_date, end_date, extra=LOG_EXTRA) + + # 清除缓存,强制重新加载数据 + self._latest_cache.clear() + self._series_cache.clear() + self._coverage_cache.clear() + + # 执行回调 + with self._refresh_lock: + callbacks = self._refresh_callbacks.pop(refresh_key, []) + self._refresh_in_progress[refresh_key] = False + + for callback in callbacks: + try: + callback() + except Exception as exc: + LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA) + + except Exception as exc: + LOGGER.exception("后台数据补数失败: %s", exc, extra=LOG_EXTRA) + with self._refresh_lock: + self._refresh_in_progress[refresh_key] = False + + # 启动后台线程 + thread = threading.Thread(target=refresh_task, daemon=True) + thread.start() + + def is_refreshing(self, start_date: str = None, end_date: str = None) -> bool: + """检查指定日期范围是否正在补数中。""" + with self._refresh_lock: + if not start_date and not end_date: + # 检查是否有任何补数正在进行 + return any(self._refresh_in_progress.values()) + + # 检查指定日期范围 + for key, in_progress in self._refresh_in_progress.items(): + if in_progress and key.startswith(start_date or '') and key.endswith(end_date or ''): + return True + + return False + + def wait_for_refresh_complete( + self, + timeout: float = None, + start_date: str = None, + end_date: str = None + ) -> bool: + """等待数据补数完成。 + + Args: + timeout: 超时时间(秒),默认为MAX_REFRESH_WAIT + start_date: 开始日期 + end_date: 结束日期 + + Returns: + bool: True表示补数已完成,False表示超时 + """ + if timeout is None: + timeout = self.MAX_REFRESH_WAIT + + start_time = time.time() if hasattr(time, 'time') else 0 + current_time_func = time.time if hasattr(time, 'time') else lambda: 0 + while current_time_func() - start_time < timeout: + if not self.is_refreshing(start_date, end_date): + return True + + # 短暂休眠后再次检查 + if hasattr(time, 'sleep'): + time.sleep(min(self.REFRESH_RETRY_INTERVAL, timeout / 10)) + + return False + + def on_data_refresh( + self, + callback: Callable, + start_date: str = None, + end_date: str = None + ) -> None: + """注册数据补数完成的回调函数。""" + if start_date and end_date: + refresh_key = f"{start_date}_{end_date}" + with self._refresh_lock: + self._refresh_callbacks.setdefault(refresh_key, []).append(callback) + # 如果当前没有补数在进行,则直接调用回调 + if not self._refresh_in_progress.get(refresh_key, False): + try: + callback() + except Exception as exc: + LOGGER.exception("补数回调执行失败: %s", exc, extra=LOG_EXTRA) + + def set_auto_refresh_window(self, days: int) -> None: + """设置自动补数的时间窗口。 + + Args: + days: 自动补数的天数窗口 + """ + if days > 0: + self.AUTO_REFRESH_WINDOW = days + LOGGER.info("自动补数窗口已设置为 %d 天", days, extra=LOG_EXTRA) + + def set_refresh_retry_interval(self, seconds: int) -> None: + """设置补数检查的重试间隔。 + + Args: + seconds: 重试间隔(秒) + """ + if seconds > 0: + self.REFRESH_RETRY_INTERVAL = seconds + LOGGER.info("补数重试间隔已设置为 %d 秒", seconds, extra=LOG_EXTRA) + + def set_max_refresh_wait(self, seconds: int) -> None: + """设置最大等待补数完成时间。 + + Args: + seconds: 最大等待时间(秒) + """ + if seconds > 0: + self.MAX_REFRESH_WAIT = seconds + LOGGER.info("最大补数等待时间已设置为 %d 秒", seconds, extra=LOG_EXTRA) + + def force_refresh_data(self, start_date: str, end_date: str) -> bool: + """强制刷新指定日期范围内的数据。 + + Args: + start_date: 开始日期(格式:YYYYMMDD) + end_date: 结束日期(格式:YYYYMMDD) + + Returns: + bool: 是否成功触发刷新 + """ + try: + # 解析日期 + start = _parse_trade_date(start_date) + end = _parse_trade_date(end_date) + if not start or not end: + LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA) + return False + + # 触发刷新 + self._trigger_background_refresh(end_date) + return True + except Exception as exc: + LOGGER.exception("强制刷新数据失败: %s", exc, extra=LOG_EXTRA) + return False + + def get_refresh_status(self) -> Dict[str, Dict[str, Any]]: + """获取当前所有补数任务的状态。 + + Returns: + Dict: 包含所有补数任务状态的字典 + """ + with self._refresh_lock: + status = {} + for key, in_progress in self._refresh_in_progress.items(): + start, end = key.split('_')[:2] if '_' in key else (key, key) + status[key] = { + 'start_date': start, + 'end_date': end, + 'in_progress': in_progress, + 'callback_count': len(self._refresh_callbacks.get(key, [])) + } + return status + + def cancel_all_refresh_tasks(self) -> None: + """取消所有正在等待的补数任务回调。 + 注意:已经开始执行的补数任务无法取消,但它们的结果将被忽略。 + """ + with self._refresh_lock: + self._refresh_callbacks.clear() + # 保留刷新状态以避免立即重新触发 + LOGGER.info("所有补数任务回调已取消", extra=LOG_EXTRA) + + def clear_coverage_cache(self) -> None: + """清除数据覆盖情况的缓存。""" + self._coverage_cache.clear() + LOGGER.info("数据覆盖缓存已清除", extra=LOG_EXTRA) + + def get_data_coverage(self, start_date: str, end_date: str) -> Dict: + """获取指定日期范围内的数据覆盖情况。 + + Args: + start_date: 开始日期(格式:YYYYMMDD) + end_date: 结束日期(格式:YYYYMMDD) + + Returns: + Dict: 数据覆盖情况的详细信息 + """ + try: + # 解析日期 + start = _parse_trade_date(start_date) + end = _parse_trade_date(end_date) + if not start or not end: + LOGGER.error("日期格式不正确: %s, %s", start_date, end_date, extra=LOG_EXTRA) + return {} + + # 转换日期格式 + start_d = date.fromisoformat(start.strftime('%Y-%m-%d')) + end_d = date.fromisoformat(end.strftime('%Y-%m-%d')) + + # 收集数据覆盖情况 + if collect_data_coverage is None: + LOGGER.error("collect_data_coverage 函数不可用,请检查导入配置", extra=LOG_EXTRA) + return {} + + coverage = collect_data_coverage(start_d, end_d) + return coverage + except Exception as exc: + LOGGER.exception("获取数据覆盖情况失败: %s", exc, extra=LOG_EXTRA) + return {} def _resolve_column(self, table: str, column: str) -> Optional[str]: columns = self._get_table_columns(table) @@ -712,3 +1135,17 @@ class DataBroker: if name.lower() == lowered: return name return None + +# 确保time模块可用 +import sys +try: + import time +except ImportError: + # 创建一个简单的替代实现 + class TimeStub: + def time(self): + return 0 + def sleep(self, seconds): + pass + time = TimeStub() + LOGGER.warning("无法导入time模块,使用替代实现", extra=LOG_EXTRA)