update
This commit is contained in:
parent
c6c781cc6b
commit
2dc2753827
13
README.md
13
README.md
@ -14,20 +14,17 @@
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 初始化数据库结构
|
||||
python -m app.cli init-db
|
||||
|
||||
# 一键开机检查(默认回溯 365 天,缺失数据会自动补齐)
|
||||
python -m app.cli boot-check --days 365
|
||||
|
||||
# 启动界面
|
||||
# 启动交互界面(内含数据库初始化、开机检查、样例回测入口)
|
||||
streamlit run app/ui/streamlit_app.py
|
||||
```
|
||||
|
||||
Streamlit `自检测试` 页签提供:
|
||||
- 数据库初始化快捷按钮;
|
||||
- TuShare 小范围拉取测试;
|
||||
- 开机检查器(展示当前数据覆盖范围与股票基础信息完整度)。
|
||||
- 一键开机检查(可自动补数并展示覆盖摘要);
|
||||
- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。
|
||||
|
||||
`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。
|
||||
|
||||
## 下一步
|
||||
|
||||
|
||||
77
app/cli.py
77
app/cli.py
@ -1,77 +0,0 @@
|
||||
"""Command line entry points for routine tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import date
|
||||
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
result = initialize_database()
|
||||
if result.skipped:
|
||||
print("Database already initialized; skipping schema creation")
|
||||
else:
|
||||
print(f"Initialized database with {result.executed} statements")
|
||||
|
||||
|
||||
def run_sample_backtest() -> None:
|
||||
cfg = BtConfig(
|
||||
id="demo",
|
||||
name="Demo Strategy",
|
||||
start_date=date(2020, 1, 1),
|
||||
end_date=date(2020, 3, 31),
|
||||
universe=["000001.SZ"],
|
||||
params={
|
||||
"target": 0.035,
|
||||
"stop": -0.015,
|
||||
"hold_days": 10,
|
||||
},
|
||||
)
|
||||
run_backtest(cfg)
|
||||
|
||||
|
||||
def run_boot_check_cli(days: int) -> None:
|
||||
report = run_boot_check(days=days)
|
||||
print("Boot check summary:")
|
||||
print(f" Period: {report.start} ~ {report.end}")
|
||||
print(f" Expected trading days: {report.expected_trading_days}")
|
||||
for name, info in report.tables.items():
|
||||
print(
|
||||
f" {name}: min={info.get('min')}, max={info.get('max')}, "
|
||||
f"distinct={info.get('distinct_days')}, ok={info.get('meets_expectation')}"
|
||||
)
|
||||
stock = report.stock_basic
|
||||
print(
|
||||
f" stock_basic: total={stock.get('total')}, "
|
||||
f"SSE listed={stock.get('sse_listed')}, SZSE listed={stock.get('szse_listed')}"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Investment assistant toolkit")
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
sub.add_parser("init-db", help="Initialize SQLite schema")
|
||||
|
||||
boot_parser = sub.add_parser("boot-check", help="Run startup data coverage check")
|
||||
boot_parser.add_argument("--days", type=int, default=365, help="Lookback window in days")
|
||||
|
||||
sub.add_parser("sample-backtest", help="Execute demo backtest run")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None or args.command == "init-db":
|
||||
init_db()
|
||||
elif args.command == "boot-check":
|
||||
run_boot_check_cli(days=args.days)
|
||||
elif args.command == "sample-backtest":
|
||||
run_sample_backtest()
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -2,8 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, List
|
||||
|
||||
from app.utils.db import db_session
|
||||
|
||||
@ -86,9 +86,10 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||||
exchange TEXT,
|
||||
cal_date TEXT PRIMARY KEY,
|
||||
cal_date TEXT,
|
||||
is_open INTEGER,
|
||||
pretrade_date TEXT
|
||||
pretrade_date TEXT,
|
||||
PRIMARY KEY (exchange, cal_date)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
@ -203,29 +204,49 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
"""
|
||||
)
|
||||
|
||||
REQUIRED_TABLES = (
|
||||
"stock_basic",
|
||||
"daily",
|
||||
"daily_basic",
|
||||
"adj_factor",
|
||||
"suspend",
|
||||
"trade_calendar",
|
||||
"stk_limit",
|
||||
"news",
|
||||
"heat_daily",
|
||||
"bt_config",
|
||||
"bt_trades",
|
||||
"bt_nav",
|
||||
"bt_report",
|
||||
"run_log",
|
||||
"agent_utils",
|
||||
"alloc_log",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationResult:
|
||||
executed: int
|
||||
skipped: bool = False
|
||||
missing_tables: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _schema_exists() -> bool:
|
||||
def _missing_tables() -> List[str]:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
return False
|
||||
return list(REQUIRED_TABLES)
|
||||
existing = {row["name"] for row in rows}
|
||||
return [name for name in REQUIRED_TABLES if name not in existing]
|
||||
|
||||
|
||||
def initialize_database() -> MigrationResult:
|
||||
"""Create tables and indexes required by the application."""
|
||||
|
||||
if _schema_exists():
|
||||
return MigrationResult(executed=0, skipped=True)
|
||||
missing = _missing_tables()
|
||||
if not missing:
|
||||
return MigrationResult(executed=0, skipped=True, missing_tables=[])
|
||||
|
||||
executed = 0
|
||||
with db_session() as conn:
|
||||
@ -233,4 +254,4 @@ def initialize_database() -> MigrationResult:
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
cursor.executescript(statement)
|
||||
executed += 1
|
||||
return MigrationResult(executed=executed)
|
||||
return MigrationResult(executed=executed, skipped=False, missing_tables=missing)
|
||||
|
||||
@ -2,21 +2,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import date
|
||||
from datetime import date, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.llm.explain import make_human_card
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
|
||||
|
||||
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
|
||||
).fetchall()
|
||||
except Exception:
|
||||
return []
|
||||
options: list[str] = []
|
||||
for row in rows[:limit]:
|
||||
code = row["ts_code"]
|
||||
name = row["name"] or ""
|
||||
label = f"{code} | {name}" if name else code
|
||||
options.append(label)
|
||||
return options
|
||||
|
||||
|
||||
def _parse_ts_code(selection: str) -> str:
|
||||
return selection.split(' | ')[0].strip().upper()
|
||||
|
||||
|
||||
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
||||
start_str = start.strftime('%Y%m%d')
|
||||
end_str = end.strftime('%Y%m%d')
|
||||
range_query = (
|
||||
"SELECT trade_date, open, high, low, close, vol, amount "
|
||||
"FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date"
|
||||
)
|
||||
fallback_query = (
|
||||
"SELECT trade_date, open, high, low, close, vol, amount "
|
||||
"FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200"
|
||||
)
|
||||
with db_session(read_only=True) as conn:
|
||||
df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str))
|
||||
if df.empty:
|
||||
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.sort_values('trade_date')
|
||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||
df.set_index('trade_date', inplace=True)
|
||||
return df
|
||||
def render_today_plan() -> None:
|
||||
st.header("今日计划")
|
||||
st.write("待接入候选池筛选与多智能体决策结果。")
|
||||
@ -27,12 +75,49 @@ def render_today_plan() -> None:
|
||||
def render_backtest() -> None:
|
||||
st.header("回测与复盘")
|
||||
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||
st.button("开始回测")
|
||||
|
||||
default_start = date(2020, 1, 1)
|
||||
default_end = date(2020, 3, 31)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
start_date = col1.date_input("开始日期", value=default_start)
|
||||
end_date = col2.date_input("结束日期", value=default_end)
|
||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ")
|
||||
target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f")
|
||||
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f")
|
||||
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
|
||||
|
||||
if st.button("运行回测"):
|
||||
with st.spinner("正在执行回测..."):
|
||||
try:
|
||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
cfg = BtConfig(
|
||||
id="streamlit_demo",
|
||||
name="Streamlit Demo Strategy",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
)
|
||||
result = run_backtest(cfg)
|
||||
st.success("回测执行完成,详见回测结果摘要。")
|
||||
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"回测执行失败:{exc}")
|
||||
|
||||
|
||||
def render_settings() -> None:
|
||||
st.header("数据与设置")
|
||||
st.text_input("TuShare Token")
|
||||
cfg = get_config()
|
||||
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||
if st.button("保存 Token"):
|
||||
cfg.tushare_token = token.strip() or None
|
||||
st.success("TuShare Token 已更新,仅保存在当前会话。")
|
||||
|
||||
st.write("新闻源开关与数据库备份将在此配置。")
|
||||
|
||||
|
||||
@ -79,6 +164,65 @@ def render_tests() -> None:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"开机检查失败:{exc}")
|
||||
|
||||
st.divider()
|
||||
st.subheader("股票行情可视化")
|
||||
options = _load_stock_options()
|
||||
default_code = options[0] if options else "000001.SZ"
|
||||
|
||||
if options:
|
||||
selection = st.selectbox("选择股票", options, index=0)
|
||||
ts_code = _parse_ts_code(selection)
|
||||
else:
|
||||
ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper()
|
||||
|
||||
viz_col1, viz_col2 = st.columns(2)
|
||||
default_start = date.today() - timedelta(days=180)
|
||||
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
|
||||
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
|
||||
|
||||
if start_date > end_date:
|
||||
st.error("开始日期不能晚于结束日期")
|
||||
return
|
||||
|
||||
with st.spinner("正在加载行情数据..."):
|
||||
try:
|
||||
df = _load_daily_frame(ts_code, start_date, end_date)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"读取数据失败:{exc}")
|
||||
return
|
||||
|
||||
if df.empty:
|
||||
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
|
||||
return
|
||||
|
||||
price_df = df[["close"]].rename(columns={"close": "收盘价"})
|
||||
volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"})
|
||||
|
||||
if price_df.shape[0] > 180:
|
||||
sampled = price_df.resample('3D').last().dropna()
|
||||
else:
|
||||
sampled = price_df
|
||||
|
||||
if volume_df.shape[0] > 180:
|
||||
volume_sampled = volume_df.resample('3D').mean().dropna()
|
||||
else:
|
||||
volume_sampled = volume_df
|
||||
|
||||
first_close = sampled.iloc[0, 0]
|
||||
last_close = sampled.iloc[-1, 0]
|
||||
delta_abs = last_close - first_close
|
||||
delta_pct = (delta_abs / first_close * 100) if first_close else 0.0
|
||||
|
||||
metric_col1, metric_col2, metric_col3 = st.columns(3)
|
||||
metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}")
|
||||
metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%")
|
||||
metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}")
|
||||
|
||||
st.line_chart(sampled, width='stretch')
|
||||
st.bar_chart(volume_sampled, width='stretch')
|
||||
st.caption("提示:成交量单位为手,若需更长区间请调整日期后重新加载。")
|
||||
st.dataframe(df.reset_index().tail(10), width='stretch')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user