update
This commit is contained in:
parent
2e98e81715
commit
ee853333a8
@ -7,7 +7,7 @@ from datetime import date
|
||||
from statistics import mean, pstdev
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.departments import DepartmentManager
|
||||
from app.agents.game import Decision, decide
|
||||
from app.llm.metrics import record_decision as metrics_record_decision
|
||||
@ -76,6 +76,9 @@ class BtConfig:
|
||||
class PortfolioState:
|
||||
cash: float = 1_000_000.0
|
||||
holdings: Dict[str, float] = field(default_factory=dict)
|
||||
cost_basis: Dict[str, float] = field(default_factory=dict)
|
||||
opened_dates: Dict[str, str] = field(default_factory=dict)
|
||||
realized_pnl: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -230,9 +233,9 @@ class BacktestEngine:
|
||||
trade_date: date,
|
||||
state: PortfolioState,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> List[Decision]:
|
||||
) -> List[tuple[str, AgentContext, Decision]]:
|
||||
feature_map = self.load_market_data(trade_date)
|
||||
decisions: List[Decision] = []
|
||||
records: List[tuple[str, AgentContext, Decision]] = []
|
||||
for ts_code, payload in feature_map.items():
|
||||
features = payload.get("features", {})
|
||||
market_snapshot = payload.get("market_snapshot", {})
|
||||
@ -266,7 +269,7 @@ class BacktestEngine:
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
|
||||
decisions.append(decision)
|
||||
records.append((ts_code, context, decision))
|
||||
self.record_agent_state(context, decision)
|
||||
if decision_callback:
|
||||
try:
|
||||
@ -275,7 +278,7 @@ class BacktestEngine:
|
||||
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
|
||||
# TODO: translate decisions into fills, holdings, and NAV updates.
|
||||
_ = state
|
||||
return decisions
|
||||
return records
|
||||
|
||||
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
|
||||
payload = {
|
||||
@ -390,6 +393,309 @@ class BacktestEngine:
|
||||
_ = payload
|
||||
# TODO: persist payload into bt_trades / audit tables when schema is ready.
|
||||
|
||||
try:
|
||||
self._record_investment_candidate(context, decision)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("写入 investment_pool 失败", extra=LOG_EXTRA)
|
||||
|
||||
def _apply_portfolio_updates(
|
||||
self,
|
||||
trade_date: date,
|
||||
state: PortfolioState,
|
||||
records: List[tuple[str, AgentContext, Decision]],
|
||||
result: BacktestResult,
|
||||
) -> None:
|
||||
trade_date_str = trade_date.isoformat()
|
||||
price_map: Dict[str, float] = {}
|
||||
decisions_map: Dict[str, Decision] = {}
|
||||
for ts_code, context, decision in records:
|
||||
scope_values = context.raw.get("scope_values") if context.raw else {}
|
||||
if not isinstance(scope_values, Mapping):
|
||||
scope_values = {}
|
||||
price = scope_values.get("daily.close") or scope_values.get("close")
|
||||
if price is None:
|
||||
continue
|
||||
try:
|
||||
price = float(price)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
price_map[ts_code] = price
|
||||
decisions_map[ts_code] = decision
|
||||
|
||||
if not price_map and state.holdings:
|
||||
trade_date_compact = trade_date.strftime("%Y%m%d")
|
||||
for ts_code in state.holdings.keys():
|
||||
fetched = self.data_broker.fetch_latest(ts_code, trade_date_compact, ["daily.close"])
|
||||
price = fetched.get("daily.close")
|
||||
if price:
|
||||
price_map[ts_code] = float(price)
|
||||
|
||||
portfolio_value_before = state.cash
|
||||
for ts_code, qty in state.holdings.items():
|
||||
price = price_map.get(ts_code)
|
||||
if price is None:
|
||||
continue
|
||||
portfolio_value_before += qty * price
|
||||
|
||||
if portfolio_value_before <= 0:
|
||||
portfolio_value_before = state.cash or 1.0
|
||||
|
||||
trades_records: List[Dict[str, Any]] = []
|
||||
for ts_code, decision in decisions_map.items():
|
||||
price = price_map.get(ts_code)
|
||||
if price is None or price <= 0:
|
||||
continue
|
||||
current_qty = state.holdings.get(ts_code, 0.0)
|
||||
desired_qty = current_qty
|
||||
if decision.action is AgentAction.SELL:
|
||||
desired_qty = 0.0
|
||||
elif decision.action is AgentAction.HOLD:
|
||||
desired_qty = current_qty
|
||||
else:
|
||||
target_weight = max(decision.target_weight, 0.0)
|
||||
desired_value = target_weight * portfolio_value_before
|
||||
if desired_value > 0:
|
||||
desired_qty = desired_value / price
|
||||
else:
|
||||
desired_qty = current_qty
|
||||
|
||||
delta = desired_qty - current_qty
|
||||
if abs(delta) < 1e-6:
|
||||
continue
|
||||
|
||||
if delta > 0:
|
||||
cost = delta * price
|
||||
if cost > state.cash:
|
||||
affordable_qty = state.cash / price if price > 0 else 0.0
|
||||
delta = max(0.0, affordable_qty)
|
||||
cost = delta * price
|
||||
desired_qty = current_qty + delta
|
||||
if delta <= 0:
|
||||
continue
|
||||
total_cost = state.cost_basis.get(ts_code, 0.0) * current_qty + cost
|
||||
new_qty = current_qty + delta
|
||||
state.cost_basis[ts_code] = total_cost / new_qty if new_qty > 0 else 0.0
|
||||
state.cash -= cost
|
||||
state.holdings[ts_code] = new_qty
|
||||
state.opened_dates.setdefault(ts_code, trade_date_str)
|
||||
trades_records.append(
|
||||
{
|
||||
"trade_date": trade_date_str,
|
||||
"ts_code": ts_code,
|
||||
"action": "buy",
|
||||
"quantity": float(delta),
|
||||
"price": price,
|
||||
"value": cost,
|
||||
"confidence": decision.confidence,
|
||||
"target_weight": decision.target_weight,
|
||||
}
|
||||
)
|
||||
else:
|
||||
sell_qty = abs(delta)
|
||||
if sell_qty > current_qty:
|
||||
sell_qty = current_qty
|
||||
delta = -sell_qty
|
||||
proceeds = sell_qty * price
|
||||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||||
realized = (price - cost_basis) * sell_qty
|
||||
state.cash += proceeds
|
||||
state.realized_pnl += realized
|
||||
new_qty = current_qty + delta
|
||||
if new_qty <= 1e-6:
|
||||
state.holdings.pop(ts_code, None)
|
||||
state.cost_basis.pop(ts_code, None)
|
||||
state.opened_dates.pop(ts_code, None)
|
||||
else:
|
||||
state.holdings[ts_code] = new_qty
|
||||
trades_records.append(
|
||||
{
|
||||
"trade_date": trade_date_str,
|
||||
"ts_code": ts_code,
|
||||
"action": "sell",
|
||||
"quantity": float(sell_qty),
|
||||
"price": price,
|
||||
"value": proceeds,
|
||||
"confidence": decision.confidence,
|
||||
"target_weight": decision.target_weight,
|
||||
"realized_pnl": realized,
|
||||
}
|
||||
)
|
||||
|
||||
market_value = 0.0
|
||||
unrealized_pnl = 0.0
|
||||
for ts_code, qty in state.holdings.items():
|
||||
price = price_map.get(ts_code)
|
||||
if price is None:
|
||||
continue
|
||||
market_value += qty * price
|
||||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||||
unrealized_pnl += (price - cost_basis) * qty
|
||||
|
||||
nav = state.cash + market_value
|
||||
result.nav_series.append(
|
||||
{
|
||||
"trade_date": trade_date_str,
|
||||
"nav": nav,
|
||||
"cash": state.cash,
|
||||
"market_value": market_value,
|
||||
"realized_pnl": state.realized_pnl,
|
||||
"unrealized_pnl": unrealized_pnl,
|
||||
}
|
||||
)
|
||||
if trades_records:
|
||||
result.trades.extend(trades_records)
|
||||
|
||||
try:
|
||||
self._persist_portfolio(
|
||||
trade_date_str,
|
||||
state,
|
||||
market_value,
|
||||
unrealized_pnl,
|
||||
trades_records,
|
||||
price_map,
|
||||
decisions_map,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("持仓数据写入失败", extra=LOG_EXTRA)
|
||||
|
||||
def _record_investment_candidate(
|
||||
self, context: AgentContext, decision: Decision
|
||||
) -> None:
|
||||
status = _candidate_status(decision.action, decision.requires_review)
|
||||
summary = _extract_summary(decision)
|
||||
if not summary:
|
||||
collected_signals: List[str] = []
|
||||
for dept in decision.department_decisions.values():
|
||||
collected_signals.extend(dept.signals)
|
||||
summary = ";".join(str(sig) for sig in collected_signals[:3])
|
||||
|
||||
metadata = {
|
||||
"target_weight": decision.target_weight,
|
||||
"feasible_actions": [action.value for action in decision.feasible_actions],
|
||||
"department_votes": decision.department_votes,
|
||||
"requires_review": decision.requires_review,
|
||||
"confidence": decision.confidence,
|
||||
}
|
||||
if decision.department_decisions:
|
||||
metadata["departments"] = {
|
||||
code: dept.to_dict()
|
||||
for code, dept in decision.department_decisions.items()
|
||||
}
|
||||
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO investment_pool
|
||||
(trade_date, ts_code, score, status, rationale, tags, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
context.trade_date,
|
||||
context.ts_code,
|
||||
float(decision.confidence or 0.0),
|
||||
status,
|
||||
summary or None,
|
||||
json.dumps(_department_tags(decision), ensure_ascii=False),
|
||||
json.dumps(metadata, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
def _persist_portfolio(
|
||||
self,
|
||||
trade_date: str,
|
||||
state: PortfolioState,
|
||||
market_value: float,
|
||||
unrealized_pnl: float,
|
||||
trades: List[Dict[str, Any]],
|
||||
price_map: Dict[str, float],
|
||||
decisions_map: Dict[str, Decision],
|
||||
) -> None:
|
||||
holdings_rows: List[tuple] = []
|
||||
for ts_code, qty in state.holdings.items():
|
||||
price = price_map.get(ts_code)
|
||||
market_val = qty * price if price is not None else None
|
||||
cost_basis = state.cost_basis.get(ts_code, 0.0)
|
||||
unrealized = (price - cost_basis) * qty if price is not None else None
|
||||
decision = decisions_map.get(ts_code)
|
||||
target_weight = decision.target_weight if decision else None
|
||||
metadata = {
|
||||
"last_action": decision.action.value if decision else None,
|
||||
"confidence": decision.confidence if decision else None,
|
||||
}
|
||||
holdings_rows.append(
|
||||
(
|
||||
ts_code,
|
||||
state.opened_dates.get(ts_code, trade_date),
|
||||
None,
|
||||
qty,
|
||||
cost_basis,
|
||||
price,
|
||||
market_val,
|
||||
state.realized_pnl,
|
||||
unrealized,
|
||||
target_weight,
|
||||
"open",
|
||||
None,
|
||||
json.dumps(metadata, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
snapshot_metadata = {
|
||||
"holdings": len(state.holdings),
|
||||
}
|
||||
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO portfolio_snapshots
|
||||
(trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trade_date,
|
||||
market_value + state.cash,
|
||||
state.cash,
|
||||
market_value,
|
||||
unrealized_pnl,
|
||||
state.realized_pnl,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
json.dumps(snapshot_metadata, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
conn.execute("DELETE FROM portfolio_positions")
|
||||
if holdings_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO portfolio_positions
|
||||
(ts_code, opened_date, closed_date, quantity, cost_price, market_price, market_value, realized_pnl, unrealized_pnl, target_weight, status, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
holdings_rows,
|
||||
)
|
||||
|
||||
if trades:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO portfolio_trades
|
||||
(trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, 0, NULL, 'backtest', NULL, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
trade["trade_date"],
|
||||
trade["ts_code"],
|
||||
trade["action"],
|
||||
trade["quantity"],
|
||||
trade["price"],
|
||||
json.dumps(trade, ensure_ascii=False),
|
||||
)
|
||||
for trade in trades
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
@ -398,8 +704,8 @@ class BacktestEngine:
|
||||
result = BacktestResult()
|
||||
current = self.cfg.start_date
|
||||
while current <= self.cfg.end_date:
|
||||
decisions = self.simulate_day(current, state, decision_callback)
|
||||
_ = decisions
|
||||
records = self.simulate_day(current, state, decision_callback)
|
||||
self._apply_portfolio_updates(current, state, records, result)
|
||||
current = date.fromordinal(current.toordinal() + 1)
|
||||
return result
|
||||
|
||||
@ -415,9 +721,33 @@ def run_backtest(
|
||||
_ = conn
|
||||
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
||||
return result
|
||||
|
||||
|
||||
def _candidate_status(action: AgentAction, requires_review: bool) -> str:
|
||||
mapping = {
|
||||
AgentAction.SELL: "exit",
|
||||
AgentAction.HOLD: "watch",
|
||||
AgentAction.BUY_S: "buy_s",
|
||||
AgentAction.BUY_M: "buy_m",
|
||||
AgentAction.BUY_L: "buy_l",
|
||||
}
|
||||
base = mapping.get(action, "candidate")
|
||||
if requires_review:
|
||||
return f"{base}_review"
|
||||
return base
|
||||
def _extract_summary(decision: Decision) -> str:
|
||||
for dept_decision in decision.department_decisions.values():
|
||||
summary = getattr(dept_decision, "summary", "")
|
||||
if summary:
|
||||
return str(summary)
|
||||
return ""
|
||||
|
||||
|
||||
def _department_tags(decision: Decision) -> List[str]:
|
||||
tags: List[str] = []
|
||||
for code, dept in decision.department_decisions.items():
|
||||
action = getattr(dept, "action", None)
|
||||
if action is None:
|
||||
continue
|
||||
tags.append(f"{code}:{action.value}")
|
||||
return sorted(set(tags))
|
||||
|
||||
@ -362,6 +362,67 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
reason TEXT,
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS investment_pool (
|
||||
trade_date TEXT,
|
||||
ts_code TEXT,
|
||||
score REAL,
|
||||
status TEXT,
|
||||
rationale TEXT,
|
||||
tags TEXT,
|
||||
metadata TEXT,
|
||||
created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
PRIMARY KEY (trade_date, ts_code)
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS portfolio_positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts_code TEXT NOT NULL,
|
||||
opened_date TEXT NOT NULL,
|
||||
closed_date TEXT,
|
||||
quantity REAL NOT NULL,
|
||||
cost_price REAL NOT NULL,
|
||||
market_price REAL,
|
||||
market_value REAL,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
unrealized_pnl REAL DEFAULT 0,
|
||||
target_weight REAL,
|
||||
status TEXT NOT NULL DEFAULT 'open',
|
||||
notes TEXT,
|
||||
metadata TEXT,
|
||||
updated_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS portfolio_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trade_date TEXT NOT NULL,
|
||||
ts_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
fee REAL DEFAULT 0,
|
||||
order_id TEXT,
|
||||
source TEXT,
|
||||
notes TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS portfolio_snapshots (
|
||||
trade_date TEXT PRIMARY KEY,
|
||||
total_value REAL,
|
||||
cash REAL,
|
||||
invested_value REAL,
|
||||
unrealized_pnl REAL,
|
||||
realized_pnl REAL,
|
||||
net_flow REAL,
|
||||
exposure REAL,
|
||||
notes TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
@ -391,6 +452,10 @@ REQUIRED_TABLES = (
|
||||
"run_log",
|
||||
"agent_utils",
|
||||
"alloc_log",
|
||||
"investment_pool",
|
||||
"portfolio_positions",
|
||||
"portfolio_trades",
|
||||
"portfolio_snapshots",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -47,6 +47,12 @@ from app.utils.config import (
|
||||
)
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils.portfolio import (
|
||||
get_latest_snapshot,
|
||||
list_investment_pool,
|
||||
list_positions,
|
||||
list_recent_trades,
|
||||
)
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -529,7 +535,87 @@ def render_today_plan() -> None:
|
||||
else:
|
||||
st.info("暂无基础代理评分。")
|
||||
|
||||
st.caption("以上内容来源于 agent_utils 表,可通过回测或实时评估自动更新。")
|
||||
st.divider()
|
||||
st.subheader("投资池与仓位概览")
|
||||
|
||||
snapshot = get_latest_snapshot()
|
||||
if snapshot:
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
if snapshot.total_value is not None:
|
||||
col_a.metric("组合净值", f"{snapshot.total_value:,.2f}")
|
||||
if snapshot.cash is not None:
|
||||
col_b.metric("现金余额", f"{snapshot.cash:,.2f}")
|
||||
if snapshot.invested_value is not None:
|
||||
col_c.metric("持仓市值", f"{snapshot.invested_value:,.2f}")
|
||||
detail_cols = st.columns(4)
|
||||
if snapshot.unrealized_pnl is not None:
|
||||
detail_cols[0].metric("浮盈", f"{snapshot.unrealized_pnl:,.2f}")
|
||||
if snapshot.realized_pnl is not None:
|
||||
detail_cols[1].metric("已实现盈亏", f"{snapshot.realized_pnl:,.2f}")
|
||||
if snapshot.net_flow is not None:
|
||||
detail_cols[2].metric("净流入", f"{snapshot.net_flow:,.2f}")
|
||||
if snapshot.exposure is not None:
|
||||
detail_cols[3].metric("风险敞口", f"{snapshot.exposure:.2%}")
|
||||
if snapshot.notes:
|
||||
st.caption(f"备注:{snapshot.notes}")
|
||||
else:
|
||||
st.info("暂无组合快照,请在执行回测或实盘同步后写入 portfolio_snapshots。")
|
||||
|
||||
candidates = list_investment_pool(trade_date=trade_date)
|
||||
if candidates:
|
||||
candidate_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"交易日": item.trade_date,
|
||||
"代码": item.ts_code,
|
||||
"评分": item.score,
|
||||
"状态": item.status,
|
||||
"标签": "、".join(item.tags) if item.tags else "-",
|
||||
"理由": item.rationale or "",
|
||||
}
|
||||
for item in candidates
|
||||
]
|
||||
)
|
||||
st.write("候选投资池:")
|
||||
st.dataframe(candidate_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("候选投资池暂无数据。")
|
||||
|
||||
positions = list_positions(active_only=False)
|
||||
if positions:
|
||||
position_df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"ID": pos.id,
|
||||
"代码": pos.ts_code,
|
||||
"开仓日": pos.opened_date,
|
||||
"平仓日": pos.closed_date or "-",
|
||||
"状态": pos.status,
|
||||
"数量": pos.quantity,
|
||||
"成本": pos.cost_price,
|
||||
"现价": pos.market_price,
|
||||
"市值": pos.market_value,
|
||||
"浮盈": pos.unrealized_pnl,
|
||||
"已实现": pos.realized_pnl,
|
||||
"目标权重": pos.target_weight,
|
||||
}
|
||||
for pos in positions
|
||||
]
|
||||
)
|
||||
st.write("组合持仓:")
|
||||
st.dataframe(position_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("组合持仓暂无记录。")
|
||||
|
||||
trades = list_recent_trades(limit=20)
|
||||
if trades:
|
||||
trades_df = pd.DataFrame(trades)
|
||||
st.write("近期成交:")
|
||||
st.dataframe(trades_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.caption("近期成交暂无记录。")
|
||||
|
||||
st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。")
|
||||
|
||||
|
||||
def render_backtest() -> None:
|
||||
|
||||
236
app/utils/portfolio.py
Normal file
236
app/utils/portfolio.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""Portfolio data access helpers for candidate pool, positions, and PnL tracking."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "portfolio"}
|
||||
|
||||
|
||||
def _loads_or_default(payload: Optional[str], default: Any) -> Any:
|
||||
if not payload:
|
||||
return default
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
LOGGER.debug("JSON 解析失败 payload=%s", payload, extra=LOG_EXTRA)
|
||||
return default
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvestmentCandidate:
|
||||
trade_date: str
|
||||
ts_code: str
|
||||
score: Optional[float]
|
||||
status: str
|
||||
rationale: Optional[str]
|
||||
tags: List[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
def list_investment_pool(
|
||||
*,
|
||||
trade_date: Optional[str] = None,
|
||||
status: Optional[Iterable[str]] = None,
|
||||
limit: int = 200,
|
||||
) -> List[InvestmentCandidate]:
|
||||
"""Return investment candidates for the given trade date (latest if None)."""
|
||||
|
||||
query = [
|
||||
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata",
|
||||
"FROM investment_pool",
|
||||
]
|
||||
params: List[Any] = []
|
||||
|
||||
if trade_date:
|
||||
query.append("WHERE trade_date = ?")
|
||||
params.append(trade_date)
|
||||
else:
|
||||
query.append(
|
||||
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
|
||||
)
|
||||
|
||||
if status:
|
||||
placeholders = ", ".join("?" for _ in status)
|
||||
query.append(f"AND status IN ({placeholders})")
|
||||
params.extend(list(status))
|
||||
|
||||
query.append("ORDER BY score DESC NULLS LAST, ts_code")
|
||||
query.append("LIMIT ?")
|
||||
params.append(int(limit))
|
||||
|
||||
sql = "\n".join(query)
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("查询 investment_pool 失败", extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
candidates: List[InvestmentCandidate] = []
|
||||
for row in rows:
|
||||
candidates.append(
|
||||
InvestmentCandidate(
|
||||
trade_date=row["trade_date"],
|
||||
ts_code=row["ts_code"],
|
||||
score=row["score"],
|
||||
status=row["status"] or "unknown",
|
||||
rationale=row["rationale"],
|
||||
tags=list(_loads_or_default(row["tags"], [])),
|
||||
metadata=dict(_loads_or_default(row["metadata"], {})),
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioPosition:
|
||||
id: int
|
||||
ts_code: str
|
||||
opened_date: str
|
||||
closed_date: Optional[str]
|
||||
quantity: float
|
||||
cost_price: float
|
||||
market_price: Optional[float]
|
||||
market_value: Optional[float]
|
||||
realized_pnl: float
|
||||
unrealized_pnl: float
|
||||
target_weight: Optional[float]
|
||||
status: str
|
||||
notes: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
def list_positions(*, active_only: bool = True) -> List[PortfolioPosition]:
|
||||
"""Return current portfolio positions."""
|
||||
|
||||
sql = """
|
||||
SELECT id, ts_code, opened_date, closed_date, quantity, cost_price,
|
||||
market_price, market_value, realized_pnl, unrealized_pnl,
|
||||
target_weight, status, notes, metadata
|
||||
FROM portfolio_positions
|
||||
{where_clause}
|
||||
ORDER BY status DESC, opened_date DESC
|
||||
"""
|
||||
|
||||
where_clause = ""
|
||||
params: List[Any] = []
|
||||
if active_only:
|
||||
where_clause = "WHERE status = 'open'"
|
||||
|
||||
sql = sql.format(where_clause=where_clause)
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("查询 portfolio_positions 失败", extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
positions: List[PortfolioPosition] = []
|
||||
for row in rows:
|
||||
positions.append(
|
||||
PortfolioPosition(
|
||||
id=row["id"],
|
||||
ts_code=row["ts_code"],
|
||||
opened_date=row["opened_date"],
|
||||
closed_date=row["closed_date"],
|
||||
quantity=float(row["quantity"]),
|
||||
cost_price=float(row["cost_price"]),
|
||||
market_price=row["market_price"],
|
||||
market_value=row["market_value"],
|
||||
realized_pnl=row["realized_pnl"],
|
||||
unrealized_pnl=row["unrealized_pnl"],
|
||||
target_weight=row["target_weight"],
|
||||
status=row["status"],
|
||||
notes=row["notes"],
|
||||
metadata=dict(_loads_or_default(row["metadata"], {})),
|
||||
)
|
||||
)
|
||||
return positions
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioSnapshot:
|
||||
trade_date: str
|
||||
total_value: Optional[float]
|
||||
cash: Optional[float]
|
||||
invested_value: Optional[float]
|
||||
unrealized_pnl: Optional[float]
|
||||
realized_pnl: Optional[float]
|
||||
net_flow: Optional[float]
|
||||
exposure: Optional[float]
|
||||
notes: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
def get_latest_snapshot() -> Optional[PortfolioSnapshot]:
|
||||
"""Fetch the most recent portfolio snapshot."""
|
||||
|
||||
sql = """
|
||||
SELECT trade_date, total_value, cash, invested_value, unrealized_pnl,
|
||||
realized_pnl, net_flow, exposure, notes, metadata
|
||||
FROM portfolio_snapshots
|
||||
ORDER BY trade_date DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
row = conn.execute(sql).fetchone()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("查询 portfolio_snapshots 失败", extra=LOG_EXTRA)
|
||||
return None
|
||||
|
||||
if not row:
|
||||
return None
|
||||
return PortfolioSnapshot(
|
||||
trade_date=row["trade_date"],
|
||||
total_value=row["total_value"],
|
||||
cash=row["cash"],
|
||||
invested_value=row["invested_value"],
|
||||
unrealized_pnl=row["unrealized_pnl"],
|
||||
realized_pnl=row["realized_pnl"],
|
||||
net_flow=row["net_flow"],
|
||||
exposure=row["exposure"],
|
||||
notes=row["notes"],
|
||||
metadata=dict(_loads_or_default(row["metadata"], {})),
|
||||
)
|
||||
|
||||
|
||||
def list_recent_trades(limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Return recent trades for monitoring purposes."""
|
||||
|
||||
sql = """
|
||||
SELECT trade_date, ts_code, action, quantity, price, fee, order_id, source, notes, metadata
|
||||
FROM portfolio_trades
|
||||
ORDER BY trade_date DESC, id DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(sql, (int(limit),)).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("查询 portfolio_trades 失败", extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
trades: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
trades.append(
|
||||
{
|
||||
"trade_date": row["trade_date"],
|
||||
"ts_code": row["ts_code"],
|
||||
"action": row["action"],
|
||||
"quantity": row["quantity"],
|
||||
"price": row["price"],
|
||||
"fee": row["fee"],
|
||||
"order_id": row["order_id"],
|
||||
"source": row["source"],
|
||||
"notes": row["notes"],
|
||||
"metadata": _loads_or_default(row["metadata"], {}),
|
||||
}
|
||||
)
|
||||
return trades
|
||||
@ -34,3 +34,10 @@
|
||||
- `agent_utils` 表新增 `_telemetry` 与 `_department_telemetry` JSON 字段(存于 `utils` 列内部),记录每个部门的 provider、模型、温度、回合数、工具调用列表与 token 统计,可在 Streamlit “部门意见”详情页展开查看。
|
||||
- `app/data/logs/agent_*.log` 会追加 `telemetry` 行,保存每轮函数调用的摘要,方便离线分析提示版本与 LLM 配置对决策的影响。
|
||||
- Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。
|
||||
- 新增投资管理数据层:SQLite 中创建 `investment_pool`、`portfolio_positions`、`portfolio_trades`、`portfolio_snapshots` 四张表;`app/utils/portfolio.py` 提供访问接口,今日计划页可实时展示候选池、持仓与成交。
|
||||
- 回测引擎 `record_agent_state()` 现同步写入 `investment_pool`,将每日全局决策的置信度、部门标签与目标权重落库,作为后续提示参数调优与候选池管理的基础数据。
|
||||
|
||||
## 下一阶段路线图
|
||||
- 将 `BacktestEngine` 封装为 `DecisionEnv`,让一次策略配置跑完整个回测周期并输出奖励、约束违例等指标。
|
||||
- 接入 Bandit/贝叶斯优化,对 Prompt 版本、部门权重、温度范围做离线搜索,利用新增的 snapshot/positions 数据衡量风险与收益。
|
||||
- 构建持仓/成交写入流程(回测与实时),确保 RL 训练能复原资金曲线、资金占用与调仓成本。
|
||||
|
||||
Loading…
Reference in New Issue
Block a user