add comprehensive logging to decision workflow and risk assessment

This commit is contained in:
sam 2025-10-17 18:37:08 +08:00
parent 1ca2f2be19
commit 4b68d84b3c
7 changed files with 416 additions and 60 deletions

View File

@ -4,8 +4,13 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional
from app.utils.logging import get_logger
from .base import AgentAction from .base import AgentAction
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_belief"}
@dataclass @dataclass
class BeliefRevisionResult: class BeliefRevisionResult:
@ -46,12 +51,21 @@ def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: Ag
"votes": {action.value: count for action, count in action_votes.items()}, "votes": {action.value: count for action, count in action_votes.items()},
"reasons": reasons, "reasons": reasons,
} }
return BeliefRevisionResult( result = BeliefRevisionResult(
consensus_action=consensus_action, consensus_action=consensus_action,
consensus_confidence=consensus_confidence, consensus_confidence=consensus_confidence,
conflicts=conflicts, conflicts=conflicts,
notes=notes, notes=notes,
) )
LOGGER.debug(
"信念修正完成 consensus=%s confidence=%.3f conflicts=%s vote_counts=%s",
result.consensus_action.value if result.consensus_action else None,
result.consensus_confidence,
result.conflicts,
notes["votes"],
extra=LOG_EXTRA,
)
return result
# avoid circular import typing # avoid circular import typing

View File

