add comprehensive logging to decision workflow and risk assessment

This commit is contained in:
sam 2025-10-17 18:37:08 +08:00
parent 1ca2f2be19
commit 4b68d84b3c
7 changed files with 416 additions and 60 deletions

View File

@ -4,8 +4,13 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional
from app.utils.logging import get_logger
from .base import AgentAction
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_belief"}
@dataclass
class BeliefRevisionResult:
@ -46,12 +51,21 @@ def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: Ag
"votes": {action.value: count for action, count in action_votes.items()},
"reasons": reasons,
}
return BeliefRevisionResult(
result = BeliefRevisionResult(
consensus_action=consensus_action,
consensus_confidence=consensus_confidence,
conflicts=conflicts,
notes=notes,
)
LOGGER.debug(
"信念修正完成 consensus=%s confidence=%.3f conflicts=%s vote_counts=%s",
result.consensus_action.value if result.consensus_action else None,
result.consensus_confidence,
result.conflicts,
notes["votes"],
extra=LOG_EXTRA,
)
return result
# avoid circular import typing

View File

@ -5,6 +5,8 @@ from dataclasses import dataclass, field
from math import log
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from app.utils.logging import get_logger
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
from .registry import weight_map
@ -20,6 +22,10 @@ from .protocols import (
)
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_workflow"}
ACTIONS: Tuple[AgentAction, ...] = (
AgentAction.SELL,
AgentAction.HOLD,
@ -188,9 +194,30 @@ class DecisionWorkflow:
self.norm_weights: Dict[str, float] = {}
self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {}
self.belief_revision: Optional[BeliefRevisionResult] = None
LOGGER.debug(
"初始化决策流程 ts_code=%s trade_date=%s method=%s agents=%s departments=%s",
context.ts_code,
context.trade_date,
method,
len(self.agent_list),
bool(self.department_manager),
extra=LOG_EXTRA,
)
def run(self) -> Decision:
LOGGER.debug(
"执行决策流程 ts_code=%s method=%s feasible=%s",
self.context.ts_code,
self.method,
[action.value for action in self.feasible_actions],
extra=LOG_EXTRA,
)
if not self.feasible_actions:
LOGGER.warning(
"无可行动作,回退到 HOLD ts_code=%s",
self.context.ts_code,
extra=LOG_EXTRA,
)
return Decision(
action=AgentAction.HOLD,
confidence=0.0,
@ -201,6 +228,13 @@ class DecisionWorkflow:
self._evaluate_departments()
action, confidence = self._select_action()
LOGGER.debug(
"初步动作选择完成 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
risk_assessment = self._apply_risk(action)
exec_action = self._finalize_execution(action, risk_assessment)
self._finalize_conflicts(exec_action)
@ -210,7 +244,7 @@ class DecisionWorkflow:
self.department_votes,
)
return Decision(
decision = Decision(
action=action,
confidence=confidence,
target_weight=target_weight_for_action(action),
@ -224,6 +258,16 @@ class DecisionWorkflow:
belief_updates=self.belief_updates,
belief_revision=self.belief_revision,
)
LOGGER.info(
"决策完成 ts_code=%s action=%s confidence=%.3f review=%s risk_status=%s",
self.context.ts_code,
decision.action.value,
decision.confidence,
decision.requires_review,
risk_assessment.status,
extra=LOG_EXTRA,
)
return decision
def _evaluate_departments(self) -> None:
if not self.department_manager:
@ -236,7 +280,19 @@ class DecisionWorkflow:
market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}),
raw=dict(getattr(self.context, "raw", {}) or {}),
)
LOGGER.debug(
"开始部门评估 ts_code=%s departments=%s",
self.context.ts_code,
list(self.department_manager.agents.keys()),
extra=LOG_EXTRA,
)
self.department_decisions = self.department_manager.evaluate(dept_context)
LOGGER.debug(
"部门评估完成 ts_code=%s decisions=%s",
self.context.ts_code,
list(self.department_decisions.keys()),
extra=LOG_EXTRA,
)
if self.department_decisions:
self.department_round = self.host.start_round(
self.host_trace,
@ -285,11 +341,32 @@ class DecisionWorkflow:
)
if self.method == "vote":
return vote(self.filtered_utilities, self.norm_weights)
action, confidence = vote(self.filtered_utilities, self.norm_weights)
LOGGER.debug(
"采用投票机制 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
return action, confidence
action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores)
if action not in self.feasible_actions:
return vote(self.filtered_utilities, self.norm_weights)
LOGGER.debug(
"纳什解不可行,改用投票 ts_code=%s invalid_action=%s",
self.context.ts_code,
action.value,
extra=LOG_EXTRA,
)
action, confidence = vote(self.filtered_utilities, self.norm_weights)
LOGGER.debug(
"纳什解计算完成 ts_code=%s action=%s confidence=%.3f",
self.context.ts_code,
action.value,
confidence,
extra=LOG_EXTRA,
)
return action, confidence
def _apply_risk(self, action: AgentAction) -> RiskAssessment:
@ -307,6 +384,17 @@ class DecisionWorkflow:
self.department_round.outcome = action.value
self.host.finalize_round(self.department_round)
LOGGER.debug(
"风险评估结果 ts_code=%s action=%s status=%s reason=%s conflict=%s votes=%s",
self.context.ts_code,
action.value,
assessment.status,
assessment.reason,
conflict_flag,
dict(self.department_votes),
extra=LOG_EXTRA,
)
if assessment.status != "ok":
self.risk_round = self.host.ensure_round(
self.host_trace,
@ -398,12 +486,28 @@ class DecisionWorkflow:
)
self.host.finalize_round(self.execution_round)
self.execution_round.notes.setdefault("target_weight", exec_weight)
LOGGER.info(
"执行阶段结论 ts_code=%s final_action=%s original=%s status=%s target_weight=%.4f review=%s",
self.context.ts_code,
exec_action.value,
action.value,
exec_status,
exec_weight,
requires_review,
extra=LOG_EXTRA,
)
return exec_action
def _finalize_conflicts(self, exec_action: AgentAction) -> None:
self.host.close(self.host_trace)
self.belief_revision = revise_beliefs(self.belief_updates, exec_action)
if self.belief_revision.conflicts:
LOGGER.warning(
"发现信念冲突 ts_code=%s conflicts=%s",
self.context.ts_code,
self.belief_revision.conflicts,
extra=LOG_EXTRA,
)
conflict_round = self.host.ensure_round(
self.host_trace,
agenda="conflict_resolution",
@ -436,15 +540,35 @@ def decide(
department_manager: Optional[DepartmentManager] = None,
department_context: Optional[DepartmentContext] = None,
) -> Decision:
agent_list = list(agents)
LOGGER.debug(
"进入多智能体决策 ts_code=%s trade_date=%s agents=%s method=%s",
context.ts_code,
context.trade_date,
len(agent_list),
method,
extra=LOG_EXTRA,
)
workflow = DecisionWorkflow(
context,
agents,
agent_list,
weights,
method,
department_manager,
department_context,
)
return workflow.run()
decision = workflow.run()
LOGGER.info(
"完成多智能体决策 ts_code=%s trade_date=%s action=%s confidence=%.3f review=%s method=%s",
context.ts_code,
context.trade_date,
decision.action.value,
decision.confidence,
decision.requires_review,
method,
extra=LOG_EXTRA,
)
return decision
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:

View File

@ -1,8 +1,13 @@
"""Risk agent acts as leader with veto rights."""
from __future__ import annotations
from app.utils.logging import get_logger
from .base import Agent, AgentAction, AgentContext
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_risk"}
class RiskRecommendation:
"""Represents structured recommendation from the risk agent."""
@ -68,21 +73,44 @@ class RiskAgent(Agent):
features = dict(context.features or {})
risk_penalty = float(features.get("risk_penalty") or 0.0)
def finalize(
recommendation: RiskRecommendation,
trigger: str,
) -> RiskRecommendation:
LOGGER.debug(
"风险代理评估 ts_code=%s action=%s status=%s reason=%s trigger=%s penalty=%.3f conflict=%s",
context.ts_code,
decision_action.value,
recommendation.status,
recommendation.reason,
trigger,
risk_penalty,
conflict_flag,
extra=LOG_EXTRA,
)
return recommendation
if bool(features.get("is_suspended")):
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="blocked",
reason="suspended",
recommended_action=AgentAction.HOLD,
notes={"trigger": "is_suspended"},
),
"is_suspended",
)
if bool(features.get("is_blacklisted")):
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="blocked",
reason="blacklist",
recommended_action=fallback,
notes={"trigger": "is_blacklisted"},
),
"is_blacklisted",
)
if bool(features.get("limit_up")) and decision_action in {
@ -90,22 +118,28 @@ class RiskAgent(Agent):
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="blocked",
reason="limit_up",
recommended_action=AgentAction.HOLD,
notes={"trigger": "limit_up"},
),
"limit_up",
)
if bool(features.get("position_limit")) and decision_action in {
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="pending_review",
reason="position_limit",
recommended_action=AgentAction.BUY_S,
notes={"trigger": "position_limit"},
),
"position_limit",
)
if risk_penalty >= 0.9 and decision_action in {
@ -113,28 +147,40 @@ class RiskAgent(Agent):
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="blocked",
reason="risk_penalty_extreme",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
),
"risk_penalty_extreme",
)
if risk_penalty >= 0.7 and decision_action in {
AgentAction.BUY_S,
AgentAction.BUY_M,
AgentAction.BUY_L,
}:
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="pending_review",
reason="risk_penalty_high",
recommended_action=AgentAction.HOLD,
notes={"risk_penalty": risk_penalty},
),
"risk_penalty_high",
)
if conflict_flag:
return RiskRecommendation(
return finalize(
RiskRecommendation(
status="pending_review",
reason="conflict_threshold",
),
"conflict_threshold",
)
return RiskRecommendation(status="ok", reason="clear")
return finalize(
RiskRecommendation(status="ok", reason="clear"),
"clear",
)

View File

@ -2,14 +2,17 @@
from __future__ import annotations
import json
from datetime import datetime
from typing import Any, Dict, Optional
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "data_ingest"}
class JobLogger:
"""任务记录器。"""
"""任务记录器,通过数据库记录抓取作业运行情况"""
def __init__(self, job_type: str) -> None:
"""初始化任务记录器。
@ -28,17 +31,36 @@ class JobLogger:
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""",
(self.job_type,)
(self.job_type,),
)
self.job_id = cursor.lastrowid
session.commit()
LOGGER.info(
"抓取任务启动 job_type=%s job_id=%s",
self.job_type,
self.job_id,
extra=LOG_EXTRA,
)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""结束任务记录。"""
if exc_val:
LOGGER.exception(
"抓取任务失败 job_type=%s job_id=%s err=%s",
self.job_type,
self.job_id,
exc_val,
extra=LOG_EXTRA,
)
self.update_status("failed", str(exc_val))
else:
LOGGER.info(
"抓取任务完成 job_type=%s job_id=%s",
self.job_type,
self.job_id,
extra=LOG_EXTRA,
)
self.update_status("success")
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
@ -49,6 +71,7 @@ class JobLogger:
error_msg: 错误信息如果有
"""
if not self.job_id:
LOGGER.debug("忽略无效任务状态更新 job_type=%s status=%s", self.job_type, status, extra=LOG_EXTRA)
return
with db_session() as session:
@ -60,9 +83,17 @@ class JobLogger:
updated_at = CURRENT_TIMESTAMP
WHERE id = ?
""",
(status, error_msg, self.job_id)
(status, error_msg, self.job_id),
)
session.commit()
LOGGER.debug(
"更新任务状态 job_type=%s job_id=%s status=%s error=%s",
self.job_type,
self.job_id,
status,
error_msg,
extra=LOG_EXTRA,
)
def update_metadata(self, metadata: Dict[str, Any]) -> None:
"""更新任务元数据。
@ -71,6 +102,11 @@ class JobLogger:
metadata: 元数据字典
"""
if not self.job_id:
LOGGER.debug(
"忽略元数据更新(尚未初始化) job_type=%s",
self.job_type,
extra=LOG_EXTRA,
)
return
with db_session() as session:
@ -80,6 +116,13 @@ class JobLogger:
SET metadata = ?
WHERE id = ?
""",
(json.dumps(metadata), self.job_id)
(json.dumps(metadata), self.job_id),
)
session.commit()
LOGGER.debug(
"记录任务元数据 job_type=%s job_id=%s keys=%s",
self.job_type,
self.job_id,
sorted(metadata.keys()),
extra=LOG_EXTRA,
)

