add comprehensive logging to decision workflow and risk assessment
This commit is contained in:
parent
1ca2f2be19
commit
4b68d84b3c
@ -4,8 +4,13 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
from .base import AgentAction
|
from .base import AgentAction
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "decision_belief"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BeliefRevisionResult:
|
class BeliefRevisionResult:
|
||||||
@ -46,12 +51,21 @@ def revise_beliefs(belief_updates: Dict[str, "BeliefUpdate"], default_action: Ag
|
|||||||
"votes": {action.value: count for action, count in action_votes.items()},
|
"votes": {action.value: count for action, count in action_votes.items()},
|
||||||
"reasons": reasons,
|
"reasons": reasons,
|
||||||
}
|
}
|
||||||
return BeliefRevisionResult(
|
result = BeliefRevisionResult(
|
||||||
consensus_action=consensus_action,
|
consensus_action=consensus_action,
|
||||||
consensus_confidence=consensus_confidence,
|
consensus_confidence=consensus_confidence,
|
||||||
conflicts=conflicts,
|
conflicts=conflicts,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
)
|
)
|
||||||
|
LOGGER.debug(
|
||||||
|
"信念修正完成 consensus=%s confidence=%.3f conflicts=%s vote_counts=%s",
|
||||||
|
result.consensus_action.value if result.consensus_action else None,
|
||||||
|
result.consensus_confidence,
|
||||||
|
result.conflicts,
|
||||||
|
notes["votes"],
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# avoid circular import typing
|
# avoid circular import typing
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from dataclasses import dataclass, field
|
|||||||
from math import log
|
from math import log
|
||||||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
from .base import Agent, AgentAction, AgentContext, UtilityMatrix
|
||||||
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
|
from .departments import DepartmentContext, DepartmentDecision, DepartmentManager
|
||||||
from .registry import weight_map
|
from .registry import weight_map
|
||||||
@ -20,6 +22,10 @@ from .protocols import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "decision_workflow"}
|
||||||
|
|
||||||
|
|
||||||
ACTIONS: Tuple[AgentAction, ...] = (
|
ACTIONS: Tuple[AgentAction, ...] = (
|
||||||
AgentAction.SELL,
|
AgentAction.SELL,
|
||||||
AgentAction.HOLD,
|
AgentAction.HOLD,
|
||||||
@ -188,9 +194,30 @@ class DecisionWorkflow:
|
|||||||
self.norm_weights: Dict[str, float] = {}
|
self.norm_weights: Dict[str, float] = {}
|
||||||
self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {}
|
self.filtered_utilities: Dict[AgentAction, Dict[str, float]] = {}
|
||||||
self.belief_revision: Optional[BeliefRevisionResult] = None
|
self.belief_revision: Optional[BeliefRevisionResult] = None
|
||||||
|
LOGGER.debug(
|
||||||
|
"初始化决策流程 ts_code=%s trade_date=%s method=%s agents=%s departments=%s",
|
||||||
|
context.ts_code,
|
||||||
|
context.trade_date,
|
||||||
|
method,
|
||||||
|
len(self.agent_list),
|
||||||
|
bool(self.department_manager),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
def run(self) -> Decision:
|
def run(self) -> Decision:
|
||||||
|
LOGGER.debug(
|
||||||
|
"执行决策流程 ts_code=%s method=%s feasible=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
self.method,
|
||||||
|
[action.value for action in self.feasible_actions],
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
if not self.feasible_actions:
|
if not self.feasible_actions:
|
||||||
|
LOGGER.warning(
|
||||||
|
"无可行动作,回退到 HOLD ts_code=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return Decision(
|
return Decision(
|
||||||
action=AgentAction.HOLD,
|
action=AgentAction.HOLD,
|
||||||
confidence=0.0,
|
confidence=0.0,
|
||||||
@ -201,6 +228,13 @@ class DecisionWorkflow:
|
|||||||
|
|
||||||
self._evaluate_departments()
|
self._evaluate_departments()
|
||||||
action, confidence = self._select_action()
|
action, confidence = self._select_action()
|
||||||
|
LOGGER.debug(
|
||||||
|
"初步动作选择完成 ts_code=%s action=%s confidence=%.3f",
|
||||||
|
self.context.ts_code,
|
||||||
|
action.value,
|
||||||
|
confidence,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
risk_assessment = self._apply_risk(action)
|
risk_assessment = self._apply_risk(action)
|
||||||
exec_action = self._finalize_execution(action, risk_assessment)
|
exec_action = self._finalize_execution(action, risk_assessment)
|
||||||
self._finalize_conflicts(exec_action)
|
self._finalize_conflicts(exec_action)
|
||||||
@ -210,7 +244,7 @@ class DecisionWorkflow:
|
|||||||
self.department_votes,
|
self.department_votes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Decision(
|
decision = Decision(
|
||||||
action=action,
|
action=action,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
target_weight=target_weight_for_action(action),
|
target_weight=target_weight_for_action(action),
|
||||||
@ -224,6 +258,16 @@ class DecisionWorkflow:
|
|||||||
belief_updates=self.belief_updates,
|
belief_updates=self.belief_updates,
|
||||||
belief_revision=self.belief_revision,
|
belief_revision=self.belief_revision,
|
||||||
)
|
)
|
||||||
|
LOGGER.info(
|
||||||
|
"决策完成 ts_code=%s action=%s confidence=%.3f review=%s risk_status=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
decision.action.value,
|
||||||
|
decision.confidence,
|
||||||
|
decision.requires_review,
|
||||||
|
risk_assessment.status,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return decision
|
||||||
|
|
||||||
def _evaluate_departments(self) -> None:
|
def _evaluate_departments(self) -> None:
|
||||||
if not self.department_manager:
|
if not self.department_manager:
|
||||||
@ -236,7 +280,19 @@ class DecisionWorkflow:
|
|||||||
market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}),
|
market_snapshot=dict(getattr(self.context, "market_snapshot", {}) or {}),
|
||||||
raw=dict(getattr(self.context, "raw", {}) or {}),
|
raw=dict(getattr(self.context, "raw", {}) or {}),
|
||||||
)
|
)
|
||||||
|
LOGGER.debug(
|
||||||
|
"开始部门评估 ts_code=%s departments=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
list(self.department_manager.agents.keys()),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
self.department_decisions = self.department_manager.evaluate(dept_context)
|
self.department_decisions = self.department_manager.evaluate(dept_context)
|
||||||
|
LOGGER.debug(
|
||||||
|
"部门评估完成 ts_code=%s decisions=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
list(self.department_decisions.keys()),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
if self.department_decisions:
|
if self.department_decisions:
|
||||||
self.department_round = self.host.start_round(
|
self.department_round = self.host.start_round(
|
||||||
self.host_trace,
|
self.host_trace,
|
||||||
@ -285,11 +341,32 @@ class DecisionWorkflow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.method == "vote":
|
if self.method == "vote":
|
||||||
return vote(self.filtered_utilities, self.norm_weights)
|
action, confidence = vote(self.filtered_utilities, self.norm_weights)
|
||||||
|
LOGGER.debug(
|
||||||
|
"采用投票机制 ts_code=%s action=%s confidence=%.3f",
|
||||||
|
self.context.ts_code,
|
||||||
|
action.value,
|
||||||
|
confidence,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return action, confidence
|
||||||
|
|
||||||
action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores)
|
action, confidence = nash_bargain(self.filtered_utilities, self.norm_weights, hold_scores)
|
||||||
if action not in self.feasible_actions:
|
if action not in self.feasible_actions:
|
||||||
return vote(self.filtered_utilities, self.norm_weights)
|
LOGGER.debug(
|
||||||
|
"纳什解不可行,改用投票 ts_code=%s invalid_action=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
action.value,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
action, confidence = vote(self.filtered_utilities, self.norm_weights)
|
||||||
|
LOGGER.debug(
|
||||||
|
"纳什解计算完成 ts_code=%s action=%s confidence=%.3f",
|
||||||
|
self.context.ts_code,
|
||||||
|
action.value,
|
||||||
|
confidence,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return action, confidence
|
return action, confidence
|
||||||
|
|
||||||
def _apply_risk(self, action: AgentAction) -> RiskAssessment:
|
def _apply_risk(self, action: AgentAction) -> RiskAssessment:
|
||||||
@ -307,6 +384,17 @@ class DecisionWorkflow:
|
|||||||
self.department_round.outcome = action.value
|
self.department_round.outcome = action.value
|
||||||
self.host.finalize_round(self.department_round)
|
self.host.finalize_round(self.department_round)
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
"风险评估结果 ts_code=%s action=%s status=%s reason=%s conflict=%s votes=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
action.value,
|
||||||
|
assessment.status,
|
||||||
|
assessment.reason,
|
||||||
|
conflict_flag,
|
||||||
|
dict(self.department_votes),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
if assessment.status != "ok":
|
if assessment.status != "ok":
|
||||||
self.risk_round = self.host.ensure_round(
|
self.risk_round = self.host.ensure_round(
|
||||||
self.host_trace,
|
self.host_trace,
|
||||||
@ -398,12 +486,28 @@ class DecisionWorkflow:
|
|||||||
)
|
)
|
||||||
self.host.finalize_round(self.execution_round)
|
self.host.finalize_round(self.execution_round)
|
||||||
self.execution_round.notes.setdefault("target_weight", exec_weight)
|
self.execution_round.notes.setdefault("target_weight", exec_weight)
|
||||||
|
LOGGER.info(
|
||||||
|
"执行阶段结论 ts_code=%s final_action=%s original=%s status=%s target_weight=%.4f review=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
exec_action.value,
|
||||||
|
action.value,
|
||||||
|
exec_status,
|
||||||
|
exec_weight,
|
||||||
|
requires_review,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return exec_action
|
return exec_action
|
||||||
|
|
||||||
def _finalize_conflicts(self, exec_action: AgentAction) -> None:
|
def _finalize_conflicts(self, exec_action: AgentAction) -> None:
|
||||||
self.host.close(self.host_trace)
|
self.host.close(self.host_trace)
|
||||||
self.belief_revision = revise_beliefs(self.belief_updates, exec_action)
|
self.belief_revision = revise_beliefs(self.belief_updates, exec_action)
|
||||||
if self.belief_revision.conflicts:
|
if self.belief_revision.conflicts:
|
||||||
|
LOGGER.warning(
|
||||||
|
"发现信念冲突 ts_code=%s conflicts=%s",
|
||||||
|
self.context.ts_code,
|
||||||
|
self.belief_revision.conflicts,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
conflict_round = self.host.ensure_round(
|
conflict_round = self.host.ensure_round(
|
||||||
self.host_trace,
|
self.host_trace,
|
||||||
agenda="conflict_resolution",
|
agenda="conflict_resolution",
|
||||||
@ -436,15 +540,35 @@ def decide(
|
|||||||
department_manager: Optional[DepartmentManager] = None,
|
department_manager: Optional[DepartmentManager] = None,
|
||||||
department_context: Optional[DepartmentContext] = None,
|
department_context: Optional[DepartmentContext] = None,
|
||||||
) -> Decision:
|
) -> Decision:
|
||||||
|
agent_list = list(agents)
|
||||||
|
LOGGER.debug(
|
||||||
|
"进入多智能体决策 ts_code=%s trade_date=%s agents=%s method=%s",
|
||||||
|
context.ts_code,
|
||||||
|
context.trade_date,
|
||||||
|
len(agent_list),
|
||||||
|
method,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
workflow = DecisionWorkflow(
|
workflow = DecisionWorkflow(
|
||||||
context,
|
context,
|
||||||
agents,
|
agent_list,
|
||||||
weights,
|
weights,
|
||||||
method,
|
method,
|
||||||
department_manager,
|
department_manager,
|
||||||
department_context,
|
department_context,
|
||||||
)
|
)
|
||||||
return workflow.run()
|
decision = workflow.run()
|
||||||
|
LOGGER.info(
|
||||||
|
"完成多智能体决策 ts_code=%s trade_date=%s action=%s confidence=%.3f review=%s method=%s",
|
||||||
|
context.ts_code,
|
||||||
|
context.trade_date,
|
||||||
|
decision.action.value,
|
||||||
|
decision.confidence,
|
||||||
|
decision.requires_review,
|
||||||
|
method,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return decision
|
||||||
|
|
||||||
|
|
||||||
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
|
def _department_scores(decision: DepartmentDecision) -> Dict[AgentAction, float]:
|
||||||
|
|||||||
@ -1,8 +1,13 @@
|
|||||||
"""Risk agent acts as leader with veto rights."""
|
"""Risk agent acts as leader with veto rights."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
from .base import Agent, AgentAction, AgentContext
|
from .base import Agent, AgentAction, AgentContext
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "decision_risk"}
|
||||||
|
|
||||||
|
|
||||||
class RiskRecommendation:
|
class RiskRecommendation:
|
||||||
"""Represents structured recommendation from the risk agent."""
|
"""Represents structured recommendation from the risk agent."""
|
||||||
@ -68,21 +73,44 @@ class RiskAgent(Agent):
|
|||||||
features = dict(context.features or {})
|
features = dict(context.features or {})
|
||||||
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
risk_penalty = float(features.get("risk_penalty") or 0.0)
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
recommendation: RiskRecommendation,
|
||||||
|
trigger: str,
|
||||||
|
) -> RiskRecommendation:
|
||||||
|
LOGGER.debug(
|
||||||
|
"风险代理评估 ts_code=%s action=%s status=%s reason=%s trigger=%s penalty=%.3f conflict=%s",
|
||||||
|
context.ts_code,
|
||||||
|
decision_action.value,
|
||||||
|
recommendation.status,
|
||||||
|
recommendation.reason,
|
||||||
|
trigger,
|
||||||
|
risk_penalty,
|
||||||
|
conflict_flag,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return recommendation
|
||||||
|
|
||||||
if bool(features.get("is_suspended")):
|
if bool(features.get("is_suspended")):
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="blocked",
|
RiskRecommendation(
|
||||||
reason="suspended",
|
status="blocked",
|
||||||
recommended_action=AgentAction.HOLD,
|
reason="suspended",
|
||||||
notes={"trigger": "is_suspended"},
|
recommended_action=AgentAction.HOLD,
|
||||||
|
notes={"trigger": "is_suspended"},
|
||||||
|
),
|
||||||
|
"is_suspended",
|
||||||
)
|
)
|
||||||
|
|
||||||
if bool(features.get("is_blacklisted")):
|
if bool(features.get("is_blacklisted")):
|
||||||
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
|
fallback = AgentAction.SELL if decision_action is AgentAction.SELL else AgentAction.HOLD
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="blocked",
|
RiskRecommendation(
|
||||||
reason="blacklist",
|
status="blocked",
|
||||||
recommended_action=fallback,
|
reason="blacklist",
|
||||||
notes={"trigger": "is_blacklisted"},
|
recommended_action=fallback,
|
||||||
|
notes={"trigger": "is_blacklisted"},
|
||||||
|
),
|
||||||
|
"is_blacklisted",
|
||||||
)
|
)
|
||||||
|
|
||||||
if bool(features.get("limit_up")) and decision_action in {
|
if bool(features.get("limit_up")) and decision_action in {
|
||||||
@ -90,22 +118,28 @@ class RiskAgent(Agent):
|
|||||||
AgentAction.BUY_M,
|
AgentAction.BUY_M,
|
||||||
AgentAction.BUY_L,
|
AgentAction.BUY_L,
|
||||||
}:
|
}:
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="blocked",
|
RiskRecommendation(
|
||||||
reason="limit_up",
|
status="blocked",
|
||||||
recommended_action=AgentAction.HOLD,
|
reason="limit_up",
|
||||||
notes={"trigger": "limit_up"},
|
recommended_action=AgentAction.HOLD,
|
||||||
|
notes={"trigger": "limit_up"},
|
||||||
|
),
|
||||||
|
"limit_up",
|
||||||
)
|
)
|
||||||
|
|
||||||
if bool(features.get("position_limit")) and decision_action in {
|
if bool(features.get("position_limit")) and decision_action in {
|
||||||
AgentAction.BUY_M,
|
AgentAction.BUY_M,
|
||||||
AgentAction.BUY_L,
|
AgentAction.BUY_L,
|
||||||
}:
|
}:
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="pending_review",
|
RiskRecommendation(
|
||||||
reason="position_limit",
|
status="pending_review",
|
||||||
recommended_action=AgentAction.BUY_S,
|
reason="position_limit",
|
||||||
notes={"trigger": "position_limit"},
|
recommended_action=AgentAction.BUY_S,
|
||||||
|
notes={"trigger": "position_limit"},
|
||||||
|
),
|
||||||
|
"position_limit",
|
||||||
)
|
)
|
||||||
|
|
||||||
if risk_penalty >= 0.9 and decision_action in {
|
if risk_penalty >= 0.9 and decision_action in {
|
||||||
@ -113,28 +147,40 @@ class RiskAgent(Agent):
|
|||||||
AgentAction.BUY_M,
|
AgentAction.BUY_M,
|
||||||
AgentAction.BUY_L,
|
AgentAction.BUY_L,
|
||||||
}:
|
}:
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="blocked",
|
RiskRecommendation(
|
||||||
reason="risk_penalty_extreme",
|
status="blocked",
|
||||||
recommended_action=AgentAction.HOLD,
|
reason="risk_penalty_extreme",
|
||||||
notes={"risk_penalty": risk_penalty},
|
recommended_action=AgentAction.HOLD,
|
||||||
|
notes={"risk_penalty": risk_penalty},
|
||||||
|
),
|
||||||
|
"risk_penalty_extreme",
|
||||||
)
|
)
|
||||||
if risk_penalty >= 0.7 and decision_action in {
|
if risk_penalty >= 0.7 and decision_action in {
|
||||||
AgentAction.BUY_S,
|
AgentAction.BUY_S,
|
||||||
AgentAction.BUY_M,
|
AgentAction.BUY_M,
|
||||||
AgentAction.BUY_L,
|
AgentAction.BUY_L,
|
||||||
}:
|
}:
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="pending_review",
|
RiskRecommendation(
|
||||||
reason="risk_penalty_high",
|
status="pending_review",
|
||||||
recommended_action=AgentAction.HOLD,
|
reason="risk_penalty_high",
|
||||||
notes={"risk_penalty": risk_penalty},
|
recommended_action=AgentAction.HOLD,
|
||||||
|
notes={"risk_penalty": risk_penalty},
|
||||||
|
),
|
||||||
|
"risk_penalty_high",
|
||||||
)
|
)
|
||||||
|
|
||||||
if conflict_flag:
|
if conflict_flag:
|
||||||
return RiskRecommendation(
|
return finalize(
|
||||||
status="pending_review",
|
RiskRecommendation(
|
||||||
reason="conflict_threshold",
|
status="pending_review",
|
||||||
|
reason="conflict_threshold",
|
||||||
|
),
|
||||||
|
"conflict_threshold",
|
||||||
)
|
)
|
||||||
|
|
||||||
return RiskRecommendation(status="ok", reason="clear")
|
return finalize(
|
||||||
|
RiskRecommendation(status="ok", reason="clear"),
|
||||||
|
"clear",
|
||||||
|
)
|
||||||
|
|||||||
@ -2,14 +2,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "data_ingest"}
|
||||||
|
|
||||||
|
|
||||||
class JobLogger:
|
class JobLogger:
|
||||||
"""任务记录器。"""
|
"""任务记录器,通过数据库记录抓取作业运行情况。"""
|
||||||
|
|
||||||
def __init__(self, job_type: str) -> None:
|
def __init__(self, job_type: str) -> None:
|
||||||
"""初始化任务记录器。
|
"""初始化任务记录器。
|
||||||
@ -28,17 +31,36 @@ class JobLogger:
|
|||||||
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
|
INSERT INTO fetch_jobs (job_type, status, created_at, updated_at)
|
||||||
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
VALUES (?, 'running', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
""",
|
""",
|
||||||
(self.job_type,)
|
(self.job_type,),
|
||||||
)
|
)
|
||||||
self.job_id = cursor.lastrowid
|
self.job_id = cursor.lastrowid
|
||||||
session.commit()
|
session.commit()
|
||||||
|
LOGGER.info(
|
||||||
|
"抓取任务启动 job_type=%s job_id=%s",
|
||||||
|
self.job_type,
|
||||||
|
self.job_id,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
"""结束任务记录。"""
|
"""结束任务记录。"""
|
||||||
if exc_val:
|
if exc_val:
|
||||||
|
LOGGER.exception(
|
||||||
|
"抓取任务失败 job_type=%s job_id=%s err=%s",
|
||||||
|
self.job_type,
|
||||||
|
self.job_id,
|
||||||
|
exc_val,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
self.update_status("failed", str(exc_val))
|
self.update_status("failed", str(exc_val))
|
||||||
else:
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"抓取任务完成 job_type=%s job_id=%s",
|
||||||
|
self.job_type,
|
||||||
|
self.job_id,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
self.update_status("success")
|
self.update_status("success")
|
||||||
|
|
||||||
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
|
def update_status(self, status: str, error_msg: Optional[str] = None) -> None:
|
||||||
@ -49,6 +71,7 @@ class JobLogger:
|
|||||||
error_msg: 错误信息(如果有)
|
error_msg: 错误信息(如果有)
|
||||||
"""
|
"""
|
||||||
if not self.job_id:
|
if not self.job_id:
|
||||||
|
LOGGER.debug("忽略无效任务状态更新 job_type=%s status=%s", self.job_type, status, extra=LOG_EXTRA)
|
||||||
return
|
return
|
||||||
|
|
||||||
with db_session() as session:
|
with db_session() as session:
|
||||||
@ -60,9 +83,17 @@ class JobLogger:
|
|||||||
updated_at = CURRENT_TIMESTAMP
|
updated_at = CURRENT_TIMESTAMP
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""",
|
""",
|
||||||
(status, error_msg, self.job_id)
|
(status, error_msg, self.job_id),
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
LOGGER.debug(
|
||||||
|
"更新任务状态 job_type=%s job_id=%s status=%s error=%s",
|
||||||
|
self.job_type,
|
||||||
|
self.job_id,
|
||||||
|
status,
|
||||||
|
error_msg,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
||||||
"""更新任务元数据。
|
"""更新任务元数据。
|
||||||
@ -71,6 +102,11 @@ class JobLogger:
|
|||||||
metadata: 元数据字典
|
metadata: 元数据字典
|
||||||
"""
|
"""
|
||||||
if not self.job_id:
|
if not self.job_id:
|
||||||
|
LOGGER.debug(
|
||||||
|
"忽略元数据更新(尚未初始化) job_type=%s",
|
||||||
|
self.job_type,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
with db_session() as session:
|
with db_session() as session:
|
||||||
@ -80,6 +116,13 @@ class JobLogger:
|
|||||||
SET metadata = ?
|
SET metadata = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""",
|
""",
|
||||||
(json.dumps(metadata), self.job_id)
|
(json.dumps(metadata), self.job_id),
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
LOGGER.debug(
|
||||||
|
"记录任务元数据 job_type=%s job_id=%s keys=%s",
|
||||||
|
self.job_type,
|
||||||
|
self.job_id,
|
||||||
|
sorted(metadata.keys()),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|||||||
@ -7,6 +7,10 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from app.backtest.decision_env import DecisionEnv
|
from app.backtest.decision_env import DecisionEnv
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "decision_env"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -26,6 +30,13 @@ class DecisionEnvAdapter:
|
|||||||
else:
|
else:
|
||||||
self._keys = list(self.observation_keys)
|
self._keys = list(self.observation_keys)
|
||||||
self._last_reset_obs = None
|
self._last_reset_obs = None
|
||||||
|
LOGGER.debug(
|
||||||
|
"初始化 DecisionEnvAdapter obs_dim=%s action_dim=%s keys=%s",
|
||||||
|
len(self._keys),
|
||||||
|
self.env.action_dim,
|
||||||
|
self._keys,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_dim(self) -> int:
|
def action_dim(self) -> int:
|
||||||
@ -38,12 +49,24 @@ class DecisionEnvAdapter:
|
|||||||
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
|
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
|
||||||
raw = self.env.reset()
|
raw = self.env.reset()
|
||||||
self._last_reset_obs = raw
|
self._last_reset_obs = raw
|
||||||
|
LOGGER.debug(
|
||||||
|
"环境重置完成 episode=%s",
|
||||||
|
raw.get("episode"),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return self._to_array(raw), raw
|
return self._to_array(raw), raw
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: Sequence[float]
|
self, action: Sequence[float]
|
||||||
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
|
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
|
||||||
obs_dict, reward, done, info = self.env.step(action)
|
obs_dict, reward, done, info = self.env.step(action)
|
||||||
|
LOGGER.debug(
|
||||||
|
"环境执行动作 action=%s reward=%.4f done=%s",
|
||||||
|
[round(float(a), 4) for a in action],
|
||||||
|
reward,
|
||||||
|
done,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return self._to_array(obs_dict), reward, done, info, obs_dict
|
return self._to_array(obs_dict), reward, done, info, obs_dict
|
||||||
|
|
||||||
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:
|
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:
|
||||||
|
|||||||
@ -10,8 +10,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import Beta
|
from torch.distributions import Beta
|
||||||
|
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
from .adapters import DecisionEnvAdapter
|
from .adapters import DecisionEnvAdapter
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "rl_ppo"}
|
||||||
|
|
||||||
|
|
||||||
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
|
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
|
||||||
if isinstance(layer, nn.Linear):
|
if isinstance(layer, nn.Linear):
|
||||||
@ -168,6 +173,15 @@ class PPOTrainer:
|
|||||||
if config.seed is not None:
|
if config.seed is not None:
|
||||||
torch.manual_seed(config.seed)
|
torch.manual_seed(config.seed)
|
||||||
np.random.seed(config.seed)
|
np.random.seed(config.seed)
|
||||||
|
LOGGER.info(
|
||||||
|
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
|
||||||
|
obs_dim,
|
||||||
|
action_dim,
|
||||||
|
config.total_timesteps,
|
||||||
|
config.rollout_steps,
|
||||||
|
config.device,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
def train(self) -> TrainingSummary:
|
def train(self) -> TrainingSummary:
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
@ -180,6 +194,14 @@ class PPOTrainer:
|
|||||||
diagnostics: List[Dict[str, float]] = []
|
diagnostics: List[Dict[str, float]] = []
|
||||||
current_return = 0.0
|
current_return = 0.0
|
||||||
current_length = 0
|
current_length = 0
|
||||||
|
LOGGER.info(
|
||||||
|
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
|
||||||
|
cfg.total_timesteps,
|
||||||
|
cfg.rollout_steps,
|
||||||
|
cfg.epochs,
|
||||||
|
cfg.minibatch_size,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
while timesteps < cfg.total_timesteps:
|
while timesteps < cfg.total_timesteps:
|
||||||
rollout.reset()
|
rollout.reset()
|
||||||
@ -203,6 +225,14 @@ class PPOTrainer:
|
|||||||
if done:
|
if done:
|
||||||
episode_rewards.append(current_return)
|
episode_rewards.append(current_return)
|
||||||
episode_lengths.append(current_length)
|
episode_lengths.append(current_length)
|
||||||
|
LOGGER.info(
|
||||||
|
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
|
||||||
|
episode_rewards[-1],
|
||||||
|
episode_lengths[-1],
|
||||||
|
len(episode_rewards),
|
||||||
|
timesteps,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
current_return = 0.0
|
current_return = 0.0
|
||||||
current_length = 0
|
current_length = 0
|
||||||
next_obs_array, _ = self.adapter.reset()
|
next_obs_array, _ = self.adapter.reset()
|
||||||
@ -216,7 +246,17 @@ class PPOTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
|
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
|
||||||
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
|
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
|
||||||
|
LOGGER.debug(
|
||||||
|
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
|
||||||
|
rollout._pos,
|
||||||
|
timesteps,
|
||||||
|
cfg.total_timesteps - timesteps,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_policy_loss = None
|
||||||
|
last_value_loss = None
|
||||||
|
last_entropy = None
|
||||||
for _ in range(cfg.epochs):
|
for _ in range(cfg.epochs):
|
||||||
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
|
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
|
||||||
cfg.minibatch_size
|
cfg.minibatch_size
|
||||||
@ -241,6 +281,9 @@ class PPOTrainer:
|
|||||||
value_loss.backward()
|
value_loss.backward()
|
||||||
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
|
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
|
||||||
self.value_optimizer.step()
|
self.value_optimizer.step()
|
||||||
|
last_policy_loss = float(policy_loss.detach().cpu())
|
||||||
|
last_value_loss = float(value_loss.detach().cpu())
|
||||||
|
last_entropy = float(entropy.mean().detach().cpu())
|
||||||
|
|
||||||
diagnostics.append(
|
diagnostics.append(
|
||||||
{
|
{
|
||||||
@ -249,13 +292,30 @@ class PPOTrainer:
|
|||||||
"entropy": float(entropy.mean().detach().cpu()),
|
"entropy": float(entropy.mean().detach().cpu()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
LOGGER.info(
|
||||||
|
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
|
||||||
|
timesteps,
|
||||||
|
cfg.total_timesteps,
|
||||||
|
last_policy_loss if last_policy_loss is not None else 0.0,
|
||||||
|
last_value_loss if last_value_loss is not None else 0.0,
|
||||||
|
last_entropy if last_entropy is not None else 0.0,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainingSummary(
|
summary = TrainingSummary(
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
episode_rewards=episode_rewards,
|
episode_rewards=episode_rewards,
|
||||||
episode_lengths=episode_lengths,
|
episode_lengths=episode_lengths,
|
||||||
diagnostics=diagnostics,
|
diagnostics=diagnostics,
|
||||||
)
|
)
|
||||||
|
LOGGER.info(
|
||||||
|
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
|
||||||
|
summary.timesteps,
|
||||||
|
len(summary.episode_rewards),
|
||||||
|
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
|
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
|
||||||
|
|||||||
@ -5,6 +5,10 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, Iterable, Mapping, Optional, Sequence
|
from typing import Dict, Iterable, Mapping, Optional, Sequence
|
||||||
|
|
||||||
from .data_access import DataBroker
|
from .data_access import DataBroker
|
||||||
|
from .logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "feature_snapshot"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -15,6 +19,11 @@ class FeatureSnapshotService:
|
|||||||
|
|
||||||
def __init__(self, broker: Optional[DataBroker] = None) -> None:
|
def __init__(self, broker: Optional[DataBroker] = None) -> None:
|
||||||
self.broker = broker or DataBroker()
|
self.broker = broker or DataBroker()
|
||||||
|
LOGGER.debug(
|
||||||
|
"初始化特征快照服务 broker=%s",
|
||||||
|
type(self.broker).__name__,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
def load_latest(
|
def load_latest(
|
||||||
self,
|
self,
|
||||||
@ -27,13 +36,34 @@ class FeatureSnapshotService:
|
|||||||
"""Fetch a snapshot of feature values for the given universe."""
|
"""Fetch a snapshot of feature values for the given universe."""
|
||||||
|
|
||||||
if not ts_codes:
|
if not ts_codes:
|
||||||
|
LOGGER.debug(
|
||||||
|
"跳过快照加载(标的为空) trade_date=%s",
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return {}
|
return {}
|
||||||
return self.broker.fetch_batch_latest(
|
field_count = len(fields)
|
||||||
|
LOGGER.debug(
|
||||||
|
"加载特征快照 trade_date=%s universe=%s fields=%s auto_refresh=%s",
|
||||||
|
trade_date,
|
||||||
|
len(ts_codes),
|
||||||
|
field_count,
|
||||||
|
auto_refresh,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
snapshot = self.broker.fetch_batch_latest(
|
||||||
list(ts_codes),
|
list(ts_codes),
|
||||||
trade_date,
|
trade_date,
|
||||||
fields,
|
fields,
|
||||||
auto_refresh=auto_refresh,
|
auto_refresh=auto_refresh,
|
||||||
)
|
)
|
||||||
|
LOGGER.debug(
|
||||||
|
"特征快照加载完成 trade_date=%s universe=%s",
|
||||||
|
trade_date,
|
||||||
|
len(snapshot),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return snapshot
|
||||||
|
|
||||||
def load_single(
|
def load_single(
|
||||||
self,
|
self,
|
||||||
@ -45,14 +75,30 @@ class FeatureSnapshotService:
|
|||||||
) -> Mapping[str, object]:
|
) -> Mapping[str, object]:
|
||||||
"""Convenience wrapper to reuse the snapshot logic for a single symbol."""
|
"""Convenience wrapper to reuse the snapshot logic for a single symbol."""
|
||||||
|
|
||||||
|
field_list = list(fields)
|
||||||
|
LOGGER.debug(
|
||||||
|
"加载单标的快照 trade_date=%s ts_code=%s fields=%s auto_refresh=%s",
|
||||||
|
trade_date,
|
||||||
|
ts_code,
|
||||||
|
len(field_list),
|
||||||
|
auto_refresh,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
snapshot = self.load_latest(
|
snapshot = self.load_latest(
|
||||||
trade_date,
|
trade_date,
|
||||||
list(fields),
|
field_list,
|
||||||
[ts_code],
|
[ts_code],
|
||||||
auto_refresh=auto_refresh,
|
auto_refresh=auto_refresh,
|
||||||
)
|
)
|
||||||
return snapshot.get(ts_code, {})
|
result = snapshot.get(ts_code, {})
|
||||||
|
if not result:
|
||||||
|
LOGGER.debug(
|
||||||
|
"单标的快照为空 trade_date=%s ts_code=%s",
|
||||||
|
trade_date,
|
||||||
|
ts_code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["FeatureSnapshotService"]
|
__all__ = ["FeatureSnapshotService"]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user