This commit is contained in:
sam 2025-09-29 07:50:06 +08:00
parent de88b198b3
commit 2afdaa76ed
4 changed files with 196 additions and 101 deletions

View File

@ -9,11 +9,12 @@ from app.agents.base import AgentAction
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
from app.llm.prompts import department_prompt from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger from app.utils.logging import get_logger, get_conversation_logger
from app.utils.data_access import DataBroker from app.utils.data_access import DataBroker
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"} LOG_EXTRA = {"stage": "department"}
CONV_LOGGER = get_conversation_logger()
@dataclass @dataclass
@ -113,6 +114,12 @@ class DepartmentAgent:
primary_endpoint = llm_cfg.primary primary_endpoint = llm_cfg.primary
final_message: Optional[Dict[str, Any]] = None final_message: Optional[Dict[str, Any]] = None
CONV_LOGGER.info(
"dept=%s ts_code=%s trade_date=%s start",
self.settings.code,
context.ts_code,
context.trade_date,
)
for round_idx in range(self._max_rounds): for round_idx in range(self._max_rounds):
try: try:
@ -142,6 +149,12 @@ class DepartmentAgent:
if message.get("tool_calls"): if message.get("tool_calls"):
assistant_record["tool_calls"] = message.get("tool_calls") assistant_record["tool_calls"] = message.get("tool_calls")
messages.append(assistant_record) messages.append(assistant_record)
CONV_LOGGER.info(
"dept=%s round=%s assistant=%s",
self.settings.code,
round_idx + 1,
assistant_record,
)
tool_calls = message.get("tool_calls") or [] tool_calls = message.get("tool_calls") or []
if tool_calls: if tool_calls:
@ -163,6 +176,13 @@ class DepartmentAgent:
} }
) )
delivered_requests.update(delivered) delivered_requests.update(delivered)
CONV_LOGGER.info(
"dept=%s round=%s tool_call=%s response=%s",
self.settings.code,
round_idx + 1,
call,
tool_response,
)
continue continue
final_message = message final_message = message
@ -175,6 +195,11 @@ class DepartmentAgent:
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
final_message = message final_message = message
CONV_LOGGER.warning(
"dept=%s rounds_exhausted last_message=%s",
self.settings.code,
final_message,
)
mutable_context.raw["supplement_transcript"] = list(transcript) mutable_context.raw["supplement_transcript"] = list(transcript)
@ -209,6 +234,13 @@ class DepartmentAgent:
decision.confidence, decision.confidence,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
CONV_LOGGER.info(
"dept=%s decision action=%s confidence=%.2f summary=%s",
self.settings.code,
decision.action.value,
decision.confidence,
summary or "",
)
return decision return decision
@staticmethod @staticmethod
@ -418,6 +450,11 @@ class DepartmentAgent:
exc, exc,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
CONV_LOGGER.error(
"dept=%s legacy_call_failed err=%s",
self.settings.code,
exc,
)
return DepartmentDecision( return DepartmentDecision(
department=self.settings.code, department=self.settings.code,
action=AgentAction.HOLD, action=AgentAction.HOLD,
@ -427,6 +464,7 @@ class DepartmentAgent:
) )
context.raw["supplement_transcript"] = [response] context.raw["supplement_transcript"] = [response]
CONV_LOGGER.info("dept=%s legacy_response=%s", self.settings.code, response)
decision_data = _parse_department_response(response) decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action")) action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5) confidence = _clamp_float(decision_data.get("confidence"), default=0.5)

View File

