This commit is contained in:
sam 2025-09-26 21:43:44 +08:00
parent c6c781cc6b
commit 2dc2753827
4 changed files with 186 additions and 101 deletions

View File

@ -14,20 +14,17 @@
## 快速开始 ## 快速开始
```bash ```bash
# 初始化数据库结构 # 启动交互界面(内含数据库初始化、开机检查、样例回测入口)
python -m app.cli init-db
# 一键开机检查(默认回溯 365 天,缺失数据会自动补齐)
python -m app.cli boot-check --days 365
# 启动界面
streamlit run app/ui/streamlit_app.py streamlit run app/ui/streamlit_app.py
``` ```
Streamlit `自检测试` 页签提供: Streamlit `自检测试` 页签提供:
- 数据库初始化快捷按钮; - 数据库初始化快捷按钮;
- TuShare 小范围拉取测试; - TuShare 小范围拉取测试;
- 开机检查器(展示当前数据覆盖范围与股票基础信息完整度)。 - 一键开机检查(可自动补数并展示覆盖摘要);
- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。
`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。
## 下一步 ## 下一步

View File

@ -1,77 +0,0 @@
"""Command line entry points for routine tasks."""
from __future__ import annotations
import argparse
from datetime import date
from app.backtest.engine import BtConfig, run_backtest
from app.data.schema import initialize_database
from app.ingest.checker import run_boot_check
def init_db() -> None:
result = initialize_database()
if result.skipped:
print("Database already initialized; skipping schema creation")
else:
print(f"Initialized database with {result.executed} statements")
def run_sample_backtest() -> None:
cfg = BtConfig(
id="demo",
name="Demo Strategy",
start_date=date(2020, 1, 1),
end_date=date(2020, 3, 31),
universe=["000001.SZ"],
params={
"target": 0.035,
"stop": -0.015,
"hold_days": 10,
},
)
run_backtest(cfg)
def run_boot_check_cli(days: int) -> None:
report = run_boot_check(days=days)
print("Boot check summary:")
print(f" Period: {report.start} ~ {report.end}")
print(f" Expected trading days: {report.expected_trading_days}")
for name, info in report.tables.items():
print(
f" {name}: min={info.get('min')}, max={info.get('max')}, "
f"distinct={info.get('distinct_days')}, ok={info.get('meets_expectation')}"
)
stock = report.stock_basic
print(
f" stock_basic: total={stock.get('total')}, "
f"SSE listed={stock.get('sse_listed')}, SZSE listed={stock.get('szse_listed')}"
)
def main() -> None:
parser = argparse.ArgumentParser(description="Investment assistant toolkit")
sub = parser.add_subparsers(dest="command")
sub.add_parser("init-db", help="Initialize SQLite schema")
boot_parser = sub.add_parser("boot-check", help="Run startup data coverage check")
boot_parser.add_argument("--days", type=int, default=365, help="Lookback window in days")
sub.add_parser("sample-backtest", help="Execute demo backtest run")
args = parser.parse_args()
if args.command is None or args.command == "init-db":
init_db()
elif args.command == "boot-check":
run_boot_check_cli(days=args.days)
elif args.command == "sample-backtest":
run_sample_backtest()
else:
parser.print_help()
if __name__ == "__main__":
main()

View File

@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Iterable from typing import Iterable, List
from app.utils.db import db_session from app.utils.db import db_session
@ -86,9 +86,10 @@ SCHEMA_STATEMENTS: Iterable[str] = (
""" """
CREATE TABLE IF NOT EXISTS trade_calendar ( CREATE TABLE IF NOT EXISTS trade_calendar (
exchange TEXT, exchange TEXT,
cal_date TEXT PRIMARY KEY, cal_date TEXT,
is_open INTEGER, is_open INTEGER,
pretrade_date TEXT pretrade_date TEXT,
PRIMARY KEY (exchange, cal_date)
); );
""", """,
""" """
@ -203,29 +204,49 @@ SCHEMA_STATEMENTS: Iterable[str] = (
""" """
) )
REQUIRED_TABLES = (
"stock_basic",
"daily",
"daily_basic",
"adj_factor",
"suspend",
"trade_calendar",
"stk_limit",
"news",
"heat_daily",
"bt_config",
"bt_trades",
"bt_nav",
"bt_report",
"run_log",
"agent_utils",
"alloc_log",
)
@dataclass @dataclass
class MigrationResult: class MigrationResult:
executed: int executed: int
skipped: bool = False skipped: bool = False
missing_tables: List[str] = field(default_factory=list)
def _schema_exists() -> bool: def _missing_tables() -> List[str]:
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
cursor = conn.execute( rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
)
return cursor.fetchone() is not None
except sqlite3.OperationalError: except sqlite3.OperationalError:
return False return list(REQUIRED_TABLES)
existing = {row["name"] for row in rows}
return [name for name in REQUIRED_TABLES if name not in existing]
def initialize_database() -> MigrationResult: def initialize_database() -> MigrationResult:
"""Create tables and indexes required by the application.""" """Create tables and indexes required by the application."""
if _schema_exists(): missing = _missing_tables()
return MigrationResult(executed=0, skipped=True) if not missing:
return MigrationResult(executed=0, skipped=True, missing_tables=[])
executed = 0 executed = 0
with db_session() as conn: with db_session() as conn:
@ -233,4 +254,4 @@ def initialize_database() -> MigrationResult:
for statement in SCHEMA_STATEMENTS: for statement in SCHEMA_STATEMENTS:
cursor.executescript(statement) cursor.executescript(statement)
executed += 1 executed += 1
return MigrationResult(executed=executed) return MigrationResult(executed=executed, skipped=False, missing_tables=missing)

View File

@ -2,21 +2,69 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from datetime import date from datetime import date, timedelta
from pathlib import Path from pathlib import Path
ROOT = Path(__file__).resolve().parents[2] ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path: if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(ROOT))
import pandas as pd
import streamlit as st import streamlit as st
from app.backtest.engine import BtConfig, run_backtest
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.ingest.checker import run_boot_check from app.ingest.checker import run_boot_check
from app.ingest.tushare import FetchJob, run_ingestion from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.explain import make_human_card from app.llm.explain import make_human_card
from app.utils.config import get_config
from app.utils.db import db_session
def _load_stock_options(limit: int = 500) -> list[str]:
try:
with db_session(read_only=True) as conn:
rows = conn.execute(
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
).fetchall()
except Exception:
return []
options: list[str] = []
for row in rows[:limit]:
code = row["ts_code"]
name = row["name"] or ""
label = f"{code} | {name}" if name else code
options.append(label)
return options
def _parse_ts_code(selection: str) -> str:
return selection.split(' | ')[0].strip().upper()
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
start_str = start.strftime('%Y%m%d')
end_str = end.strftime('%Y%m%d')
range_query = (
"SELECT trade_date, open, high, low, close, vol, amount "
"FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date"
)
fallback_query = (
"SELECT trade_date, open, high, low, close, vol, amount "
"FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200"
)
with db_session(read_only=True) as conn:
df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str))
if df.empty:
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
if df.empty:
return df
df = df.sort_values('trade_date')
df['trade_date'] = pd.to_datetime(df['trade_date'])
df.set_index('trade_date', inplace=True)
return df
def render_today_plan() -> None: def render_today_plan() -> None:
st.header("今日计划") st.header("今日计划")
st.write("待接入候选池筛选与多智能体决策结果。") st.write("待接入候选池筛选与多智能体决策结果。")
@ -27,12 +75,49 @@ def render_today_plan() -> None:
def render_backtest() -> None: def render_backtest() -> None:
st.header("回测与复盘") st.header("回测与复盘")
st.write("在此运行回测、展示净值曲线与代理贡献。") st.write("在此运行回测、展示净值曲线与代理贡献。")
st.button("开始回测")
default_start = date(2020, 1, 1)
default_end = date(2020, 3, 31)
col1, col2 = st.columns(2)
start_date = col1.date_input("开始日期", value=default_start)
end_date = col2.date_input("结束日期", value=default_end)
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ")
target = st.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f")
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f")
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
if st.button("运行回测"):
with st.spinner("正在执行回测..."):
try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
cfg = BtConfig(
id="streamlit_demo",
name="Streamlit Demo Strategy",
start_date=start_date,
end_date=end_date,
universe=universe,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
)
result = run_backtest(cfg)
st.success("回测执行完成,详见回测结果摘要。")
st.json({"nav_records": result.nav_series, "trades": result.trades})
except Exception as exc: # noqa: BLE001
st.error(f"回测执行失败:{exc}")
def render_settings() -> None: def render_settings() -> None:
st.header("数据与设置") st.header("数据与设置")
st.text_input("TuShare Token") cfg = get_config()
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
if st.button("保存 Token"):
cfg.tushare_token = token.strip() or None
st.success("TuShare Token 已更新,仅保存在当前会话。")
st.write("新闻源开关与数据库备份将在此配置。") st.write("新闻源开关与数据库备份将在此配置。")
@ -79,6 +164,65 @@ def render_tests() -> None:
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
st.error(f"开机检查失败:{exc}") st.error(f"开机检查失败:{exc}")
st.divider()
st.subheader("股票行情可视化")
options = _load_stock_options()
default_code = options[0] if options else "000001.SZ"
if options:
selection = st.selectbox("选择股票", options, index=0)
ts_code = _parse_ts_code(selection)
else:
ts_code = st.text_input("输入股票代码(如 000001.SZ", value=default_code).strip().upper()
viz_col1, viz_col2 = st.columns(2)
default_start = date.today() - timedelta(days=180)
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
if start_date > end_date:
st.error("开始日期不能晚于结束日期")
return
with st.spinner("正在加载行情数据..."):
try:
df = _load_daily_frame(ts_code, start_date, end_date)
except Exception as exc: # noqa: BLE001
st.error(f"读取数据失败:{exc}")
return
if df.empty:
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
return
price_df = df[["close"]].rename(columns={"close": "收盘价"})
volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"})
if price_df.shape[0] > 180:
sampled = price_df.resample('3D').last().dropna()
else:
sampled = price_df
if volume_df.shape[0] > 180:
volume_sampled = volume_df.resample('3D').mean().dropna()
else:
volume_sampled = volume_df
first_close = sampled.iloc[0, 0]
last_close = sampled.iloc[-1, 0]
delta_abs = last_close - first_close
delta_pct = (delta_abs / first_close * 100) if first_close else 0.0
metric_col1, metric_col2, metric_col3 = st.columns(3)
metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}")
metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%")
metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}")
st.line_chart(sampled, width='stretch')
st.bar_chart(volume_sampled, width='stretch')
st.caption("提示:成交量单位为手,若需更长区间请调整日期后重新加载。")
st.dataframe(df.reset_index().tail(10), width='stretch')
def main() -> None: def main() -> None:
st.set_page_config(page_title="多智能体投资助理", layout="wide") st.set_page_config(page_title="多智能体投资助理", layout="wide")