llm-quant/tests/test_portfolio_sync.py

172 lines
5.3 KiB
Python

"""Tests for live portfolio sync utilities."""
from __future__ import annotations
import pytest
from app.utils.db import db_session
from app.utils.portfolio_sync import (
RealtimePosition,
RealtimeSnapshot,
RealtimeTrade,
sync_portfolio_state,
)
def _fetch_one(sql: str, params: tuple | None = None):
with db_session(read_only=True) as conn:
return conn.execute(sql, params or ()).fetchone()
def _fetch_all(sql: str, params: tuple | None = None):
with db_session(read_only=True) as conn:
return conn.execute(sql, params or ()).fetchall()
def test_sync_portfolio_state_inserts_records(isolated_db):
snapshot = RealtimeSnapshot(
trade_date="2025-01-10",
total_value=100_000.0,
cash=40_000.0,
invested_value=60_000.0,
unrealized_pnl=600.0,
realized_pnl=250.0,
net_flow=-5_000.0,
exposure=0.6,
notes="intraday sync",
metadata={"source": "broker_api"},
)
positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-03",
quantity=1_500,
cost_price=12.5,
market_price=13.2,
realized_pnl=200.0,
unrealized_pnl=1050.0,
target_weight=0.3,
metadata={"account": "live"},
)
]
trades = [
RealtimeTrade(
trade_date="2025-01-10",
ts_code="000001.SZ",
action="buy",
quantity=500,
price=13.2,
fee=4.5,
order_id="order-001",
source="broker",
notes="increase position",
metadata={"account": "live"},
)
]
sync_portfolio_state(snapshot, positions, trades)
snap_row = _fetch_one("SELECT * FROM portfolio_snapshots WHERE trade_date = ?", ("2025-01-10",))
assert snap_row is not None
assert snap_row["total_value"] == pytest.approx(100_000.0)
assert snap_row["net_flow"] == pytest.approx(-5_000.0)
pos_row = _fetch_one("SELECT * FROM portfolio_positions WHERE ts_code = '000001.SZ'")
assert pos_row is not None
assert pos_row["quantity"] == pytest.approx(1_500.0)
assert pos_row["status"] == "open"
assert pos_row["target_weight"] == pytest.approx(0.3)
assert pos_row["metadata"] is not None
trade_row = _fetch_one("SELECT * FROM portfolio_trades WHERE order_id = 'order-001'")
assert trade_row is not None
assert trade_row["quantity"] == pytest.approx(500.0)
assert trade_row["price"] == pytest.approx(13.2)
assert trade_row["source"] == "broker"
def test_sync_portfolio_state_updates_and_closes(isolated_db):
# prime database with initial state
initial_snapshot = RealtimeSnapshot(
trade_date="2025-01-10",
total_value=90_000.0,
cash=50_000.0,
invested_value=40_000.0,
)
initial_positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-02",
quantity=800,
cost_price=11.0,
market_price=11.5,
status="open",
),
RealtimePosition(
ts_code="000002.SZ",
opened_date="2025-01-04",
quantity=600,
cost_price=8.0,
market_price=8.2,
status="open",
),
]
sync_portfolio_state(initial_snapshot, initial_positions, [])
# update next day with only one open position and revised trade info
followup_snapshot = RealtimeSnapshot(
trade_date="2025-01-11",
total_value=95_000.0,
cash=54_000.0,
invested_value=41_000.0,
net_flow=1_000.0,
)
followup_positions = [
RealtimePosition(
ts_code="000001.SZ",
opened_date="2025-01-02",
quantity=500,
cost_price=11.0,
market_price=12.0,
status="open",
realized_pnl=300.0,
target_weight=0.25,
)
]
trades = [
RealtimeTrade(
trade_date="2025-01-11",
ts_code="000001.SZ",
action="sell",
quantity=300,
price=12.0,
order_id="order-xyz",
source="broker",
),
# update existing trade by reusing order id
RealtimeTrade(
trade_date="2025-01-11",
ts_code="000001.SZ",
action="sell",
quantity=300,
price=12.1,
fee=3.2,
order_id="order-xyz",
source="broker",
notes="amended fill",
),
]
sync_portfolio_state(followup_snapshot, followup_positions, trades)
open_rows = _fetch_all("SELECT ts_code, status, closed_date FROM portfolio_positions")
status_map = {row["ts_code"]: (row["status"], row["closed_date"]) for row in open_rows}
assert status_map["000001.SZ"][0] == "open"
assert status_map["000001.SZ"][1] in (None, "2025-01-11")
assert status_map["000002.SZ"][0] == "closed"
assert status_map["000002.SZ"][1] == "2025-01-11"
trade_rows = _fetch_all("SELECT id FROM portfolio_trades")
assert len(trade_rows) == 1, "duplicate trades should be merged via order_id"
trade_record = _fetch_one("SELECT price, fee FROM portfolio_trades WHERE order_id = 'order-xyz'")
assert trade_record["price"] == pytest.approx(12.1)
assert trade_record["fee"] == pytest.approx(3.2)