@ -251,20 +251,13 @@ class BacktestEngine:
method=self.cfg.method, method=self.cfg.method,
department_manager=self.department_manager, department_manager=self.department_manager,
) )
decisions.append(decision)
self.record_agent_state(context, decision)
if decision_callback:
try:
decision_callback(ts_code, trade_date, context, decision)
except Exception: # noqa: BLE001
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
try: try:
metrics_record_decision( metrics_record_decision(
ts_code=ts_code, ts_code=ts_code,
trade_date=context.trade_date, trade_date=context.trade_date,
action=decision.action.value, action=decision.action.value,
confidence=decision.confidence, confidence=decision.confidence,
summary=decision.summary, summary=_extract_summary(decision),
source="backtest", source="backtest",
departments={ departments={
code: dept.to_dict() code: dept.to_dict()
@ -273,6 +266,13 @@ class BacktestEngine:
) )
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA) LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
decisions.append(decision)
self.record_agent_state(context, decision)
if decision_callback:
try:
decision_callback(ts_code, trade_date, context, decision)
except Exception: # noqa: BLE001
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
# TODO: translate decisions into fills, holdings, and NAV updates. # TODO: translate decisions into fills, holdings, and NAV updates.
_ = state _ = state
return decisions return decisions
@ -408,3 +408,9 @@ def run_backtest(
_ = conn _ = conn
# Implementation should persist bt_nav, bt_trades, and bt_report rows. # Implementation should persist bt_nav, bt_trades, and bt_report rows.
return result return result
def _extract_summary(decision: Decision) -> str:
for dept_decision in decision.department_decisions.values():
summary = getattr(dept_decision, "summary", "")
if summary:
return str(summary)
return ""

View File

@ -51,6 +51,60 @@ LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"} LOG_EXTRA = {"stage": "ui"}
def render_global_dashboard() -> None:
"""Render a persistent sidebar with realtime LLM stats and recent decisions."""
metrics_container = st.sidebar.container()
decisions_container = st.sidebar.container()
st.session_state["dashboard_placeholders"] = (metrics_container, decisions_container)
_update_dashboard_sidebar()
def _update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None:
placeholders = st.session_state.get("dashboard_placeholders")
if not placeholders:
return
metrics_container, decisions_container = placeholders
metrics = metrics or snapshot_llm_metrics()
metrics_container.empty()
with metrics_container.container():
st.header("系统监控")
col_a, col_b, col_c = st.columns(3)
col_a.metric("LLM 调用", metrics.get("total_calls", 0))
col_b.metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
col_c.metric("Completion Tokens", metrics.get("total_completion_tokens", 0))
provider_calls = metrics.get("provider_calls", {})
model_calls = metrics.get("model_calls", {})
if provider_calls or model_calls:
with st.expander("调用分布", expanded=False):
if provider_calls:
st.write("按 Provider")
st.json(provider_calls)
if model_calls:
st.write("按模型:")
st.json(model_calls)
decisions_container.empty()
with decisions_container.container():
st.subheader("最新决策")
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
if decisions:
for record in reversed(decisions[-10:]):
ts_code = record.get("ts_code")
trade_date = record.get("trade_date")
action = record.get("action")
confidence = record.get("confidence", 0.0)
summary = record.get("summary")
st.markdown(
f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
)
if summary:
st.caption(summary)
else:
st.caption("暂无决策记录。执行回测或实时评估后可在此查看。")
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
"""Attempt to query provider API and return available model ids.""" """Attempt to query provider API and return available model ids."""
@ -159,45 +213,7 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
def render_today_plan() -> None: def render_today_plan() -> None:
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
st.header("今日计划") st.header("今日计划")
st.caption("统计数据基于最近一次渲染,刷新页面即可获取最新结果。") st.caption("统计与决策概览现已移至左侧“系统监控”侧栏。")
metrics_state = snapshot_llm_metrics()
st.subheader("LLM 调用统计 (实时)")
stats_col1, stats_col2, stats_col3 = st.columns(3)
stats_col1.metric("总调用次数", metrics_state.get("total_calls", 0))
stats_col2.metric("Prompt Tokens", metrics_state.get("total_prompt_tokens", 0))
stats_col3.metric("Completion Tokens", metrics_state.get("total_completion_tokens", 0))
provider_calls = metrics_state.get("provider_calls", {})
model_calls = metrics_state.get("model_calls", {})
if provider_calls or model_calls:
with st.expander("调用明细", expanded=False):
if provider_calls:
st.write("按 Provider")
st.json(provider_calls)
if model_calls:
st.write("按模型:")
st.json(model_calls)
st.subheader("最近决策 (全局)")
decision_feed = metrics_state.get("recent_decisions", []) or llm_recent_decisions(20)
if decision_feed:
for record in reversed(decision_feed[-20:]):
ts_code = record.get("ts_code")
trade_date = record.get("trade_date")
action = record.get("action")
confidence = record.get("confidence")
summary = record.get("summary")
departments = record.get("departments", {})
st.markdown(
f"**{trade_date} {ts_code}** → {action} (信心 {confidence:.2f})"
)
if summary:
st.caption(f"摘要:{summary}")
if departments:
st.json(departments)
st.divider()
else:
st.caption("暂无决策记录,执行回测或实时评估后可在此查看。")
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
date_rows = conn.execute( date_rows = conn.execute(
@ -437,13 +453,17 @@ def render_backtest() -> None:
if st.button("运行回测"): if st.button("运行回测"):
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
decision_log_container = st.container() decision_log_container = st.container()
status_placeholder = st.empty() status_box = st.status("准备执行回测...", expanded=True)
llm_stats_placeholder = st.empty() llm_stats_placeholder = st.empty()
decision_entries: List[str] = [] decision_entries: List[str] = []
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None: def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
ts_label = trade_dt.isoformat() ts_label = trade_dt.isoformat()
summary = decision.summary summary = ""
for dept_decision in decision.department_decisions.values():
if getattr(dept_decision, "summary", ""):
summary = str(dept_decision.summary)
break
entry_lines = [ entry_lines = [
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})", f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
] ]
@ -457,67 +477,71 @@ def render_backtest() -> None:
if dep_highlights: if dep_highlights:
entry_lines.append("部门意见:" + "".join(dep_highlights)) entry_lines.append("部门意见:" + "".join(dep_highlights))
decision_entries.append(" \n".join(entry_lines)) decision_entries.append(" \n".join(entry_lines))
decision_log_container.markdown("\n\n".join(decision_entries[-50:])) decision_log_container.markdown("\n\n".join(decision_entries[-200:]))
status_placeholder.info( status_box.write(f"{ts_label} {ts_code}{decision.action.value} (信心 {decision.confidence:.2f})")
f"最新决策:{ts_code} -> {decision.action.value} ({decision.confidence:.2f})"
)
stats = snapshot_llm_metrics() stats = snapshot_llm_metrics()
llm_stats_placeholder.json( llm_stats_placeholder.json(
{ {
"LLM 调用次数": stats.get("total_calls", 0), "LLM 调用次数": stats.get("total_calls", 0),
"Prompt Tokens": stats.get("total_prompt_tokens", 0), "Prompt Tokens": stats.get("total_prompt_tokens", 0),
"Completion Tokens": stats.get("total_completion_tokens", 0), "Completion Tokens": stats.get("total_completion_tokens", 0),
"按 Provider": stats.get("provider_calls", {}),
"按模型": stats.get("model_calls", {}),
} }
) )
_update_dashboard_sidebar(stats)
reset_llm_metrics() reset_llm_metrics()
with st.spinner("正在执行回测..."): status_box.update(label="执行回测中...", state="running")
try: try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()] universe = [code.strip() for code in universe_text.split(',') if code.strip()]
LOGGER.info( LOGGER.info(
"回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", "回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
start_date, start_date,
end_date, end_date,
universe, universe,
target, target,
stop, stop,
hold_days, hold_days,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
cfg = BtConfig( cfg = BtConfig(
id="streamlit_demo", id="streamlit_demo",
name="Streamlit Demo Strategy", name="Streamlit Demo Strategy",
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
universe=universe, universe=universe,
params={ params={
"target": target, "target": target,
"stop": stop, "stop": stop,
"hold_days": int(hold_days), "hold_days": int(hold_days),
}, },
) )
result = run_backtest(cfg, decision_callback=_decision_callback) result = run_backtest(cfg, decision_callback=_decision_callback)
LOGGER.info( LOGGER.info(
"回测完成nav_records=%s trades=%s", "回测完成nav_records=%s trades=%s",
len(result.nav_series), len(result.nav_series),
len(result.trades), len(result.trades),
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
st.success("回测执行完成,详见下方结果与统计。") status_box.update(label="回测执行完成", state="complete")
metrics = snapshot_llm_metrics() st.success("回测执行完成,详见下方结果与统计。")
llm_stats_placeholder.json( metrics = snapshot_llm_metrics()
{ llm_stats_placeholder.json(
"LLM 调用次数": metrics.get("total_calls", 0), {
"Prompt Tokens": metrics.get("total_prompt_tokens", 0), "LLM 调用次数": metrics.get("total_calls", 0),
"Completion Tokens": metrics.get("total_completion_tokens", 0), "Prompt Tokens": metrics.get("total_prompt_tokens", 0),
"按 Provider": metrics.get("provider_calls", {}), "Completion Tokens": metrics.get("total_completion_tokens", 0),
"按模型": metrics.get("model_calls", {}), "按 Provider": metrics.get("provider_calls", {}),
} "按模型": metrics.get("model_calls", {}),
) }
st.json({"nav_records": result.nav_series, "trades": result.trades}) )
except Exception as exc: # noqa: BLE001 _update_dashboard_sidebar(metrics)
LOGGER.exception("回测执行失败", extra=LOG_EXTRA) st.json({"nav_records": result.nav_series, "trades": result.trades})
st.error(f"回测执行失败:{exc}") except Exception as exc: # noqa: BLE001
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
status_box.update(label="回测执行失败", state="error")
st.error(f"回测执行失败:{exc}")
def render_settings() -> None: def render_settings() -> None:
@ -1151,6 +1175,7 @@ def render_tests() -> None:
def main() -> None: def main() -> None:
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
st.set_page_config(page_title="多智能体投资助理", layout="wide") st.set_page_config(page_title="多智能体投资助理", layout="wide")
render_global_dashboard()
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA) LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)
with tabs[0]: with tabs[0]:

View File

@ -20,6 +20,9 @@ from .db import db_session
_LOGGER_NAME = "app.logging" _LOGGER_NAME = "app.logging"
_IS_CONFIGURED = False _IS_CONFIGURED = False
_CONVERSATION_LOGGER_NAME = "app.conversation"
_CONVERSATION_HANDLER: Optional[Handler] = None
_CONVERSATION_LOGFILE: Optional[Path] = None
class DatabaseLogHandler(Handler): class DatabaseLogHandler(Handler):
@ -77,6 +80,7 @@ def setup_logging(
log_dir: Path = cfg.data_paths.root / "logs" log_dir: Path = cfg.data_paths.root / "logs"
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
logfile = log_dir / f"app_{timestamp}.log" logfile = log_dir / f"app_{timestamp}.log"
conversation_logfile = log_dir / f"agent_{timestamp}.log"
root = logging.getLogger() root = logging.getLogger()
root.setLevel(level) root.setLevel(level)
@ -113,6 +117,19 @@ def setup_logging(
}, },
}, },
) )
conversation_logger = logging.getLogger(_CONVERSATION_LOGGER_NAME)
conversation_logger.setLevel(level)
conversation_logger.handlers.clear()
conversation_logger.propagate = False
conv_handler = logging.FileHandler(conversation_logfile, encoding="utf-8")
conv_handler.setLevel(level)
conv_handler.setFormatter(formatter)
conversation_logger.addHandler(conv_handler)
global _CONVERSATION_HANDLER, _CONVERSATION_LOGFILE
_CONVERSATION_HANDLER = conv_handler
_CONVERSATION_LOGFILE = conversation_logfile
return root return root
@ -128,5 +145,14 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
return logger return logger
def get_conversation_logger() -> logging.Logger:
"""Return conversation logger for agent dialogues."""
setup_logging()
logger = logging.getLogger(_CONVERSATION_LOGGER_NAME)
logger.propagate = False
return logger
# 默认在模块导入时完成配置,适配现有调用方式。 # 默认在模块导入时完成配置,适配现有调用方式。
setup_logging() setup_logging()