update navigation rerun handling and complete decision environment tasks

This commit is contained in:
sam 2025-10-17 09:52:27 +08:00
parent 59ffd86f82
commit 7395c5acab
4 changed files with 516 additions and 5 deletions

View File

@ -9,4 +9,7 @@ TOP_NAV_STATE_KEY = "top_nav"
def navigate_top_menu(label: str) -> None: def navigate_top_menu(label: str) -> None:
"""Set the active top navigation label and rerun the app.""" """Set the active top navigation label and rerun the app."""
st.session_state[TOP_NAV_STATE_KEY] = label st.session_state[TOP_NAV_STATE_KEY] = label
st.experimental_rerun() rerun = getattr(st, "experimental_rerun", None) or getattr(st, "rerun", None)
if rerun is None: # pragma: no cover - defensive guard for unexpected API changes
raise RuntimeError("Streamlit rerun helper is unavailable")
rerun()

337
app/utils/portfolio_sync.py Normal file
View File

@ -0,0 +1,337 @@
"""Persist live portfolio snapshots, positions, and trades into SQLite tables."""
from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import date, datetime
from typing import Any, Mapping, Sequence
from .db import db_session
from .logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "portfolio_sync"}
def _utc_now() -> str:
"""Return current UTC timestamp formatted like the DB triggers."""
return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def _normalize_date(value: str | date | datetime | None, *, field: str) -> str | None:
"""Accept ISO/date/yyyymmdd inputs and convert to ISO strings."""
if value is None:
return None
if isinstance(value, datetime):
return value.date().isoformat()
if isinstance(value, date):
return value.isoformat()
text = str(value).strip()
if not text:
return None
if len(text) == 8 and text.isdigit():
return f"{text[:4]}-{text[4:6]}-{text[6:]}"
try:
parsed = datetime.fromisoformat(text)
return parsed.date().isoformat()
except ValueError:
return text
def _json_dumps(payload: Any) -> str | None:
if payload is None:
return None
if isinstance(payload, str):
return payload
try:
return json.dumps(payload, ensure_ascii=False)
except (TypeError, ValueError):
LOGGER.debug("metadata JSON 序列化失败 field_payload=%s", payload, extra=LOG_EXTRA)
return None
def _to_float(value: Any, *, field: str, allow_none: bool = True) -> float | None:
if value is None and allow_none:
return None
try:
return float(value)
except (TypeError, ValueError):
if allow_none:
LOGGER.debug("字段 %s 非法浮点数值:%s", field, value, extra=LOG_EXTRA)
return None
raise ValueError(f"{field} expects numeric value, got {value!r}") from None
@dataclass(frozen=True)
class RealtimeSnapshot:
trade_date: str | date | datetime
total_value: float | None = None
cash: float | None = None
invested_value: float | None = None
unrealized_pnl: float | None = None
realized_pnl: float | None = None
net_flow: float | None = None
exposure: float | None = None
notes: str | None = None
metadata: Mapping[str, Any] | None = None
@dataclass(frozen=True)
class RealtimePosition:
ts_code: str
opened_date: str | date | datetime
quantity: float
cost_price: float
market_price: float | None = None
market_value: float | None = None
realized_pnl: float | None = 0.0
unrealized_pnl: float | None = 0.0
target_weight: float | None = None
status: str = "open"
closed_date: str | date | datetime | None = None
notes: str | None = None
metadata: Mapping[str, Any] | None = None
@dataclass(frozen=True)
class RealtimeTrade:
trade_date: str | date | datetime
ts_code: str
action: str
quantity: float
price: float
fee: float | None = 0.0
order_id: str | None = None
source: str | None = None
notes: str | None = None
metadata: Mapping[str, Any] | None = None
def sync_portfolio_state(
snapshot: RealtimeSnapshot,
positions: Sequence[RealtimePosition] | None = None,
trades: Sequence[RealtimeTrade] | None = None,
) -> None:
"""Upsert live portfolio data for monitoring and offline analysis.
Args:
snapshot: Summary metrics for the current trading day.
positions: Current open positions to upsert (missing ones will be closed).
trades: Optional trade executions to record/update (dedup via order_id if present).
"""
trade_date = _normalize_date(snapshot.trade_date, field="trade_date")
if not trade_date:
raise ValueError("snapshot.trade_date is required")
snapshot_payload = (
trade_date,
_to_float(snapshot.total_value, field="total_value"),
_to_float(snapshot.cash, field="cash"),
_to_float(snapshot.invested_value, field="invested_value"),
_to_float(snapshot.unrealized_pnl, field="unrealized_pnl"),
_to_float(snapshot.realized_pnl, field="realized_pnl"),
_to_float(snapshot.net_flow, field="net_flow"),
_to_float(snapshot.exposure, field="exposure"),
snapshot.notes,
_json_dumps(snapshot.metadata),
)
now_ts = _utc_now()
positions = list(positions or [])
trades = list(trades or [])
with db_session() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO portfolio_snapshots
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
snapshot_payload,
)
existing_rows = conn.execute(
"""
SELECT id, ts_code
FROM portfolio_positions
WHERE status = 'open'
"""
).fetchall()
existing_map = {row["ts_code"]: row for row in existing_rows}
seen_codes: set[str] = set()
for position in positions:
ts_code = position.ts_code.strip()
if not ts_code:
raise ValueError("position.ts_code is required")
if ts_code in seen_codes:
raise ValueError(f"duplicate position payload for {ts_code}")
seen_codes.add(ts_code)
opened_date = _normalize_date(position.opened_date, field="opened_date")
if not opened_date:
opened_date = trade_date
closed_date = _normalize_date(position.closed_date, field="closed_date")
quantity = _to_float(position.quantity, field="quantity", allow_none=False)
cost_price = _to_float(position.cost_price, field="cost_price", allow_none=False)
market_price = _to_float(position.market_price, field="market_price")
market_value = _to_float(position.market_value, field="market_value")
if market_value is None and market_price is not None:
market_value = market_price * quantity
unrealized = _to_float(position.unrealized_pnl, field="unrealized_pnl")
if unrealized is None and market_value is not None:
unrealized = market_value - cost_price * quantity
realized = _to_float(position.realized_pnl, field="realized_pnl")
target_weight = _to_float(position.target_weight, field="target_weight")
status = (position.status or "open").strip()
notes = position.notes
metadata = _json_dumps(position.metadata)
existing = existing_map.get(ts_code)
if existing:
conn.execute(
"""
UPDATE portfolio_positions
SET opened_date = ?, closed_date = ?, quantity = ?, cost_price = ?, market_price = ?,
market_value = ?, realized_pnl = ?, unrealized_pnl = ?, target_weight = ?, status = ?,
notes = ?, metadata = ?, updated_at = ?
WHERE id = ?
""",
(
opened_date,
closed_date,
quantity,
cost_price,
market_price,
market_value,
realized,
unrealized,
target_weight,
status,
notes,
metadata,
now_ts,
existing["id"],
),
)
else:
conn.execute(
"""
INSERT INTO portfolio_positions
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value,
realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
ts_code,
opened_date,
closed_date,
quantity,
cost_price,
market_price,
market_value,
realized,
unrealized,
target_weight,
status,
notes,
metadata,
),
)
stale_codes = set(existing_map) - seen_codes
for ts_code in stale_codes:
row_id = existing_map[ts_code]["id"]
conn.execute(
"""
UPDATE portfolio_positions
SET status = 'closed',
closed_date = COALESCE(closed_date, ?),
updated_at = ?
WHERE id = ?
""",
(trade_date, now_ts, row_id),
)
for trade in trades:
trade_ts = _normalize_date(trade.trade_date, field="trade.trade_date")
if not trade_ts:
raise ValueError("trade.trade_date is required")
ts_code = trade.ts_code.strip()
if not ts_code:
raise ValueError("trade.ts_code is required")
action = trade.action.strip()
if not action:
raise ValueError("trade.action is required")
quantity = _to_float(trade.quantity, field="trade.quantity", allow_none=False)
price = _to_float(trade.price, field="trade.price", allow_none=False)
fee = _to_float(trade.fee, field="trade.fee")
metadata_json = _json_dumps(trade.metadata)
order_id = (trade.order_id or "").strip() or None
if order_id:
existing_trade = conn.execute(
"SELECT id FROM portfolio_trades WHERE order_id = ?",
(order_id,),
).fetchone()
if existing_trade:
conn.execute(
"""
UPDATE portfolio_trades
SET trade_date = ?, ts_code = ?, action = ?, quantity = ?, price = ?, fee = ?,
source = ?, notes = ?, metadata = ?
WHERE id = ?
""",
(
trade_ts,
ts_code,
action,
quantity,
price,
fee,
trade.source,
trade.notes,
metadata_json,
existing_trade["id"],
),
)
continue
conn.execute(
"""
INSERT INTO portfolio_trades
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
trade_ts,
ts_code,
action,
quantity,
price,
fee,
order_id,
trade.source,
trade.notes,
metadata_json,
),
)
LOGGER.info(
"实时持仓写入完成 trade_date=%s positions=%s trades=%s",
trade_date,
len(positions),
len(trades),
extra=LOG_EXTRA,
)
__all__ = [
"RealtimeSnapshot",
"RealtimePosition",
"RealtimeTrade",
"sync_portfolio_state",
]

View File

@ -18,11 +18,11 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| DecisionEnv 扩展 | 🔄 | Episode 指标新增 Sharpe/Calmar奖励函数引入风险惩罚继续覆盖提示版本、function 策略等。 | | DecisionEnv 扩展 | ✅ | Episode 指标现已包含 Sharpe/Calmar奖励函数集成风险惩罚并覆写提示版本、function 策略等部门控制。 |
| 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 | | 强化学习基线 | ✅ | PPO/SAC 等连续动作算法已接入并形成实验基线。 |
| 奖励与评估体系 | 🔄 | 决策环境奖励已纳入风险/Turnover/Sharpe-Calmar待接入成交与资金曲线指标。 | | 奖励与评估体系 | ✅ | 决策环境奖励结合风险/Turnover/Sharpe-Calmar并同步输出成交与资金曲线指标。 |
| 实时持仓链路 | ⏳ | 建立线上持仓/成交写入与离线调参与监控共享的数据源。 | | 实时持仓链路 | ✅ | 新增 `app/utils/portfolio_sync.py`,回测与实时持仓、成交数据统一写入 `portfolio_*` 表供离线调参与监控共享。 |
| 全局参数搜索 | 🔄 | 已上线 epsilon-greedy 调参与指标输出,后续补充贝叶斯优化 / BOHB。 | | 全局参数搜索 | ✅ | epsilon-greedy + 高斯过程贝叶斯优化 + BOHB 继任者已落地,输出全量调参与指标。 |
## 多智能体协同与 LLM ## 多智能体协同与 LLM

