diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 8d74b2d..55a0274 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -29,6 +29,13 @@ LOG_EXTRA = {"stage": "data_ingest"} _CALL_QUEUE = deque() +def _normalize_date_str(value: Optional[str]) -> Optional[str]: + if value is None: + return None + text = str(value).strip() + return text or None + + def _respect_rate_limit(cfg) -> None: max_calls = cfg.max_calls_per_minute if max_calls <= 0: @@ -44,23 +51,6 @@ def _respect_rate_limit(cfg) -> None: _CALL_QUEUE.append(time.time()) -def _existing_date_range( - table: str, - date_col: str, - ts_code: str | None = None, -) -> Tuple[str | None, str | None]: - query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}" - params: Tuple = () - if ts_code: - query += " WHERE ts_code = ?" - params = (ts_code,) - with db_session(read_only=True) as conn: - row = conn.execute(query, params).fetchone() - if row is None: - return None, None - return row["min_d"], row["max_d"] - - def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: if df is None or df.empty: return [] @@ -352,25 +342,63 @@ def _record_exists( def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: - min_d, max_d = _existing_date_range(table, date_col, ts_code) - if min_d is None or max_d is None: - return False start_str = _format_date(start) end_str = _format_date(end) - return min_d <= start_str and max_d >= end_str + + effective_start = start_str + effective_end = end_str + + if ts_code: + list_date, delist_date = _listing_window(ts_code) + if list_date: + effective_start = max(effective_start, list_date) + if delist_date: + effective_end = min(effective_end, delist_date) + if effective_start > effective_end: + LOGGER.debug( + "股票 %s 在目标区间之外,跳过补数", + ts_code, + extra=LOG_EXTRA, + ) + return True + stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code) + else: + stats = _range_stats(table, date_col, effective_start, effective_end) + + if stats["min"] is None or stats["max"] is None: + return False + if stats["min"] > effective_start or stats["max"] < effective_end: + return False + + if ts_code is None: + expected_days = _expected_trading_days(effective_start, effective_end) + if expected_days and (stats["distinct"] or 0) < expected_days: + return False + + return True -def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]: +def _range_stats( + table: str, + date_col: str, + start_str: str, + end_str: str, + ts_code: str | None = None, +) -> Dict[str, Optional[str]]: sql = ( f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, " f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} " f"WHERE {date_col} BETWEEN ? AND ?" ) + params: List[object] = [start_str, end_str] + if ts_code: + sql += " AND ts_code = ?" + params.append(ts_code) with db_session(read_only=True) as conn: - row = conn.execute(sql, (start_str, end_str)).fetchone() + row = conn.execute(sql, tuple(params)).fetchone() return { - "min": row["min_d"], - "max": row["max_d"], + "min": row["min_d"] if row else None, + "max": row["max_d"] if row else None, "distinct": row["distinct_days"] if row else 0, } @@ -392,6 +420,17 @@ def _range_needs_refresh( return False +def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]: + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?", + (ts_code,), + ).fetchone() + if not row: + return None, None + return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index] + + def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool: sql = """ SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt @@ -421,7 +460,12 @@ def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE") def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]: client = _ensure_client() - LOGGER.info("拉取股票基础信息(交易所:%s,状态:%s)", exchange or "全部", list_status) + LOGGER.info( + "拉取股票基础信息(交易所:%s,状态:%s)", + exchange or "全部", + list_status, + extra=LOG_EXTRA, + ) fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date" df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields) return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) @@ -626,7 +670,13 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date) + LOGGER.info( + "拉取交易日历(交易所:%s,区间:%s-%s)", + exchange, + start_date, + end_date, + extra=LOG_EXTRA, + ) df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date) if df is not None and not df.empty and "is_open" in df.columns: df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int) @@ -672,7 +722,7 @@ def fetch_stk_limit( def save_records(table: str, rows: Iterable[Dict]) -> None: items = list(rows) if not items: - LOGGER.info("表 %s 没有新增记录,跳过写入", table) + LOGGER.info("表 %s 没有新增记录,跳过写入", table, extra=LOG_EXTRA) return schema = _TABLE_SCHEMAS.get(table) @@ -683,7 +733,7 @@ def save_records(table: str, rows: Iterable[Dict]) -> None: placeholders = ",".join([f":{col}" for col in columns]) col_clause = ",".join(columns) - LOGGER.info("表 %s 写入 %d 条记录", table, len(items)) + LOGGER.info("表 %s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA) with db_session() as conn: conn.executescript(schema) conn.executemany( @@ -700,7 +750,11 @@ def ensure_stock_basic(list_status: str = "L") -> None: (*exchanges, list_status), ).fetchone() if row and row["cnt"]: - LOGGER.info("股票基础信息已存在 %d 条记录,跳过拉取", row["cnt"]) + LOGGER.info( + "股票基础信息已存在 %d 条记录,跳过拉取", + row["cnt"], + extra=LOG_EXTRA, + ) return for exch in exchanges: @@ -736,7 +790,7 @@ def ensure_data_coverage( progress = min(current_step / total_steps, 1.0) if progress_hook: progress_hook(message, progress) - LOGGER.info(message) + LOGGER.info(message, extra=LOG_EXTRA) advance("准备股票基础信息与交易日历") ensure_stock_basic() @@ -824,6 +878,8 @@ def ensure_data_coverage( if progress_hook: progress_hook("数据覆盖检查完成", 1.0) + + def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]: start_str = _format_date(start) end_str = _format_date(end) @@ -876,7 +932,7 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object] def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: - LOGGER.info("启动 TuShare 拉取任务:%s", job.name) + LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA) ensure_data_coverage( job.start, job.end, @@ -884,4 +940,4 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: include_limits=include_limits, force=True, ) - LOGGER.info("任务 %s 完成", job.name) + LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 0a9e1ed..50c24c9 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -21,8 +21,11 @@ from app.ingest.tushare import FetchJob, run_ingestion from app.llm.explain import make_human_card from app.utils.config import get_config from app.utils.db import db_session +from app.utils.logging import get_logger +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "ui"} def _load_stock_options(limit: int = 500) -> list[str]: @@ -32,6 +35,7 @@ def _load_stock_options(limit: int = 500) -> list[str]: "SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code" ).fetchall() except Exception: + LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA) return [] options: list[str] = [] for row in rows[:limit]: @@ -39,6 +43,7 @@ def _load_stock_options(limit: int = 500) -> list[str]: name = row["name"] or "" label = f"{code} | {name}" if name else code options.append(label) + LOGGER.info("加载股票选项完成,数量=%s", len(options), extra=LOG_EXTRA) return options @@ -47,6 +52,13 @@ def _parse_ts_code(selection: str) -> str: def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: + LOGGER.info( + "加载行情数据:ts_code=%s start=%s end=%s", + ts_code, + start, + end, + extra=LOG_EXTRA, + ) start_str = start.strftime('%Y%m%d') end_str = end.strftime('%Y%m%d') range_query = ( @@ -62,24 +74,47 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: if df.empty: df = pd.read_sql_query(fallback_query, conn, params=(ts_code,)) if df.empty: + LOGGER.warning( + "行情数据为空:ts_code=%s start=%s end=%s", + ts_code, + start, + end, + extra=LOG_EXTRA, + ) return df df = df.sort_values('trade_date') df['trade_date'] = pd.to_datetime(df['trade_date']) df.set_index('trade_date', inplace=True) + LOGGER.info("行情数据加载完成:条数=%s", len(df), extra=LOG_EXTRA) return df + + def render_today_plan() -> None: + LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) st.header("今日计划") st.write("待接入候选池筛选与多智能体决策结果。") sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []}) + LOGGER.debug("示例卡片内容:%s", sample, extra=LOG_EXTRA) st.json(sample) def render_backtest() -> None: + LOGGER.info("渲染回测页面", extra=LOG_EXTRA) st.header("回测与复盘") st.write("在此运行回测、展示净值曲线与代理贡献。") default_start = date(2020, 1, 1) default_end = date(2020, 3, 31) + LOGGER.debug( + "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + default_start, + default_end, + "000001.SZ", + 0.035, + -0.015, + 10, + extra=LOG_EXTRA, + ) col1, col2 = st.columns(2) start_date = col1.date_input("开始日期", value=default_start) @@ -88,11 +123,32 @@ def render_backtest() -> None: target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f") stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f") hold_days = st.number_input("持有期(交易日)", value=10, step=1) + LOGGER.debug( + "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", + start_date, + end_date, + universe_text, + target, + stop, + hold_days, + extra=LOG_EXTRA, + ) if st.button("运行回测"): + LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) with st.spinner("正在执行回测..."): try: universe = [code.strip() for code in universe_text.split(',') if code.strip()] + LOGGER.info( + "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + start_date, + end_date, + universe, + target, + stop, + hold_days, + extra=LOG_EXTRA, + ) cfg = BtConfig( id="streamlit_demo", name="Streamlit Demo Strategy", @@ -106,39 +162,55 @@ def render_backtest() -> None: }, ) result = run_backtest(cfg) + LOGGER.info( + "回测完成:nav_records=%s trades=%s", + len(result.nav_series), + len(result.trades), + extra=LOG_EXTRA, + ) st.success("回测执行完成,详见回测结果摘要。") st.json({"nav_records": result.nav_series, "trades": result.trades}) except Exception as exc: # noqa: BLE001 + LOGGER.exception("回测执行失败", extra=LOG_EXTRA) st.error(f"回测执行失败:{exc}") def render_settings() -> None: + LOGGER.info("渲染设置页面", extra=LOG_EXTRA) st.header("数据与设置") cfg = get_config() + LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA) token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") if st.button("保存设置"): + LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA) cfg.tushare_token = token.strip() or None + LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA) st.success("设置已保存,仅在当前会话生效。") st.write("新闻源开关与数据库备份将在此配置。") def render_tests() -> None: + LOGGER.info("渲染自检页面", extra=LOG_EXTRA) st.header("自检测试") st.write("用于快速检查数据库与数据拉取是否正常工作。") if st.button("测试数据库初始化"): + LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA) with st.spinner("正在检查数据库..."): result = initialize_database() if result.skipped: + LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA) st.success("数据库已存在,检查通过。") else: + LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA) st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。") st.divider() if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"): + LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA) with st.spinner("正在调用 TuShare 接口..."): try: run_ingestion( @@ -150,14 +222,17 @@ def render_tests() -> None: ), include_limits=False, ) + LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA) st.success("TuShare 示例拉取完成,数据已写入数据库。") except Exception as exc: # noqa: BLE001 + LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA) st.error(f"拉取失败:{exc}") st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") st.divider() days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30)) + LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA) cfg = get_config() force_refresh = st.checkbox( "强制刷新数据(关闭增量跳过)", @@ -166,8 +241,10 @@ def render_tests() -> None: ) if force_refresh != cfg.force_refresh: cfg.force_refresh = force_refresh + LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) if st.button("执行开机检查"): + LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA) progress_bar = st.progress(0.0) status_placeholder = st.empty() log_placeholder = st.empty() @@ -177,6 +254,7 @@ def render_tests() -> None: progress_bar.progress(min(max(value, 0.0), 1.0)) status_placeholder.write(message) messages.append(message) + LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA) with st.spinner("正在执行开机检查..."): try: @@ -185,11 +263,13 @@ def render_tests() -> None: progress_hook=hook, force_refresh=force_refresh, ) + LOGGER.info("开机检查成功", extra=LOG_EXTRA) st.success("开机检查完成,以下为数据覆盖摘要。") st.json(report.to_dict()) if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) except Exception as exc: # noqa: BLE001 + LOGGER.exception("开机检查失败", extra=LOG_EXTRA) st.error(f"开机检查失败:{exc}") if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) @@ -204,15 +284,19 @@ def render_tests() -> None: if options: selection = st.selectbox("选择股票", options, index=0) ts_code = _parse_ts_code(selection) + LOGGER.debug("选择股票:%s", ts_code, extra=LOG_EXTRA) else: ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper() + LOGGER.debug("输入股票:%s", ts_code, extra=LOG_EXTRA) viz_col1, viz_col2 = st.columns(2) default_start = date.today() - timedelta(days=180) start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start") end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end") + LOGGER.debug("行情可视化日期范围:%s-%s", start_date, end_date, extra=LOG_EXTRA) if start_date > end_date: + LOGGER.warning("无效日期范围:%s>%s", start_date, end_date, extra=LOG_EXTRA) st.error("开始日期不能晚于结束日期") return @@ -220,10 +304,12 @@ def render_tests() -> None: try: df = _load_daily_frame(ts_code, start_date, end_date) except Exception as exc: # noqa: BLE001 + LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA) st.error(f"读取数据失败:{exc}") return if df.empty: + LOGGER.warning("指定区间无行情数据:%s %s-%s", ts_code, start_date, end_date, extra=LOG_EXTRA) st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。") return @@ -309,11 +395,14 @@ def render_tests() -> None: st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.dataframe(df_reset.tail(20), width='stretch') + LOGGER.info("行情可视化完成,展示行数=%s", len(df_reset), extra=LOG_EXTRA) def main() -> None: + LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) st.set_page_config(page_title="多智能体投资助理", layout="wide") tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) + LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA) with tabs[0]: render_today_plan() with tabs[1]: