260 lines
8.0 KiB
Python
260 lines
8.0 KiB
Python
"""Portfolio data access helpers for candidate pool, positions, and PnL tracking."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
from .db import db_session
|
|
from .logging import get_logger
|
|
from .portfolio_init import get_portfolio_config
|
|
|
|
LOGGER = get_logger(__name__)
|
|
LOG_EXTRA = {"stage": "portfolio"}
|
|
|
|
|
|
def _loads_or_default(payload: Optional[str], default: Any) -> Any:
|
|
if not payload:
|
|
return default
|
|
try:
|
|
return json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
LOGGER.debug("JSON 解析失败 payload=%s", payload, extra=LOG_EXTRA)
|
|
return default
|
|
|
|
|
|
@dataclass
|
|
class InvestmentCandidate:
|
|
trade_date: str
|
|
ts_code: str
|
|
score: Optional[float]
|
|
status: str
|
|
rationale: Optional[str]
|
|
tags: List[str]
|
|
metadata: Dict[str, Any]
|
|
name: Optional[str] = None
|
|
industry: Optional[str] = None
|
|
|
|
|
|
def list_investment_pool(
|
|
*,
|
|
trade_date: Optional[str] = None,
|
|
status: Optional[Iterable[str]] = None,
|
|
limit: int = 200,
|
|
) -> List[InvestmentCandidate]:
|
|
"""Return investment candidates for the given trade date (latest if None)."""
|
|
|
|
query = [
|
|
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
|
|
"FROM investment_pool",
|
|
]
|
|
params: List[Any] = []
|
|
|
|
if trade_date:
|
|
query.append("WHERE trade_date = ?")
|
|
params.append(trade_date)
|
|
else:
|
|
query.append(
|
|
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
|
|
)
|
|
|
|
if status:
|
|
placeholders = ", ".join("?" for _ in status)
|
|
query.append(f"AND status IN ({placeholders})")
|
|
params.extend(list(status))
|
|
|
|
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
|
|
query.append("LIMIT ?")
|
|
params.append(int(limit))
|
|
|
|
sql = "\n".join(query)
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("查询 investment_pool 失败", extra=LOG_EXTRA)
|
|
return []
|
|
|
|
candidates: List[InvestmentCandidate] = []
|
|
for row in rows:
|
|
candidates.append(
|
|
InvestmentCandidate(
|
|
trade_date=row["trade_date"],
|
|
ts_code=row["ts_code"],
|
|
score=row["score"],
|
|
status=row["status"] or "unknown",
|
|
rationale=row["rationale"],
|
|
tags=list(_loads_or_default(row["tags"], [])),
|
|
metadata=dict(_loads_or_default(row["metadata"], {})),
|
|
name=row["name"],
|
|
industry=row["industry"],
|
|
)
|
|
)
|
|
return candidates
|
|
|
|
|
|
@dataclass
|
|
class PortfolioPosition:
|
|
id: int
|
|
ts_code: str
|
|
opened_date: str
|
|
closed_date: Optional[str]
|
|
quantity: float
|
|
cost_price: float
|
|
market_price: Optional[float]
|
|
market_value: Optional[float]
|
|
realized_pnl: float
|
|
unrealized_pnl: float
|
|
target_weight: Optional[float]
|
|
status: str
|
|
notes: Optional[str]
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
def list_positions(*, active_only: bool = True) -> List[PortfolioPosition]:
|
|
"""Return current portfolio positions."""
|
|
|
|
sql = """
|
|
SELECT id, ts_code, opened_date, closed_date, quantity, cost_price,
|
|
market_price, market_value, realized_pnl, unrealized_pnl,
|
|
target_weight, status, notes, metadata
|
|
FROM portfolio_positions
|
|
{where_clause}
|
|
ORDER BY status DESC, opened_date DESC
|
|
"""
|
|
|
|
where_clause = ""
|
|
params: List[Any] = []
|
|
if active_only:
|
|
where_clause = "WHERE status = 'open'"
|
|
|
|
sql = sql.format(where_clause=where_clause)
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("查询 portfolio_positions 失败", extra=LOG_EXTRA)
|
|
return []
|
|
|
|
positions: List[PortfolioPosition] = []
|
|
for row in rows:
|
|
positions.append(
|
|
PortfolioPosition(
|
|
id=row["id"],
|
|
ts_code=row["ts_code"],
|
|
opened_date=row["opened_date"],
|
|
closed_date=row["closed_date"],
|
|
quantity=float(row["quantity"]),
|
|
cost_price=float(row["cost_price"]),
|
|
market_price=row["market_price"],
|
|
market_value=row["market_value"],
|
|
realized_pnl=row["realized_pnl"],
|
|
unrealized_pnl=row["unrealized_pnl"],
|
|
target_weight=row["target_weight"],
|
|
status=row["status"],
|
|
notes=row["notes"],
|
|
metadata=dict(_loads_or_default(row["metadata"], {})),
|
|
)
|
|
)
|
|
return positions
|
|
|
|
|
|
@dataclass
|
|
class PortfolioSnapshot:
|
|
trade_date: str
|
|
total_value: Optional[float]
|
|
cash: Optional[float]
|
|
invested_value: Optional[float]
|
|
unrealized_pnl: Optional[float]
|
|
realized_pnl: Optional[float]
|
|
net_flow: Optional[float]
|
|
exposure: Optional[float]
|
|
notes: Optional[str]
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
|
"""Fetch the most recent portfolio snapshot.
|
|
|
|
Returns:
|
|
最新的投资组合快照,如果没有数据则返回初始快照(仅包含初始资金)
|
|
"""
|
|
sql = """
|
|
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
|
|
realized_pnl, net_flow, exposure, notes, metadata
|
|
FROM portfolio_snapshots
|
|
ORDER BY trade_date DESC
|
|
LIMIT 1
|
|
"""
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
row = conn.execute(sql).fetchone()
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("查询 portfolio_snapshots 失败", extra=LOG_EXTRA)
|
|
return None
|
|
|
|
if not row:
|
|
# 如果没有快照,返回初始状态(只有初始资金)
|
|
config = get_portfolio_config()
|
|
initial_capital = config["initial_capital"]
|
|
return PortfolioSnapshot(
|
|
trade_date="", # 空日期表示初始状态
|
|
total_value=initial_capital,
|
|
cash=initial_capital,
|
|
invested_value=0.0,
|
|
unrealized_pnl=0.0,
|
|
realized_pnl=0.0,
|
|
net_flow=0.0,
|
|
exposure=0.0,
|
|
notes="Initial portfolio state",
|
|
metadata={"initial_capital": initial_capital, "currency": config["currency"]},
|
|
)
|
|
|
|
return PortfolioSnapshot(
|
|
trade_date=row["trade_date"],
|
|
total_value=row["total_value"],
|
|
cash=row["cash"],
|
|
invested_value=row["invested_value"],
|
|
unrealized_pnl=row["unrealized_pnl"],
|
|
realized_pnl=row["realized_pnl"],
|
|
net_flow=row["net_flow"],
|
|
exposure=row["exposure"],
|
|
notes=row["notes"],
|
|
metadata=dict(_loads_or_default(row["metadata"], {})),
|
|
)
|
|
|
|
|
|
def list_recent_trades(limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""Return recent trades for monitoring purposes."""
|
|
|
|
sql = """
|
|
SELECT trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata
|
|
FROM portfolio_trades
|
|
ORDER BY trade_date DESC, id DESC
|
|
LIMIT ?
|
|
"""
|
|
with db_session(read_only=True) as conn:
|
|
try:
|
|
rows = conn.execute(sql, (int(limit),)).fetchall()
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("查询 portfolio_trades 失败", extra=LOG_EXTRA)
|
|
return []
|
|
|
|
trades: List[Dict[str, Any]] = []
|
|
for row in rows:
|
|
trades.append(
|
|
{
|
|
"trade_date": row["trade_date"],
|
|
"ts_code": row["ts_code"],
|
|
"action": row["action"],
|
|
"quantity": row["quantity"],
|
|
"price": row["price"],
|
|
"fee": row["fee"],
|
|
"order_id": row["order_id"],
|
|
"source": row["source"],
|
|
"notes": row["notes"],
|
|
"metadata": _loads_or_default(row["metadata"], {}),
|
|
}
|
|
)
|
|
return trades
|