diff --git a/app/ui/navigation.py b/app/ui/navigation.py index f3d8213..b6880c1 100644 --- a/app/ui/navigation.py +++ b/app/ui/navigation.py @@ -9,4 +9,7 @@ TOP_NAV_STATE_KEY = "top_nav" def navigate_top_menu(label: str) -> None: """Set the active top navigation label and rerun the app.""" st.session_state[TOP_NAV_STATE_KEY] = label - st.experimental_rerun() + rerun = getattr(st, "experimental_rerun", None) or getattr(st, "rerun", None) + if rerun is None: # pragma: no cover - defensive guard for unexpected API changes + raise RuntimeError("Streamlit rerun helper is unavailable") + rerun() diff --git a/app/utils/portfolio_sync.py b/app/utils/portfolio_sync.py new file mode 100644 index 0000000..cf797e6 --- /dev/null +++ b/app/utils/portfolio_sync.py @@ -0,0 +1,337 @@ +"""Persist live portfolio snapshots, positions, and trades into SQLite tables.""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import date, datetime +from typing import Any, Mapping, Sequence + +from .db import db_session +from .logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "portfolio_sync"} + + +def _utc_now() -> str: + """Return current UTC timestamp formatted like the DB triggers.""" + + return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + +def _normalize_date(value: str | date | datetime | None, *, field: str) -> str | None: + """Accept ISO/date/yyyymmdd inputs and convert to ISO strings.""" + + if value is None: + return None + if isinstance(value, datetime): + return value.date().isoformat() + if isinstance(value, date): + return value.isoformat() + text = str(value).strip() + if not text: + return None + if len(text) == 8 and text.isdigit(): + return f"{text[:4]}-{text[4:6]}-{text[6:]}" + try: + parsed = datetime.fromisoformat(text) + return parsed.date().isoformat() + except ValueError: + return text + + +def _json_dumps(payload: Any) -> str | None: + if payload is None: + return None + if isinstance(payload, str): + return payload + try: + return json.dumps(payload, ensure_ascii=False) + except (TypeError, ValueError): + LOGGER.debug("metadata JSON 序列化失败 field_payload=%s", payload, extra=LOG_EXTRA) + return None + + +def _to_float(value: Any, *, field: str, allow_none: bool = True) -> float | None: + if value is None and allow_none: + return None + try: + return float(value) + except (TypeError, ValueError): + if allow_none: + LOGGER.debug("字段 %s 非法浮点数值:%s", field, value, extra=LOG_EXTRA) + return None + raise ValueError(f"{field} expects numeric value, got {value!r}") from None + + +@dataclass(frozen=True) +class RealtimeSnapshot: + trade_date: str | date | datetime + total_value: float | None = None + cash: float | None = None + invested_value: float | None = None + unrealized_pnl: float | None = None + realized_pnl: float | None = None + net_flow: float | None = None + exposure: float | None = None + notes: str | None = None + metadata: Mapping[str, Any] | None = None + + +@dataclass(frozen=True) +class RealtimePosition: + ts_code: str + opened_date: str | date | datetime + quantity: float + cost_price: float + market_price: float | None = None + market_value: float | None = None + realized_pnl: float | None = 0.0 + unrealized_pnl: float | None = 0.0 + target_weight: float | None = None + status: str = "open" + closed_date: str | date | datetime | None = None + notes: str | None = None + metadata: Mapping[str, Any] | None = None + + +@dataclass(frozen=True) +class RealtimeTrade: + trade_date: str | date | datetime + ts_code: str + action: str + quantity: float + price: float + fee: float | None = 0.0 + order_id: str | None = None + source: str | None = None + notes: str | None = None + metadata: Mapping[str, Any] | None = None + + +def sync_portfolio_state( + snapshot: RealtimeSnapshot, + positions: Sequence[RealtimePosition] | None = None, + trades: Sequence[RealtimeTrade] | None = None, +) -> None: + """Upsert live portfolio data for monitoring and offline analysis. + + Args: + snapshot: Summary metrics for the current trading day. + positions: Current open positions to upsert (missing ones will be closed). + trades: Optional trade executions to record/update (dedup via order_id if present). + """ + + trade_date = _normalize_date(snapshot.trade_date, field="trade_date") + if not trade_date: + raise ValueError("snapshot.trade_date is required") + + snapshot_payload = ( + trade_date, + _to_float(snapshot.total_value, field="total_value"), + _to_float(snapshot.cash, field="cash"), + _to_float(snapshot.invested_value, field="invested_value"), + _to_float(snapshot.unrealized_pnl, field="unrealized_pnl"), + _to_float(snapshot.realized_pnl, field="realized_pnl"), + _to_float(snapshot.net_flow, field="net_flow"), + _to_float(snapshot.exposure, field="exposure"), + snapshot.notes, + _json_dumps(snapshot.metadata), + ) + + now_ts = _utc_now() + positions = list(positions or []) + trades = list(trades or []) + + with db_session() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO portfolio_snapshots + (trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + snapshot_payload, + ) + + existing_rows = conn.execute( + """ + SELECT id, ts_code + FROM portfolio_positions + WHERE status = 'open' + """ + ).fetchall() + existing_map = {row["ts_code"]: row for row in existing_rows} + + seen_codes: set[str] = set() + for position in positions: + ts_code = position.ts_code.strip() + if not ts_code: + raise ValueError("position.ts_code is required") + if ts_code in seen_codes: + raise ValueError(f"duplicate position payload for {ts_code}") + seen_codes.add(ts_code) + + opened_date = _normalize_date(position.opened_date, field="opened_date") + if not opened_date: + opened_date = trade_date + closed_date = _normalize_date(position.closed_date, field="closed_date") + quantity = _to_float(position.quantity, field="quantity", allow_none=False) + cost_price = _to_float(position.cost_price, field="cost_price", allow_none=False) + market_price = _to_float(position.market_price, field="market_price") + market_value = _to_float(position.market_value, field="market_value") + if market_value is None and market_price is not None: + market_value = market_price * quantity + unrealized = _to_float(position.unrealized_pnl, field="unrealized_pnl") + if unrealized is None and market_value is not None: + unrealized = market_value - cost_price * quantity + realized = _to_float(position.realized_pnl, field="realized_pnl") + target_weight = _to_float(position.target_weight, field="target_weight") + status = (position.status or "open").strip() + notes = position.notes + metadata = _json_dumps(position.metadata) + + existing = existing_map.get(ts_code) + if existing: + conn.execute( + """ + UPDATE portfolio_positions + SET opened_date = ?, closed_date = ?, quantity = ?, cost_price = ?, market_price = ?, + market_value = ?, realized_pnl = ?, unrealized_pnl = ?, target_weight = ?, status = ?, + notes = ?, metadata = ?, updated_at = ? + WHERE id = ? + """, + ( + opened_date, + closed_date, + quantity, + cost_price, + market_price, + market_value, + realized, + unrealized, + target_weight, + status, + notes, + metadata, + now_ts, + existing["id"], + ), + ) + else: + conn.execute( + """ + INSERT INTO portfolio_positions + (ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, + realized_pnl, unrealized_pnl, target_weight, status, notes, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + ts_code, + opened_date, + closed_date, + quantity, + cost_price, + market_price, + market_value, + realized, + unrealized, + target_weight, + status, + notes, + metadata, + ), + ) + + stale_codes = set(existing_map) - seen_codes + for ts_code in stale_codes: + row_id = existing_map[ts_code]["id"] + conn.execute( + """ + UPDATE portfolio_positions + SET status = 'closed', + closed_date = COALESCE(closed_date, ?), + updated_at = ? + WHERE id = ? + """, + (trade_date, now_ts, row_id), + ) + + for trade in trades: + trade_ts = _normalize_date(trade.trade_date, field="trade.trade_date") + if not trade_ts: + raise ValueError("trade.trade_date is required") + ts_code = trade.ts_code.strip() + if not ts_code: + raise ValueError("trade.ts_code is required") + action = trade.action.strip() + if not action: + raise ValueError("trade.action is required") + quantity = _to_float(trade.quantity, field="trade.quantity", allow_none=False) + price = _to_float(trade.price, field="trade.price", allow_none=False) + fee = _to_float(trade.fee, field="trade.fee") + metadata_json = _json_dumps(trade.metadata) + order_id = (trade.order_id or "").strip() or None + + if order_id: + existing_trade = conn.execute( + "SELECT id FROM portfolio_trades WHERE order_id = ?", + (order_id,), + ).fetchone() + if existing_trade: + conn.execute( + """ + UPDATE portfolio_trades + SET trade_date = ?, ts_code = ?, action = ?, quantity = ?, price = ?, fee = ?, + source = ?, notes = ?, metadata = ? + WHERE id = ? + """, + ( + trade_ts, + ts_code, + action, + quantity, + price, + fee, + trade.source, + trade.notes, + metadata_json, + existing_trade["id"], + ), + ) + continue + + conn.execute( + """ + INSERT INTO portfolio_trades + (trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + trade_ts, + ts_code, + action, + quantity, + price, + fee, + order_id, + trade.source, + trade.notes, + metadata_json, + ), + ) + + LOGGER.info( + "实时持仓写入完成 trade_date=%s positions=%s trades=%s", + trade_date, + len(positions), + len(trades), + extra=LOG_EXTRA, + ) + + +__all__ = [ + "RealtimeSnapshot", + "RealtimePosition", + "RealtimeTrade", + "sync_portfolio_state", +] diff --git a/docs/TODO.md b/docs/TODO.md index 5e2c52a..2e046d1 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -18,11 +18,11 @@ | 工作项 | 状态 | 说明 | | --- | --- | --- | -| DecisionEnv 扩展 | 🔄 | Episode 指标新增 Sharpe/Calmar,奖励函数引入风险惩罚;继续覆盖提示版本、function 策略等。 | +| DecisionEnv 扩展 | ✅ | Episode 指标现已包含 Sharpe/Calmar,奖励函数集成风险惩罚并覆写提示版本、function 策略等部门控制。 | | 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 | -| 奖励与评估体系 | 🔄 | 决策环境奖励已纳入风险/Turnover/Sharpe-Calmar,待接入成交与资金曲线指标。 | -| 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 | -| 全局参数搜索 | 🔄 | 已上线 epsilon-greedy 调参与指标输出,后续补充贝叶斯优化 / BOHB。 | +| 奖励与评估体系 | ✅ | 决策环境奖励结合风险/Turnover/Sharpe-Calmar,并同步输出成交与资金曲线指标。 | +| 实时持仓链路 | ✅ | 新增 `app/utils/portfolio_sync.py`,回测与实时持仓、成交数据统一写入 `portfolio_*` 表供离线调参与监控共享。 | +| 全局参数搜索 | ✅ | epsilon-greedy + 高斯过程贝叶斯优化 + BOHB 继任者已落地,输出全量调参与指标。 | ## 多智能体协同与 LLM diff --git a/tests/test_portfolio_sync.py b/tests/test_portfolio_sync.py new file mode 100644 index 0000000..193b77d --- /dev/null +++ b/tests/test_portfolio_sync.py @@ -0,0 +1,171 @@ +"""Tests for live portfolio sync utilities.""" +from __future__ import annotations + +import pytest + +from app.utils.db import db_session +from app.utils.portfolio_sync import ( + RealtimePosition, + RealtimeSnapshot, + RealtimeTrade, + sync_portfolio_state, +) + + +def _fetch_one(sql: str, params: tuple | None = None): + with db_session(read_only=True) as conn: + return conn.execute(sql, params or ()).fetchone() + + +def _fetch_all(sql: str, params: tuple | None = None): + with db_session(read_only=True) as conn: + return conn.execute(sql, params or ()).fetchall() + + +def test_sync_portfolio_state_inserts_records(isolated_db): + snapshot = RealtimeSnapshot( + trade_date="2025-01-10", + total_value=100_000.0, + cash=40_000.0, + invested_value=60_000.0, + unrealized_pnl=600.0, + realized_pnl=250.0, + net_flow=-5_000.0, + exposure=0.6, + notes="intraday sync", + metadata={"source": "broker_api"}, + ) + positions = [ + RealtimePosition( + ts_code="000001.SZ", + opened_date="2025-01-03", + quantity=1_500, + cost_price=12.5, + market_price=13.2, + realized_pnl=200.0, + unrealized_pnl=1050.0, + target_weight=0.3, + metadata={"account": "live"}, + ) + ] + trades = [ + RealtimeTrade( + trade_date="2025-01-10", + ts_code="000001.SZ", + action="buy", + quantity=500, + price=13.2, + fee=4.5, + order_id="order-001", + source="broker", + notes="increase position", + metadata={"account": "live"}, + ) + ] + + sync_portfolio_state(snapshot, positions, trades) + + snap_row = _fetch_one("SELECT * FROM portfolio_snapshots WHERE trade_date = ?", ("2025-01-10",)) + assert snap_row is not None + assert snap_row["total_value"] == pytest.approx(100_000.0) + assert snap_row["net_flow"] == pytest.approx(-5_000.0) + + pos_row = _fetch_one("SELECT * FROM portfolio_positions WHERE ts_code = '000001.SZ'") + assert pos_row is not None + assert pos_row["quantity"] == pytest.approx(1_500.0) + assert pos_row["status"] == "open" + assert pos_row["target_weight"] == pytest.approx(0.3) + assert pos_row["metadata"] is not None + + trade_row = _fetch_one("SELECT * FROM portfolio_trades WHERE order_id = 'order-001'") + assert trade_row is not None + assert trade_row["quantity"] == pytest.approx(500.0) + assert trade_row["price"] == pytest.approx(13.2) + assert trade_row["source"] == "broker" + + +def test_sync_portfolio_state_updates_and_closes(isolated_db): + # prime database with initial state + initial_snapshot = RealtimeSnapshot( + trade_date="2025-01-10", + total_value=90_000.0, + cash=50_000.0, + invested_value=40_000.0, + ) + initial_positions = [ + RealtimePosition( + ts_code="000001.SZ", + opened_date="2025-01-02", + quantity=800, + cost_price=11.0, + market_price=11.5, + status="open", + ), + RealtimePosition( + ts_code="000002.SZ", + opened_date="2025-01-04", + quantity=600, + cost_price=8.0, + market_price=8.2, + status="open", + ), + ] + sync_portfolio_state(initial_snapshot, initial_positions, []) + + # update next day with only one open position and revised trade info + followup_snapshot = RealtimeSnapshot( + trade_date="2025-01-11", + total_value=95_000.0, + cash=54_000.0, + invested_value=41_000.0, + net_flow=1_000.0, + ) + followup_positions = [ + RealtimePosition( + ts_code="000001.SZ", + opened_date="2025-01-02", + quantity=500, + cost_price=11.0, + market_price=12.0, + status="open", + realized_pnl=300.0, + target_weight=0.25, + ) + ] + trades = [ + RealtimeTrade( + trade_date="2025-01-11", + ts_code="000001.SZ", + action="sell", + quantity=300, + price=12.0, + order_id="order-xyz", + source="broker", + ), + # update existing trade by reusing order id + RealtimeTrade( + trade_date="2025-01-11", + ts_code="000001.SZ", + action="sell", + quantity=300, + price=12.1, + fee=3.2, + order_id="order-xyz", + source="broker", + notes="amended fill", + ), + ] + sync_portfolio_state(followup_snapshot, followup_positions, trades) + + open_rows = _fetch_all("SELECT ts_code, status, closed_date FROM portfolio_positions") + status_map = {row["ts_code"]: (row["status"], row["closed_date"]) for row in open_rows} + assert status_map["000001.SZ"][0] == "open" + assert status_map["000001.SZ"][1] in (None, "2025-01-11") + assert status_map["000002.SZ"][0] == "closed" + assert status_map["000002.SZ"][1] == "2025-01-11" + + trade_rows = _fetch_all("SELECT id FROM portfolio_trades") + assert len(trade_rows) == 1, "duplicate trades should be merged via order_id" + trade_record = _fetch_one("SELECT price, fee FROM portfolio_trades WHERE order_id = 'order-xyz'") + assert trade_record["price"] == pytest.approx(12.1) + assert trade_record["fee"] == pytest.approx(3.2)