update
This commit is contained in:
parent
c6c781cc6b
commit
2dc2753827
13
README.md
13
README.md
@ -14,20 +14,17 @@
|
|||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 初始化数据库结构
|
# 启动交互界面(内含数据库初始化、开机检查、样例回测入口)
|
||||||
python -m app.cli init-db
|
|
||||||
|
|
||||||
# 一键开机检查(默认回溯 365 天,缺失数据会自动补齐)
|
|
||||||
python -m app.cli boot-check --days 365
|
|
||||||
|
|
||||||
# 启动界面
|
|
||||||
streamlit run app/ui/streamlit_app.py
|
streamlit run app/ui/streamlit_app.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Streamlit `自检测试` 页签提供:
|
Streamlit `自检测试` 页签提供:
|
||||||
- 数据库初始化快捷按钮;
|
- 数据库初始化快捷按钮;
|
||||||
- TuShare 小范围拉取测试;
|
- TuShare 小范围拉取测试;
|
||||||
- 开机检查器(展示当前数据覆盖范围与股票基础信息完整度)。
|
- 一键开机检查(可自动补数并展示覆盖摘要);
|
||||||
|
- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。
|
||||||
|
|
||||||
|
`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。
|
||||||
|
|
||||||
## 下一步
|
## 下一步
|
||||||
|
|
||||||
|
|||||||
77
app/cli.py
77
app/cli.py
@ -1,77 +0,0 @@
|
|||||||
"""Command line entry points for routine tasks."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from datetime import date
|
|
||||||
|
|
||||||
from app.backtest.engine import BtConfig, run_backtest
|
|
||||||
from app.data.schema import initialize_database
|
|
||||||
from app.ingest.checker import run_boot_check
|
|
||||||
|
|
||||||
|
|
||||||
def init_db() -> None:
|
|
||||||
result = initialize_database()
|
|
||||||
if result.skipped:
|
|
||||||
print("Database already initialized; skipping schema creation")
|
|
||||||
else:
|
|
||||||
print(f"Initialized database with {result.executed} statements")
|
|
||||||
|
|
||||||
|
|
||||||
def run_sample_backtest() -> None:
|
|
||||||
cfg = BtConfig(
|
|
||||||
id="demo",
|
|
||||||
name="Demo Strategy",
|
|
||||||
start_date=date(2020, 1, 1),
|
|
||||||
end_date=date(2020, 3, 31),
|
|
||||||
universe=["000001.SZ"],
|
|
||||||
params={
|
|
||||||
"target": 0.035,
|
|
||||||
"stop": -0.015,
|
|
||||||
"hold_days": 10,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
run_backtest(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def run_boot_check_cli(days: int) -> None:
|
|
||||||
report = run_boot_check(days=days)
|
|
||||||
print("Boot check summary:")
|
|
||||||
print(f" Period: {report.start} ~ {report.end}")
|
|
||||||
print(f" Expected trading days: {report.expected_trading_days}")
|
|
||||||
for name, info in report.tables.items():
|
|
||||||
print(
|
|
||||||
f" {name}: min={info.get('min')}, max={info.get('max')}, "
|
|
||||||
f"distinct={info.get('distinct_days')}, ok={info.get('meets_expectation')}"
|
|
||||||
)
|
|
||||||
stock = report.stock_basic
|
|
||||||
print(
|
|
||||||
f" stock_basic: total={stock.get('total')}, "
|
|
||||||
f"SSE listed={stock.get('sse_listed')}, SZSE listed={stock.get('szse_listed')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Investment assistant toolkit")
|
|
||||||
sub = parser.add_subparsers(dest="command")
|
|
||||||
|
|
||||||
sub.add_parser("init-db", help="Initialize SQLite schema")
|
|
||||||
|
|
||||||
boot_parser = sub.add_parser("boot-check", help="Run startup data coverage check")
|
|
||||||
boot_parser.add_argument("--days", type=int, default=365, help="Lookback window in days")
|
|
||||||
|
|
||||||
sub.add_parser("sample-backtest", help="Execute demo backtest run")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.command is None or args.command == "init-db":
|
|
||||||
init_db()
|
|
||||||
elif args.command == "boot-check":
|
|
||||||
run_boot_check_cli(days=args.days)
|
|
||||||
elif args.command == "sample-backtest":
|
|
||||||
run_sample_backtest()
|
|
||||||
else:
|
|
||||||
parser.print_help()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -2,8 +2,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterable
|
from typing import Iterable, List
|
||||||
|
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
|
||||||
@ -86,9 +86,10 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
|||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS trade_calendar (
|
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||||||
exchange TEXT,
|
exchange TEXT,
|
||||||
cal_date TEXT PRIMARY KEY,
|
cal_date TEXT,
|
||||||
is_open INTEGER,
|
is_open INTEGER,
|
||||||
pretrade_date TEXT
|
pretrade_date TEXT,
|
||||||
|
PRIMARY KEY (exchange, cal_date)
|
||||||
);
|
);
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
@ -203,29 +204,49 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REQUIRED_TABLES = (
|
||||||
|
"stock_basic",
|
||||||
|
"daily",
|
||||||
|
"daily_basic",
|
||||||
|
"adj_factor",
|
||||||
|
"suspend",
|
||||||
|
"trade_calendar",
|
||||||
|
"stk_limit",
|
||||||
|
"news",
|
||||||
|
"heat_daily",
|
||||||
|
"bt_config",
|
||||||
|
"bt_trades",
|
||||||
|
"bt_nav",
|
||||||
|
"bt_report",
|
||||||
|
"run_log",
|
||||||
|
"agent_utils",
|
||||||
|
"alloc_log",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MigrationResult:
|
class MigrationResult:
|
||||||
executed: int
|
executed: int
|
||||||
skipped: bool = False
|
skipped: bool = False
|
||||||
|
missing_tables: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def _schema_exists() -> bool:
|
def _missing_tables() -> List[str]:
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
cursor = conn.execute(
|
rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='news'"
|
|
||||||
)
|
|
||||||
return cursor.fetchone() is not None
|
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
return False
|
return list(REQUIRED_TABLES)
|
||||||
|
existing = {row["name"] for row in rows}
|
||||||
|
return [name for name in REQUIRED_TABLES if name not in existing]
|
||||||
|
|
||||||
|
|
||||||
def initialize_database() -> MigrationResult:
|
def initialize_database() -> MigrationResult:
|
||||||
"""Create tables and indexes required by the application."""
|
"""Create tables and indexes required by the application."""
|
||||||
|
|
||||||
if _schema_exists():
|
missing = _missing_tables()
|
||||||
return MigrationResult(executed=0, skipped=True)
|
if not missing:
|
||||||
|
return MigrationResult(executed=0, skipped=True, missing_tables=[])
|
||||||
|
|
||||||
executed = 0
|
executed = 0
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
@ -233,4 +254,4 @@ def initialize_database() -> MigrationResult:
|
|||||||
for statement in SCHEMA_STATEMENTS:
|
for statement in SCHEMA_STATEMENTS:
|
||||||
cursor.executescript(statement)
|
cursor.executescript(statement)
|
||||||
executed += 1
|
executed += 1
|
||||||
return MigrationResult(executed=executed)
|
return MigrationResult(executed=executed, skipped=False, missing_tables=missing)
|
||||||
|
|||||||
@ -2,21 +2,69 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from datetime import date
|
from datetime import date, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parents[2]
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
if str(ROOT) not in sys.path:
|
if str(ROOT) not in sys.path:
|
||||||
sys.path.insert(0, str(ROOT))
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from app.backtest.engine import BtConfig, run_backtest
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.ingest.checker import run_boot_check
|
from app.ingest.checker import run_boot_check
|
||||||
from app.ingest.tushare import FetchJob, run_ingestion
|
from app.ingest.tushare import FetchJob, run_ingestion
|
||||||
from app.llm.explain import make_human_card
|
from app.llm.explain import make_human_card
|
||||||
|
from app.utils.config import get_config
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
|
||||||
|
).fetchall()
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
options: list[str] = []
|
||||||
|
for row in rows[:limit]:
|
||||||
|
code = row["ts_code"]
|
||||||
|
name = row["name"] or ""
|
||||||
|
label = f"{code} | {name}" if name else code
|
||||||
|
options.append(label)
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ts_code(selection: str) -> str:
|
||||||
|
return selection.split(' | ')[0].strip().upper()
|
||||||
|
|
||||||
|
|
||||||
|
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
start_str = start.strftime('%Y%m%d')
|
||||||
|
end_str = end.strftime('%Y%m%d')
|
||||||
|
range_query = (
|
||||||
|
"SELECT trade_date, open, high, low, close, vol, amount "
|
||||||
|
"FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date"
|
||||||
|
)
|
||||||
|
fallback_query = (
|
||||||
|
"SELECT trade_date, open, high, low, close, vol, amount "
|
||||||
|
"FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200"
|
||||||
|
)
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str))
|
||||||
|
if df.empty:
|
||||||
|
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
|
||||||
|
if df.empty:
|
||||||
|
return df
|
||||||
|
df = df.sort_values('trade_date')
|
||||||
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
|
df.set_index('trade_date', inplace=True)
|
||||||
|
return df
|
||||||
def render_today_plan() -> None:
|
def render_today_plan() -> None:
|
||||||
st.header("今日计划")
|
st.header("今日计划")
|
||||||
st.write("待接入候选池筛选与多智能体决策结果。")
|
st.write("待接入候选池筛选与多智能体决策结果。")
|
||||||
@ -27,12 +75,49 @@ def render_today_plan() -> None:
|
|||||||
def render_backtest() -> None:
|
def render_backtest() -> None:
|
||||||
st.header("回测与复盘")
|
st.header("回测与复盘")
|
||||||
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||||
st.button("开始回测")
|
|
||||||
|
default_start = date(2020, 1, 1)
|
||||||
|
default_end = date(2020, 3, 31)
|
||||||
|
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
start_date = col1.date_input("开始日期", value=default_start)
|
||||||
|
end_date = col2.date_input("结束日期", value=default_end)
|
||||||
|
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ")
|
||||||
|
target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f")
|
||||||
|
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f")
|
||||||
|
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
|
||||||
|
|
||||||
|
if st.button("运行回测"):
|
||||||
|
with st.spinner("正在执行回测..."):
|
||||||
|
try:
|
||||||
|
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="streamlit_demo",
|
||||||
|
name="Streamlit Demo Strategy",
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
universe=universe,
|
||||||
|
params={
|
||||||
|
"target": target,
|
||||||
|
"stop": stop,
|
||||||
|
"hold_days": int(hold_days),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = run_backtest(cfg)
|
||||||
|
st.success("回测执行完成,详见回测结果摘要。")
|
||||||
|
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
st.error(f"回测执行失败:{exc}")
|
||||||
|
|
||||||
|
|
||||||
def render_settings() -> None:
|
def render_settings() -> None:
|
||||||
st.header("数据与设置")
|
st.header("数据与设置")
|
||||||
st.text_input("TuShare Token")
|
cfg = get_config()
|
||||||
|
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||||
|
if st.button("保存 Token"):
|
||||||
|
cfg.tushare_token = token.strip() or None
|
||||||
|
st.success("TuShare Token 已更新,仅保存在当前会话。")
|
||||||
|
|
||||||
st.write("新闻源开关与数据库备份将在此配置。")
|
st.write("新闻源开关与数据库备份将在此配置。")
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +164,65 @@ def render_tests() -> None:
|
|||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
st.error(f"开机检查失败:{exc}")
|
st.error(f"开机检查失败:{exc}")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
st.subheader("股票行情可视化")
|
||||||
|
options = _load_stock_options()
|
||||||
|
default_code = options[0] if options else "000001.SZ"
|
||||||
|
|
||||||
|
if options:
|
||||||
|
selection = st.selectbox("选择股票", options, index=0)
|
||||||
|
ts_code = _parse_ts_code(selection)
|
||||||
|
else:
|
||||||
|
ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper()
|
||||||
|
|
||||||
|
viz_col1, viz_col2 = st.columns(2)
|
||||||
|
default_start = date.today() - timedelta(days=180)
|
||||||
|
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
|
||||||
|
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
|
||||||
|
|
||||||
|
if start_date > end_date:
|
||||||
|
st.error("开始日期不能晚于结束日期")
|
||||||
|
return
|
||||||
|
|
||||||
|
with st.spinner("正在加载行情数据..."):
|
||||||
|
try:
|
||||||
|
df = _load_daily_frame(ts_code, start_date, end_date)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
st.error(f"读取数据失败:{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
|
||||||
|
return
|
||||||
|
|
||||||
|
price_df = df[["close"]].rename(columns={"close": "收盘价"})
|
||||||
|
volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"})
|
||||||
|
|
||||||
|
if price_df.shape[0] > 180:
|
||||||
|
sampled = price_df.resample('3D').last().dropna()
|
||||||
|
else:
|
||||||
|
sampled = price_df
|
||||||
|
|
||||||
|
if volume_df.shape[0] > 180:
|
||||||
|
volume_sampled = volume_df.resample('3D').mean().dropna()
|
||||||
|
else:
|
||||||
|
volume_sampled = volume_df
|
||||||
|
|
||||||
|
first_close = sampled.iloc[0, 0]
|
||||||
|
last_close = sampled.iloc[-1, 0]
|
||||||
|
delta_abs = last_close - first_close
|
||||||
|
delta_pct = (delta_abs / first_close * 100) if first_close else 0.0
|
||||||
|
|
||||||
|
metric_col1, metric_col2, metric_col3 = st.columns(3)
|
||||||
|
metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}")
|
||||||
|
metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%")
|
||||||
|
metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}")
|
||||||
|
|
||||||
|
st.line_chart(sampled, width='stretch')
|
||||||
|
st.bar_chart(volume_sampled, width='stretch')
|
||||||
|
st.caption("提示:成交量单位为手,若需更长区间请调整日期后重新加载。")
|
||||||
|
st.dataframe(df.reset_index().tail(10), width='stretch')
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user