update
This commit is contained in:
parent
199cb484b9
commit
de88b198b3
@ -35,10 +35,10 @@
|
||||
|
||||
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
||||
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
||||
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。
|
||||
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 只需声明所需表(如 `daily`、`daily_basic`)及窗口,系统用 `DataBroker` 拉取整行数据、回传结果,形成“请求→取数→复议”闭环。
|
||||
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||
|
||||
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。
|
||||
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型只需按 schema 声明 `daily` / `daily_basic` 等表名与窗口,系统使用 `DataBroker` 一次性返回指定交易日的全部列,再带着查询结果进入下一轮提示,从而形成闭环。
|
||||
|
||||
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from app.agents.base import AgentAction
|
||||
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
||||
@ -17,9 +17,10 @@ LOG_EXTRA = {"stage": "department"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataRequest:
|
||||
field: str
|
||||
class TableRequest:
|
||||
name: str
|
||||
window: int = 1
|
||||
trade_date: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -64,6 +65,8 @@ class DepartmentDecision:
|
||||
class DepartmentAgent:
|
||||
"""Wraps LLM ensemble logic for a single analytical department."""
|
||||
|
||||
ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: DepartmentSettings,
|
||||
@ -106,10 +109,7 @@ class DepartmentAgent:
|
||||
)
|
||||
|
||||
transcript: List[str] = []
|
||||
delivered_requests: set[Tuple[str, int]] = {
|
||||
(field, 1)
|
||||
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
||||
}
|
||||
delivered_requests: set[Tuple[str, int, str]] = set()
|
||||
|
||||
primary_endpoint = llm_cfg.primary
|
||||
final_message: Optional[Dict[str, Any]] = None
|
||||
@ -135,6 +135,14 @@ class DepartmentAgent:
|
||||
message = choice.get("message", {})
|
||||
transcript.append(_message_to_text(message))
|
||||
|
||||
assistant_record: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": _extract_message_content(message),
|
||||
}
|
||||
if message.get("tool_calls"):
|
||||
assistant_record["tool_calls"] = message.get("tool_calls")
|
||||
messages.append(assistant_record)
|
||||
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
@ -151,7 +159,6 @@ class DepartmentAgent:
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call.get("id"),
|
||||
"name": call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_response, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
@ -213,165 +220,50 @@ class DepartmentAgent:
|
||||
def _fulfill_data_requests(
|
||||
self,
|
||||
context: DepartmentContext,
|
||||
requests: Sequence[DataRequest],
|
||||
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]:
|
||||
requests: Sequence[TableRequest],
|
||||
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]:
|
||||
lines: List[str] = []
|
||||
payload: List[Dict[str, Any]] = []
|
||||
delivered: set[Tuple[str, int]] = set()
|
||||
delivered: set[Tuple[str, int, str]] = set()
|
||||
|
||||
ts_code = context.ts_code
|
||||
trade_date = self._normalize_trade_date(context.trade_date)
|
||||
|
||||
latest_groups: Dict[str, List[str]] = {}
|
||||
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
|
||||
values_map, db_alias_map, series_map = _build_context_lookup(context)
|
||||
default_trade_date = self._normalize_trade_date(context.trade_date)
|
||||
|
||||
for req in requests:
|
||||
field = req.field.strip()
|
||||
if not field:
|
||||
table = (req.name or "").strip().lower()
|
||||
if not table:
|
||||
continue
|
||||
window = req.window
|
||||
resolved: Optional[Tuple[str, str]] = None
|
||||
if "." in field:
|
||||
resolved = self._broker.resolve_field(field)
|
||||
elif field in db_alias_map:
|
||||
resolved = db_alias_map[field]
|
||||
|
||||
if resolved:
|
||||
table, column = resolved
|
||||
canonical = f"{table}.{column}"
|
||||
if window <= 1:
|
||||
latest_groups.setdefault(canonical, []).append(field)
|
||||
delivered.add((field, 1))
|
||||
delivered.add((canonical, 1))
|
||||
else:
|
||||
series_requests.append((req, resolved))
|
||||
delivered.add((field, window))
|
||||
delivered.add((canonical, window))
|
||||
if table not in self.ALLOWED_TABLES:
|
||||
lines.append(f"- {table}: 不在允许的表列表中")
|
||||
continue
|
||||
trade_date = self._normalize_trade_date(req.trade_date or default_trade_date)
|
||||
window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||
key = (table, window, trade_date)
|
||||
if key in delivered:
|
||||
lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求")
|
||||
continue
|
||||
|
||||
if field in values_map:
|
||||
value = values_map[field]
|
||||
if window <= 1:
|
||||
payload.append(
|
||||
{
|
||||
"field": field,
|
||||
"window": 1,
|
||||
"source": "context",
|
||||
"values": [
|
||||
{
|
||||
"trade_date": context.trade_date,
|
||||
"value": value,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
lines.append(f"- {field}: {value} (来自上下文)")
|
||||
else:
|
||||
series = series_map.get(field)
|
||||
if series:
|
||||
trimmed = series[: window]
|
||||
payload.append(
|
||||
{
|
||||
"field": field,
|
||||
"window": window,
|
||||
"source": "context_series",
|
||||
"values": [
|
||||
{"trade_date": dt, "value": val}
|
||||
for dt, val in trimmed
|
||||
],
|
||||
}
|
||||
)
|
||||
preview = ", ".join(
|
||||
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
|
||||
)
|
||||
lines.append(
|
||||
f"- {field} (window={window} 来自上下文序列): {preview}"
|
||||
)
|
||||
else:
|
||||
payload.append(
|
||||
{
|
||||
"field": field,
|
||||
"window": window,
|
||||
"source": "context",
|
||||
"values": [
|
||||
{
|
||||
"trade_date": context.trade_date,
|
||||
"value": value,
|
||||
}
|
||||
],
|
||||
"warning": "仅提供当前值,缺少历史序列",
|
||||
}
|
||||
)
|
||||
lines.append(
|
||||
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
|
||||
)
|
||||
delivered.add((field, window))
|
||||
if field in db_alias_map:
|
||||
resolved = db_alias_map[field]
|
||||
canonical = f"{resolved[0]}.{resolved[1]}"
|
||||
delivered.add((canonical, window))
|
||||
continue
|
||||
|
||||
lines.append(f"- {field}: 字段不存在或不可用")
|
||||
|
||||
if latest_groups:
|
||||
latest_values = self._broker.fetch_latest(
|
||||
ts_code, trade_date, list(latest_groups.keys())
|
||||
)
|
||||
for canonical, aliases in latest_groups.items():
|
||||
value = latest_values.get(canonical)
|
||||
if value is None:
|
||||
lines.append(f"- {canonical}: (数据缺失)")
|
||||
else:
|
||||
lines.append(f"- {canonical}: {value}")
|
||||
for alias in aliases:
|
||||
payload.append(
|
||||
{
|
||||
"field": alias,
|
||||
"window": 1,
|
||||
"source": "database",
|
||||
"values": [
|
||||
{
|
||||
"trade_date": trade_date,
|
||||
"value": value,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
for req, resolved in series_requests:
|
||||
table, column = resolved
|
||||
series = self._broker.fetch_series(
|
||||
table,
|
||||
column,
|
||||
ts_code,
|
||||
trade_date,
|
||||
window=req.window,
|
||||
)
|
||||
if series:
|
||||
rows = self._broker.fetch_table_rows(table, ts_code, trade_date, window)
|
||||
if rows:
|
||||
preview = ", ".join(
|
||||
f"{dt}:{val:.4f}"
|
||||
for dt, val in series[: min(len(series), 5)]
|
||||
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
|
||||
)
|
||||
lines.append(
|
||||
f"- {req.field} (window={req.window}): {preview}"
|
||||
f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)} 行 {preview}"
|
||||
)
|
||||
else:
|
||||
lines.append(
|
||||
f"- {req.field} (window={req.window}): (数据缺失)"
|
||||
f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)"
|
||||
)
|
||||
payload.append(
|
||||
{
|
||||
"field": req.field,
|
||||
"window": req.window,
|
||||
"source": "database",
|
||||
"values": [
|
||||
{"trade_date": dt, "value": val}
|
||||
for dt, val in series
|
||||
],
|
||||
"table": table,
|
||||
"window": window,
|
||||
"trade_date": trade_date,
|
||||
"rows": rows,
|
||||
}
|
||||
)
|
||||
delivered.add(key)
|
||||
|
||||
return lines, payload, delivered
|
||||
|
||||
@ -380,9 +272,9 @@ class DepartmentAgent:
|
||||
self,
|
||||
context: DepartmentContext,
|
||||
call: Mapping[str, Any],
|
||||
delivered_requests: set[Tuple[str, int]],
|
||||
delivered_requests: set[Tuple[str, int, str]],
|
||||
round_idx: int,
|
||||
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
|
||||
) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]:
|
||||
function_block = call.get("function") or {}
|
||||
name = function_block.get("name") or ""
|
||||
if name != "fetch_data":
|
||||
@ -398,23 +290,29 @@ class DepartmentAgent:
|
||||
}, set()
|
||||
|
||||
args = _parse_tool_arguments(function_block.get("arguments"))
|
||||
raw_requests = args.get("requests") or []
|
||||
requests: List[DataRequest] = []
|
||||
base_trade_date = self._normalize_trade_date(
|
||||
args.get("trade_date") or context.trade_date
|
||||
)
|
||||
raw_requests = args.get("tables") or []
|
||||
requests: List[TableRequest] = []
|
||||
skipped: List[str] = []
|
||||
for item in raw_requests:
|
||||
field = str(item.get("field", "")).strip()
|
||||
if not field:
|
||||
name = str(item.get("name", "")).strip().lower()
|
||||
if not name:
|
||||
continue
|
||||
window_raw = item.get("window")
|
||||
try:
|
||||
window = int(item.get("window", 1))
|
||||
window = int(window_raw) if window_raw is not None else 1
|
||||
except (TypeError, ValueError):
|
||||
window = 1
|
||||
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||
key = (field, window)
|
||||
override_date = item.get("trade_date")
|
||||
req_date = self._normalize_trade_date(override_date or base_trade_date)
|
||||
key = (name, window, req_date)
|
||||
if key in delivered_requests:
|
||||
skipped.append(field)
|
||||
skipped.append(name)
|
||||
continue
|
||||
requests.append(DataRequest(field=field, window=window))
|
||||
requests.append(TableRequest(name=name, window=window, trade_date=req_date))
|
||||
|
||||
if not requests:
|
||||
return {
|
||||
@ -461,34 +359,44 @@ class DepartmentAgent:
|
||||
"function": {
|
||||
"name": "fetch_data",
|
||||
"description": (
|
||||
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
|
||||
"window 表示希望返回的最近数据点数量。"
|
||||
"根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily' 与 'daily_basic' 表。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"requests": {
|
||||
"tables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "数据字段,格式为 table.column",
|
||||
"enum": self.ALLOWED_TABLES,
|
||||
"description": "表名,例如 daily 或 daily_basic",
|
||||
},
|
||||
"window": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": max_window,
|
||||
"description": "返回最近多少个数据点,默认为 1",
|
||||
"description": "向前回溯的记录条数,默认为 1",
|
||||
},
|
||||
"trade_date": {
|
||||
"type": "string",
|
||||
"pattern": r"^\\d{8}$",
|
||||
"description": "覆盖默认交易日(格式 YYYYMMDD)",
|
||||
},
|
||||
},
|
||||
"required": ["field"],
|
||||
"required": ["name"],
|
||||
},
|
||||
"minItems": 1,
|
||||
}
|
||||
},
|
||||
"trade_date": {
|
||||
"type": "string",
|
||||
"pattern": r"^\\d{8}$",
|
||||
"description": "默认交易日(格式 YYYYMMDD)",
|
||||
},
|
||||
},
|
||||
"required": ["requests"],
|
||||
"required": ["tables"],
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -555,39 +463,6 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||
return context
|
||||
|
||||
|
||||
def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
|
||||
raw_requests = payload.get("data_requests")
|
||||
requests: List[DataRequest] = []
|
||||
if not isinstance(raw_requests, list):
|
||||
return requests
|
||||
seen: set[Tuple[str, int]] = set()
|
||||
for item in raw_requests:
|
||||
field = ""
|
||||
window = 1
|
||||
if isinstance(item, str):
|
||||
field = item.strip()
|
||||
elif isinstance(item, Mapping):
|
||||
candidate = item.get("field")
|
||||
if candidate is None:
|
||||
continue
|
||||
field = str(candidate).strip()
|
||||
try:
|
||||
window = int(item.get("window", 1))
|
||||
except (TypeError, ValueError):
|
||||
window = 1
|
||||
else:
|
||||
continue
|
||||
if not field:
|
||||
continue
|
||||
window = max(1, window)
|
||||
key = (field, window)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
requests.append(DataRequest(field=field, window=window))
|
||||
return requests
|
||||
|
||||
|
||||
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
@ -636,38 +511,6 @@ def _extract_message_content(message: Mapping[str, Any]) -> str:
|
||||
return json.dumps(message, ensure_ascii=False)
|
||||
|
||||
|
||||
def _build_context_lookup(
|
||||
context: DepartmentContext,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
db_alias: Dict[str, Tuple[str, str]] = {}
|
||||
series_map: Dict[str, List[Tuple[str, float]]] = {}
|
||||
|
||||
for source in (context.features or {}, context.market_snapshot or {}):
|
||||
for key, value in source.items():
|
||||
values[str(key)] = value
|
||||
|
||||
scope_values = context.raw.get("scope_values") or {}
|
||||
for key, value in scope_values.items():
|
||||
key_str = str(key)
|
||||
values[key_str] = value
|
||||
if "." in key_str:
|
||||
table, column = key_str.split(".", 1)
|
||||
db_alias.setdefault(column, (table, column))
|
||||
db_alias.setdefault(key_str, (table, column))
|
||||
values.setdefault(column, value)
|
||||
|
||||
close_series = context.raw.get("close_series") or []
|
||||
if isinstance(close_series, list) and close_series:
|
||||
series_map["close"] = close_series
|
||||
series_map["daily.close"] = close_series
|
||||
|
||||
turnover_series = context.raw.get("turnover_series") or []
|
||||
if isinstance(turnover_series, list) and turnover_series:
|
||||
series_map["turnover_rate"] = turnover_series
|
||||
series_map["daily_basic.turnover_rate"] = turnover_series
|
||||
|
||||
return values, db_alias, series_map
|
||||
|
||||
|
||||
class DepartmentManager:
|
||||
|
||||
@ -5,11 +5,12 @@ import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from statistics import mean, pstdev
|
||||
from typing import Any, Dict, Iterable, List, Mapping
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.departments import DepartmentManager
|
||||
from app.agents.game import Decision, decide
|
||||
from app.llm.metrics import record_decision as metrics_record_decision
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.config import get_config
|
||||
@ -224,7 +225,12 @@ class BacktestEngine:
|
||||
|
||||
return feature_map
|
||||
|
||||
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
|
||||
def simulate_day(
|
||||
self,
|
||||
trade_date: date,
|
||||
state: PortfolioState,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> List[Decision]:
|
||||
feature_map = self.load_market_data(trade_date)
|
||||
decisions: List[Decision] = []
|
||||
for ts_code, payload in feature_map.items():
|
||||
@ -247,6 +253,26 @@ class BacktestEngine:
|
||||
)
|
||||
decisions.append(decision)
|
||||
self.record_agent_state(context, decision)
|
||||
if decision_callback:
|
||||
try:
|
||||
decision_callback(ts_code, trade_date, context, decision)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
|
||||
try:
|
||||
metrics_record_decision(
|
||||
ts_code=ts_code,
|
||||
trade_date=context.trade_date,
|
||||
action=decision.action.value,
|
||||
confidence=decision.confidence,
|
||||
summary=decision.summary,
|
||||
source="backtest",
|
||||
departments={
|
||||
code: dept.to_dict()
|
||||
for code, dept in decision.department_decisions.items()
|
||||
},
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
|
||||
# TODO: translate decisions into fills, holdings, and NAV updates.
|
||||
_ = state
|
||||
return decisions
|
||||
@ -357,20 +383,27 @@ class BacktestEngine:
|
||||
_ = payload
|
||||
# TODO: persist payload into bt_trades / audit tables when schema is ready.
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
def run(
|
||||
self,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> BacktestResult:
|
||||
state = PortfolioState()
|
||||
result = BacktestResult()
|
||||
current = self.cfg.start_date
|
||||
while current <= self.cfg.end_date:
|
||||
decisions = self.simulate_day(current, state)
|
||||
decisions = self.simulate_day(current, state, decision_callback)
|
||||
_ = decisions
|
||||
current = date.fromordinal(current.toordinal() + 1)
|
||||
return result
|
||||
|
||||
|
||||
def run_backtest(cfg: BtConfig) -> BacktestResult:
|
||||
def run_backtest(
|
||||
cfg: BtConfig,
|
||||
*,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> BacktestResult:
|
||||
engine = BacktestEngine(cfg)
|
||||
result = engine.run()
|
||||
result = engine.run(decision_callback=decision_callback)
|
||||
with db_session() as conn:
|
||||
_ = conn
|
||||
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
||||
|
||||
@ -17,6 +17,7 @@ from app.utils.config import (
|
||||
LLMEndpoint,
|
||||
get_config,
|
||||
)
|
||||
from app.llm.metrics import record_call
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -220,11 +221,12 @@ def call_endpoint_with_messages(
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||||
record_call(provider_key, model)
|
||||
return response.json()
|
||||
|
||||
if not api_key:
|
||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||
return _request_openai_chat(
|
||||
data = _request_openai_chat(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
@ -234,6 +236,11 @@ def call_endpoint_with_messages(
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
usage = data.get("usage", {}) if isinstance(data, dict) else {}
|
||||
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
|
||||
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
|
||||
record_call(provider_key, model, prompt_tokens, completion_tokens)
|
||||
return data
|
||||
|
||||
|
||||
def _normalize_response(text: str) -> str:
|
||||
|
||||
114
app/llm/metrics.py
Normal file
114
app/llm/metrics.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""Simple runtime metrics collector for LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Deque, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Metrics:
|
||||
total_calls: int = 0
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
provider_calls: Dict[str, int] = field(default_factory=dict)
|
||||
model_calls: Dict[str, int] = field(default_factory=dict)
|
||||
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
|
||||
decision_action_counts: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
_METRICS = _Metrics()
|
||||
_LOCK = Lock()
|
||||
|
||||
|
||||
def record_call(
|
||||
provider: str,
|
||||
model: Optional[str] = None,
|
||||
prompt_tokens: Optional[int] = None,
|
||||
completion_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Record a single LLM API invocation."""
|
||||
|
||||
normalized_provider = (provider or "unknown").lower()
|
||||
normalized_model = (model or "").strip()
|
||||
with _LOCK:
|
||||
_METRICS.total_calls += 1
|
||||
_METRICS.provider_calls[normalized_provider] = (
|
||||
_METRICS.provider_calls.get(normalized_provider, 0) + 1
|
||||
)
|
||||
if normalized_model:
|
||||
_METRICS.model_calls[normalized_model] = (
|
||||
_METRICS.model_calls.get(normalized_model, 0) + 1
|
||||
)
|
||||
if prompt_tokens:
|
||||
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
||||
if completion_tokens:
|
||||
_METRICS.total_completion_tokens += int(completion_tokens)
|
||||
|
||||
|
||||
def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
"""Return a snapshot of current metrics. Optionally reset counters."""
|
||||
|
||||
with _LOCK:
|
||||
data = {
|
||||
"total_calls": _METRICS.total_calls,
|
||||
"total_prompt_tokens": _METRICS.total_prompt_tokens,
|
||||
"total_completion_tokens": _METRICS.total_completion_tokens,
|
||||
"provider_calls": dict(_METRICS.provider_calls),
|
||||
"model_calls": dict(_METRICS.model_calls),
|
||||
"decision_action_counts": dict(_METRICS.decision_action_counts),
|
||||
"recent_decisions": list(_METRICS.decisions),
|
||||
}
|
||||
if reset:
|
||||
_METRICS.total_calls = 0
|
||||
_METRICS.total_prompt_tokens = 0
|
||||
_METRICS.total_completion_tokens = 0
|
||||
_METRICS.provider_calls.clear()
|
||||
_METRICS.model_calls.clear()
|
||||
_METRICS.decision_action_counts.clear()
|
||||
_METRICS.decisions.clear()
|
||||
return data
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
"""Reset all collected metrics."""
|
||||
|
||||
snapshot(reset=True)
|
||||
|
||||
|
||||
def record_decision(
|
||||
*,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
summary: str,
|
||||
source: str,
|
||||
departments: Optional[Dict[str, object]] = None,
|
||||
) -> None:
|
||||
"""Record a high-level decision for later inspection."""
|
||||
|
||||
record = {
|
||||
"ts_code": ts_code,
|
||||
"trade_date": trade_date,
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
"summary": summary,
|
||||
"source": source,
|
||||
"departments": departments or {},
|
||||
}
|
||||
with _LOCK:
|
||||
_METRICS.decisions.append(record)
|
||||
_METRICS.decision_action_counts[action] = (
|
||||
_METRICS.decision_action_counts.get(action, 0) + 1
|
||||
)
|
||||
|
||||
|
||||
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||
"""Return the most recent decisions up to limit."""
|
||||
|
||||
with _LOCK:
|
||||
if limit <= 0:
|
||||
return []
|
||||
return list(_METRICS.decisions)[-limit:]
|
||||
@ -62,7 +62,7 @@ def department_prompt(
|
||||
"risks": ["风险点", "..."]
|
||||
}}
|
||||
|
||||
如需额外数据,请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组(元素包含 `field` 以及可选的 `window`);`field` 必须符合【可用数据范围】,`window` 默认为 1。
|
||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。
|
||||
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
||||
|
||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||
|
||||
@ -20,11 +20,18 @@ import requests
|
||||
from requests.exceptions import RequestException
|
||||
import streamlit as st
|
||||
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
from app.llm.client import llm_config_snapshot, run_llm
|
||||
from app.llm.metrics import (
|
||||
reset as reset_llm_metrics,
|
||||
snapshot as snapshot_llm_metrics,
|
||||
recent_decisions as llm_recent_decisions,
|
||||
)
|
||||
from app.utils.config import (
|
||||
ALLOWED_LLM_STRATEGIES,
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
@ -152,6 +159,45 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
||||
def render_today_plan() -> None:
|
||||
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||
st.header("今日计划")
|
||||
st.caption("统计数据基于最近一次渲染,刷新页面即可获取最新结果。")
|
||||
|
||||
metrics_state = snapshot_llm_metrics()
|
||||
st.subheader("LLM 调用统计 (实时)")
|
||||
stats_col1, stats_col2, stats_col3 = st.columns(3)
|
||||
stats_col1.metric("总调用次数", metrics_state.get("total_calls", 0))
|
||||
stats_col2.metric("Prompt Tokens", metrics_state.get("total_prompt_tokens", 0))
|
||||
stats_col3.metric("Completion Tokens", metrics_state.get("total_completion_tokens", 0))
|
||||
provider_calls = metrics_state.get("provider_calls", {})
|
||||
model_calls = metrics_state.get("model_calls", {})
|
||||
if provider_calls or model_calls:
|
||||
with st.expander("调用明细", expanded=False):
|
||||
if provider_calls:
|
||||
st.write("按 Provider:")
|
||||
st.json(provider_calls)
|
||||
if model_calls:
|
||||
st.write("按模型:")
|
||||
st.json(model_calls)
|
||||
|
||||
st.subheader("最近决策 (全局)")
|
||||
decision_feed = metrics_state.get("recent_decisions", []) or llm_recent_decisions(20)
|
||||
if decision_feed:
|
||||
for record in reversed(decision_feed[-20:]):
|
||||
ts_code = record.get("ts_code")
|
||||
trade_date = record.get("trade_date")
|
||||
action = record.get("action")
|
||||
confidence = record.get("confidence")
|
||||
summary = record.get("summary")
|
||||
departments = record.get("departments", {})
|
||||
st.markdown(
|
||||
f"**{trade_date} {ts_code}** → {action} (信心 {confidence:.2f})"
|
||||
)
|
||||
if summary:
|
||||
st.caption(f"摘要:{summary}")
|
||||
if departments:
|
||||
st.json(departments)
|
||||
st.divider()
|
||||
else:
|
||||
st.caption("暂无决策记录,执行回测或实时评估后可在此查看。")
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
date_rows = conn.execute(
|
||||
@ -323,7 +369,7 @@ def render_today_plan() -> None:
|
||||
st.subheader("部门意见")
|
||||
if dept_records:
|
||||
dept_df = pd.DataFrame(dept_records)
|
||||
st.dataframe(dept_df, use_container_width=True, hide_index=True)
|
||||
st.dataframe(dept_df, width='stretch', hide_index=True)
|
||||
for code, details in dept_details.items():
|
||||
with st.expander(f"{code} 补充详情", expanded=False):
|
||||
supplements = details.get("supplements", [])
|
||||
@ -345,7 +391,7 @@ def render_today_plan() -> None:
|
||||
st.subheader("代理评分")
|
||||
if agent_records:
|
||||
agent_df = pd.DataFrame(agent_records)
|
||||
st.dataframe(agent_df, use_container_width=True, hide_index=True)
|
||||
st.dataframe(agent_df, width='stretch', hide_index=True)
|
||||
else:
|
||||
st.info("暂无基础代理评分。")
|
||||
|
||||
@ -390,6 +436,41 @@ def render_backtest() -> None:
|
||||
|
||||
if st.button("运行回测"):
|
||||
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||||
decision_log_container = st.container()
|
||||
status_placeholder = st.empty()
|
||||
llm_stats_placeholder = st.empty()
|
||||
decision_entries: List[str] = []
|
||||
|
||||
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
|
||||
ts_label = trade_dt.isoformat()
|
||||
summary = decision.summary
|
||||
entry_lines = [
|
||||
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
|
||||
]
|
||||
if summary:
|
||||
entry_lines.append(f"摘要:{summary}")
|
||||
dep_highlights = []
|
||||
for dept_code, dept_decision in decision.department_decisions.items():
|
||||
dep_highlights.append(
|
||||
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
|
||||
)
|
||||
if dep_highlights:
|
||||
entry_lines.append("部门意见:" + ";".join(dep_highlights))
|
||||
decision_entries.append(" \n".join(entry_lines))
|
||||
decision_log_container.markdown("\n\n".join(decision_entries[-50:]))
|
||||
status_placeholder.info(
|
||||
f"最新决策:{ts_code} -> {decision.action.value} ({decision.confidence:.2f})"
|
||||
)
|
||||
stats = snapshot_llm_metrics()
|
||||
llm_stats_placeholder.json(
|
||||
{
|
||||
"LLM 调用次数": stats.get("total_calls", 0),
|
||||
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
|
||||
"Completion Tokens": stats.get("total_completion_tokens", 0),
|
||||
}
|
||||
)
|
||||
|
||||
reset_llm_metrics()
|
||||
with st.spinner("正在执行回测..."):
|
||||
try:
|
||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
@ -415,14 +496,24 @@ def render_backtest() -> None:
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
)
|
||||
result = run_backtest(cfg)
|
||||
result = run_backtest(cfg, decision_callback=_decision_callback)
|
||||
LOGGER.info(
|
||||
"回测完成:nav_records=%s trades=%s",
|
||||
len(result.nav_series),
|
||||
len(result.trades),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
st.success("回测执行完成,详见回测结果摘要。")
|
||||
st.success("回测执行完成,详见下方结果与统计。")
|
||||
metrics = snapshot_llm_metrics()
|
||||
llm_stats_placeholder.json(
|
||||
{
|
||||
"LLM 调用次数": metrics.get("total_calls", 0),
|
||||
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
|
||||
"Completion Tokens": metrics.get("total_completion_tokens", 0),
|
||||
"按 Provider": metrics.get("provider_calls", {}),
|
||||
"按模型": metrics.get("model_calls", {}),
|
||||
}
|
||||
)
|
||||
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||
@ -654,7 +745,7 @@ def render_settings() -> None:
|
||||
ensemble_rows,
|
||||
num_rows="dynamic",
|
||||
key="global_ensemble_editor",
|
||||
use_container_width=True,
|
||||
width='stretch',
|
||||
hide_index=True,
|
||||
column_config={
|
||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
||||
@ -735,7 +826,7 @@ def render_settings() -> None:
|
||||
dept_rows,
|
||||
num_rows="fixed",
|
||||
key="department_editor",
|
||||
use_container_width=True,
|
||||
width='stretch',
|
||||
hide_index=True,
|
||||
column_config={
|
||||
"code": st.column_config.TextColumn("编码", disabled=True),
|
||||
@ -998,7 +1089,7 @@ def render_tests() -> None:
|
||||
]
|
||||
)
|
||||
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
||||
st.plotly_chart(candle_fig, use_container_width=True)
|
||||
st.plotly_chart(candle_fig, width='stretch')
|
||||
|
||||
vol_fig = px.bar(
|
||||
df_reset,
|
||||
@ -1008,7 +1099,7 @@ def render_tests() -> None:
|
||||
title="成交量",
|
||||
)
|
||||
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||
st.plotly_chart(vol_fig, use_container_width=True)
|
||||
st.plotly_chart(vol_fig, width='stretch')
|
||||
|
||||
amt_fig = px.bar(
|
||||
df_reset,
|
||||
@ -1018,7 +1109,7 @@ def render_tests() -> None:
|
||||
title="成交额",
|
||||
)
|
||||
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||
st.plotly_chart(amt_fig, use_container_width=True)
|
||||
st.plotly_chart(amt_fig, width='stretch')
|
||||
|
||||
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
||||
box_fig = px.box(
|
||||
@ -1029,7 +1120,7 @@ def render_tests() -> None:
|
||||
title="月度收盘价分布",
|
||||
)
|
||||
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
||||
st.plotly_chart(box_fig, use_container_width=True)
|
||||
st.plotly_chart(box_fig, width='stretch')
|
||||
|
||||
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
||||
st.dataframe(df_reset.tail(20), width='stretch')
|
||||
|
||||
@ -199,6 +199,55 @@ class DataBroker:
|
||||
return False
|
||||
return row is not None
|
||||
|
||||
def fetch_table_rows(
|
||||
self,
|
||||
table: str,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
window: int,
|
||||
) -> List[Dict[str, object]]:
|
||||
if window <= 0:
|
||||
return []
|
||||
window = min(window, self.MAX_WINDOW)
|
||||
columns = self._get_table_columns(table)
|
||||
if not columns:
|
||||
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
column_list = ", ".join(columns)
|
||||
has_trade_date = "trade_date" in columns
|
||||
if has_trade_date:
|
||||
query = (
|
||||
f"SELECT {column_list} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
params: Tuple[object, ...] = (ts_code, trade_date, window)
|
||||
else:
|
||||
query = (
|
||||
f"SELECT {column_list} FROM {table} "
|
||||
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
||||
)
|
||||
params = (ts_code, window)
|
||||
|
||||
results: List[Dict[str, object]] = []
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"表查询失败 table=%s err=%s",
|
||||
table,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
|
||||
for row in rows:
|
||||
record = {col: row[col] for col in columns}
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
||||
normalized = _safe_split(field)
|
||||
if not normalized:
|
||||
@ -215,7 +264,7 @@ class DataBroker:
|
||||
return None
|
||||
return table, resolved
|
||||
|
||||
def _get_table_columns(self, table: str) -> Optional[set[str]]:
|
||||
def _get_table_columns(self, table: str) -> Optional[List[str]]:
|
||||
if not _is_safe_identifier(table):
|
||||
return None
|
||||
cache = getattr(self, "_column_cache", None)
|
||||
@ -234,7 +283,7 @@ class DataBroker:
|
||||
if not rows:
|
||||
cache[table] = None
|
||||
return None
|
||||
columns = {row["name"] for row in rows if row["name"]}
|
||||
columns = [row["name"] for row in rows if row["name"]]
|
||||
cache[table] = columns
|
||||
return columns
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user