View File

@ -0,0 +1,171 @@
"""Tests for live portfolio sync utilities."""
from __future__ import annotations
import pytest
from app.utils.db import db_session
from app.utils.portfolio_sync import (
RealtimePosition,
RealtimeSnapshot,
RealtimeTrade,
sync_portfolio_state,
)
def _fetch_one(sql: str, params: tuple | None = None):
with db_session(read_only=True) as conn:
return conn.execute(sql, params or ()).fetchone()
def _fetch_all(sql: str, params: tuple | None = None):
with db_session(read_only=True) as conn:
return conn.execute(sql, params or ()).fetchall()
def test_sync_portfolio_state_inserts_records(isolated_db):
snapshot = RealtimeSnapshot(
trade_date="2025-01-10",
total_value=100_000.0,
cash=40_000.0,
invested_value=60_000.0,
unrealized_pnl=600.0,
realized_pnl=250.0,
net_flow=-5_000.0,
exposure=0.6,
notes="intraday sync",
metadata={"source": "broker_api"},
)
positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-03",
quantity=1_500,
cost_price=12.5,
market_price=13.2,
realized_pnl=200.0,
unrealized_pnl=1050.0,
target_weight=0.3,
metadata={"account": "live"},
)
]
trades = [
RealtimeTrade(
trade_date="2025-01-10",
ts_code="000001.SZ",
action="buy",
quantity=500,
price=13.2,
fee=4.5,
order_id="order-001",
source="broker",
notes="increase position",
metadata={"account": "live"},
)
]
sync_portfolio_state(snapshot, positions, trades)
snap_row = _fetch_one("SELECT * FROM portfolio_snapshots WHERE trade_date = ?", ("2025-01-10",))
assert snap_row is not None
assert snap_row["total_value"] == pytest.approx(100_000.0)
assert snap_row["net_flow"] == pytest.approx(-5_000.0)
pos_row = _fetch_one("SELECT * FROM portfolio_positions WHERE ts_code = '000001.SZ'")
assert pos_row is not None
assert pos_row["quantity"] == pytest.approx(1_500.0)
assert pos_row["status"] == "open"
assert pos_row["target_weight"] == pytest.approx(0.3)
assert pos_row["metadata"] is not None
trade_row = _fetch_one("SELECT * FROM portfolio_trades WHERE order_id = 'order-001'")
assert trade_row is not None
assert trade_row["quantity"] == pytest.approx(500.0)
assert trade_row["price"] == pytest.approx(13.2)
assert trade_row["source"] == "broker"
def test_sync_portfolio_state_updates_and_closes(isolated_db):
# prime database with initial state
initial_snapshot = RealtimeSnapshot(
trade_date="2025-01-10",
total_value=90_000.0,
cash=50_000.0,
invested_value=40_000.0,
)
initial_positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-02",
quantity=800,
cost_price=11.0,
market_price=11.5,
status="open",
),
RealtimePosition(
ts_code="000002.SZ",
opened_date="2025-01-04",
quantity=600,
cost_price=8.0,
market_price=8.2,
status="open",
),
]
sync_portfolio_state(initial_snapshot, initial_positions, [])
# update next day with only one open position and revised trade info
followup_snapshot = RealtimeSnapshot(
trade_date="2025-01-11",
total_value=95_000.0,
cash=54_000.0,
invested_value=41_000.0,
net_flow=1_000.0,
)
followup_positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-02",
quantity=500,
cost_price=11.0,
market_price=12.0,
status="open",
realized_pnl=300.0,
target_weight=0.25,
)
]
trades = [
RealtimeTrade(
trade_date="2025-01-11",
ts_code="000001.SZ",
action="sell",
quantity=300,
price=12.0,
order_id="order-xyz",
source="broker",
),
# update existing trade by reusing order id
RealtimeTrade(
trade_date="2025-01-11",
ts_code="000001.SZ",
action="sell",
quantity=300,
price=12.1,
fee=3.2,
order_id="order-xyz",
source="broker",
notes="amended fill",
),
]
sync_portfolio_state(followup_snapshot, followup_positions, trades)
open_rows = _fetch_all("SELECT ts_code, status, closed_date FROM portfolio_positions")
status_map = {row["ts_code"]: (row["status"], row["closed_date"]) for row in open_rows}
assert status_map["000001.SZ"][0] == "open"
assert status_map["000001.SZ"][1] in (None, "2025-01-11")
assert status_map["000002.SZ"][0] == "closed"
assert status_map["000002.SZ"][1] == "2025-01-11"
trade_rows = _fetch_all("SELECT id FROM portfolio_trades")
assert len(trade_rows) == 1, "duplicate trades should be merged via order_id"
trade_record = _fetch_one("SELECT price, fee FROM portfolio_trades WHERE order_id = 'order-xyz'")
assert trade_record["price"] == pytest.approx(12.1)
assert trade_record["fee"] == pytest.approx(3.2)