llm-quant/app/utils/portfolio.py

260 lines
8.0 KiB
Python

"""Portfolio data access helpers for candidate pool, positions, and PnL tracking."""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
from .db import db_session
from .logging import get_logger
from .portfolio_init import get_portfolio_config
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "portfolio"}
def _loads_or_default(payload: Optional[str], default: Any) -> Any:
if not payload:
return default
try:
return json.loads(payload)
except json.JSONDecodeError:
LOGGER.debug("JSON 解析失败 payload=%s", payload, extra=LOG_EXTRA)
return default
@dataclass
class InvestmentCandidate:
trade_date: str
ts_code: str
score: Optional[float]
status: str
rationale: Optional[str]
tags: List[str]
metadata: Dict[str, Any]
name: Optional[str] = None
industry: Optional[str] = None
def list_investment_pool(
*,
trade_date: Optional[str] = None,
status: Optional[Iterable[str]] = None,
limit: int = 200,
) -> List[InvestmentCandidate]:
"""Return investment candidates for the given trade date (latest if None)."""
query = [
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
"FROM investment_pool",
]
params: List[Any] = []
if trade_date:
query.append("WHERE trade_date = ?")
params.append(trade_date)
else:
query.append(
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
)
if status:
placeholders = ", ".join("?" for _ in status)
query.append(f"AND status IN ({placeholders})")
params.extend(list(status))
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
query.append("LIMIT ?")
params.append(int(limit))
sql = "\n".join(query)
with db_session(read_only=True) as conn:
try:
rows = conn.execute(sql, params).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("查询 investment_pool 失败", extra=LOG_EXTRA)
return []
candidates: List[InvestmentCandidate] = []
for row in rows:
candidates.append(
InvestmentCandidate(
trade_date=row["trade_date"],
ts_code=row["ts_code"],
score=row["score"],
status=row["status"] or "unknown",
rationale=row["rationale"],
tags=list(_loads_or_default(row["tags"], [])),
metadata=dict(_loads_or_default(row["metadata"], {})),
name=row["name"],
industry=row["industry"],
)
)
return candidates
@dataclass
class PortfolioPosition:
id: int
ts_code: str
opened_date: str
closed_date: Optional[str]
quantity: float
cost_price: float
market_price: Optional[float]
market_value: Optional[float]
realized_pnl: float
unrealized_pnl: float
target_weight: Optional[float]
status: str
notes: Optional[str]
metadata: Dict[str, Any]
def list_positions(*, active_only: bool = True) -> List[PortfolioPosition]:
"""Return current portfolio positions."""
sql = """
SELECT id, ts_code, opened_date, closed_date, quantity, cost_price,
market_price, market_value, realized_pnl, unrealized_pnl,
target_weight, status, notes, metadata
FROM portfolio_positions
{where_clause}
ORDER BY status DESC, opened_date DESC
"""
where_clause = ""
params: List[Any] = []
if active_only:
where_clause = "WHERE status = 'open'"
sql = sql.format(where_clause=where_clause)
with db_session(read_only=True) as conn:
try:
rows = conn.execute(sql, params).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("查询 portfolio_positions 失败", extra=LOG_EXTRA)
return []
positions: List[PortfolioPosition] = []
for row in rows:
positions.append(
PortfolioPosition(
id=row["id"],
ts_code=row["ts_code"],
opened_date=row["opened_date"],
closed_date=row["closed_date"],
quantity=float(row["quantity"]),
cost_price=float(row["cost_price"]),
market_price=row["market_price"],
market_value=row["market_value"],
realized_pnl=row["realized_pnl"],
unrealized_pnl=row["unrealized_pnl"],
target_weight=row["target_weight"],
status=row["status"],
notes=row["notes"],
metadata=dict(_loads_or_default(row["metadata"], {})),
)
)
return positions
@dataclass
class PortfolioSnapshot:
trade_date: str
total_value: Optional[float]
cash: Optional[float]
invested_value: Optional[float]
unrealized_pnl: Optional[float]
realized_pnl: Optional[float]
net_flow: Optional[float]
exposure: Optional[float]
notes: Optional[str]
metadata: Dict[str, Any]
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
"""Fetch the most recent portfolio snapshot.
Returns:
最新的投资组合快照,如果没有数据则返回初始快照(仅包含初始资金)
"""
sql = """
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
realized_pnl, net_flow, exposure, notes, metadata
FROM portfolio_snapshots
ORDER BY trade_date DESC
LIMIT 1
"""
with db_session(read_only=True) as conn:
try:
row = conn.execute(sql).fetchone()
except Exception: # noqa: BLE001
LOGGER.exception("查询 portfolio_snapshots 失败", extra=LOG_EXTRA)
return None
if not row:
# 如果没有快照,返回初始状态(只有初始资金)
config = get_portfolio_config()
initial_capital = config["initial_capital"]
return PortfolioSnapshot(
trade_date="", # 空日期表示初始状态
total_value=initial_capital,
cash=initial_capital,
invested_value=0.0,
unrealized_pnl=0.0,
realized_pnl=0.0,
net_flow=0.0,
exposure=0.0,
notes="Initial portfolio state",
metadata={"initial_capital": initial_capital, "currency": config["currency"]},
)
return PortfolioSnapshot(
trade_date=row["trade_date"],
total_value=row["total_value"],
cash=row["cash"],
invested_value=row["invested_value"],
unrealized_pnl=row["unrealized_pnl"],
realized_pnl=row["realized_pnl"],
net_flow=row["net_flow"],
exposure=row["exposure"],
notes=row["notes"],
metadata=dict(_loads_or_default(row["metadata"], {})),
)
def list_recent_trades(limit: int = 50) -> List[Dict[str, Any]]:
"""Return recent trades for monitoring purposes."""
sql = """
SELECT trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata
FROM portfolio_trades
ORDER BY trade_date DESC, id DESC
LIMIT ?
"""
with db_session(read_only=True) as conn:
try:
rows = conn.execute(sql, (int(limit),)).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("查询 portfolio_trades 失败", extra=LOG_EXTRA)
return []
trades: List[Dict[str, Any]] = []
for row in rows:
trades.append(
{
"trade_date": row["trade_date"],
"ts_code": row["ts_code"],
"action": row["action"],
"quantity": row["quantity"],
"price": row["price"],
"fee": row["fee"],
"order_id": row["order_id"],
"source": row["source"],
"notes": row["notes"],
"metadata": _loads_or_default(row["metadata"], {}),
}
)
return trades