View File

@ -7,6 +7,10 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Tuple
import numpy as np
from app.backtest.decision_env import DecisionEnv
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "decision_env"}
@dataclass
@ -26,6 +30,13 @@ class DecisionEnvAdapter:
else:
self._keys = list(self.observation_keys)
self._last_reset_obs = None
LOGGER.debug(
"初始化 DecisionEnvAdapter obs_dim=%s action_dim=%s keys=%s",
len(self._keys),
self.env.action_dim,
self._keys,
extra=LOG_EXTRA,
)
@property
def action_dim(self) -> int:
@ -38,12 +49,24 @@ class DecisionEnvAdapter:
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
raw = self.env.reset()
self._last_reset_obs = raw
LOGGER.debug(
"环境重置完成 episode=%s",
raw.get("episode"),
extra=LOG_EXTRA,
)
return self._to_array(raw), raw
def step(
self, action: Sequence[float]
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
obs_dict, reward, done, info = self.env.step(action)
LOGGER.debug(
"环境执行动作 action=%s reward=%.4f done=%s",
[round(float(a), 4) for a in action],
reward,
done,
extra=LOG_EXTRA,
)
return self._to_array(obs_dict), reward, done, info, obs_dict
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:

View File

@ -10,8 +10,13 @@ import torch
from torch import nn
from torch.distributions import Beta
from app.utils.logging import get_logger
from .adapters import DecisionEnvAdapter
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "rl_ppo"}
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
if isinstance(layer, nn.Linear):
@ -168,6 +173,15 @@ class PPOTrainer:
if config.seed is not None:
torch.manual_seed(config.seed)
np.random.seed(config.seed)
LOGGER.info(
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
obs_dim,
action_dim,
config.total_timesteps,
config.rollout_steps,
config.device,
extra=LOG_EXTRA,
)
def train(self) -> TrainingSummary:
cfg = self.config
@ -180,6 +194,14 @@ class PPOTrainer:
diagnostics: List[Dict[str, float]] = []
current_return = 0.0
current_length = 0
LOGGER.info(
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
cfg.total_timesteps,
cfg.rollout_steps,
cfg.epochs,
cfg.minibatch_size,
extra=LOG_EXTRA,
)
while timesteps < cfg.total_timesteps:
rollout.reset()
@ -203,6 +225,14 @@ class PPOTrainer:
if done:
episode_rewards.append(current_return)
episode_lengths.append(current_length)
LOGGER.info(
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
episode_rewards[-1],
episode_lengths[-1],
len(episode_rewards),
timesteps,
extra=LOG_EXTRA,
)
current_return = 0.0
current_length = 0
next_obs_array, _ = self.adapter.reset()
@ -216,7 +246,17 @@ class PPOTrainer:
with torch.no_grad():
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
LOGGER.debug(
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
rollout._pos,
timesteps,
cfg.total_timesteps - timesteps,
extra=LOG_EXTRA,
)
last_policy_loss = None
last_value_loss = None
last_entropy = None
for _ in range(cfg.epochs):
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
cfg.minibatch_size
@ -241,6 +281,9 @@ class PPOTrainer:
value_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
self.value_optimizer.step()
last_policy_loss = float(policy_loss.detach().cpu())
last_value_loss = float(value_loss.detach().cpu())
last_entropy = float(entropy.mean().detach().cpu())
diagnostics.append(
{
@ -249,13 +292,30 @@ class PPOTrainer:
"entropy": float(entropy.mean().detach().cpu()),
}
)
LOGGER.info(
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
timesteps,
cfg.total_timesteps,
last_policy_loss if last_policy_loss is not None else 0.0,
last_value_loss if last_value_loss is not None else 0.0,
last_entropy if last_entropy is not None else 0.0,
extra=LOG_EXTRA,
)
return TrainingSummary(
summary = TrainingSummary(
timesteps=timesteps,
episode_rewards=episode_rewards,
episode_lengths=episode_lengths,
diagnostics=diagnostics,
)
LOGGER.info(
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
summary.timesteps,
len(summary.episode_rewards),
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
extra=LOG_EXTRA,
)
return summary
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:

View File

@ -5,6 +5,10 @@ from dataclasses import dataclass
from typing import Dict, Iterable, Mapping, Optional, Sequence
from .data_access import DataBroker
from .logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "feature_snapshot"}
@dataclass
@ -15,6 +19,11 @@ class FeatureSnapshotService:
def __init__(self, broker: Optional[DataBroker] = None) -> None:
self.broker = broker or DataBroker()
LOGGER.debug(
"初始化特征快照服务 broker=%s",
type(self.broker).__name__,
extra=LOG_EXTRA,
)
def load_latest(
self,
@ -27,13 +36,34 @@ class FeatureSnapshotService:
"""Fetch a snapshot of feature values for the given universe."""
if not ts_codes:
LOGGER.debug(
"跳过快照加载(标的为空) trade_date=%s",
trade_date,
extra=LOG_EXTRA,
)
return {}
return self.broker.fetch_batch_latest(
field_count = len(fields)
LOGGER.debug(
"加载特征快照 trade_date=%s universe=%s fields=%s auto_refresh=%s",
trade_date,
len(ts_codes),
field_count,
auto_refresh,
extra=LOG_EXTRA,
)
snapshot = self.broker.fetch_batch_latest(
list(ts_codes),
trade_date,
fields,
auto_refresh=auto_refresh,
)
LOGGER.debug(
"特征快照加载完成 trade_date=%s universe=%s",
trade_date,
len(snapshot),
extra=LOG_EXTRA,
)
return snapshot
def load_single(
self,
@ -45,14 +75,30 @@ class FeatureSnapshotService:
) -> Mapping[str, object]:
"""Convenience wrapper to reuse the snapshot logic for a single symbol."""
field_list = list(fields)
LOGGER.debug(
"加载单标的快照 trade_date=%s ts_code=%s fields=%s auto_refresh=%s",
trade_date,
ts_code,
len(field_list),
auto_refresh,
extra=LOG_EXTRA,
)
snapshot = self.load_latest(
trade_date,
list(fields),
field_list,
[ts_code],
auto_refresh=auto_refresh,
)
return snapshot.get(ts_code, {})
result = snapshot.get(ts_code, {})
if not result:
LOGGER.debug(
"单标的快照为空 trade_date=%s ts_code=%s",
trade_date,
ts_code,
extra=LOG_EXTRA,
)
return result
__all__ = ["FeatureSnapshotService"]