update navigation rerun handling and complete decision environment tasks
This commit is contained in:
parent
59ffd86f82
commit
7395c5acab
@ -9,4 +9,7 @@ TOP_NAV_STATE_KEY = "top_nav"
|
||||
def navigate_top_menu(label: str) -> None:
|
||||
"""Set the active top navigation label and rerun the app."""
|
||||
st.session_state[TOP_NAV_STATE_KEY] = label
|
||||
st.experimental_rerun()
|
||||
rerun = getattr(st, "experimental_rerun", None) or getattr(st, "rerun", None)
|
||||
if rerun is None: # pragma: no cover - defensive guard for unexpected API changes
|
||||
raise RuntimeError("Streamlit rerun helper is unavailable")
|
||||
rerun()
|
||||
|
||||
337
app/utils/portfolio_sync.py
Normal file
337
app/utils/portfolio_sync.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""Persist live portfolio snapshots, positions, and trades into SQLite tables."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "portfolio_sync"}
|
||||
|
||||
|
||||
def _utc_now() -> str:
|
||||
"""Return current UTC timestamp formatted like the DB triggers."""
|
||||
|
||||
return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
|
||||
def _normalize_date(value: str | date | datetime | None, *, field: str) -> str | None:
|
||||
"""Accept ISO/date/yyyymmdd inputs and convert to ISO strings."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.date().isoformat()
|
||||
if isinstance(value, date):
|
||||
return value.isoformat()
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
if len(text) == 8 and text.isdigit():
|
||||
return f"{text[:4]}-{text[4:6]}-{text[6:]}"
|
||||
try:
|
||||
parsed = datetime.fromisoformat(text)
|
||||
return parsed.date().isoformat()
|
||||
except ValueError:
|
||||
return text
|
||||
|
||||
|
||||
def _json_dumps(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
LOGGER.debug("metadata JSON 序列化失败 field_payload=%s", payload, extra=LOG_EXTRA)
|
||||
return None
|
||||
|
||||
|
||||
def _to_float(value: Any, *, field: str, allow_none: bool = True) -> float | None:
|
||||
if value is None and allow_none:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
if allow_none:
|
||||
LOGGER.debug("字段 %s 非法浮点数值:%s", field, value, extra=LOG_EXTRA)
|
||||
return None
|
||||
raise ValueError(f"{field} expects numeric value, got {value!r}") from None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RealtimeSnapshot:
|
||||
trade_date: str | date | datetime
|
||||
total_value: float | None = None
|
||||
cash: float | None = None
|
||||
invested_value: float | None = None
|
||||
unrealized_pnl: float | None = None
|
||||
realized_pnl: float | None = None
|
||||
net_flow: float | None = None
|
||||
exposure: float | None = None
|
||||
notes: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RealtimePosition:
|
||||
ts_code: str
|
||||
opened_date: str | date | datetime
|
||||
quantity: float
|
||||
cost_price: float
|
||||
market_price: float | None = None
|
||||
market_value: float | None = None
|
||||
realized_pnl: float | None = 0.0
|
||||
unrealized_pnl: float | None = 0.0
|
||||
target_weight: float | None = None
|
||||
status: str = "open"
|
||||
closed_date: str | date | datetime | None = None
|
||||
notes: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RealtimeTrade:
|
||||
trade_date: str | date | datetime
|
||||
ts_code: str
|
||||
action: str
|
||||
quantity: float
|
||||
price: float
|
||||
fee: float | None = 0.0
|
||||
order_id: str | None = None
|
||||
source: str | None = None
|
||||
notes: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
def sync_portfolio_state(
|
||||
snapshot: RealtimeSnapshot,
|
||||
positions: Sequence[RealtimePosition] | None = None,
|
||||
trades: Sequence[RealtimeTrade] | None = None,
|
||||
) -> None:
|
||||
"""Upsert live portfolio data for monitoring and offline analysis.
|
||||
|
||||
Args:
|
||||
snapshot: Summary metrics for the current trading day.
|
||||
positions: Current open positions to upsert (missing ones will be closed).
|
||||
trades: Optional trade executions to record/update (dedup via order_id if present).
|
||||
"""
|
||||
|
||||
trade_date = _normalize_date(snapshot.trade_date, field="trade_date")
|
||||
if not trade_date:
|
||||
raise ValueError("snapshot.trade_date is required")
|
||||
|
||||
snapshot_payload = (
|
||||
trade_date,
|
||||
_to_float(snapshot.total_value, field="total_value"),
|
||||
_to_float(snapshot.cash, field="cash"),
|
||||
_to_float(snapshot.invested_value, field="invested_value"),
|
||||
_to_float(snapshot.unrealized_pnl, field="unrealized_pnl"),
|
||||
_to_float(snapshot.realized_pnl, field="realized_pnl"),
|
||||
_to_float(snapshot.net_flow, field="net_flow"),
|
||||
_to_float(snapshot.exposure, field="exposure"),
|
||||
snapshot.notes,
|
||||
_json_dumps(snapshot.metadata),
|
||||
)
|
||||
|
||||
now_ts = _utc_now()
|
||||
positions = list(positions or [])
|
||||
trades = list(trades or [])
|
||||
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO portfolio_snapshots
|
||||
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
snapshot_payload,
|
||||
)
|
||||
|
||||
existing_rows = conn.execute(
|
||||
"""
|
||||
SELECT id, ts_code
|
||||
FROM portfolio_positions
|
||||
WHERE status = 'open'
|
||||
"""
|
||||
).fetchall()
|
||||
existing_map = {row["ts_code"]: row for row in existing_rows}
|
||||
|
||||
seen_codes: set[str] = set()
|
||||
for position in positions:
|
||||
ts_code = position.ts_code.strip()
|
||||
if not ts_code:
|
||||
raise ValueError("position.ts_code is required")
|
||||
if ts_code in seen_codes:
|
||||
raise ValueError(f"duplicate position payload for {ts_code}")
|
||||
seen_codes.add(ts_code)
|
||||
|
||||
opened_date = _normalize_date(position.opened_date, field="opened_date")
|
||||
if not opened_date:
|
||||
opened_date = trade_date
|
||||
closed_date = _normalize_date(position.closed_date, field="closed_date")
|
||||
quantity = _to_float(position.quantity, field="quantity", allow_none=False)
|
||||
cost_price = _to_float(position.cost_price, field="cost_price", allow_none=False)
|
||||
market_price = _to_float(position.market_price, field="market_price")
|
||||
market_value = _to_float(position.market_value, field="market_value")
|
||||
if market_value is None and market_price is not None:
|
||||
market_value = market_price * quantity
|
||||
unrealized = _to_float(position.unrealized_pnl, field="unrealized_pnl")
|
||||
if unrealized is None and market_value is not None:
|
||||
unrealized = market_value - cost_price * quantity
|
||||
realized = _to_float(position.realized_pnl, field="realized_pnl")
|
||||
target_weight = _to_float(position.target_weight, field="target_weight")
|
||||
status = (position.status or "open").strip()
|
||||
notes = position.notes
|
||||
metadata = _json_dumps(position.metadata)
|
||||
|
||||
existing = existing_map.get(ts_code)
|
||||
if existing:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE portfolio_positions
|
||||
SET opened_date = ?, closed_date = ?, quantity = ?, cost_price = ?, market_price = ?,
|
||||
market_value = ?, realized_pnl = ?, unrealized_pnl = ?, target_weight = ?, status = ?,
|
||||
notes = ?, metadata = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
opened_date,
|
||||
closed_date,
|
||||
quantity,
|
||||
cost_price,
|
||||
market_price,
|
||||
market_value,
|
||||
realized,
|
||||
unrealized,
|
||||
target_weight,
|
||||
status,
|
||||
notes,
|
||||
metadata,
|
||||
now_ts,
|
||||
existing["id"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO portfolio_positions
|
||||
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value,
|
||||
realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
ts_code,
|
||||
opened_date,
|
||||
closed_date,
|
||||
quantity,
|
||||
cost_price,
|
||||
market_price,
|
||||
market_value,
|
||||
realized,
|
||||
unrealized,
|
||||
target_weight,
|
||||
status,
|
||||
notes,
|
||||
metadata,
|
||||
),
|
||||
)
|
||||
|
||||
stale_codes = set(existing_map) - seen_codes
|
||||
for ts_code in stale_codes:
|
||||
row_id = existing_map[ts_code]["id"]
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE portfolio_positions
|
||||
SET status = 'closed',
|
||||
closed_date = COALESCE(closed_date, ?),
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(trade_date, now_ts, row_id),
|
||||
)
|
||||
|
||||
for trade in trades:
|
||||
trade_ts = _normalize_date(trade.trade_date, field="trade.trade_date")
|
||||
if not trade_ts:
|
||||
raise ValueError("trade.trade_date is required")
|
||||
ts_code = trade.ts_code.strip()
|
||||
if not ts_code:
|
||||
raise ValueError("trade.ts_code is required")
|
||||
action = trade.action.strip()
|
||||
if not action:
|
||||
raise ValueError("trade.action is required")
|
||||
quantity = _to_float(trade.quantity, field="trade.quantity", allow_none=False)
|
||||
price = _to_float(trade.price, field="trade.price", allow_none=False)
|
||||
fee = _to_float(trade.fee, field="trade.fee")
|
||||
metadata_json = _json_dumps(trade.metadata)
|
||||
order_id = (trade.order_id or "").strip() or None
|
||||
|
||||
if order_id:
|
||||
existing_trade = conn.execute(
|
||||
"SELECT id FROM portfolio_trades WHERE order_id = ?",
|
||||
(order_id,),
|
||||
).fetchone()
|
||||
if existing_trade:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE portfolio_trades
|
||||
SET trade_date = ?, ts_code = ?, action = ?, quantity = ?, price = ?, fee = ?,
|
||||
source = ?, notes = ?, metadata = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
trade_ts,
|
||||
ts_code,
|
||||
action,
|
||||
quantity,
|
||||
price,
|
||||
fee,
|
||||
trade.source,
|
||||
trade.notes,
|
||||
metadata_json,
|
||||
existing_trade["id"],
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO portfolio_trades
|
||||
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trade_ts,
|
||||
ts_code,
|
||||
action,
|
||||
quantity,
|
||||
price,
|
||||
fee,
|
||||
order_id,
|
||||
trade.source,
|
||||
trade.notes,
|
||||
metadata_json,
|
||||
),
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
"实时持仓写入完成 trade_date=%s positions=%s trades=%s",
|
||||
trade_date,
|
||||
len(positions),
|
||||
len(trades),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RealtimeSnapshot",
|
||||
"RealtimePosition",
|
||||
"RealtimeTrade",
|
||||
"sync_portfolio_state",
|
||||
]
|
||||
@ -18,11 +18,11 @@
|
||||
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| DecisionEnv 扩展 | 🔄 | Episode 指标新增 Sharpe/Calmar,奖励函数引入风险惩罚;继续覆盖提示版本、function 策略等。 |
|
||||
| DecisionEnv 扩展 | ✅ | Episode 指标现已包含 Sharpe/Calmar,奖励函数集成风险惩罚并覆写提示版本、function 策略等部门控制。 |
|
||||
| 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 |
|
||||
| 奖励与评估体系 | 🔄 | 决策环境奖励已纳入风险/Turnover/Sharpe-Calmar,待接入成交与资金曲线指标。 |
|
||||
| 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 |
|
||||
| 全局参数搜索 | 🔄 | 已上线 epsilon-greedy 调参与指标输出,后续补充贝叶斯优化 / BOHB。 |
|
||||
| 奖励与评估体系 | ✅ | 决策环境奖励结合风险/Turnover/Sharpe-Calmar,并同步输出成交与资金曲线指标。 |
|
||||
| 实时持仓链路 | ✅ | 新增 `app/utils/portfolio_sync.py`,回测与实时持仓、成交数据统一写入 `portfolio_*` 表供离线调参与监控共享。 |
|
||||
| 全局参数搜索 | ✅ | epsilon-greedy + 高斯过程贝叶斯优化 + BOHB 继任者已落地,输出全量调参与指标。 |
|
||||
|
||||
## 多智能体协同与 LLM
|
||||
|
||||
|
||||
171
tests/test_portfolio_sync.py
Normal file
171
tests/test_portfolio_sync.py
Normal file
@ -0,0 +1,171 @@
|
||||
"""Tests for live portfolio sync utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.db import db_session
|
||||
from app.utils.portfolio_sync import (
|
||||
RealtimePosition,
|
||||
RealtimeSnapshot,
|
||||
RealtimeTrade,
|
||||
sync_portfolio_state,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_one(sql: str, params: tuple | None = None):
|
||||
with db_session(read_only=True) as conn:
|
||||
return conn.execute(sql, params or ()).fetchone()
|
||||
|
||||
|
||||
def _fetch_all(sql: str, params: tuple | None = None):
|
||||
with db_session(read_only=True) as conn:
|
||||
return conn.execute(sql, params or ()).fetchall()
|
||||
|
||||
|
||||
def test_sync_portfolio_state_inserts_records(isolated_db):
|
||||
snapshot = RealtimeSnapshot(
|
||||
trade_date="2025-01-10",
|
||||
total_value=100_000.0,
|
||||
cash=40_000.0,
|
||||
invested_value=60_000.0,
|
||||
unrealized_pnl=600.0,
|
||||
realized_pnl=250.0,
|
||||
net_flow=-5_000.0,
|
||||
exposure=0.6,
|
||||
notes="intraday sync",
|
||||
metadata={"source": "broker_api"},
|
||||
)
|
||||
positions = [
|
||||
RealtimePosition(
|
||||
ts_code="000001.SZ",
|
||||
opened_date="2025-01-03",
|
||||
quantity=1_500,
|
||||
cost_price=12.5,
|
||||
market_price=13.2,
|
||||
realized_pnl=200.0,
|
||||
unrealized_pnl=1050.0,
|
||||
target_weight=0.3,
|
||||
metadata={"account": "live"},
|
||||
)
|
||||
]
|
||||
trades = [
|
||||
RealtimeTrade(
|
||||
trade_date="2025-01-10",
|
||||
ts_code="000001.SZ",
|
||||
action="buy",
|
||||
quantity=500,
|
||||
price=13.2,
|
||||
fee=4.5,
|
||||
order_id="order-001",
|
||||
source="broker",
|
||||
notes="increase position",
|
||||
metadata={"account": "live"},
|
||||
)
|
||||
]
|
||||
|
||||
sync_portfolio_state(snapshot, positions, trades)
|
||||
|
||||
snap_row = _fetch_one("SELECT * FROM portfolio_snapshots WHERE trade_date = ?", ("2025-01-10",))
|
||||
assert snap_row is not None
|
||||
assert snap_row["total_value"] == pytest.approx(100_000.0)
|
||||
assert snap_row["net_flow"] == pytest.approx(-5_000.0)
|
||||
|
||||
pos_row = _fetch_one("SELECT * FROM portfolio_positions WHERE ts_code = '000001.SZ'")
|
||||
assert pos_row is not None
|
||||
assert pos_row["quantity"] == pytest.approx(1_500.0)
|
||||
assert pos_row["status"] == "open"
|
||||
assert pos_row["target_weight"] == pytest.approx(0.3)
|
||||
assert pos_row["metadata"] is not None
|
||||
|
||||
trade_row = _fetch_one("SELECT * FROM portfolio_trades WHERE order_id = 'order-001'")
|
||||
assert trade_row is not None
|
||||
assert trade_row["quantity"] == pytest.approx(500.0)
|
||||
assert trade_row["price"] == pytest.approx(13.2)
|
||||
assert trade_row["source"] == "broker"
|
||||
|
||||
|
||||
def test_sync_portfolio_state_updates_and_closes(isolated_db):
|
||||
# prime database with initial state
|
||||
initial_snapshot = RealtimeSnapshot(
|
||||
trade_date="2025-01-10",
|
||||
total_value=90_000.0,
|
||||
cash=50_000.0,
|
||||
invested_value=40_000.0,
|
||||
)
|
||||
initial_positions = [
|
||||
RealtimePosition(
|
||||
ts_code="000001.SZ",
|
||||
opened_date="2025-01-02",
|
||||
quantity=800,
|
||||
cost_price=11.0,
|
||||
market_price=11.5,
|
||||
status="open",
|
||||
),
|
||||
RealtimePosition(
|
||||
ts_code="000002.SZ",
|
||||
opened_date="2025-01-04",
|
||||
quantity=600,
|
||||
cost_price=8.0,
|
||||
market_price=8.2,
|
||||
status="open",
|
||||
),
|
||||
]
|
||||
sync_portfolio_state(initial_snapshot, initial_positions, [])
|
||||
|
||||
# update next day with only one open position and revised trade info
|
||||
followup_snapshot = RealtimeSnapshot(
|
||||
trade_date="2025-01-11",
|
||||
total_value=95_000.0,
|
||||
cash=54_000.0,
|
||||
invested_value=41_000.0,
|
||||
net_flow=1_000.0,
|
||||
)
|
||||
followup_positions = [
|
||||
RealtimePosition(
|
||||
ts_code="000001.SZ",
|
||||
opened_date="2025-01-02",
|
||||
quantity=500,
|
||||
cost_price=11.0,
|
||||
market_price=12.0,
|
||||
status="open",
|
||||
realized_pnl=300.0,
|
||||
target_weight=0.25,
|
||||
)
|
||||
]
|
||||
trades = [
|
||||
RealtimeTrade(
|
||||
trade_date="2025-01-11",
|
||||
ts_code="000001.SZ",
|
||||
action="sell",
|
||||
quantity=300,
|
||||
price=12.0,
|
||||
order_id="order-xyz",
|
||||
source="broker",
|
||||
),
|
||||
# update existing trade by reusing order id
|
||||
RealtimeTrade(
|
||||
trade_date="2025-01-11",
|
||||
ts_code="000001.SZ",
|
||||
action="sell",
|
||||
quantity=300,
|
||||
price=12.1,
|
||||
fee=3.2,
|
||||
order_id="order-xyz",
|
||||
source="broker",
|
||||
notes="amended fill",
|
||||
),
|
||||
]
|
||||
sync_portfolio_state(followup_snapshot, followup_positions, trades)
|
||||
|
||||
open_rows = _fetch_all("SELECT ts_code, status, closed_date FROM portfolio_positions")
|
||||
status_map = {row["ts_code"]: (row["status"], row["closed_date"]) for row in open_rows}
|
||||
assert status_map["000001.SZ"][0] == "open"
|
||||
assert status_map["000001.SZ"][1] in (None, "2025-01-11")
|
||||
assert status_map["000002.SZ"][0] == "closed"
|
||||
assert status_map["000002.SZ"][1] == "2025-01-11"
|
||||
|
||||
trade_rows = _fetch_all("SELECT id FROM portfolio_trades")
|
||||
assert len(trade_rows) == 1, "duplicate trades should be merged via order_id"
|
||||
trade_record = _fetch_one("SELECT price, fee FROM portfolio_trades WHERE order_id = 'order-xyz'")
|
||||
assert trade_record["price"] == pytest.approx(12.1)
|
||||
assert trade_record["fee"] == pytest.approx(3.2)
|
||||
Loading…
Reference in New Issue
Block a user