@ -5,6 +5,8 @@ from dataclasses import dataclass, field
from math import log from math import log
from typing import Dict, Iterable, List, Mapping, Optional, Tuple from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from app.utils.logging import get_logger
from .base import Agent, AgentAction, AgentContext, UtilityMatrix from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map from .registry import weight_map
@ -20,6 +22,10 @@ from .protocols import (
) )
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_workflow"}
ACTIONS: Tuple[AgentAction, ...] = ( ACTIONS: Tuple[AgentAction, ...] = (
AgentAction.SELL, AgentAction.SELL,
AgentAction.HOLD, AgentAction.HOLD,
@ -188,9 +194,30 @@ class DecisionWorkflow:
self.norm_weights: Dict[str, float] = {} self.norm_weights: Dict[str, float] = {}
self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {} self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {}
self.belief_revision: Optional[BeliefRevisionResult] = None self.belief_revision: Optional[BeliefRevisionResult] = None
LOGGER.debug(
"初始化决策流程 ts_code=%s trade_date=%s method=%s agents=%s departments=%s",
context.ts_code,
context.trade_date,
method,
len(self.agent_list),
bool(self.department_manager),
extra=LOG_EXTRA,
)
def run(self) -> Decision: def run(self) -> Decision:
LOGGER.debug(
"执行决策流程 ts_code=%s method=%s feasible=%s",
self.context.ts_code,
self.method,
[action.value for action in self.feasible_actions],
extra=LOG_EXTRA,
)
if not self.feasible_actions: if not self.feasible_actions:
LOGGER.warning(
"无可行动作,回退到 HOLD ts_code=%s",
self.context.ts_code,
extra=LOG_EXTRA,
)
return Decision( return Decision(
action=AgentAction.HOLD, action=AgentAction.HOLD,
confidence=0.0, confidence=0.0,
@ -201,6 +228,13 @@ class DecisionWorkflow:
self._evaluate_departments() self._evaluate_departments()
action, confidence = self._select_action() action, confidence = self._select_action()
LOGGER.debug(
"初步动作选择完成 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
risk_assessment = self._apply_risk(action) risk_assessment = self._apply_risk(action)
exec_action = self._finalize_execution(action, risk_assessment) exec_action = self._finalize_execution(action, risk_assessment)
self._finalize_conflicts(exec_action) self._finalize_conflicts(exec_action)
@ -210,7 +244,7 @@ class DecisionWorkflow:
self.department_votes, self.department_votes,
) )
return Decision( decision = Decision(
action=action, action=action,
confidence=confidence, confidence=confidence,
target_weight=target_weight_for_action(action), target_weight=target_weight_for_action(action),
@ -224,6 +258,16 @@ class DecisionWorkflow:
belief_updates=self.belief_updates, belief_updates=self.belief_updates,
belief_revision=self.belief_revision, belief_revision=self.belief_revision,
) )
LOGGER.info(
"决策完成 ts_code=%s action=%s confidence=%.3f review=%s risk_status=%s",
self.context.ts_code,
decision.action.value,
decision.confidence,
decision.requires_review,
risk_assessment.status,
extra=LOG_EXTRA,
)
return decision
def _evaluate_departments(self) -> None: def _evaluate_departments(self) -> None:
if not self.department_manager: if not self.department_manager:
@ -236,7 +280,19 @@ class DecisionWorkflow:
market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}), market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}),
raw=dict(getattr(self.context, "raw", {}) or {}), raw=dict(getattr(self.context, "raw", {}) or {}),
) )
LOGGER.debug(
"开始部门评估 ts_code=%s departments=%s",
self.context.ts_code,
list(self.department_manager.agents.keys()),
extra=LOG_EXTRA,
)
self.department_decisions = self.department_manager.evaluate(dept_context) self.department_decisions = self.department_manager.evaluate(dept_context)
LOGGER.debug(
"部门评估完成 ts_code=%s decisions=%s",
self.context.ts_code,
list(self.department_decisions.keys()),
extra=LOG_EXTRA,
)
if self.department_decisions: if self.department_decisions:
self.department_round = self.host.start_round( self.department_round = self.host.start_round(
self.host_trace, self.host_trace,
@ -285,11 +341,32 @@ class DecisionWorkflow:
) )
if self.method == "vote": if self.method == "vote":
return vote(self.filtered_utilities, self.norm_weights) action, confidence = vote(self.filtered_utilities, self.norm_weights)
LOGGER.debug(
"采用投票机制 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
return action, confidence
action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores) action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores)
if action not in self.feasible_actions: if action not in self.feasible_actions:
return vote(self.filtered_utilities, self.norm_weights) LOGGER.debug(
"纳什解不可行,改用投票 ts_code=%s invalid_action=%s",
self.context.ts_code,
action.value,
extra=LOG_EXTRA,
)
action, confidence = vote(self.filtered_utilities, self.norm_weights)
LOGGER.debug(
"纳什解计算完成 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
return action, confidence return action, confidence
def _apply_risk(self, action: AgentAction) -> RiskAssessment: def _apply_risk(self, action: AgentAction) -> RiskAssessment:
@ -307,6 +384,17 @@ class DecisionWorkflow:
self.department_round.outcome = action.value self.department_round.outcome = action.value
self.host.finalize_round(self.department_round) self.host.finalize_round(self.department_round)
LOGGER.debug(
"风险评估结果 ts_code=%s action=%s status=%s reason=%s conflict=%s votes=%s",
self.context.ts_code,
action.value,
assessment.status,
assessment.reason,
conflict_flag,
dict(self.department_votes),
extra=LOG_EXTRA,
)
if assessment.status != "ok": if assessment.status != "ok":
self.risk_round = self.host.ensure_round( self.risk_round = self.host.ensure_round(
self.host_trace, self.host_trace,
@ -398,12 +486,28 @@ class DecisionWorkflow:
) )
self.host.finalize_round(self.execution_round) self.host.finalize_round(self.execution_round)
self.execution_round.notes.setdefault("target_weight", exec_weight) self.execution_round.notes.setdefault("target_weight", exec_weight)
LOGGER.info(
"执行阶段结论 ts_code=%s final_action=%s original=%s status=%s target_weight=%.4f review=%s",
self.context.ts_code,
exec_action.value,
action.value,
exec_status,
exec_weight,
requires_review,
extra=LOG_EXTRA,
)
return exec_action return exec_action
def _finalize_conflicts(self, exec_action: AgentAction) -> None: def _finalize_conflicts(self, exec_action: AgentAction) -> None:
self.host.close(self.host_trace) self.host.close(self.host_trace)
self.belief_revision = revise_beliefs(self.belief_updates, exec_action) self.belief_revision = revise_beliefs(self.belief_updates, exec_action)
if self.belief_revision.conflicts: if self.belief_revision.conflicts:
LOGGER.warning(
"发现信念冲突 ts_code=%s conflicts=%s",
self.context.ts_code,
self.belief_revision.conflicts,
extra=LOG_EXTRA,
)
conflict_round = self.host.ensure_round( conflict_round = self.host.ensure_round(
self.host_trace, self.host_trace,
agenda="conflict_resolution", agenda="conflict_resolution",
@ -436,15 +540,35 @@ def decide(
department_manager: Optional[DepartmentManager] = None, department_manager: Optional[DepartmentManager] = None,
department_context: Optional[DepartmentContext] = None, department_context: Optional[DepartmentContext] = None,
) -> Decision: ) -> Decision:
agent_list = list(agents)
LOGGER.debug(
"进入多智能体决策 ts_code=%s trade_date=%s agents=%s method=%s",
context.ts_code,
context.trade_date,
len(agent_list),
method,
extra=LOG_EXTRA,
)
workflow = DecisionWorkflow( workflow = DecisionWorkflow(
context, context,
agents, agent_list,
weights, weights,
method, method,
department_manager, department_manager,
department_context, department_context,
) )
return workflow.run() decision = workflow.run()
LOGGER.info(
"完成多智能体决策 ts_code=%s trade_date=%s action=%s confidence=%.3f review=%s method=%s",
context.ts_code,
context.trade_date,
decision.action.value,
decision.confidence,
decision.requires_review,
method,
extra=LOG_EXTRA,
)
return decision
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]: def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:

View File

@ -1,8 +1,13 @@
"""Risk agent acts as leader with veto rights.""" """Risk agent acts as leader with veto rights."""
from __future__ import annotations from __future__ import annotations
from app.utils.logging import get_logger
from .base import Agent, AgentAction, AgentContext from .base import Agent, AgentAction, AgentContext
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_risk"}
class RiskRecommendation: class RiskRecommendation:
"""Represents structured recommendation from the risk agent.""" """Represents structured recommendation from the risk agent."""
@ -68,21 +73,44 @@ class RiskAgent(Agent):
features = dict(context.features or {}) features = dict(context.features or {})
risk_penalty = float(features.get("risk_penalty") or 0.0) risk_penalty = float(features.get("risk_penalty") or 0.0)
def finalize(
recommendation: RiskRecommendation,
trigger: str,
) -> RiskRecommendation:
LOGGER.debug(
"风险代理评估 ts_code=%s action=%s status=%s reason=%s trigger=%s penalty=%.3f conflict=%s",
context.ts_code,
decision_action.value,
recommendation.status,
recommendation.reason,
trigger,
risk_penalty,
conflict_flag,
extra=LOG_EXTRA,
)
return recommendation
if bool(features.get("is_suspended")): if bool(features.get("is_suspended")):
return RiskRecommendation( return finalize(
status="blocked", RiskRecommendation(
reason="suspended", status="blocked",
recommended_action=AgentAction.HOLD, reason="suspended",
notes={"trigger": "is_suspended"}, recommended_action=AgentAction.HOLD,
notes={"trigger": "is_suspended"},
),
"is_suspended",
) )
if bool(features.get("is_blacklisted")): if bool(features.get("is_blacklisted")):
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
return RiskRecommendation( return finalize(
status="blocked", RiskRecommendation(
reason="blacklist", status="blocked",
recommended_action=fallback, reason="blacklist",
notes={"trigger": "is_blacklisted"}, recommended_action=fallback,
notes={"trigger": "is_blacklisted"},
),
"is_blacklisted",
) )
if bool(features.get("limit_up")) and decision_action in { if bool(features.get("limit_up")) and decision_action in {
@ -90,22 +118,28 @@ class RiskAgent(Agent):
AgentAction.BUY_M, AgentAction.BUY_M,
AgentAction.BUY_L, AgentAction.BUY_L,
}: }:
return RiskRecommendation( return finalize(
status="blocked", RiskRecommendation(
reason="limit_up", status="blocked",
recommended_action=AgentAction.HOLD, reason="limit_up",
notes={"trigger": "limit_up"}, recommended_action=AgentAction.HOLD,
notes={"trigger": "limit_up"},
),
"limit_up",
) )
if bool(features.get("position_limit")) and decision_action in { if bool(features.get("position_limit")) and decision_action in {
AgentAction.BUY_M, AgentAction.BUY_M,
AgentAction.BUY_L, AgentAction.BUY_L,
}: }:
return RiskRecommendation( return finalize(
status="pending_review", RiskRecommendation(
reason="position_limit", status="pending_review",
recommended_action=AgentAction.BUY_S, reason="position_limit",
notes={"trigger": "position_limit"}, recommended_action=AgentAction.BUY_S,
notes={"trigger": "position_limit"},
),
"position_limit",
) )
if risk_penalty >= 0.9 and decision_action in { if risk_penalty >= 0.9 and decision_action in {
@ -113,28 +147,40 @@ class RiskAgent(Agent):
AgentAction.BUY_M, AgentAction.BUY_M,
AgentAction.BUY_L, AgentAction.BUY_L,
}: }:
return RiskRecommendation( return finalize(
status="blocked", RiskRecommendation(
reason="risk_penalty_extreme", status="blocked",
recommended_action=AgentAction.HOLD, reason="risk_penalty_extreme",
notes={"risk_penalty": risk_penalty}, recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
),
"risk_penalty_extreme",
) )
if risk_penalty >= 0.7 and decision_action in { if risk_penalty >= 0.7 and decision_action in {
AgentAction.BUY_S, AgentAction.BUY_S,
AgentAction.BUY_M, AgentAction.BUY_M,
AgentAction.BUY_L, AgentAction.BUY_L,
}: }:
return RiskRecommendation( return finalize(
status="pending_review", RiskRecommendation(
reason="risk_penalty_high", status="pending_review",
recommended_action=AgentAction.HOLD, reason="risk_penalty_high",
notes={"risk_penalty": risk_penalty}, recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
),
"risk_penalty_high",
) )
if conflict_flag: if conflict_flag:
return RiskRecommendation( return finalize(
status="pending_review", RiskRecommendation(
reason="conflict_threshold", status="pending_review",
reason="conflict_threshold",
),
"conflict_threshold",
) )
return RiskRecommendation(status="ok", reason="clear") return finalize(
RiskRecommendation(status="ok", reason="clear"),
"clear",
)

View File

@ -2,24 +2,27 @@
from __future__ import annotations from __future__ import annotations
import json import json
from datetime import datetime
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_ingest"}
class JobLogger: class JobLogger:
"""任务记录器。""" """任务记录器,通过数据库记录抓取作业运行情况"""
def __init__(self, job_type: str) -> None: def __init__(self, job_type: str) -> None:
"""初始化任务记录器。 """初始化任务记录器。
Args: Args:
job_type: 任务类型 job_type: 任务类型
""" """
self.job_type = job_type self.job_type = job_type
self.job_id: Optional[int] = None self.job_id: Optional[int] = None
def __enter__(self) -> "JobLogger": def __enter__(self) -> "JobLogger":
"""开始记录任务。""" """开始记录任务。"""
with db_session() as session: with db_session() as session:
@ -28,29 +31,49 @@ class JobLogger:
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at) INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""", """,
(self.job_type,) (self.job_type,),
) )
self.job_id = cursor.lastrowid self.job_id = cursor.lastrowid
session.commit() session.commit()
LOGGER.info(
"抓取任务启动 job_type=%s job_id=%s",
self.job_type,
self.job_id,
extra=LOG_EXTRA,
)
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""结束任务记录。""" """结束任务记录。"""
if exc_val: if exc_val:
LOGGER.exception(
"抓取任务失败 job_type=%s job_id=%s err=%s",
self.job_type,
self.job_id,
exc_val,
extra=LOG_EXTRA,
)
self.update_status("failed", str(exc_val)) self.update_status("failed", str(exc_val))
else: else:
LOGGER.info(
"抓取任务完成 job_type=%s job_id=%s",
self.job_type,
self.job_id,
extra=LOG_EXTRA,
)
self.update_status("success") self.update_status("success")
def update_status(self, status: str, error_msg: Optional[str] = None) -> None: def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
"""更新任务状态。 """更新任务状态。
Args: Args:
status: 新状态 status: 新状态
error_msg: 错误信息如果有 error_msg: 错误信息如果有
""" """
if not self.job_id: if not self.job_id:
LOGGER.debug("忽略无效任务状态更新 job_type=%s status=%s", self.job_type, status, extra=LOG_EXTRA)
return return
with db_session() as session: with db_session() as session:
session.execute( session.execute(
""" """
@ -60,19 +83,32 @@ class JobLogger:
updated_at = CURRENT_TIMESTAMP updated_at = CURRENT_TIMESTAMP
WHERE id = ? WHERE id = ?
""", """,
(status, error_msg, self.job_id) (status, error_msg, self.job_id),
) )
session.commit() session.commit()
LOGGER.debug(
"更新任务状态 job_type=%s job_id=%s status=%s error=%s",
self.job_type,
self.job_id,
status,
error_msg,
extra=LOG_EXTRA,
)
def update_metadata(self, metadata: Dict[str, Any]) -> None: def update_metadata(self, metadata: Dict[str, Any]) -> None:
"""更新任务元数据。 """更新任务元数据。
Args: Args:
metadata: 元数据字典 metadata: 元数据字典
""" """
if not self.job_id: if not self.job_id:
LOGGER.debug(
"忽略元数据更新(尚未初始化) job_type=%s",
self.job_type,
extra=LOG_EXTRA,
)
return return
with db_session() as session: with db_session() as session:
session.execute( session.execute(
""" """
@ -80,6 +116,13 @@ class JobLogger:
SET metadata = ? SET metadata = ?
WHERE id = ? WHERE id = ?
""", """,
(json.dumps(metadata), self.job_id) (json.dumps(metadata), self.job_id),
) )
session.commit() session.commit()
LOGGER.debug(
"记录任务元数据 job_type=%s job_id=%s keys=%s",
self.job_type,
self.job_id,
sorted(metadata.keys()),
extra=LOG_EXTRA,
)

View File

@ -7,6 +7,10 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Tuple
import numpy as np import numpy as np
from app.backtest.decision_env import DecisionEnv from app.backtest.decision_env import DecisionEnv
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_env"}
@dataclass @dataclass
@ -26,6 +30,13 @@ class DecisionEnvAdapter:
else: else:
self._keys = list(self.observation_keys) self._keys = list(self.observation_keys)
self._last_reset_obs = None self._last_reset_obs = None
LOGGER.debug(
"初始化 DecisionEnvAdapter obs_dim=%s action_dim=%s keys=%s",
len(self._keys),
self.env.action_dim,
self._keys,
extra=LOG_EXTRA,
)
@property @property
def action_dim(self) -> int: def action_dim(self) -> int:
@ -38,12 +49,24 @@ class DecisionEnvAdapter:
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]: def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
raw = self.env.reset() raw = self.env.reset()
self._last_reset_obs = raw self._last_reset_obs = raw
LOGGER.debug(
"环境重置完成 episode=%s",
raw.get("episode"),
extra=LOG_EXTRA,
)
return self._to_array(raw), raw return self._to_array(raw), raw
def step( def step(
self, action: Sequence[float] self, action: Sequence[float]
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]: ) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
obs_dict, reward, done, info = self.env.step(action) obs_dict, reward, done, info = self.env.step(action)
LOGGER.debug(
"环境执行动作 action=%s reward=%.4f done=%s",
[round(float(a), 4) for a in action],
reward,
done,
extra=LOG_EXTRA,
)
return self._to_array(obs_dict), reward, done, info, obs_dict return self._to_array(obs_dict), reward, done, info, obs_dict
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray: def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:

View File

@ -10,8 +10,13 @@ import torch
from torch import nn from torch import nn
from torch.distributions import Beta from torch.distributions import Beta
from app.utils.logging import get_logger
from .adapters import DecisionEnvAdapter from .adapters import DecisionEnvAdapter
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "rl_ppo"}
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module: def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
@ -168,6 +173,15 @@ class PPOTrainer:
if config.seed is not None: if config.seed is not None:
torch.manual_seed(config.seed) torch.manual_seed(config.seed)
np.random.seed(config.seed) np.random.seed(config.seed)
LOGGER.info(
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
obs_dim,
action_dim,
config.total_timesteps,
config.rollout_steps,
config.device,
extra=LOG_EXTRA,
)
def train(self) -> TrainingSummary: def train(self) -> TrainingSummary:
cfg = self.config cfg = self.config
@ -180,6 +194,14 @@ class PPOTrainer:
diagnostics: List[Dict[str, float]] = [] diagnostics: List[Dict[str, float]] = []
current_return = 0.0 current_return = 0.0
current_length = 0 current_length = 0
LOGGER.info(
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
cfg.total_timesteps,
cfg.rollout_steps,
cfg.epochs,
cfg.minibatch_size,
extra=LOG_EXTRA,
)
while timesteps < cfg.total_timesteps: while timesteps < cfg.total_timesteps:
rollout.reset() rollout.reset()
@ -203,6 +225,14 @@ class PPOTrainer:
if done: if done:
episode_rewards.append(current_return) episode_rewards.append(current_return)
episode_lengths.append(current_length) episode_lengths.append(current_length)
LOGGER.info(
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
episode_rewards[-1],
episode_lengths[-1],
len(episode_rewards),
timesteps,
extra=LOG_EXTRA,
)
current_return = 0.0 current_return = 0.0
current_length = 0 current_length = 0
next_obs_array, _ = self.adapter.reset() next_obs_array, _ = self.adapter.reset()
@ -216,7 +246,17 @@ class PPOTrainer:
with torch.no_grad(): with torch.no_grad():
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item() next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda) rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
LOGGER.debug(
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
rollout._pos,
timesteps,
cfg.total_timesteps - timesteps,
extra=LOG_EXTRA,
)
last_policy_loss = None
last_value_loss = None
last_entropy = None
for _ in range(cfg.epochs): for _ in range(cfg.epochs):
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches( for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
cfg.minibatch_size cfg.minibatch_size
@ -241,6 +281,9 @@ class PPOTrainer:
value_loss.backward() value_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm) nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
self.value_optimizer.step() self.value_optimizer.step()
last_policy_loss = float(policy_loss.detach().cpu())
last_value_loss = float(value_loss.detach().cpu())
last_entropy = float(entropy.mean().detach().cpu())
diagnostics.append( diagnostics.append(
{ {
@ -249,13 +292,30 @@ class PPOTrainer:
"entropy": float(entropy.mean().detach().cpu()), "entropy": float(entropy.mean().detach().cpu()),
} }
) )
LOGGER.info(
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
timesteps,
cfg.total_timesteps,
last_policy_loss if last_policy_loss is not None else 0.0,
last_value_loss if last_value_loss is not None else 0.0,
last_entropy if last_entropy is not None else 0.0,
extra=LOG_EXTRA,
)
return TrainingSummary( summary = TrainingSummary(
timesteps=timesteps, timesteps=timesteps,
episode_rewards=episode_rewards, episode_rewards=episode_rewards,
episode_lengths=episode_lengths, episode_lengths=episode_lengths,
diagnostics=diagnostics, diagnostics=diagnostics,
) )
LOGGER.info(
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
summary.timesteps,
len(summary.episode_rewards),
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
extra=LOG_EXTRA,
)
return summary
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary: def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:

View File

@ -5,6 +5,10 @@ from dataclasses import dataclass
from typing import Dict, Iterable, Mapping, Optional, Sequence from typing import Dict, Iterable, Mapping, Optional, Sequence
from .data_access import DataBroker from .data_access import DataBroker
from .logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "feature_snapshot"}
@dataclass @dataclass
@ -15,6 +19,11 @@ class FeatureSnapshotService:
def __init__(self, broker: Optional[DataBroker] = None) -> None: def __init__(self, broker: Optional[DataBroker] = None) -> None:
self.broker = broker or DataBroker() self.broker = broker or DataBroker()
LOGGER.debug(
"初始化特征快照服务 broker=%s",
type(self.broker).__name__,
extra=LOG_EXTRA,
)
def load_latest( def load_latest(
self, self,
@ -27,13 +36,34 @@ class FeatureSnapshotService:
"""Fetch a snapshot of feature values for the given universe.""" """Fetch a snapshot of feature values for the given universe."""
if not ts_codes: if not ts_codes:
LOGGER.debug(
"跳过快照加载(标的为空) trade_date=%s",
trade_date,
extra=LOG_EXTRA,
)
return {} return {}
return self.broker.fetch_batch_latest( field_count = len(fields)
LOGGER.debug(
"加载特征快照 trade_date=%s universe=%s fields=%s auto_refresh=%s",
trade_date,
len(ts_codes),
field_count,
auto_refresh,
extra=LOG_EXTRA,
)
snapshot = self.broker.fetch_batch_latest(
list(ts_codes), list(ts_codes),
trade_date, trade_date,
fields, fields,
auto_refresh=auto_refresh, auto_refresh=auto_refresh,
) )
LOGGER.debug(
"特征快照加载完成 trade_date=%s universe=%s",
trade_date,
len(snapshot),
extra=LOG_EXTRA,
)
return snapshot
def load_single( def load_single(
self, self,
@ -45,14 +75,30 @@ class FeatureSnapshotService:
) -> Mapping[str, object]: ) -> Mapping[str, object]:
"""Convenience wrapper to reuse the snapshot logic for a single symbol.""" """Convenience wrapper to reuse the snapshot logic for a single symbol."""
field_list = list(fields)
LOGGER.debug(
"加载单标的快照 trade_date=%s ts_code=%s fields=%s auto_refresh=%s",
trade_date,
ts_code,
len(field_list),
auto_refresh,
extra=LOG_EXTRA,
)
snapshot = self.load_latest( snapshot = self.load_latest(
trade_date, trade_date,
list(fields), field_list,
[ts_code], [ts_code],
auto_refresh=auto_refresh, auto_refresh=auto_refresh,
) )
return snapshot.get(ts_code, {}) result = snapshot.get(ts_code, {})
if not result:
LOGGER.debug(
"单标的快照为空 trade_date=%s ts_code=%s",
trade_date,
ts_code,
extra=LOG_EXTRA,
)
return result
__all__ = ["FeatureSnapshotService"] __all__ = ["FeatureSnapshotService"]