diff --git a/README.md b/README.md index 663e1ad..b8bd6ca 100644 --- a/README.md +++ b/README.md @@ -14,20 +14,17 @@ ## 快速开始 ```bash -# 初始化数据库结构 -python -m app.cli init-db - -# 一键开机检查(默认回溯 365 天,缺失数据会自动补齐) -python -m app.cli boot-check --days 365 - -# 启动界面 +# 启动交互界面(内含数据库初始化、开机检查、样例回测入口) streamlit run app/ui/streamlit_app.py ``` Streamlit `自检测试` 页签提供: - 数据库初始化快捷按钮; - TuShare 小范围拉取测试; -- 开机检查器(展示当前数据覆盖范围与股票基础信息完整度)。 +- 一键开机检查(可自动补数并展示覆盖摘要); +- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。 + +`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。 ## 下一步 diff --git a/app/cli.py b/app/cli.py deleted file mode 100644 index 3b90ff0..0000000 --- a/app/cli.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Command line entry points for routine tasks.""" -from __future__ import annotations - -import argparse -from datetime import date - -from app.backtest.engine import BtConfig, run_backtest -from app.data.schema import initialize_database -from app.ingest.checker import run_boot_check - - -def init_db() -> None: - result = initialize_database() - if result.skipped: - print("Database already initialized; skipping schema creation") - else: - print(f"Initialized database with {result.executed} statements") - - -def run_sample_backtest() -> None: - cfg = BtConfig( - id="demo", - name="Demo Strategy", - start_date=date(2020, 1, 1), - end_date=date(2020, 3, 31), - universe=["000001.SZ"], - params={ - "target": 0.035, - "stop": -0.015, - "hold_days": 10, - }, - ) - run_backtest(cfg) - - -def run_boot_check_cli(days: int) -> None: - report = run_boot_check(days=days) - print("Boot check summary:") - print(f" Period: {report.start} ~ {report.end}") - print(f" Expected trading days: {report.expected_trading_days}") - for name, info in report.tables.items(): - print( - f" {name}: min={info.get('min')}, max={info.get('max')}, " - f"distinct={info.get('distinct_days')}, ok={info.get('meets_expectation')}" - ) - stock = report.stock_basic - print( - f" stock_basic: total={stock.get('total')}, " - f"SSE listed={stock.get('sse_listed')}, SZSE listed={stock.get('szse_listed')}" - ) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Investment assistant toolkit") - sub = parser.add_subparsers(dest="command") - - sub.add_parser("init-db", help="Initialize SQLite schema") - - boot_parser = sub.add_parser("boot-check", help="Run startup data coverage check") - boot_parser.add_argument("--days", type=int, default=365, help="Lookback window in days") - - sub.add_parser("sample-backtest", help="Execute demo backtest run") - - args = parser.parse_args() - - if args.command is None or args.command == "init-db": - init_db() - elif args.command == "boot-check": - run_boot_check_cli(days=args.days) - elif args.command == "sample-backtest": - run_sample_backtest() - else: - parser.print_help() - - -if __name__ == "__main__": - main() diff --git a/app/data/schema.py b/app/data/schema.py index ff3fe1d..7bcbe54 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -2,8 +2,8 @@ from __future__ import annotations import sqlite3 -from dataclasses import dataclass -from typing import Iterable +from dataclasses import dataclass, field +from typing import Iterable, List from app.utils.db import db_session @@ -86,9 +86,10 @@ SCHEMA_STATEMENTS: Iterable[str] = ( """ CREATE TABLE IF NOT EXISTS trade_calendar ( exchange TEXT, - cal_date TEXT PRIMARY KEY, + cal_date TEXT, is_open INTEGER, - pretrade_date TEXT + pretrade_date TEXT, + PRIMARY KEY (exchange, cal_date) ); """, """ @@ -203,29 +204,49 @@ SCHEMA_STATEMENTS: Iterable[str] = ( """ ) +REQUIRED_TABLES = ( + "stock_basic", + "daily", + "daily_basic", + "adj_factor", + "suspend", + "trade_calendar", + "stk_limit", + "news", + "heat_daily", + "bt_config", + "bt_trades", + "bt_nav", + "bt_report", + "run_log", + "agent_utils", + "alloc_log", +) + @dataclass class MigrationResult: executed: int skipped: bool = False + missing_tables: List[str] = field(default_factory=list) -def _schema_exists() -> bool: +def _missing_tables() -> List[str]: try: with db_session(read_only=True) as conn: - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='news'" - ) - return cursor.fetchone() is not None + rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() except sqlite3.OperationalError: - return False + return list(REQUIRED_TABLES) + existing = {row["name"] for row in rows} + return [name for name in REQUIRED_TABLES if name not in existing] def initialize_database() -> MigrationResult: """Create tables and indexes required by the application.""" - if _schema_exists(): - return MigrationResult(executed=0, skipped=True) + missing = _missing_tables() + if not missing: + return MigrationResult(executed=0, skipped=True, missing_tables=[]) executed = 0 with db_session() as conn: @@ -233,4 +254,4 @@ def initialize_database() -> MigrationResult: for statement in SCHEMA_STATEMENTS: cursor.executescript(statement) executed += 1 - return MigrationResult(executed=executed) + return MigrationResult(executed=executed, skipped=False, missing_tables=missing) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index da4484c..6146c20 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -2,21 +2,69 @@ from __future__ import annotations import sys -from datetime import date +from datetime import date, timedelta from pathlib import Path ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) +import pandas as pd import streamlit as st +from app.backtest.engine import BtConfig, run_backtest from app.data.schema import initialize_database from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion from app.llm.explain import make_human_card +from app.utils.config import get_config +from app.utils.db import db_session + + +def _load_stock_options(limit: int = 500) -> list[str]: + try: + with db_session(read_only=True) as conn: + rows = conn.execute( + "SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code" + ).fetchall() + except Exception: + return [] + options: list[str] = [] + for row in rows[:limit]: + code = row["ts_code"] + name = row["name"] or "" + label = f"{code} | {name}" if name else code + options.append(label) + return options + + +def _parse_ts_code(selection: str) -> str: + return selection.split(' | ')[0].strip().upper() + + +def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: + start_str = start.strftime('%Y%m%d') + end_str = end.strftime('%Y%m%d') + range_query = ( + "SELECT trade_date, open, high, low, close, vol, amount " + "FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date" + ) + fallback_query = ( + "SELECT trade_date, open, high, low, close, vol, amount " + "FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200" + ) + with db_session(read_only=True) as conn: + df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str)) + if df.empty: + df = pd.read_sql_query(fallback_query, conn, params=(ts_code,)) + if df.empty: + return df + df = df.sort_values('trade_date') + df['trade_date'] = pd.to_datetime(df['trade_date']) + df.set_index('trade_date', inplace=True) + return df def render_today_plan() -> None: st.header("今日计划") st.write("待接入候选池筛选与多智能体决策结果。") @@ -27,12 +75,49 @@ def render_today_plan() -> None: def render_backtest() -> None: st.header("回测与复盘") st.write("在此运行回测、展示净值曲线与代理贡献。") - st.button("开始回测") + + default_start = date(2020, 1, 1) + default_end = date(2020, 3, 31) + + col1, col2 = st.columns(2) + start_date = col1.date_input("开始日期", value=default_start) + end_date = col2.date_input("结束日期", value=default_end) + universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ") + target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f") + stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f") + hold_days = st.number_input("持有期(交易日)", value=10, step=1) + + if st.button("运行回测"): + with st.spinner("正在执行回测..."): + try: + universe = [code.strip() for code in universe_text.split(',') if code.strip()] + cfg = BtConfig( + id="streamlit_demo", + name="Streamlit Demo Strategy", + start_date=start_date, + end_date=end_date, + universe=universe, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + ) + result = run_backtest(cfg) + st.success("回测执行完成,详见回测结果摘要。") + st.json({"nav_records": result.nav_series, "trades": result.trades}) + except Exception as exc: # noqa: BLE001 + st.error(f"回测执行失败:{exc}") def render_settings() -> None: st.header("数据与设置") - st.text_input("TuShare Token") + cfg = get_config() + token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") + if st.button("保存 Token"): + cfg.tushare_token = token.strip() or None + st.success("TuShare Token 已更新,仅保存在当前会话。") + st.write("新闻源开关与数据库备份将在此配置。") @@ -79,6 +164,65 @@ def render_tests() -> None: except Exception as exc: # noqa: BLE001 st.error(f"开机检查失败:{exc}") + st.divider() + st.subheader("股票行情可视化") + options = _load_stock_options() + default_code = options[0] if options else "000001.SZ" + + if options: + selection = st.selectbox("选择股票", options, index=0) + ts_code = _parse_ts_code(selection) + else: + ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper() + + viz_col1, viz_col2 = st.columns(2) + default_start = date.today() - timedelta(days=180) + start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start") + end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end") + + if start_date > end_date: + st.error("开始日期不能晚于结束日期") + return + + with st.spinner("正在加载行情数据..."): + try: + df = _load_daily_frame(ts_code, start_date, end_date) + except Exception as exc: # noqa: BLE001 + st.error(f"读取数据失败:{exc}") + return + + if df.empty: + st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。") + return + + price_df = df[["close"]].rename(columns={"close": "收盘价"}) + volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"}) + + if price_df.shape[0] > 180: + sampled = price_df.resample('3D').last().dropna() + else: + sampled = price_df + + if volume_df.shape[0] > 180: + volume_sampled = volume_df.resample('3D').mean().dropna() + else: + volume_sampled = volume_df + + first_close = sampled.iloc[0, 0] + last_close = sampled.iloc[-1, 0] + delta_abs = last_close - first_close + delta_pct = (delta_abs / first_close * 100) if first_close else 0.0 + + metric_col1, metric_col2, metric_col3 = st.columns(3) + metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}") + metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%") + metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}") + + st.line_chart(sampled, width='stretch') + st.bar_chart(volume_sampled, width='stretch') + st.caption("提示:成交量单位为手,若需更长区间请调整日期后重新加载。") + st.dataframe(df.reset_index().tail(10), width='stretch') + def main() -> None: st.set_page_config(page_title="多智能体投资助理", layout="wide")