This commit is contained in:
sam 2025-09-28 22:57:04 +08:00
parent 199cb484b9
commit de88b198b3
8 changed files with 390 additions and 253 deletions

View File

@ -35,10 +35,10 @@
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。 1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。 2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。 3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具LLM 只需声明所需表(如 `daily`、`daily_basic`)及窗口,系统用 `DataBroker` 拉取整行数据、回传结果,形成“请求→取数→复议”闭环。
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。 目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型只需按 schema 声明 `daily` / `daily_basic` 等表名与窗口,系统使用 `DataBroker` 一次性返回指定交易日的全部列,再带着查询结果进入下一轮提示,从而形成闭环。
上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。 上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction from app.agents.base import AgentAction
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
@ -17,9 +17,10 @@ LOG_EXTRA = {"stage": "department"}
@dataclass @dataclass
class DataRequest: class TableRequest:
field: str name: str
window: int = 1 window: int = 1
trade_date: Optional[str] = None
@dataclass @dataclass
@ -64,6 +65,8 @@ class DepartmentDecision:
class DepartmentAgent: class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department.""" """Wraps LLM ensemble logic for a single analytical department."""
ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"]
def __init__( def __init__(
self, self,
settings: DepartmentSettings, settings: DepartmentSettings,
@ -106,10 +109,7 @@ class DepartmentAgent:
) )
transcript: List[str] = [] transcript: List[str] = []
delivered_requests: set[Tuple[str, int]] = { delivered_requests: set[Tuple[str, int, str]] = set()
(field, 1)
for field in (mutable_context.raw.get("scope_values") or {}).keys()
}
primary_endpoint = llm_cfg.primary primary_endpoint = llm_cfg.primary
final_message: Optional[Dict[str, Any]] = None final_message: Optional[Dict[str, Any]] = None
@ -135,6 +135,14 @@ class DepartmentAgent:
message = choice.get("message", {}) message = choice.get("message", {})
transcript.append(_message_to_text(message)) transcript.append(_message_to_text(message))
assistant_record: Dict[str, Any] = {
"role": "assistant",
"content": _extract_message_content(message),
}
if message.get("tool_calls"):
assistant_record["tool_calls"] = message.get("tool_calls")
messages.append(assistant_record)
tool_calls = message.get("tool_calls") or [] tool_calls = message.get("tool_calls") or []
if tool_calls: if tool_calls:
for call in tool_calls: for call in tool_calls:
@ -151,7 +159,6 @@ class DepartmentAgent:
{ {
"role": "tool", "role": "tool",
"tool_call_id": call.get("id"), "tool_call_id": call.get("id"),
"name": call.get("function", {}).get("name"),
"content": json.dumps(tool_response, ensure_ascii=False), "content": json.dumps(tool_response, ensure_ascii=False),
} }
) )
@ -213,165 +220,50 @@ class DepartmentAgent:
def _fulfill_data_requests( def _fulfill_data_requests(
self, self,
context: DepartmentContext, context: DepartmentContext,
requests: Sequence[DataRequest], requests: Sequence[TableRequest],
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]: ) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]:
lines: List[str] = [] lines: List[str] = []
payload: List[Dict[str, Any]] = [] payload: List[Dict[str, Any]] = []
delivered: set[Tuple[str, int]] = set() delivered: set[Tuple[str, int, str]] = set()
ts_code = context.ts_code ts_code = context.ts_code
trade_date = self._normalize_trade_date(context.trade_date) default_trade_date = self._normalize_trade_date(context.trade_date)
latest_groups: Dict[str, List[str]] = {}
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
values_map, db_alias_map, series_map = _build_context_lookup(context)
for req in requests: for req in requests:
field = req.field.strip() table = (req.name or "").strip().lower()
if not field: if not table:
continue continue
window = req.window if table not in self.ALLOWED_TABLES:
resolved: Optional[Tuple[str, str]] = None lines.append(f"- {table}: 不在允许的表列表中")
if "." in field: continue
resolved = self._broker.resolve_field(field) trade_date = self._normalize_trade_date(req.trade_date or default_trade_date)
elif field in db_alias_map: window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120)))
resolved = db_alias_map[field] key = (table, window, trade_date)
if key in delivered:
if resolved: lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求")
table, column = resolved
canonical = f"{table}.{column}"
if window <= 1:
latest_groups.setdefault(canonical, []).append(field)
delivered.add((field, 1))
delivered.add((canonical, 1))
else:
series_requests.append((req, resolved))
delivered.add((field, window))
delivered.add((canonical, window))
continue continue
if field in values_map: rows = self._broker.fetch_table_rows(table, ts_code, trade_date, window)
value = values_map[field] if rows:
if window <= 1:
payload.append(
{
"field": field,
"window": 1,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
}
)
lines.append(f"- {field}: {value} (来自上下文)")
else:
series = series_map.get(field)
if series:
trimmed = series[: window]
payload.append(
{
"field": field,
"window": window,
"source": "context_series",
"values": [
{"trade_date": dt, "value": val}
for dt, val in trimmed
],
}
)
preview = ", ".join(
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
)
lines.append(
f"- {field} (window={window} 来自上下文序列): {preview}"
)
else:
payload.append(
{
"field": field,
"window": window,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
"warning": "仅提供当前值,缺少历史序列",
}
)
lines.append(
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
)
delivered.add((field, window))
if field in db_alias_map:
resolved = db_alias_map[field]
canonical = f"{resolved[0]}.{resolved[1]}"
delivered.add((canonical, window))
continue
lines.append(f"- {field}: 字段不存在或不可用")
if latest_groups:
latest_values = self._broker.fetch_latest(
ts_code, trade_date, list(latest_groups.keys())
)
for canonical, aliases in latest_groups.items():
value = latest_values.get(canonical)
if value is None:
lines.append(f"- {canonical}: (数据缺失)")
else:
lines.append(f"- {canonical}: {value}")
for alias in aliases:
payload.append(
{
"field": alias,
"window": 1,
"source": "database",
"values": [
{
"trade_date": trade_date,
"value": value,
}
],
}
)
for req, resolved in series_requests:
table, column = resolved
series = self._broker.fetch_series(
table,
column,
ts_code,
trade_date,
window=req.window,
)
if series:
preview = ", ".join( preview = ", ".join(
f"{dt}:{val:.4f}" f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
for dt, val in series[: min(len(series), 5)]
) )
lines.append( lines.append(
f"- {req.field} (window={req.window}): {preview}" f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)}{preview}"
) )
else: else:
lines.append( lines.append(
f"- {req.field} (window={req.window}): (数据缺失)" f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)"
) )
payload.append( payload.append(
{ {
"field": req.field, "table": table,
"window": req.window, "window": window,
"source": "database", "trade_date": trade_date,
"values": [ "rows": rows,
{"trade_date": dt, "value": val}
for dt, val in series
],
} }
) )
delivered.add(key)
return lines, payload, delivered return lines, payload, delivered
@ -380,9 +272,9 @@ class DepartmentAgent:
self, self,
context: DepartmentContext, context: DepartmentContext,
call: Mapping[str, Any], call: Mapping[str, Any],
delivered_requests: set[Tuple[str, int]], delivered_requests: set[Tuple[str, int, str]],
round_idx: int, round_idx: int,
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]: ) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]:
function_block = call.get("function") or {} function_block = call.get("function") or {}
name = function_block.get("name") or "" name = function_block.get("name") or ""
if name != "fetch_data": if name != "fetch_data":
@ -398,23 +290,29 @@ class DepartmentAgent:
}, set() }, set()
args = _parse_tool_arguments(function_block.get("arguments")) args = _parse_tool_arguments(function_block.get("arguments"))
raw_requests = args.get("requests") or [] base_trade_date = self._normalize_trade_date(
requests: List[DataRequest] = [] args.get("trade_date") or context.trade_date
)
raw_requests = args.get("tables") or []
requests: List[TableRequest] = []
skipped: List[str] = [] skipped: List[str] = []
for item in raw_requests: for item in raw_requests:
field = str(item.get("field", "")).strip() name = str(item.get("name", "")).strip().lower()
if not field: if not name:
continue continue
window_raw = item.get("window")
try: try:
window = int(item.get("window", 1)) window = int(window_raw) if window_raw is not None else 1
except (TypeError, ValueError): except (TypeError, ValueError):
window = 1 window = 1
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120))) window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
key = (field, window) override_date = item.get("trade_date")
req_date = self._normalize_trade_date(override_date or base_trade_date)
key = (name, window, req_date)
if key in delivered_requests: if key in delivered_requests:
skipped.append(field) skipped.append(name)
continue continue
requests.append(DataRequest(field=field, window=window)) requests.append(TableRequest(name=name, window=window, trade_date=req_date))
if not requests: if not requests:
return { return {
@ -461,34 +359,44 @@ class DepartmentAgent:
"function": { "function": {
"name": "fetch_data", "name": "fetch_data",
"description": ( "description": (
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段," "根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily''daily_basic' 表。"
"window 表示希望返回的最近数据点数量。"
), ),
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"requests": { "tables": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"field": { "name": {
"type": "string", "type": "string",
"description": "数据字段,格式为 table.column", "enum": self.ALLOWED_TABLES,
"description": "表名,例如 daily 或 daily_basic",
}, },
"window": { "window": {
"type": "integer", "type": "integer",
"minimum": 1, "minimum": 1,
"maximum": max_window, "maximum": max_window,
"description": "返回最近多少个数据点,默认为 1", "description": "向前回溯的记录条数,默认为 1",
},
"trade_date": {
"type": "string",
"pattern": r"^\\d{8}$",
"description": "覆盖默认交易日(格式 YYYYMMDD",
}, },
}, },
"required": ["field"], "required": ["name"],
}, },
"minItems": 1, "minItems": 1,
} },
"trade_date": {
"type": "string",
"pattern": r"^\\d{8}$",
"description": "默认交易日(格式 YYYYMMDD",
},
}, },
"required": ["requests"], "required": ["tables"],
}, },
}, },
} }
@ -555,39 +463,6 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
return context return context
def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
raw_requests = payload.get("data_requests")
requests: List[DataRequest] = []
if not isinstance(raw_requests, list):
return requests
seen: set[Tuple[str, int]] = set()
for item in raw_requests:
field = ""
window = 1
if isinstance(item, str):
field = item.strip()
elif isinstance(item, Mapping):
candidate = item.get("field")
if candidate is None:
continue
field = str(candidate).strip()
try:
window = int(item.get("window", 1))
except (TypeError, ValueError):
window = 1
else:
continue
if not field:
continue
window = max(1, window)
key = (field, window)
if key in seen:
continue
seen.add(key)
requests.append(DataRequest(field=field, window=window))
return requests
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict): if isinstance(payload, dict):
return dict(payload) return dict(payload)
@ -636,38 +511,6 @@ def _extract_message_content(message: Mapping[str, Any]) -> str:
return json.dumps(message, ensure_ascii=False) return json.dumps(message, ensure_ascii=False)
def _build_context_lookup(
context: DepartmentContext,
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
values: Dict[str, Any] = {}
db_alias: Dict[str, Tuple[str, str]] = {}
series_map: Dict[str, List[Tuple[str, float]]] = {}
for source in (context.features or {}, context.market_snapshot or {}):
for key, value in source.items():
values[str(key)] = value
scope_values = context.raw.get("scope_values") or {}
for key, value in scope_values.items():
key_str = str(key)
values[key_str] = value
if "." in key_str:
table, column = key_str.split(".", 1)
db_alias.setdefault(column, (table, column))
db_alias.setdefault(key_str, (table, column))
values.setdefault(column, value)
close_series = context.raw.get("close_series") or []
if isinstance(close_series, list) and close_series:
series_map["close"] = close_series
series_map["daily.close"] = close_series
turnover_series = context.raw.get("turnover_series") or []
if isinstance(turnover_series, list) and turnover_series:
series_map["turnover_rate"] = turnover_series
series_map["daily_basic.turnover_rate"] = turnover_series
return values, db_alias, series_map
class DepartmentManager: class DepartmentManager:

View File

@ -5,11 +5,12 @@ import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date from datetime import date
from statistics import mean, pstdev from statistics import mean, pstdev
from typing import Any, Dict, Iterable, List, Mapping from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from app.agents.base import AgentContext from app.agents.base import AgentContext
from app.agents.departments import DepartmentManager from app.agents.departments import DepartmentManager
from app.agents.game import Decision, decide from app.agents.game import Decision, decide
from app.llm.metrics import record_decision as metrics_record_decision
from app.agents.registry import default_agents from app.agents.registry import default_agents
from app.utils.data_access import DataBroker from app.utils.data_access import DataBroker
from app.utils.config import get_config from app.utils.config import get_config
@ -224,7 +225,12 @@ class BacktestEngine:
return feature_map return feature_map
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]: def simulate_day(
self,
trade_date: date,
state: PortfolioState,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> List[Decision]:
feature_map = self.load_market_data(trade_date) feature_map = self.load_market_data(trade_date)
decisions: List[Decision] = [] decisions: List[Decision] = []
for ts_code, payload in feature_map.items(): for ts_code, payload in feature_map.items():
@ -247,6 +253,26 @@ class BacktestEngine:
) )
decisions.append(decision) decisions.append(decision)
self.record_agent_state(context, decision) self.record_agent_state(context, decision)
if decision_callback:
try:
decision_callback(ts_code, trade_date, context, decision)
except Exception: # noqa: BLE001
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
try:
metrics_record_decision(
ts_code=ts_code,
trade_date=context.trade_date,
action=decision.action.value,
confidence=decision.confidence,
summary=decision.summary,
source="backtest",
departments={
code: dept.to_dict()
for code, dept in decision.department_decisions.items()
},
)
except Exception: # noqa: BLE001
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
# TODO: translate decisions into fills, holdings, and NAV updates. # TODO: translate decisions into fills, holdings, and NAV updates.
_ = state _ = state
return decisions return decisions
@ -357,20 +383,27 @@ class BacktestEngine:
_ = payload _ = payload
# TODO: persist payload into bt_trades / audit tables when schema is ready. # TODO: persist payload into bt_trades / audit tables when schema is ready.
def run(self) -> BacktestResult: def run(
self,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult:
state = PortfolioState() state = PortfolioState()
result = BacktestResult() result = BacktestResult()
current = self.cfg.start_date current = self.cfg.start_date
while current <= self.cfg.end_date: while current <= self.cfg.end_date:
decisions = self.simulate_day(current, state) decisions = self.simulate_day(current, state, decision_callback)
_ = decisions _ = decisions
current = date.fromordinal(current.toordinal() + 1) current = date.fromordinal(current.toordinal() + 1)
return result return result
def run_backtest(cfg: BtConfig) -> BacktestResult: def run_backtest(
cfg: BtConfig,
*,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult:
engine = BacktestEngine(cfg) engine = BacktestEngine(cfg)
result = engine.run() result = engine.run(decision_callback=decision_callback)
with db_session() as conn: with db_session() as conn:
_ = conn _ = conn
# Implementation should persist bt_nav, bt_trades, and bt_report rows. # Implementation should persist bt_nav, bt_trades, and bt_report rows.

View File

@ -17,6 +17,7 @@ from app.utils.config import (
LLMEndpoint, LLMEndpoint,
get_config, get_config,
) )
from app.llm.metrics import record_call
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -220,11 +221,12 @@ def call_endpoint_with_messages(
) )
if response.status_code != 200: if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}") raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
record_call(provider_key, model)
return response.json() return response.json()
if not api_key: if not api_key:
raise LLMError(f"缺少 {provider_key} API Key (model={model})") raise LLMError(f"缺少 {provider_key} API Key (model={model})")
return _request_openai_chat( data = _request_openai_chat(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
model=model, model=model,
@ -234,6 +236,11 @@ def call_endpoint_with_messages(
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
usage = data.get("usage", {}) if isinstance(data, dict) else {}
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
record_call(provider_key, model, prompt_tokens, completion_tokens)
return data
def _normalize_response(text: str) -> str: def _normalize_response(text: str) -> str:

114
app/llm/metrics.py Normal file
View File

@ -0,0 +1,114 @@
"""Simple runtime metrics collector for LLM calls."""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from threading import Lock
from typing import Deque, Dict, List, Optional
@dataclass
class _Metrics:
total_calls: int = 0
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
provider_calls: Dict[str, int] = field(default_factory=dict)
model_calls: Dict[str, int] = field(default_factory=dict)
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
decision_action_counts: Dict[str, int] = field(default_factory=dict)
_METRICS = _Metrics()
_LOCK = Lock()
def record_call(
provider: str,
model: Optional[str] = None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
) -> None:
"""Record a single LLM API invocation."""
normalized_provider = (provider or "unknown").lower()
normalized_model = (model or "").strip()
with _LOCK:
_METRICS.total_calls += 1
_METRICS.provider_calls[normalized_provider] = (
_METRICS.provider_calls.get(normalized_provider, 0) + 1
)
if normalized_model:
_METRICS.model_calls[normalized_model] = (
_METRICS.model_calls.get(normalized_model, 0) + 1
)
if prompt_tokens:
_METRICS.total_prompt_tokens += int(prompt_tokens)
if completion_tokens:
_METRICS.total_completion_tokens += int(completion_tokens)
def snapshot(reset: bool = False) -> Dict[str, object]:
"""Return a snapshot of current metrics. Optionally reset counters."""
with _LOCK:
data = {
"total_calls": _METRICS.total_calls,
"total_prompt_tokens": _METRICS.total_prompt_tokens,
"total_completion_tokens": _METRICS.total_completion_tokens,
"provider_calls": dict(_METRICS.provider_calls),
"model_calls": dict(_METRICS.model_calls),
"decision_action_counts": dict(_METRICS.decision_action_counts),
"recent_decisions": list(_METRICS.decisions),
}
if reset:
_METRICS.total_calls = 0
_METRICS.total_prompt_tokens = 0
_METRICS.total_completion_tokens = 0
_METRICS.provider_calls.clear()
_METRICS.model_calls.clear()
_METRICS.decision_action_counts.clear()
_METRICS.decisions.clear()
return data
def reset() -> None:
"""Reset all collected metrics."""
snapshot(reset=True)
def record_decision(
*,
ts_code: str,
trade_date: str,
action: str,
confidence: float,
summary: str,
source: str,
departments: Optional[Dict[str, object]] = None,
) -> None:
"""Record a high-level decision for later inspection."""
record = {
"ts_code": ts_code,
"trade_date": trade_date,
"action": action,
"confidence": confidence,
"summary": summary,
"source": source,
"departments": departments or {},
}
with _LOCK:
_METRICS.decisions.append(record)
_METRICS.decision_action_counts[action] = (
_METRICS.decision_action_counts.get(action, 0) + 1
)
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
"""Return the most recent decisions up to limit."""
with _LOCK:
if limit <= 0:
return []
return list(_METRICS.decisions)[-limit:]

View File

@ -62,7 +62,7 @@ def department_prompt(
"risks": ["风险点", "..."] "risks": ["风险点", "..."]
}} }}
如需额外数据请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组元素包含 `field` 以及可选的 `window``field` 必须符合可用数据范围`window` 默认为 1 如需额外数据请调用工具 `fetch_data`仅支持请求 `daily` `daily_basic` 在参数中填写 `tables` 数组元素包含 `name`表名与可选的 `window`向前回溯的条数默认 1 `trade_date`YYYYMMDD默认本次交易日
工具返回的数据会在后续消息中提供请在获取所有必要信息后再给出最终 JSON 答复 工具返回的数据会在后续消息中提供请在获取所有必要信息后再给出最终 JSON 答复
请严格返回单个 JSON 对象不要添加额外文本 请严格返回单个 JSON 对象不要添加额外文本

View File

@ -20,11 +20,18 @@ import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
import streamlit as st import streamlit as st
from app.agents.base import AgentContext
from app.agents.game import Decision
from app.backtest.engine import BtConfig, run_backtest from app.backtest.engine import BtConfig, run_backtest
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.ingest.checker import run_boot_check from app.ingest.checker import run_boot_check
from app.ingest.tushare import FetchJob, run_ingestion from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.client import llm_config_snapshot, run_llm from app.llm.client import llm_config_snapshot, run_llm
from app.llm.metrics import (
reset as reset_llm_metrics,
snapshot as snapshot_llm_metrics,
recent_decisions as llm_recent_decisions,
)
from app.utils.config import ( from app.utils.config import (
ALLOWED_LLM_STRATEGIES, ALLOWED_LLM_STRATEGIES,
DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_BASE_URLS,
@ -152,6 +159,45 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
def render_today_plan() -> None: def render_today_plan() -> None:
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
st.header("今日计划") st.header("今日计划")
st.caption("统计数据基于最近一次渲染,刷新页面即可获取最新结果。")
metrics_state = snapshot_llm_metrics()
st.subheader("LLM 调用统计 (实时)")
stats_col1, stats_col2, stats_col3 = st.columns(3)
stats_col1.metric("总调用次数", metrics_state.get("total_calls", 0))
stats_col2.metric("Prompt Tokens", metrics_state.get("total_prompt_tokens", 0))
stats_col3.metric("Completion Tokens", metrics_state.get("total_completion_tokens", 0))
provider_calls = metrics_state.get("provider_calls", {})
model_calls = metrics_state.get("model_calls", {})
if provider_calls or model_calls:
with st.expander("调用明细", expanded=False):
if provider_calls:
st.write("按 Provider")
st.json(provider_calls)
if model_calls:
st.write("按模型:")
st.json(model_calls)
st.subheader("最近决策 (全局)")
decision_feed = metrics_state.get("recent_decisions", []) or llm_recent_decisions(20)
if decision_feed:
for record in reversed(decision_feed[-20:]):
ts_code = record.get("ts_code")
trade_date = record.get("trade_date")
action = record.get("action")
confidence = record.get("confidence")
summary = record.get("summary")
departments = record.get("departments", {})
st.markdown(
f"**{trade_date} {ts_code}** → {action} (信心 {confidence:.2f})"
)
if summary:
st.caption(f"摘要:{summary}")
if departments:
st.json(departments)
st.divider()
else:
st.caption("暂无决策记录,执行回测或实时评估后可在此查看。")
try: try:
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
date_rows = conn.execute( date_rows = conn.execute(
@ -323,7 +369,7 @@ def render_today_plan() -> None:
st.subheader("部门意见") st.subheader("部门意见")
if dept_records: if dept_records:
dept_df = pd.DataFrame(dept_records) dept_df = pd.DataFrame(dept_records)
st.dataframe(dept_df, use_container_width=True, hide_index=True) st.dataframe(dept_df, width='stretch', hide_index=True)
for code, details in dept_details.items(): for code, details in dept_details.items():
with st.expander(f"{code} 补充详情", expanded=False): with st.expander(f"{code} 补充详情", expanded=False):
supplements = details.get("supplements", []) supplements = details.get("supplements", [])
@ -345,7 +391,7 @@ def render_today_plan() -> None:
st.subheader("代理评分") st.subheader("代理评分")
if agent_records: if agent_records:
agent_df = pd.DataFrame(agent_records) agent_df = pd.DataFrame(agent_records)
st.dataframe(agent_df, use_container_width=True, hide_index=True) st.dataframe(agent_df, width='stretch', hide_index=True)
else: else:
st.info("暂无基础代理评分。") st.info("暂无基础代理评分。")
@ -390,6 +436,41 @@ def render_backtest() -> None:
if st.button("运行回测"): if st.button("运行回测"):
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
decision_log_container = st.container()
status_placeholder = st.empty()
llm_stats_placeholder = st.empty()
decision_entries: List[str] = []
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
ts_label = trade_dt.isoformat()
summary = decision.summary
entry_lines = [
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
]
if summary:
entry_lines.append(f"摘要:{summary}")
dep_highlights = []
for dept_code, dept_decision in decision.department_decisions.items():
dep_highlights.append(
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
)
if dep_highlights:
entry_lines.append("部门意见:" + "".join(dep_highlights))
decision_entries.append(" \n".join(entry_lines))
decision_log_container.markdown("\n\n".join(decision_entries[-50:]))
status_placeholder.info(
f"最新决策:{ts_code} -> {decision.action.value} ({decision.confidence:.2f})"
)
stats = snapshot_llm_metrics()
llm_stats_placeholder.json(
{
"LLM 调用次数": stats.get("total_calls", 0),
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
"Completion Tokens": stats.get("total_completion_tokens", 0),
}
)
reset_llm_metrics()
with st.spinner("正在执行回测..."): with st.spinner("正在执行回测..."):
try: try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()] universe = [code.strip() for code in universe_text.split(',') if code.strip()]
@ -415,14 +496,24 @@ def render_backtest() -> None:
"hold_days": int(hold_days), "hold_days": int(hold_days),
}, },
) )
result = run_backtest(cfg) result = run_backtest(cfg, decision_callback=_decision_callback)
LOGGER.info( LOGGER.info(
"回测完成nav_records=%s trades=%s", "回测完成nav_records=%s trades=%s",
len(result.nav_series), len(result.nav_series),
len(result.trades), len(result.trades),
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
st.success("回测执行完成,详见回测结果摘要。") st.success("回测执行完成,详见下方结果与统计。")
metrics = snapshot_llm_metrics()
llm_stats_placeholder.json(
{
"LLM 调用次数": metrics.get("total_calls", 0),
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
"Completion Tokens": metrics.get("total_completion_tokens", 0),
"按 Provider": metrics.get("provider_calls", {}),
"按模型": metrics.get("model_calls", {}),
}
)
st.json({"nav_records": result.nav_series, "trades": result.trades}) st.json({"nav_records": result.nav_series, "trades": result.trades})
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
LOGGER.exception("回测执行失败", extra=LOG_EXTRA) LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
@ -654,7 +745,7 @@ def render_settings() -> None:
ensemble_rows, ensemble_rows,
num_rows="dynamic", num_rows="dynamic",
key="global_ensemble_editor", key="global_ensemble_editor",
use_container_width=True, width='stretch',
hide_index=True, hide_index=True,
column_config={ column_config={
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
@ -735,7 +826,7 @@ def render_settings() -> None:
dept_rows, dept_rows,
num_rows="fixed", num_rows="fixed",
key="department_editor", key="department_editor",
use_container_width=True, width='stretch',
hide_index=True, hide_index=True,
column_config={ column_config={
"code": st.column_config.TextColumn("编码", disabled=True), "code": st.column_config.TextColumn("编码", disabled=True),
@ -998,7 +1089,7 @@ def render_tests() -> None:
] ]
) )
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(candle_fig, use_container_width=True) st.plotly_chart(candle_fig, width='stretch')
vol_fig = px.bar( vol_fig = px.bar(
df_reset, df_reset,
@ -1008,7 +1099,7 @@ def render_tests() -> None:
title="成交量", title="成交量",
) )
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(vol_fig, use_container_width=True) st.plotly_chart(vol_fig, width='stretch')
amt_fig = px.bar( amt_fig = px.bar(
df_reset, df_reset,
@ -1018,7 +1109,7 @@ def render_tests() -> None:
title="成交额", title="成交额",
) )
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(amt_fig, use_container_width=True) st.plotly_chart(amt_fig, width='stretch')
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
box_fig = px.box( box_fig = px.box(
@ -1029,7 +1120,7 @@ def render_tests() -> None:
title="月度收盘价分布", title="月度收盘价分布",
) )
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(box_fig, use_container_width=True) st.plotly_chart(box_fig, width='stretch')
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
st.dataframe(df_reset.tail(20), width='stretch') st.dataframe(df_reset.tail(20), width='stretch')

View File

@ -199,6 +199,55 @@ class DataBroker:
return False return False
return row is not None return row is not None
def fetch_table_rows(
self,
table: str,
ts_code: str,
trade_date: str,
window: int,
) -> List[Dict[str, object]]:
if window <= 0:
return []
window = min(window, self.MAX_WINDOW)
columns = self._get_table_columns(table)
if not columns:
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
return []
column_list = ", ".join(columns)
has_trade_date = "trade_date" in columns
if has_trade_date:
query = (
f"SELECT {column_list} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
params: Tuple[object, ...] = (ts_code, trade_date, window)
else:
query = (
f"SELECT {column_list} FROM {table} "
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
)
params = (ts_code, window)
results: List[Dict[str, object]] = []
with db_session(read_only=True) as conn:
try:
rows = conn.execute(query, params).fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"表查询失败 table=%s err=%s",
table,
exc,
extra=LOG_EXTRA,
)
return []
for row in rows:
record = {col: row[col] for col in columns}
results.append(record)
return results
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]: def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
normalized = _safe_split(field) normalized = _safe_split(field)
if not normalized: if not normalized:
@ -215,7 +264,7 @@ class DataBroker:
return None return None
return table, resolved return table, resolved
def _get_table_columns(self, table: str) -> Optional[set[str]]: def _get_table_columns(self, table: str) -> Optional[List[str]]:
if not _is_safe_identifier(table): if not _is_safe_identifier(table):
return None return None
cache = getattr(self, "_column_cache", None) cache = getattr(self, "_column_cache", None)
@ -234,7 +283,7 @@ class DataBroker:
if not rows: if not rows:
cache[table] = None cache[table] = None
return None return None
columns = {row["name"] for row in rows if row["name"]} columns = [row["name"] for row in rows if row["name"]]
cache[table] = columns cache[table] = columns
return columns return columns