add comprehensive logging to decision workflow and risk assessment
This commit is contained in:
parent
1ca2f2be19
commit
4b68d84b3c
@ -4,8 +4,13 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
from .base import AgentAction
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_belief"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeliefRevisionResult:
|
||||
@ -46,12 +51,21 @@ def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: Ag
|
||||
"votes": {action.value: count for action, count in action_votes.items()},
|
||||
"reasons": reasons,
|
||||
}
|
||||
return BeliefRevisionResult(
|
||||
result = BeliefRevisionResult(
|
||||
consensus_action=consensus_action,
|
||||
consensus_confidence=consensus_confidence,
|
||||
conflicts=conflicts,
|
||||
notes=notes,
|
||||
)
|
||||
LOGGER.debug(
|
||||
"信念修正完成 consensus=%s confidence=%.3f conflicts=%s vote_counts=%s",
|
||||
result.consensus_action.value if result.consensus_action else None,
|
||||
result.consensus_confidence,
|
||||
result.conflicts,
|
||||
notes["votes"],
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# avoid circular import typing
|
||||
|
||||
@ -5,6 +5,8 @@ from dataclasses import dataclass, field
|
||||
from math import log
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
||||
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
|
||||
from .registry import weight_map
|
||||
@ -20,6 +22,10 @@ from .protocols import (
|
||||
)
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_workflow"}
|
||||
|
||||
|
||||
ACTIONS: Tuple[AgentAction, ...] = (
|
||||
AgentAction.SELL,
|
||||
AgentAction.HOLD,
|
||||
@ -188,9 +194,30 @@ class DecisionWorkflow:
|
||||
self.norm_weights: Dict[str, float] = {}
|
||||
self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {}
|
||||
self.belief_revision: Optional[BeliefRevisionResult] = None
|
||||
LOGGER.debug(
|
||||
"初始化决策流程 ts_code=%s trade_date=%s method=%s agents=%s departments=%s",
|
||||
context.ts_code,
|
||||
context.trade_date,
|
||||
method,
|
||||
len(self.agent_list),
|
||||
bool(self.department_manager),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
def run(self) -> Decision:
|
||||
LOGGER.debug(
|
||||
"执行决策流程 ts_code=%s method=%s feasible=%s",
|
||||
self.context.ts_code,
|
||||
self.method,
|
||||
[action.value for action in self.feasible_actions],
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if not self.feasible_actions:
|
||||
LOGGER.warning(
|
||||
"无可行动作,回退到 HOLD ts_code=%s",
|
||||
self.context.ts_code,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return Decision(
|
||||
action=AgentAction.HOLD,
|
||||
confidence=0.0,
|
||||
@ -201,6 +228,13 @@ class DecisionWorkflow:
|
||||
|
||||
self._evaluate_departments()
|
||||
action, confidence = self._select_action()
|
||||
LOGGER.debug(
|
||||
"初步动作选择完成 ts_code=%s action=%s confidence=%.3f",
|
||||
self.context.ts_code,
|
||||
action.value,
|
||||
confidence,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
risk_assessment = self._apply_risk(action)
|
||||
exec_action = self._finalize_execution(action, risk_assessment)
|
||||
self._finalize_conflicts(exec_action)
|
||||
@ -210,7 +244,7 @@ class DecisionWorkflow:
|
||||
self.department_votes,
|
||||
)
|
||||
|
||||
return Decision(
|
||||
decision = Decision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
target_weight=target_weight_for_action(action),
|
||||
@ -224,6 +258,16 @@ class DecisionWorkflow:
|
||||
belief_updates=self.belief_updates,
|
||||
belief_revision=self.belief_revision,
|
||||
)
|
||||
LOGGER.info(
|
||||
"决策完成 ts_code=%s action=%s confidence=%.3f review=%s risk_status=%s",
|
||||
self.context.ts_code,
|
||||
decision.action.value,
|
||||
decision.confidence,
|
||||
decision.requires_review,
|
||||
risk_assessment.status,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return decision
|
||||
|
||||
def _evaluate_departments(self) -> None:
|
||||
if not self.department_manager:
|
||||
@ -236,7 +280,19 @@ class DecisionWorkflow:
|
||||
market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}),
|
||||
raw=dict(getattr(self.context, "raw", {}) or {}),
|
||||
)
|
||||
LOGGER.debug(
|
||||
"开始部门评估 ts_code=%s departments=%s",
|
||||
self.context.ts_code,
|
||||
list(self.department_manager.agents.keys()),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
self.department_decisions = self.department_manager.evaluate(dept_context)
|
||||
LOGGER.debug(
|
||||
"部门评估完成 ts_code=%s decisions=%s",
|
||||
self.context.ts_code,
|
||||
list(self.department_decisions.keys()),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if self.department_decisions:
|
||||
self.department_round = self.host.start_round(
|
||||
self.host_trace,
|
||||
@ -285,11 +341,32 @@ class DecisionWorkflow:
|
||||
)
|
||||
|
||||
if self.method == "vote":
|
||||
return vote(self.filtered_utilities, self.norm_weights)
|
||||
action, confidence = vote(self.filtered_utilities, self.norm_weights)
|
||||
LOGGER.debug(
|
||||
"采用投票机制 ts_code=%s action=%s confidence=%.3f",
|
||||
self.context.ts_code,
|
||||
action.value,
|
||||
confidence,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return action, confidence
|
||||
|
||||
action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores)
|
||||
if action not in self.feasible_actions:
|
||||
return vote(self.filtered_utilities, self.norm_weights)
|
||||
LOGGER.debug(
|
||||
"纳什解不可行,改用投票 ts_code=%s invalid_action=%s",
|
||||
self.context.ts_code,
|
||||
action.value,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
action, confidence = vote(self.filtered_utilities, self.norm_weights)
|
||||
LOGGER.debug(
|
||||
"纳什解计算完成 ts_code=%s action=%s confidence=%.3f",
|
||||
self.context.ts_code,
|
||||
action.value,
|
||||
confidence,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return action, confidence
|
||||
|
||||
def _apply_risk(self, action: AgentAction) -> RiskAssessment:
|
||||
@ -307,6 +384,17 @@ class DecisionWorkflow:
|
||||
self.department_round.outcome = action.value
|
||||
self.host.finalize_round(self.department_round)
|
||||
|
||||
LOGGER.debug(
|
||||
"风险评估结果 ts_code=%s action=%s status=%s reason=%s conflict=%s votes=%s",
|
||||
self.context.ts_code,
|
||||
action.value,
|
||||
assessment.status,
|
||||
assessment.reason,
|
||||
conflict_flag,
|
||||
dict(self.department_votes),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if assessment.status != "ok":
|
||||
self.risk_round = self.host.ensure_round(
|
||||
self.host_trace,
|
||||
@ -398,12 +486,28 @@ class DecisionWorkflow:
|
||||
)
|
||||
self.host.finalize_round(self.execution_round)
|
||||
self.execution_round.notes.setdefault("target_weight", exec_weight)
|
||||
LOGGER.info(
|
||||
"执行阶段结论 ts_code=%s final_action=%s original=%s status=%s target_weight=%.4f review=%s",
|
||||
self.context.ts_code,
|
||||
exec_action.value,
|
||||
action.value,
|
||||
exec_status,
|
||||
exec_weight,
|
||||
requires_review,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return exec_action
|
||||
|
||||
def _finalize_conflicts(self, exec_action: AgentAction) -> None:
|
||||
self.host.close(self.host_trace)
|
||||
self.belief_revision = revise_beliefs(self.belief_updates, exec_action)
|
||||
if self.belief_revision.conflicts:
|
||||
LOGGER.warning(
|
||||
"发现信念冲突 ts_code=%s conflicts=%s",
|
||||
self.context.ts_code,
|
||||
self.belief_revision.conflicts,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
conflict_round = self.host.ensure_round(
|
||||
self.host_trace,
|
||||
agenda="conflict_resolution",
|
||||
@ -436,15 +540,35 @@ def decide(
|
||||
department_manager: Optional[DepartmentManager] = None,
|
||||
department_context: Optional[DepartmentContext] = None,
|
||||
) -> Decision:
|
||||
agent_list = list(agents)
|
||||
LOGGER.debug(
|
||||
"进入多智能体决策 ts_code=%s trade_date=%s agents=%s method=%s",
|
||||
context.ts_code,
|
||||
context.trade_date,
|
||||
len(agent_list),
|
||||
method,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
workflow = DecisionWorkflow(
|
||||
context,
|
||||
agents,
|
||||
agent_list,
|
||||
weights,
|
||||
method,
|
||||
department_manager,
|
||||
department_context,
|
||||
)
|
||||
return workflow.run()
|
||||
decision = workflow.run()
|
||||
LOGGER.info(
|
||||
"完成多智能体决策 ts_code=%s trade_date=%s action=%s confidence=%.3f review=%s method=%s",
|
||||
context.ts_code,
|
||||
context.trade_date,
|
||||
decision.action.value,
|
||||
decision.confidence,
|
||||
decision.requires_review,
|
||||
method,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return decision
|
||||
|
||||
|
||||
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
"""Risk agent acts as leader with veto rights."""
|
||||
from __future__ import annotations
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
from .base import Agent, AgentAction, AgentContext
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_risk"}
|
||||
|
||||
|
||||
class RiskRecommendation:
|
||||
"""Represents structured recommendation from the risk agent."""
|
||||
@ -68,21 +73,44 @@ class RiskAgent(Agent):
|
||||
features = dict(context.features or {})
|
||||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||||
|
||||
def finalize(
|
||||
recommendation: RiskRecommendation,
|
||||
trigger: str,
|
||||
) -> RiskRecommendation:
|
||||
LOGGER.debug(
|
||||
"风险代理评估 ts_code=%s action=%s status=%s reason=%s trigger=%s penalty=%.3f conflict=%s",
|
||||
context.ts_code,
|
||||
decision_action.value,
|
||||
recommendation.status,
|
||||
recommendation.reason,
|
||||
trigger,
|
||||
risk_penalty,
|
||||
conflict_flag,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return recommendation
|
||||
|
||||
if bool(features.get("is_suspended")):
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="suspended",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "is_suspended"},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="suspended",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "is_suspended"},
|
||||
),
|
||||
"is_suspended",
|
||||
)
|
||||
|
||||
if bool(features.get("is_blacklisted")):
|
||||
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="blacklist",
|
||||
recommended_action=fallback,
|
||||
notes={"trigger": "is_blacklisted"},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="blacklist",
|
||||
recommended_action=fallback,
|
||||
notes={"trigger": "is_blacklisted"},
|
||||
),
|
||||
"is_blacklisted",
|
||||
)
|
||||
|
||||
if bool(features.get("limit_up")) and decision_action in {
|
||||
@ -90,22 +118,28 @@ class RiskAgent(Agent):
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="limit_up",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "limit_up"},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="limit_up",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"trigger": "limit_up"},
|
||||
),
|
||||
"limit_up",
|
||||
)
|
||||
|
||||
if bool(features.get("position_limit")) and decision_action in {
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="position_limit",
|
||||
recommended_action=AgentAction.BUY_S,
|
||||
notes={"trigger": "position_limit"},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="position_limit",
|
||||
recommended_action=AgentAction.BUY_S,
|
||||
notes={"trigger": "position_limit"},
|
||||
),
|
||||
"position_limit",
|
||||
)
|
||||
|
||||
if risk_penalty >= 0.9 and decision_action in {
|
||||
@ -113,28 +147,40 @@ class RiskAgent(Agent):
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="risk_penalty_extreme",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="blocked",
|
||||
reason="risk_penalty_extreme",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
),
|
||||
"risk_penalty_extreme",
|
||||
)
|
||||
if risk_penalty >= 0.7 and decision_action in {
|
||||
AgentAction.BUY_S,
|
||||
AgentAction.BUY_M,
|
||||
AgentAction.BUY_L,
|
||||
}:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="risk_penalty_high",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="risk_penalty_high",
|
||||
recommended_action=AgentAction.HOLD,
|
||||
notes={"risk_penalty": risk_penalty},
|
||||
),
|
||||
"risk_penalty_high",
|
||||
)
|
||||
|
||||
if conflict_flag:
|
||||
return RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="conflict_threshold",
|
||||
return finalize(
|
||||
RiskRecommendation(
|
||||
status="pending_review",
|
||||
reason="conflict_threshold",
|
||||
),
|
||||
"conflict_threshold",
|
||||
)
|
||||
|
||||
return RiskRecommendation(status="ok", reason="clear")
|
||||
return finalize(
|
||||
RiskRecommendation(status="ok", reason="clear"),
|
||||
"clear",
|
||||
)
|
||||
|
||||
@ -2,14 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "data_ingest"}
|
||||
|
||||
|
||||
class JobLogger:
|
||||
"""任务记录器。"""
|
||||
"""任务记录器,通过数据库记录抓取作业运行情况。"""
|
||||
|
||||
def __init__(self, job_type: str) -> None:
|
||||
"""初始化任务记录器。
|
||||
@ -28,17 +31,36 @@ class JobLogger:
|
||||
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
|
||||
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
""",
|
||||
(self.job_type,)
|
||||
(self.job_type,),
|
||||
)
|
||||
self.job_id = cursor.lastrowid
|
||||
session.commit()
|
||||
LOGGER.info(
|
||||
"抓取任务启动 job_type=%s job_id=%s",
|
||||
self.job_type,
|
||||
self.job_id,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""结束任务记录。"""
|
||||
if exc_val:
|
||||
LOGGER.exception(
|
||||
"抓取任务失败 job_type=%s job_id=%s err=%s",
|
||||
self.job_type,
|
||||
self.job_id,
|
||||
exc_val,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
self.update_status("failed", str(exc_val))
|
||||
else:
|
||||
LOGGER.info(
|
||||
"抓取任务完成 job_type=%s job_id=%s",
|
||||
self.job_type,
|
||||
self.job_id,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
self.update_status("success")
|
||||
|
||||
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
|
||||
@ -49,6 +71,7 @@ class JobLogger:
|
||||
error_msg: 错误信息(如果有)
|
||||
"""
|
||||
if not self.job_id:
|
||||
LOGGER.debug("忽略无效任务状态更新 job_type=%s status=%s", self.job_type, status, extra=LOG_EXTRA)
|
||||
return
|
||||
|
||||
with db_session() as session:
|
||||
@ -60,9 +83,17 @@ class JobLogger:
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
""",
|
||||
(status, error_msg, self.job_id)
|
||||
(status, error_msg, self.job_id),
|
||||
)
|
||||
session.commit()
|
||||
LOGGER.debug(
|
||||
"更新任务状态 job_type=%s job_id=%s status=%s error=%s",
|
||||
self.job_type,
|
||||
self.job_id,
|
||||
status,
|
||||
error_msg,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
||||
"""更新任务元数据。
|
||||
@ -71,6 +102,11 @@ class JobLogger:
|
||||
metadata: 元数据字典
|
||||
"""
|
||||
if not self.job_id:
|
||||
LOGGER.debug(
|
||||
"忽略元数据更新(尚未初始化) job_type=%s",
|
||||
self.job_type,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return
|
||||
|
||||
with db_session() as session:
|
||||
@ -80,6 +116,13 @@ class JobLogger:
|
||||
SET metadata = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(json.dumps(metadata), self.job_id)
|
||||
(json.dumps(metadata), self.job_id),
|
||||
)
|
||||
session.commit()
|
||||
LOGGER.debug(
|
||||
"记录任务元数据 job_type=%s job_id=%s keys=%s",
|
||||
self.job_type,
|
||||
self.job_id,
|
||||
sorted(metadata.keys()),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
@ -7,6 +7,10 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Tuple
|
||||
import numpy as np
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_env"}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,6 +30,13 @@ class DecisionEnvAdapter:
|
||||
else:
|
||||
self._keys = list(self.observation_keys)
|
||||
self._last_reset_obs = None
|
||||
LOGGER.debug(
|
||||
"初始化 DecisionEnvAdapter obs_dim=%s action_dim=%s keys=%s",
|
||||
len(self._keys),
|
||||
self.env.action_dim,
|
||||
self._keys,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
@ -38,12 +49,24 @@ class DecisionEnvAdapter:
|
||||
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
|
||||
raw = self.env.reset()
|
||||
self._last_reset_obs = raw
|
||||
LOGGER.debug(
|
||||
"环境重置完成 episode=%s",
|
||||
raw.get("episode"),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self._to_array(raw), raw
|
||||
|
||||
def step(
|
||||
self, action: Sequence[float]
|
||||
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
|
||||
obs_dict, reward, done, info = self.env.step(action)
|
||||
LOGGER.debug(
|
||||
"环境执行动作 action=%s reward=%.4f done=%s",
|
||||
[round(float(a), 4) for a in action],
|
||||
reward,
|
||||
done,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self._to_array(obs_dict), reward, done, info, obs_dict
|
||||
|
||||
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:
|
||||
|
||||
@ -10,8 +10,13 @@ import torch
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
from .adapters import DecisionEnvAdapter
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "rl_ppo"}
|
||||
|
||||
|
||||
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
|
||||
if isinstance(layer, nn.Linear):
|
||||
@ -168,6 +173,15 @@ class PPOTrainer:
|
||||
if config.seed is not None:
|
||||
torch.manual_seed(config.seed)
|
||||
np.random.seed(config.seed)
|
||||
LOGGER.info(
|
||||
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
|
||||
obs_dim,
|
||||
action_dim,
|
||||
config.total_timesteps,
|
||||
config.rollout_steps,
|
||||
config.device,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
def train(self) -> TrainingSummary:
|
||||
cfg = self.config
|
||||
@ -180,6 +194,14 @@ class PPOTrainer:
|
||||
diagnostics: List[Dict[str, float]] = []
|
||||
current_return = 0.0
|
||||
current_length = 0
|
||||
LOGGER.info(
|
||||
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
|
||||
cfg.total_timesteps,
|
||||
cfg.rollout_steps,
|
||||
cfg.epochs,
|
||||
cfg.minibatch_size,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
while timesteps < cfg.total_timesteps:
|
||||
rollout.reset()
|
||||
@ -203,6 +225,14 @@ class PPOTrainer:
|
||||
if done:
|
||||
episode_rewards.append(current_return)
|
||||
episode_lengths.append(current_length)
|
||||
LOGGER.info(
|
||||
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
|
||||
episode_rewards[-1],
|
||||
episode_lengths[-1],
|
||||
len(episode_rewards),
|
||||
timesteps,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
current_return = 0.0
|
||||
current_length = 0
|
||||
next_obs_array, _ = self.adapter.reset()
|
||||
@ -216,7 +246,17 @@ class PPOTrainer:
|
||||
with torch.no_grad():
|
||||
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
|
||||
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
|
||||
LOGGER.debug(
|
||||
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
|
||||
rollout._pos,
|
||||
timesteps,
|
||||
cfg.total_timesteps - timesteps,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
last_policy_loss = None
|
||||
last_value_loss = None
|
||||
last_entropy = None
|
||||
for _ in range(cfg.epochs):
|
||||
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
|
||||
cfg.minibatch_size
|
||||
@ -241,6 +281,9 @@ class PPOTrainer:
|
||||
value_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
|
||||
self.value_optimizer.step()
|
||||
last_policy_loss = float(policy_loss.detach().cpu())
|
||||
last_value_loss = float(value_loss.detach().cpu())
|
||||
last_entropy = float(entropy.mean().detach().cpu())
|
||||
|
||||
diagnostics.append(
|
||||
{
|
||||
@ -249,13 +292,30 @@ class PPOTrainer:
|
||||
"entropy": float(entropy.mean().detach().cpu()),
|
||||
}
|
||||
)
|
||||
LOGGER.info(
|
||||
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
|
||||
timesteps,
|
||||
cfg.total_timesteps,
|
||||
last_policy_loss if last_policy_loss is not None else 0.0,
|
||||
last_value_loss if last_value_loss is not None else 0.0,
|
||||
last_entropy if last_entropy is not None else 0.0,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
return TrainingSummary(
|
||||
summary = TrainingSummary(
|
||||
timesteps=timesteps,
|
||||
episode_rewards=episode_rewards,
|
||||
episode_lengths=episode_lengths,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
LOGGER.info(
|
||||
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
|
||||
summary.timesteps,
|
||||
len(summary.episode_rewards),
|
||||
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
|
||||
|
||||
@ -5,6 +5,10 @@ from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Mapping, Optional, Sequence
|
||||
|
||||
from .data_access import DataBroker
|
||||
from .logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "feature_snapshot"}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -15,6 +19,11 @@ class FeatureSnapshotService:
|
||||
|
||||
def __init__(self, broker: Optional[DataBroker] = None) -> None:
|
||||
self.broker = broker or DataBroker()
|
||||
LOGGER.debug(
|
||||
"初始化特征快照服务 broker=%s",
|
||||
type(self.broker).__name__,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
def load_latest(
|
||||
self,
|
||||
@ -27,13 +36,34 @@ class FeatureSnapshotService:
|
||||
"""Fetch a snapshot of feature values for the given universe."""
|
||||
|
||||
if not ts_codes:
|
||||
LOGGER.debug(
|
||||
"跳过快照加载(标的为空) trade_date=%s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return {}
|
||||
return self.broker.fetch_batch_latest(
|
||||
field_count = len(fields)
|
||||
LOGGER.debug(
|
||||
"加载特征快照 trade_date=%s universe=%s fields=%s auto_refresh=%s",
|
||||
trade_date,
|
||||
len(ts_codes),
|
||||
field_count,
|
||||
auto_refresh,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
snapshot = self.broker.fetch_batch_latest(
|
||||
list(ts_codes),
|
||||
trade_date,
|
||||
fields,
|
||||
auto_refresh=auto_refresh,
|
||||
)
|
||||
LOGGER.debug(
|
||||
"特征快照加载完成 trade_date=%s universe=%s",
|
||||
trade_date,
|
||||
len(snapshot),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return snapshot
|
||||
|
||||
def load_single(
|
||||
self,
|
||||
@ -45,14 +75,30 @@ class FeatureSnapshotService:
|
||||
) -> Mapping[str, object]:
|
||||
"""Convenience wrapper to reuse the snapshot logic for a single symbol."""
|
||||
|
||||
field_list = list(fields)
|
||||
LOGGER.debug(
|
||||
"加载单标的快照 trade_date=%s ts_code=%s fields=%s auto_refresh=%s",
|
||||
trade_date,
|
||||
ts_code,
|
||||
len(field_list),
|
||||
auto_refresh,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
snapshot = self.load_latest(
|
||||
trade_date,
|
||||
list(fields),
|
||||
field_list,
|
||||
[ts_code],
|
||||
auto_refresh=auto_refresh,
|
||||
)
|
||||
return snapshot.get(ts_code, {})
|
||||
result = snapshot.get(ts_code, {})
|
||||
if not result:
|
||||
LOGGER.debug(
|
||||
"单标的快照为空 trade_date=%s ts_code=%s",
|
||||
trade_date,
|
||||
ts_code,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["FeatureSnapshotService"]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user