llm-quant/app/ui/views/market.py
2025-10-06 13:44:41 +08:00

211 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""行情可视化页面。"""
from __future__ import annotations
from datetime import date, datetime, timedelta
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from app.utils.db import db_session
from app.ui.shared import LOGGER, LOG_EXTRA
def _load_stock_options(limit: int = 500, min_history: int = 30) -> list[str]:
try:
with db_session(read_only=True) as conn:
rows = conn.execute(
"""
SELECT ts_code
FROM (
SELECT ts_code, MAX(trade_date) AS latest_date, COUNT(*) AS history_rows
FROM daily
GROUP BY ts_code
)
WHERE history_rows >= ?
ORDER BY ts_code ASC
LIMIT ?
""",
(min_history, limit),
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA)
return []
return [row["ts_code"] for row in rows]
def _parse_ts_code(selection: str) -> str:
return selection.split(" ", 1)[0]
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
with db_session(read_only=True) as conn:
df = pd.read_sql_query(
"""
SELECT trade_date, open, high, low, close, vol, amount
FROM daily
WHERE ts_code = ? AND trade_date BETWEEN ? AND ?
ORDER BY trade_date
""",
conn,
params=(ts_code, start.strftime("%Y%m%d"), end.strftime("%Y%m%d")),
)
if df.empty:
return df
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
return df
def _load_trade_date_range(ts_code: str) -> tuple[date | None, date | None]:
"""Fetch earliest and latest available trade dates for a stock."""
with db_session(read_only=True) as conn:
row = conn.execute(
"SELECT MIN(trade_date) AS min_date, MAX(trade_date) AS max_date FROM daily WHERE ts_code = ?",
(ts_code,),
).fetchone()
if not row:
return None, None
min_raw = row["min_date"]
max_raw = row["max_date"]
if not min_raw or not max_raw:
return None, None
min_date = datetime.strptime(min_raw, "%Y%m%d").date()
max_date = datetime.strptime(max_raw, "%Y%m%d").date()
return min_date, max_date
def render_market_visualization() -> None:
st.header("行情可视化")
st.caption("按标的查看 K 线、成交量以及常用指标。")
options = _load_stock_options()
if not options:
st.warning("暂未加载到可用的行情标的,请先执行数据同步。")
return
selection = st.selectbox("选择标的", options, index=0)
ts_code = _parse_ts_code(selection)
manual_input = st.text_input("或直接输入标的代码", value=ts_code, key="market_manual_ts_code")
if manual_input:
manual_ts = manual_input.strip().upper()
if manual_ts:
ts_code = manual_ts
min_date, max_date = _load_trade_date_range(ts_code)
if not max_date:
st.info("所选标的暂无可视化数据,请先同步行情。")
return
default_end = max_date
default_start = max(min_date, max_date - timedelta(days=180)) if min_date else max_date - timedelta(days=180)
session = st.session_state
last_ts_code = session.get("market_selected_ts_code")
start_store_key = "market_start_date_value"
end_store_key = "market_end_date_value"
start_widget_key = "market_start_date_picker"
end_widget_key = "market_end_date_picker"
# 初始化或更新session状态
if last_ts_code != ts_code:
session["market_selected_ts_code"] = ts_code
session[start_store_key] = default_start
session[end_store_key] = default_end
# 使用columns实现横向布局
col1, col2 = st.columns(2)
if last_ts_code != ts_code:
# 第一次加载时直接使用默认值初始化widget
with col1:
start_date = st.date_input(
"开始日期",
value=default_start,
key=start_widget_key,
min_value=min_date,
max_value=max_date,
)
with col2:
end_date = st.date_input(
"结束日期",
value=default_end,
key=end_widget_key,
min_value=min_date,
max_value=max_date,
)
else:
# 后续加载时直接使用widget的当前值
with col1:
start_date = st.date_input(
"开始日期",
value=session.get(start_widget_key, default_start),
key=start_widget_key,
min_value=min_date,
max_value=max_date,
)
with col2:
end_date = st.date_input(
"结束日期",
value=session.get(end_widget_key, default_end),
key=end_widget_key,
min_value=min_date,
max_value=max_date,
)
if min_date:
start_date = max(start_date, min_date)
end_date = max(end_date, min_date)
end_date = min(end_date, max_date)
start_date = min(start_date, max_date)
session[start_store_key] = start_date
session[end_store_key] = end_date
if start_date > end_date:
st.error("开始日期不能晚于结束日期。")
return
try:
df = _load_daily_frame(ts_code, start_date, end_date)
except Exception as exc: # noqa: BLE001
LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA)
st.error(f"加载行情数据失败:{exc}")
return
if df.empty:
st.info("所选区间内无行情数据。")
return
st.metric("最新收盘价", f"{df['close'].iloc[-1]:.2f}")
fig = go.Figure(
data=[
go.Candlestick(
x=df["trade_date"],
open=df["open"],
high=df["high"],
low=df["low"],
close=df["close"],
name="K线",
)
]
)
fig.update_layout(title=f"{ts_code} K线图", xaxis_title="日期", yaxis_title="价格")
st.plotly_chart(fig, use_container_width=True)
fig_vol = px.bar(df, x="trade_date", y="vol", title="成交量")
st.plotly_chart(fig_vol, use_container_width=True)
df_ma = df.copy()
df_ma["MA5"] = df_ma["close"].rolling(window=5).mean()
df_ma["MA20"] = df_ma["close"].rolling(window=20).mean()
df_ma["MA60"] = df_ma["close"].rolling(window=60).mean()
fig_ma = px.line(df_ma, x="trade_date", y=["close", "MA5", "MA20", "MA60"], title="均线对比")
st.plotly_chart(fig_ma, use_container_width=True)
st.dataframe(df, hide_index=True, width='stretch')