diff --git a/README.md b/README.md index 35ac27f..ca187d6 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,10 @@ 1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。 2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。 -3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。 +3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 只需声明所需表(如 `daily`、`daily_basic`)及窗口,系统用 `DataBroker` 拉取整行数据、回传结果,形成“请求→取数→复议”闭环。 4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 -目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。 +目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型只需按 schema 声明 `daily` / `daily_basic` 等表名与窗口,系统使用 `DataBroker` 一次性返回指定交易日的全部列,再带着查询结果进入下一轮提示,从而形成闭环。 上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。 diff --git a/app/agents/departments.py b/app/agents/departments.py index d0789d0..5ef5fb4 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -3,7 +3,7 @@ from __future__ import annotations import json from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple from app.agents.base import AgentAction from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError @@ -17,9 +17,10 @@ LOG_EXTRA = {"stage": "department"} @dataclass -class DataRequest: - field: str +class TableRequest: + name: str window: int = 1 + trade_date: Optional[str] = None @dataclass @@ -64,6 +65,8 @@ class DepartmentDecision: class DepartmentAgent: """Wraps LLM ensemble logic for a single analytical department.""" + ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"] + def __init__( self, settings: DepartmentSettings, @@ -106,10 +109,7 @@ class DepartmentAgent: ) transcript: List[str] = [] - delivered_requests: set[Tuple[str, int]] = { - (field, 1) - for field in (mutable_context.raw.get("scope_values") or {}).keys() - } + delivered_requests: set[Tuple[str, int, str]] = set() primary_endpoint = llm_cfg.primary final_message: Optional[Dict[str, Any]] = None @@ -135,6 +135,14 @@ class DepartmentAgent: message = choice.get("message", {}) transcript.append(_message_to_text(message)) + assistant_record: Dict[str, Any] = { + "role": "assistant", + "content": _extract_message_content(message), + } + if message.get("tool_calls"): + assistant_record["tool_calls"] = message.get("tool_calls") + messages.append(assistant_record) + tool_calls = message.get("tool_calls") or [] if tool_calls: for call in tool_calls: @@ -151,7 +159,6 @@ class DepartmentAgent: { "role": "tool", "tool_call_id": call.get("id"), - "name": call.get("function", {}).get("name"), "content": json.dumps(tool_response, ensure_ascii=False), } ) @@ -213,165 +220,50 @@ class DepartmentAgent: def _fulfill_data_requests( self, context: DepartmentContext, - requests: Sequence[DataRequest], - ) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]: + requests: Sequence[TableRequest], + ) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]: lines: List[str] = [] payload: List[Dict[str, Any]] = [] - delivered: set[Tuple[str, int]] = set() + delivered: set[Tuple[str, int, str]] = set() ts_code = context.ts_code - trade_date = self._normalize_trade_date(context.trade_date) - - latest_groups: Dict[str, List[str]] = {} - series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = [] - values_map, db_alias_map, series_map = _build_context_lookup(context) + default_trade_date = self._normalize_trade_date(context.trade_date) for req in requests: - field = req.field.strip() - if not field: + table = (req.name or "").strip().lower() + if not table: continue - window = req.window - resolved: Optional[Tuple[str, str]] = None - if "." in field: - resolved = self._broker.resolve_field(field) - elif field in db_alias_map: - resolved = db_alias_map[field] - - if resolved: - table, column = resolved - canonical = f"{table}.{column}" - if window <= 1: - latest_groups.setdefault(canonical, []).append(field) - delivered.add((field, 1)) - delivered.add((canonical, 1)) - else: - series_requests.append((req, resolved)) - delivered.add((field, window)) - delivered.add((canonical, window)) + if table not in self.ALLOWED_TABLES: + lines.append(f"- {table}: 不在允许的表列表中") + continue + trade_date = self._normalize_trade_date(req.trade_date or default_trade_date) + window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120))) + key = (table, window, trade_date) + if key in delivered: + lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求") continue - if field in values_map: - value = values_map[field] - if window <= 1: - payload.append( - { - "field": field, - "window": 1, - "source": "context", - "values": [ - { - "trade_date": context.trade_date, - "value": value, - } - ], - } - ) - lines.append(f"- {field}: {value} (来自上下文)") - else: - series = series_map.get(field) - if series: - trimmed = series[: window] - payload.append( - { - "field": field, - "window": window, - "source": "context_series", - "values": [ - {"trade_date": dt, "value": val} - for dt, val in trimmed - ], - } - ) - preview = ", ".join( - f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)] - ) - lines.append( - f"- {field} (window={window} 来自上下文序列): {preview}" - ) - else: - payload.append( - { - "field": field, - "window": window, - "source": "context", - "values": [ - { - "trade_date": context.trade_date, - "value": value, - } - ], - "warning": "仅提供当前值,缺少历史序列", - } - ) - lines.append( - f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列" - ) - delivered.add((field, window)) - if field in db_alias_map: - resolved = db_alias_map[field] - canonical = f"{resolved[0]}.{resolved[1]}" - delivered.add((canonical, window)) - continue - - lines.append(f"- {field}: 字段不存在或不可用") - - if latest_groups: - latest_values = self._broker.fetch_latest( - ts_code, trade_date, list(latest_groups.keys()) - ) - for canonical, aliases in latest_groups.items(): - value = latest_values.get(canonical) - if value is None: - lines.append(f"- {canonical}: (数据缺失)") - else: - lines.append(f"- {canonical}: {value}") - for alias in aliases: - payload.append( - { - "field": alias, - "window": 1, - "source": "database", - "values": [ - { - "trade_date": trade_date, - "value": value, - } - ], - } - ) - - for req, resolved in series_requests: - table, column = resolved - series = self._broker.fetch_series( - table, - column, - ts_code, - trade_date, - window=req.window, - ) - if series: + rows = self._broker.fetch_table_rows(table, ts_code, trade_date, window) + if rows: preview = ", ".join( - f"{dt}:{val:.4f}" - for dt, val in series[: min(len(series), 5)] + f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)] ) lines.append( - f"- {req.field} (window={req.window}): {preview}" + f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)} 行 {preview}" ) else: lines.append( - f"- {req.field} (window={req.window}): (数据缺失)" + f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)" ) payload.append( { - "field": req.field, - "window": req.window, - "source": "database", - "values": [ - {"trade_date": dt, "value": val} - for dt, val in series - ], + "table": table, + "window": window, + "trade_date": trade_date, + "rows": rows, } ) + delivered.add(key) return lines, payload, delivered @@ -380,9 +272,9 @@ class DepartmentAgent: self, context: DepartmentContext, call: Mapping[str, Any], - delivered_requests: set[Tuple[str, int]], + delivered_requests: set[Tuple[str, int, str]], round_idx: int, - ) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]: + ) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]: function_block = call.get("function") or {} name = function_block.get("name") or "" if name != "fetch_data": @@ -398,23 +290,29 @@ class DepartmentAgent: }, set() args = _parse_tool_arguments(function_block.get("arguments")) - raw_requests = args.get("requests") or [] - requests: List[DataRequest] = [] + base_trade_date = self._normalize_trade_date( + args.get("trade_date") or context.trade_date + ) + raw_requests = args.get("tables") or [] + requests: List[TableRequest] = [] skipped: List[str] = [] for item in raw_requests: - field = str(item.get("field", "")).strip() - if not field: + name = str(item.get("name", "")).strip().lower() + if not name: continue + window_raw = item.get("window") try: - window = int(item.get("window", 1)) + window = int(window_raw) if window_raw is not None else 1 except (TypeError, ValueError): window = 1 window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120))) - key = (field, window) + override_date = item.get("trade_date") + req_date = self._normalize_trade_date(override_date or base_trade_date) + key = (name, window, req_date) if key in delivered_requests: - skipped.append(field) + skipped.append(name) continue - requests.append(DataRequest(field=field, window=window)) + requests.append(TableRequest(name=name, window=window, trade_date=req_date)) if not requests: return { @@ -461,34 +359,44 @@ class DepartmentAgent: "function": { "name": "fetch_data", "description": ( - "根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段," - "window 表示希望返回的最近数据点数量。" + "根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily' 与 'daily_basic' 表。" ), "parameters": { "type": "object", "properties": { - "requests": { + "tables": { "type": "array", "items": { "type": "object", "properties": { - "field": { + "name": { "type": "string", - "description": "数据字段,格式为 table.column", + "enum": self.ALLOWED_TABLES, + "description": "表名,例如 daily 或 daily_basic", }, "window": { "type": "integer", "minimum": 1, "maximum": max_window, - "description": "返回最近多少个数据点,默认为 1", + "description": "向前回溯的记录条数,默认为 1", + }, + "trade_date": { + "type": "string", + "pattern": r"^\\d{8}$", + "description": "覆盖默认交易日(格式 YYYYMMDD)", }, }, - "required": ["field"], + "required": ["name"], }, "minItems": 1, - } + }, + "trade_date": { + "type": "string", + "pattern": r"^\\d{8}$", + "description": "默认交易日(格式 YYYYMMDD)", + }, }, - "required": ["requests"], + "required": ["tables"], }, }, } @@ -555,39 +463,6 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: return context -def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]: - raw_requests = payload.get("data_requests") - requests: List[DataRequest] = [] - if not isinstance(raw_requests, list): - return requests - seen: set[Tuple[str, int]] = set() - for item in raw_requests: - field = "" - window = 1 - if isinstance(item, str): - field = item.strip() - elif isinstance(item, Mapping): - candidate = item.get("field") - if candidate is None: - continue - field = str(candidate).strip() - try: - window = int(item.get("window", 1)) - except (TypeError, ValueError): - window = 1 - else: - continue - if not field: - continue - window = max(1, window) - key = (field, window) - if key in seen: - continue - seen.add(key) - requests.append(DataRequest(field=field, window=window)) - return requests - - def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: if isinstance(payload, dict): return dict(payload) @@ -636,38 +511,6 @@ def _extract_message_content(message: Mapping[str, Any]) -> str: return json.dumps(message, ensure_ascii=False) -def _build_context_lookup( - context: DepartmentContext, -) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]: - values: Dict[str, Any] = {} - db_alias: Dict[str, Tuple[str, str]] = {} - series_map: Dict[str, List[Tuple[str, float]]] = {} - - for source in (context.features or {}, context.market_snapshot or {}): - for key, value in source.items(): - values[str(key)] = value - - scope_values = context.raw.get("scope_values") or {} - for key, value in scope_values.items(): - key_str = str(key) - values[key_str] = value - if "." in key_str: - table, column = key_str.split(".", 1) - db_alias.setdefault(column, (table, column)) - db_alias.setdefault(key_str, (table, column)) - values.setdefault(column, value) - - close_series = context.raw.get("close_series") or [] - if isinstance(close_series, list) and close_series: - series_map["close"] = close_series - series_map["daily.close"] = close_series - - turnover_series = context.raw.get("turnover_series") or [] - if isinstance(turnover_series, list) and turnover_series: - series_map["turnover_rate"] = turnover_series - series_map["daily_basic.turnover_rate"] = turnover_series - - return values, db_alias, series_map class DepartmentManager: diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 912c6cf..080d4f5 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -5,11 +5,12 @@ import json from dataclasses import dataclass, field from datetime import date from statistics import mean, pstdev -from typing import Any, Dict, Iterable, List, Mapping +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional from app.agents.base import AgentContext from app.agents.departments import DepartmentManager from app.agents.game import Decision, decide +from app.llm.metrics import record_decision as metrics_record_decision from app.agents.registry import default_agents from app.utils.data_access import DataBroker from app.utils.config import get_config @@ -224,7 +225,12 @@ class BacktestEngine: return feature_map - def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]: + def simulate_day( + self, + trade_date: date, + state: PortfolioState, + decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, + ) -> List[Decision]: feature_map = self.load_market_data(trade_date) decisions: List[Decision] = [] for ts_code, payload in feature_map.items(): @@ -247,6 +253,26 @@ class BacktestEngine: ) decisions.append(decision) self.record_agent_state(context, decision) + if decision_callback: + try: + decision_callback(ts_code, trade_date, context, decision) + except Exception: # noqa: BLE001 + LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA) + try: + metrics_record_decision( + ts_code=ts_code, + trade_date=context.trade_date, + action=decision.action.value, + confidence=decision.confidence, + summary=decision.summary, + source="backtest", + departments={ + code: dept.to_dict() + for code, dept in decision.department_decisions.items() + }, + ) + except Exception: # noqa: BLE001 + LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA) # TODO: translate decisions into fills, holdings, and NAV updates. _ = state return decisions @@ -357,20 +383,27 @@ class BacktestEngine: _ = payload # TODO: persist payload into bt_trades / audit tables when schema is ready. - def run(self) -> BacktestResult: + def run( + self, + decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, + ) -> BacktestResult: state = PortfolioState() result = BacktestResult() current = self.cfg.start_date while current <= self.cfg.end_date: - decisions = self.simulate_day(current, state) + decisions = self.simulate_day(current, state, decision_callback) _ = decisions current = date.fromordinal(current.toordinal() + 1) return result -def run_backtest(cfg: BtConfig) -> BacktestResult: +def run_backtest( + cfg: BtConfig, + *, + decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, +) -> BacktestResult: engine = BacktestEngine(cfg) - result = engine.run() + result = engine.run(decision_callback=decision_callback) with db_session() as conn: _ = conn # Implementation should persist bt_nav, bt_trades, and bt_report rows. diff --git a/app/llm/client.py b/app/llm/client.py index 596c003..7a7e915 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -17,6 +17,7 @@ from app.utils.config import ( LLMEndpoint, get_config, ) +from app.llm.metrics import record_call from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -220,11 +221,12 @@ def call_endpoint_with_messages( ) if response.status_code != 200: raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}") + record_call(provider_key, model) return response.json() if not api_key: raise LLMError(f"缺少 {provider_key} API Key (model={model})") - return _request_openai_chat( + data = _request_openai_chat( base_url=base_url, api_key=api_key, model=model, @@ -234,6 +236,11 @@ def call_endpoint_with_messages( tools=tools, tool_choice=tool_choice, ) + usage = data.get("usage", {}) if isinstance(data, dict) else {} + prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total") + completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total") + record_call(provider_key, model, prompt_tokens, completion_tokens) + return data def _normalize_response(text: str) -> str: diff --git a/app/llm/metrics.py b/app/llm/metrics.py new file mode 100644 index 0000000..c14ff1e --- /dev/null +++ b/app/llm/metrics.py @@ -0,0 +1,114 @@ +"""Simple runtime metrics collector for LLM calls.""" +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from threading import Lock +from typing import Deque, Dict, List, Optional + + +@dataclass +class _Metrics: + total_calls: int = 0 + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + provider_calls: Dict[str, int] = field(default_factory=dict) + model_calls: Dict[str, int] = field(default_factory=dict) + decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500)) + decision_action_counts: Dict[str, int] = field(default_factory=dict) + + +_METRICS = _Metrics() +_LOCK = Lock() + + +def record_call( + provider: str, + model: Optional[str] = None, + prompt_tokens: Optional[int] = None, + completion_tokens: Optional[int] = None, +) -> None: + """Record a single LLM API invocation.""" + + normalized_provider = (provider or "unknown").lower() + normalized_model = (model or "").strip() + with _LOCK: + _METRICS.total_calls += 1 + _METRICS.provider_calls[normalized_provider] = ( + _METRICS.provider_calls.get(normalized_provider, 0) + 1 + ) + if normalized_model: + _METRICS.model_calls[normalized_model] = ( + _METRICS.model_calls.get(normalized_model, 0) + 1 + ) + if prompt_tokens: + _METRICS.total_prompt_tokens += int(prompt_tokens) + if completion_tokens: + _METRICS.total_completion_tokens += int(completion_tokens) + + +def snapshot(reset: bool = False) -> Dict[str, object]: + """Return a snapshot of current metrics. Optionally reset counters.""" + + with _LOCK: + data = { + "total_calls": _METRICS.total_calls, + "total_prompt_tokens": _METRICS.total_prompt_tokens, + "total_completion_tokens": _METRICS.total_completion_tokens, + "provider_calls": dict(_METRICS.provider_calls), + "model_calls": dict(_METRICS.model_calls), + "decision_action_counts": dict(_METRICS.decision_action_counts), + "recent_decisions": list(_METRICS.decisions), + } + if reset: + _METRICS.total_calls = 0 + _METRICS.total_prompt_tokens = 0 + _METRICS.total_completion_tokens = 0 + _METRICS.provider_calls.clear() + _METRICS.model_calls.clear() + _METRICS.decision_action_counts.clear() + _METRICS.decisions.clear() + return data + + +def reset() -> None: + """Reset all collected metrics.""" + + snapshot(reset=True) + + +def record_decision( + *, + ts_code: str, + trade_date: str, + action: str, + confidence: float, + summary: str, + source: str, + departments: Optional[Dict[str, object]] = None, +) -> None: + """Record a high-level decision for later inspection.""" + + record = { + "ts_code": ts_code, + "trade_date": trade_date, + "action": action, + "confidence": confidence, + "summary": summary, + "source": source, + "departments": departments or {}, + } + with _LOCK: + _METRICS.decisions.append(record) + _METRICS.decision_action_counts[action] = ( + _METRICS.decision_action_counts.get(action, 0) + 1 + ) + + +def recent_decisions(limit: int = 50) -> List[Dict[str, object]]: + """Return the most recent decisions up to limit.""" + + with _LOCK: + if limit <= 0: + return [] + return list(_METRICS.decisions)[-limit:] diff --git a/app/llm/prompts.py b/app/llm/prompts.py index 86340af..c0db8aa 100644 --- a/app/llm/prompts.py +++ b/app/llm/prompts.py @@ -62,7 +62,7 @@ def department_prompt( "risks": ["风险点", "..."] }} -如需额外数据,请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组(元素包含 `field` 以及可选的 `window`);`field` 必须符合【可用数据范围】,`window` 默认为 1。 +如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。 工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。 请严格返回单个 JSON 对象,不要添加额外文本。 diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index bb6c340..5f7f025 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -20,11 +20,18 @@ import requests from requests.exceptions import RequestException import streamlit as st +from app.agents.base import AgentContext +from app.agents.game import Decision from app.backtest.engine import BtConfig, run_backtest from app.data.schema import initialize_database from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion from app.llm.client import llm_config_snapshot, run_llm +from app.llm.metrics import ( + reset as reset_llm_metrics, + snapshot as snapshot_llm_metrics, + recent_decisions as llm_recent_decisions, +) from app.utils.config import ( ALLOWED_LLM_STRATEGIES, DEFAULT_LLM_BASE_URLS, @@ -152,6 +159,45 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: def render_today_plan() -> None: LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) st.header("今日计划") + st.caption("统计数据基于最近一次渲染,刷新页面即可获取最新结果。") + + metrics_state = snapshot_llm_metrics() + st.subheader("LLM 调用统计 (实时)") + stats_col1, stats_col2, stats_col3 = st.columns(3) + stats_col1.metric("总调用次数", metrics_state.get("total_calls", 0)) + stats_col2.metric("Prompt Tokens", metrics_state.get("total_prompt_tokens", 0)) + stats_col3.metric("Completion Tokens", metrics_state.get("total_completion_tokens", 0)) + provider_calls = metrics_state.get("provider_calls", {}) + model_calls = metrics_state.get("model_calls", {}) + if provider_calls or model_calls: + with st.expander("调用明细", expanded=False): + if provider_calls: + st.write("按 Provider:") + st.json(provider_calls) + if model_calls: + st.write("按模型:") + st.json(model_calls) + + st.subheader("最近决策 (全局)") + decision_feed = metrics_state.get("recent_decisions", []) or llm_recent_decisions(20) + if decision_feed: + for record in reversed(decision_feed[-20:]): + ts_code = record.get("ts_code") + trade_date = record.get("trade_date") + action = record.get("action") + confidence = record.get("confidence") + summary = record.get("summary") + departments = record.get("departments", {}) + st.markdown( + f"**{trade_date} {ts_code}** → {action} (信心 {confidence:.2f})" + ) + if summary: + st.caption(f"摘要:{summary}") + if departments: + st.json(departments) + st.divider() + else: + st.caption("暂无决策记录,执行回测或实时评估后可在此查看。") try: with db_session(read_only=True) as conn: date_rows = conn.execute( @@ -323,7 +369,7 @@ def render_today_plan() -> None: st.subheader("部门意见") if dept_records: dept_df = pd.DataFrame(dept_records) - st.dataframe(dept_df, use_container_width=True, hide_index=True) + st.dataframe(dept_df, width='stretch', hide_index=True) for code, details in dept_details.items(): with st.expander(f"{code} 补充详情", expanded=False): supplements = details.get("supplements", []) @@ -345,7 +391,7 @@ def render_today_plan() -> None: st.subheader("代理评分") if agent_records: agent_df = pd.DataFrame(agent_records) - st.dataframe(agent_df, use_container_width=True, hide_index=True) + st.dataframe(agent_df, width='stretch', hide_index=True) else: st.info("暂无基础代理评分。") @@ -390,6 +436,41 @@ def render_backtest() -> None: if st.button("运行回测"): LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) + decision_log_container = st.container() + status_placeholder = st.empty() + llm_stats_placeholder = st.empty() + decision_entries: List[str] = [] + + def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None: + ts_label = trade_dt.isoformat() + summary = decision.summary + entry_lines = [ + f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})", + ] + if summary: + entry_lines.append(f"摘要:{summary}") + dep_highlights = [] + for dept_code, dept_decision in decision.department_decisions.items(): + dep_highlights.append( + f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})" + ) + if dep_highlights: + entry_lines.append("部门意见:" + ";".join(dep_highlights)) + decision_entries.append(" \n".join(entry_lines)) + decision_log_container.markdown("\n\n".join(decision_entries[-50:])) + status_placeholder.info( + f"最新决策:{ts_code} -> {decision.action.value} ({decision.confidence:.2f})" + ) + stats = snapshot_llm_metrics() + llm_stats_placeholder.json( + { + "LLM 调用次数": stats.get("total_calls", 0), + "Prompt Tokens": stats.get("total_prompt_tokens", 0), + "Completion Tokens": stats.get("total_completion_tokens", 0), + } + ) + + reset_llm_metrics() with st.spinner("正在执行回测..."): try: universe = [code.strip() for code in universe_text.split(',') if code.strip()] @@ -415,14 +496,24 @@ def render_backtest() -> None: "hold_days": int(hold_days), }, ) - result = run_backtest(cfg) + result = run_backtest(cfg, decision_callback=_decision_callback) LOGGER.info( "回测完成:nav_records=%s trades=%s", len(result.nav_series), len(result.trades), extra=LOG_EXTRA, ) - st.success("回测执行完成,详见回测结果摘要。") + st.success("回测执行完成,详见下方结果与统计。") + metrics = snapshot_llm_metrics() + llm_stats_placeholder.json( + { + "LLM 调用次数": metrics.get("total_calls", 0), + "Prompt Tokens": metrics.get("total_prompt_tokens", 0), + "Completion Tokens": metrics.get("total_completion_tokens", 0), + "按 Provider": metrics.get("provider_calls", {}), + "按模型": metrics.get("model_calls", {}), + } + ) st.json({"nav_records": result.nav_series, "trades": result.trades}) except Exception as exc: # noqa: BLE001 LOGGER.exception("回测执行失败", extra=LOG_EXTRA) @@ -654,7 +745,7 @@ def render_settings() -> None: ensemble_rows, num_rows="dynamic", key="global_ensemble_editor", - use_container_width=True, + width='stretch', hide_index=True, column_config={ "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), @@ -735,7 +826,7 @@ def render_settings() -> None: dept_rows, num_rows="fixed", key="department_editor", - use_container_width=True, + width='stretch', hide_index=True, column_config={ "code": st.column_config.TextColumn("编码", disabled=True), @@ -998,7 +1089,7 @@ def render_tests() -> None: ] ) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(candle_fig, use_container_width=True) + st.plotly_chart(candle_fig, width='stretch') vol_fig = px.bar( df_reset, @@ -1008,7 +1099,7 @@ def render_tests() -> None: title="成交量", ) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(vol_fig, use_container_width=True) + st.plotly_chart(vol_fig, width='stretch') amt_fig = px.bar( df_reset, @@ -1018,7 +1109,7 @@ def render_tests() -> None: title="成交额", ) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(amt_fig, use_container_width=True) + st.plotly_chart(amt_fig, width='stretch') df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) box_fig = px.box( @@ -1029,7 +1120,7 @@ def render_tests() -> None: title="月度收盘价分布", ) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(box_fig, use_container_width=True) + st.plotly_chart(box_fig, width='stretch') st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.dataframe(df_reset.tail(20), width='stretch') diff --git a/app/utils/data_access.py b/app/utils/data_access.py index baf9fb5..8155f7a 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -199,6 +199,55 @@ class DataBroker: return False return row is not None + def fetch_table_rows( + self, + table: str, + ts_code: str, + trade_date: str, + window: int, + ) -> List[Dict[str, object]]: + if window <= 0: + return [] + window = min(window, self.MAX_WINDOW) + columns = self._get_table_columns(table) + if not columns: + LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA) + return [] + + column_list = ", ".join(columns) + has_trade_date = "trade_date" in columns + if has_trade_date: + query = ( + f"SELECT {column_list} FROM {table} " + "WHERE ts_code = ? AND trade_date <= ? " + "ORDER BY trade_date DESC LIMIT ?" + ) + params: Tuple[object, ...] = (ts_code, trade_date, window) + else: + query = ( + f"SELECT {column_list} FROM {table} " + "WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?" + ) + params = (ts_code, window) + + results: List[Dict[str, object]] = [] + with db_session(read_only=True) as conn: + try: + rows = conn.execute(query, params).fetchall() + except Exception as exc: # noqa: BLE001 + LOGGER.debug( + "表查询失败 table=%s err=%s", + table, + exc, + extra=LOG_EXTRA, + ) + return [] + + for row in rows: + record = {col: row[col] for col in columns} + results.append(record) + return results + def resolve_field(self, field: str) -> Optional[Tuple[str, str]]: normalized = _safe_split(field) if not normalized: @@ -215,7 +264,7 @@ class DataBroker: return None return table, resolved - def _get_table_columns(self, table: str) -> Optional[set[str]]: + def _get_table_columns(self, table: str) -> Optional[List[str]]: if not _is_safe_identifier(table): return None cache = getattr(self, "_column_cache", None) @@ -234,7 +283,7 @@ class DataBroker: if not rows: cache[table] = None return None - columns = {row["name"] for row in rows if row["name"]} + columns = [row["name"] for row in rows if row["name"]] cache[table] = columns return columns