update
This commit is contained in:
parent
a3b16ffa8d
commit
9c7a68d313
@ -29,6 +29,13 @@ LOG_EXTRA = {"stage": "data_ingest"}
|
|||||||
_CALL_QUEUE = deque()
|
_CALL_QUEUE = deque()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_date_str(value: Optional[str]) -> Optional[str]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
|
|
||||||
|
|
||||||
def _respect_rate_limit(cfg) -> None:
|
def _respect_rate_limit(cfg) -> None:
|
||||||
max_calls = cfg.max_calls_per_minute
|
max_calls = cfg.max_calls_per_minute
|
||||||
if max_calls <= 0:
|
if max_calls <= 0:
|
||||||
@ -44,23 +51,6 @@ def _respect_rate_limit(cfg) -> None:
|
|||||||
_CALL_QUEUE.append(time.time())
|
_CALL_QUEUE.append(time.time())
|
||||||
|
|
||||||
|
|
||||||
def _existing_date_range(
|
|
||||||
table: str,
|
|
||||||
date_col: str,
|
|
||||||
ts_code: str | None = None,
|
|
||||||
) -> Tuple[str | None, str | None]:
|
|
||||||
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
|
|
||||||
params: Tuple = ()
|
|
||||||
if ts_code:
|
|
||||||
query += " WHERE ts_code = ?"
|
|
||||||
params = (ts_code,)
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
row = conn.execute(query, params).fetchone()
|
|
||||||
if row is None:
|
|
||||||
return None, None
|
|
||||||
return row["min_d"], row["max_d"]
|
|
||||||
|
|
||||||
|
|
||||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||||
if df is None or df.empty:
|
if df is None or df.empty:
|
||||||
return []
|
return []
|
||||||
@ -352,25 +342,63 @@ def _record_exists(
|
|||||||
|
|
||||||
|
|
||||||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||||||
min_d, max_d = _existing_date_range(table, date_col, ts_code)
|
|
||||||
if min_d is None or max_d is None:
|
|
||||||
return False
|
|
||||||
start_str = _format_date(start)
|
start_str = _format_date(start)
|
||||||
end_str = _format_date(end)
|
end_str = _format_date(end)
|
||||||
return min_d <= start_str and max_d >= end_str
|
|
||||||
|
effective_start = start_str
|
||||||
|
effective_end = end_str
|
||||||
|
|
||||||
|
if ts_code:
|
||||||
|
list_date, delist_date = _listing_window(ts_code)
|
||||||
|
if list_date:
|
||||||
|
effective_start = max(effective_start, list_date)
|
||||||
|
if delist_date:
|
||||||
|
effective_end = min(effective_end, delist_date)
|
||||||
|
if effective_start > effective_end:
|
||||||
|
LOGGER.debug(
|
||||||
|
"股票 %s 在目标区间之外,跳过补数",
|
||||||
|
ts_code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code)
|
||||||
|
else:
|
||||||
|
stats = _range_stats(table, date_col, effective_start, effective_end)
|
||||||
|
|
||||||
|
if stats["min"] is None or stats["max"] is None:
|
||||||
|
return False
|
||||||
|
if stats["min"] > effective_start or stats["max"] < effective_end:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if ts_code is None:
|
||||||
|
expected_days = _expected_trading_days(effective_start, effective_end)
|
||||||
|
if expected_days and (stats["distinct"] or 0) < expected_days:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]:
|
def _range_stats(
|
||||||
|
table: str,
|
||||||
|
date_col: str,
|
||||||
|
start_str: str,
|
||||||
|
end_str: str,
|
||||||
|
ts_code: str | None = None,
|
||||||
|
) -> Dict[str, Optional[str]]:
|
||||||
sql = (
|
sql = (
|
||||||
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
|
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
|
||||||
f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} "
|
f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} "
|
||||||
f"WHERE {date_col} BETWEEN ? AND ?"
|
f"WHERE {date_col} BETWEEN ? AND ?"
|
||||||
)
|
)
|
||||||
|
params: List[object] = [start_str, end_str]
|
||||||
|
if ts_code:
|
||||||
|
sql += " AND ts_code = ?"
|
||||||
|
params.append(ts_code)
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
row = conn.execute(sql, (start_str, end_str)).fetchone()
|
row = conn.execute(sql, tuple(params)).fetchone()
|
||||||
return {
|
return {
|
||||||
"min": row["min_d"],
|
"min": row["min_d"] if row else None,
|
||||||
"max": row["max_d"],
|
"max": row["max_d"] if row else None,
|
||||||
"distinct": row["distinct_days"] if row else 0,
|
"distinct": row["distinct_days"] if row else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,6 +420,17 @@ def _range_needs_refresh(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?",
|
||||||
|
(ts_code,),
|
||||||
|
).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None, None
|
||||||
|
return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index]
|
||||||
|
|
||||||
|
|
||||||
def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool:
|
def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt
|
SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt
|
||||||
@ -421,7 +460,12 @@ def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE")
|
|||||||
|
|
||||||
def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]:
|
def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]:
|
||||||
client = _ensure_client()
|
client = _ensure_client()
|
||||||
LOGGER.info("拉取股票基础信息(交易所:%s,状态:%s)", exchange or "全部", list_status)
|
LOGGER.info(
|
||||||
|
"拉取股票基础信息(交易所:%s,状态:%s)",
|
||||||
|
exchange or "全部",
|
||||||
|
list_status,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
|
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
|
||||||
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
|
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
|
||||||
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
||||||
@ -626,7 +670,13 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
|
|||||||
client = _ensure_client()
|
client = _ensure_client()
|
||||||
start_date = _format_date(start)
|
start_date = _format_date(start)
|
||||||
end_date = _format_date(end)
|
end_date = _format_date(end)
|
||||||
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date)
|
LOGGER.info(
|
||||||
|
"拉取交易日历(交易所:%s,区间:%s-%s)",
|
||||||
|
exchange,
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||||||
if df is not None and not df.empty and "is_open" in df.columns:
|
if df is not None and not df.empty and "is_open" in df.columns:
|
||||||
df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int)
|
df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int)
|
||||||
@ -672,7 +722,7 @@ def fetch_stk_limit(
|
|||||||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||||||
items = list(rows)
|
items = list(rows)
|
||||||
if not items:
|
if not items:
|
||||||
LOGGER.info("表 %s 没有新增记录,跳过写入", table)
|
LOGGER.info("表 %s 没有新增记录,跳过写入", table, extra=LOG_EXTRA)
|
||||||
return
|
return
|
||||||
|
|
||||||
schema = _TABLE_SCHEMAS.get(table)
|
schema = _TABLE_SCHEMAS.get(table)
|
||||||
@ -683,7 +733,7 @@ def save_records(table: str, rows: Iterable[Dict]) -> None:
|
|||||||
placeholders = ",".join([f":{col}" for col in columns])
|
placeholders = ",".join([f":{col}" for col in columns])
|
||||||
col_clause = ",".join(columns)
|
col_clause = ",".join(columns)
|
||||||
|
|
||||||
LOGGER.info("表 %s 写入 %d 条记录", table, len(items))
|
LOGGER.info("表 %s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA)
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
conn.executescript(schema)
|
conn.executescript(schema)
|
||||||
conn.executemany(
|
conn.executemany(
|
||||||
@ -700,7 +750,11 @@ def ensure_stock_basic(list_status: str = "L") -> None:
|
|||||||
(*exchanges, list_status),
|
(*exchanges, list_status),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if row and row["cnt"]:
|
if row and row["cnt"]:
|
||||||
LOGGER.info("股票基础信息已存在 %d 条记录,跳过拉取", row["cnt"])
|
LOGGER.info(
|
||||||
|
"股票基础信息已存在 %d 条记录,跳过拉取",
|
||||||
|
row["cnt"],
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
for exch in exchanges:
|
for exch in exchanges:
|
||||||
@ -736,7 +790,7 @@ def ensure_data_coverage(
|
|||||||
progress = min(current_step / total_steps, 1.0)
|
progress = min(current_step / total_steps, 1.0)
|
||||||
if progress_hook:
|
if progress_hook:
|
||||||
progress_hook(message, progress)
|
progress_hook(message, progress)
|
||||||
LOGGER.info(message)
|
LOGGER.info(message, extra=LOG_EXTRA)
|
||||||
|
|
||||||
advance("准备股票基础信息与交易日历")
|
advance("准备股票基础信息与交易日历")
|
||||||
ensure_stock_basic()
|
ensure_stock_basic()
|
||||||
@ -824,6 +878,8 @@ def ensure_data_coverage(
|
|||||||
|
|
||||||
if progress_hook:
|
if progress_hook:
|
||||||
progress_hook("数据覆盖检查完成", 1.0)
|
progress_hook("数据覆盖检查完成", 1.0)
|
||||||
|
|
||||||
|
|
||||||
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
|
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
|
||||||
start_str = _format_date(start)
|
start_str = _format_date(start)
|
||||||
end_str = _format_date(end)
|
end_str = _format_date(end)
|
||||||
@ -876,7 +932,7 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
|
|||||||
|
|
||||||
|
|
||||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
|
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||||||
ensure_data_coverage(
|
ensure_data_coverage(
|
||||||
job.start,
|
job.start,
|
||||||
job.end,
|
job.end,
|
||||||
@ -884,4 +940,4 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
|||||||
include_limits=include_limits,
|
include_limits=include_limits,
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
LOGGER.info("任务 %s 完成", job.name)
|
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|
||||||
|
|||||||
@ -21,8 +21,11 @@ from app.ingest.tushare import FetchJob, run_ingestion
|
|||||||
from app.llm.explain import make_human_card
|
from app.llm.explain import make_human_card
|
||||||
from app.utils.config import get_config
|
from app.utils.config import get_config
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "ui"}
|
||||||
|
|
||||||
|
|
||||||
def _load_stock_options(limit: int = 500) -> list[str]:
|
def _load_stock_options(limit: int = 500) -> list[str]:
|
||||||
@ -32,6 +35,7 @@ def _load_stock_options(limit: int = 500) -> list[str]:
|
|||||||
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
|
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA)
|
||||||
return []
|
return []
|
||||||
options: list[str] = []
|
options: list[str] = []
|
||||||
for row in rows[:limit]:
|
for row in rows[:limit]:
|
||||||
@ -39,6 +43,7 @@ def _load_stock_options(limit: int = 500) -> list[str]:
|
|||||||
name = row["name"] or ""
|
name = row["name"] or ""
|
||||||
label = f"{code} | {name}" if name else code
|
label = f"{code} | {name}" if name else code
|
||||||
options.append(label)
|
options.append(label)
|
||||||
|
LOGGER.info("加载股票选项完成,数量=%s", len(options), extra=LOG_EXTRA)
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +52,13 @@ def _parse_ts_code(selection: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
LOGGER.info(
|
||||||
|
"加载行情数据:ts_code=%s start=%s end=%s",
|
||||||
|
ts_code,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
start_str = start.strftime('%Y%m%d')
|
start_str = start.strftime('%Y%m%d')
|
||||||
end_str = end.strftime('%Y%m%d')
|
end_str = end.strftime('%Y%m%d')
|
||||||
range_query = (
|
range_query = (
|
||||||
@ -62,24 +74,47 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
|||||||
if df.empty:
|
if df.empty:
|
||||||
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
|
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
|
||||||
if df.empty:
|
if df.empty:
|
||||||
|
LOGGER.warning(
|
||||||
|
"行情数据为空:ts_code=%s start=%s end=%s",
|
||||||
|
ts_code,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return df
|
return df
|
||||||
df = df.sort_values('trade_date')
|
df = df.sort_values('trade_date')
|
||||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||||
df.set_index('trade_date', inplace=True)
|
df.set_index('trade_date', inplace=True)
|
||||||
|
LOGGER.info("行情数据加载完成:条数=%s", len(df), extra=LOG_EXTRA)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def render_today_plan() -> None:
|
def render_today_plan() -> None:
|
||||||
|
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||||
st.header("今日计划")
|
st.header("今日计划")
|
||||||
st.write("待接入候选池筛选与多智能体决策结果。")
|
st.write("待接入候选池筛选与多智能体决策结果。")
|
||||||
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
|
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
|
||||||
|
LOGGER.debug("示例卡片内容:%s", sample, extra=LOG_EXTRA)
|
||||||
st.json(sample)
|
st.json(sample)
|
||||||
|
|
||||||
|
|
||||||
def render_backtest() -> None:
|
def render_backtest() -> None:
|
||||||
|
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
|
||||||
st.header("回测与复盘")
|
st.header("回测与复盘")
|
||||||
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||||
|
|
||||||
default_start = date(2020, 1, 1)
|
default_start = date(2020, 1, 1)
|
||||||
default_end = date(2020, 3, 31)
|
default_end = date(2020, 3, 31)
|
||||||
|
LOGGER.debug(
|
||||||
|
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||||
|
default_start,
|
||||||
|
default_end,
|
||||||
|
"000001.SZ",
|
||||||
|
0.035,
|
||||||
|
-0.015,
|
||||||
|
10,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
start_date = col1.date_input("开始日期", value=default_start)
|
start_date = col1.date_input("开始日期", value=default_start)
|
||||||
@ -88,11 +123,32 @@ def render_backtest() -> None:
|
|||||||
target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f")
|
target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f")
|
||||||
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f")
|
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f")
|
||||||
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
|
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
|
||||||
|
LOGGER.debug(
|
||||||
|
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
universe_text,
|
||||||
|
target,
|
||||||
|
stop,
|
||||||
|
hold_days,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
if st.button("运行回测"):
|
if st.button("运行回测"):
|
||||||
|
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||||||
with st.spinner("正在执行回测..."):
|
with st.spinner("正在执行回测..."):
|
||||||
try:
|
try:
|
||||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
|
LOGGER.info(
|
||||||
|
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
universe,
|
||||||
|
target,
|
||||||
|
stop,
|
||||||
|
hold_days,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
cfg = BtConfig(
|
cfg = BtConfig(
|
||||||
id="streamlit_demo",
|
id="streamlit_demo",
|
||||||
name="Streamlit Demo Strategy",
|
name="Streamlit Demo Strategy",
|
||||||
@ -106,39 +162,55 @@ def render_backtest() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
result = run_backtest(cfg)
|
result = run_backtest(cfg)
|
||||||
|
LOGGER.info(
|
||||||
|
"回测完成:nav_records=%s trades=%s",
|
||||||
|
len(result.nav_series),
|
||||||
|
len(result.trades),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
st.success("回测执行完成,详见回测结果摘要。")
|
st.success("回测执行完成,详见回测结果摘要。")
|
||||||
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||||
st.error(f"回测执行失败:{exc}")
|
st.error(f"回测执行失败:{exc}")
|
||||||
|
|
||||||
|
|
||||||
def render_settings() -> None:
|
def render_settings() -> None:
|
||||||
|
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
|
||||||
st.header("数据与设置")
|
st.header("数据与设置")
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
|
LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA)
|
||||||
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||||
|
|
||||||
if st.button("保存设置"):
|
if st.button("保存设置"):
|
||||||
|
LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA)
|
||||||
cfg.tushare_token = token.strip() or None
|
cfg.tushare_token = token.strip() or None
|
||||||
|
LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA)
|
||||||
st.success("设置已保存,仅在当前会话生效。")
|
st.success("设置已保存,仅在当前会话生效。")
|
||||||
|
|
||||||
st.write("新闻源开关与数据库备份将在此配置。")
|
st.write("新闻源开关与数据库备份将在此配置。")
|
||||||
|
|
||||||
|
|
||||||
def render_tests() -> None:
|
def render_tests() -> None:
|
||||||
|
LOGGER.info("渲染自检页面", extra=LOG_EXTRA)
|
||||||
st.header("自检测试")
|
st.header("自检测试")
|
||||||
st.write("用于快速检查数据库与数据拉取是否正常工作。")
|
st.write("用于快速检查数据库与数据拉取是否正常工作。")
|
||||||
|
|
||||||
if st.button("测试数据库初始化"):
|
if st.button("测试数据库初始化"):
|
||||||
|
LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA)
|
||||||
with st.spinner("正在检查数据库..."):
|
with st.spinner("正在检查数据库..."):
|
||||||
result = initialize_database()
|
result = initialize_database()
|
||||||
if result.skipped:
|
if result.skipped:
|
||||||
|
LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA)
|
||||||
st.success("数据库已存在,检查通过。")
|
st.success("数据库已存在,检查通过。")
|
||||||
else:
|
else:
|
||||||
|
LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA)
|
||||||
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
|
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"):
|
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"):
|
||||||
|
LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA)
|
||||||
with st.spinner("正在调用 TuShare 接口..."):
|
with st.spinner("正在调用 TuShare 接口..."):
|
||||||
try:
|
try:
|
||||||
run_ingestion(
|
run_ingestion(
|
||||||
@ -150,14 +222,17 @@ def render_tests() -> None:
|
|||||||
),
|
),
|
||||||
include_limits=False,
|
include_limits=False,
|
||||||
)
|
)
|
||||||
|
LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA)
|
||||||
st.success("TuShare 示例拉取完成,数据已写入数据库。")
|
st.success("TuShare 示例拉取完成,数据已写入数据库。")
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA)
|
||||||
st.error(f"拉取失败:{exc}")
|
st.error(f"拉取失败:{exc}")
|
||||||
|
|
||||||
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。")
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
||||||
|
LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA)
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
force_refresh = st.checkbox(
|
force_refresh = st.checkbox(
|
||||||
"强制刷新数据(关闭增量跳过)",
|
"强制刷新数据(关闭增量跳过)",
|
||||||
@ -166,8 +241,10 @@ def render_tests() -> None:
|
|||||||
)
|
)
|
||||||
if force_refresh != cfg.force_refresh:
|
if force_refresh != cfg.force_refresh:
|
||||||
cfg.force_refresh = force_refresh
|
cfg.force_refresh = force_refresh
|
||||||
|
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
|
||||||
|
|
||||||
if st.button("执行开机检查"):
|
if st.button("执行开机检查"):
|
||||||
|
LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA)
|
||||||
progress_bar = st.progress(0.0)
|
progress_bar = st.progress(0.0)
|
||||||
status_placeholder = st.empty()
|
status_placeholder = st.empty()
|
||||||
log_placeholder = st.empty()
|
log_placeholder = st.empty()
|
||||||
@ -177,6 +254,7 @@ def render_tests() -> None:
|
|||||||
progress_bar.progress(min(max(value, 0.0), 1.0))
|
progress_bar.progress(min(max(value, 0.0), 1.0))
|
||||||
status_placeholder.write(message)
|
status_placeholder.write(message)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
|
||||||
|
|
||||||
with st.spinner("正在执行开机检查..."):
|
with st.spinner("正在执行开机检查..."):
|
||||||
try:
|
try:
|
||||||
@ -185,11 +263,13 @@ def render_tests() -> None:
|
|||||||
progress_hook=hook,
|
progress_hook=hook,
|
||||||
force_refresh=force_refresh,
|
force_refresh=force_refresh,
|
||||||
)
|
)
|
||||||
|
LOGGER.info("开机检查成功", extra=LOG_EXTRA)
|
||||||
st.success("开机检查完成,以下为数据覆盖摘要。")
|
st.success("开机检查完成,以下为数据覆盖摘要。")
|
||||||
st.json(report.to_dict())
|
st.json(report.to_dict())
|
||||||
if messages:
|
if messages:
|
||||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("开机检查失败", extra=LOG_EXTRA)
|
||||||
st.error(f"开机检查失败:{exc}")
|
st.error(f"开机检查失败:{exc}")
|
||||||
if messages:
|
if messages:
|
||||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||||
@ -204,15 +284,19 @@ def render_tests() -> None:
|
|||||||
if options:
|
if options:
|
||||||
selection = st.selectbox("选择股票", options, index=0)
|
selection = st.selectbox("选择股票", options, index=0)
|
||||||
ts_code = _parse_ts_code(selection)
|
ts_code = _parse_ts_code(selection)
|
||||||
|
LOGGER.debug("选择股票:%s", ts_code, extra=LOG_EXTRA)
|
||||||
else:
|
else:
|
||||||
ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper()
|
ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper()
|
||||||
|
LOGGER.debug("输入股票:%s", ts_code, extra=LOG_EXTRA)
|
||||||
|
|
||||||
viz_col1, viz_col2 = st.columns(2)
|
viz_col1, viz_col2 = st.columns(2)
|
||||||
default_start = date.today() - timedelta(days=180)
|
default_start = date.today() - timedelta(days=180)
|
||||||
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
|
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
|
||||||
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
|
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
|
||||||
|
LOGGER.debug("行情可视化日期范围:%s-%s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
|
|
||||||
if start_date > end_date:
|
if start_date > end_date:
|
||||||
|
LOGGER.warning("无效日期范围:%s>%s", start_date, end_date, extra=LOG_EXTRA)
|
||||||
st.error("开始日期不能晚于结束日期")
|
st.error("开始日期不能晚于结束日期")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -220,10 +304,12 @@ def render_tests() -> None:
|
|||||||
try:
|
try:
|
||||||
df = _load_daily_frame(ts_code, start_date, end_date)
|
df = _load_daily_frame(ts_code, start_date, end_date)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA)
|
||||||
st.error(f"读取数据失败:{exc}")
|
st.error(f"读取数据失败:{exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if df.empty:
|
if df.empty:
|
||||||
|
LOGGER.warning("指定区间无行情数据:%s %s-%s", ts_code, start_date, end_date, extra=LOG_EXTRA)
|
||||||
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
|
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -309,11 +395,14 @@ def render_tests() -> None:
|
|||||||
|
|
||||||
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
||||||
st.dataframe(df_reset.tail(20), width='stretch')
|
st.dataframe(df_reset.tail(20), width='stretch')
|
||||||
|
LOGGER.info("行情可视化完成,展示行数=%s", len(df_reset), extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
|
||||||
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
st.set_page_config(page_title="多智能体投资助理", layout="wide")
|
||||||
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
|
||||||
|
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)
|
||||||
with tabs[0]:
|
with tabs[0]:
|
||||||
render_today_plan()
|
render_today_plan()
|
||||||
with tabs[1]:
|
with tabs[1]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user