update
This commit is contained in:
parent
199cb484b9
commit
de88b198b3
@ -35,10 +35,10 @@
|
|||||||
|
|
||||||
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
||||||
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
||||||
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 根据 schema 声明字段与窗口请求,系统用 `DataBroker` 校验、补数并回传结果,形成“请求→取数→复议”闭环。
|
3. **函数式工具调用**:通过 DeepSeek/OpenAI 的 function calling 暴露 `fetch_data` 工具,LLM 只需声明所需表(如 `daily`、`daily_basic`)及窗口,系统用 `DataBroker` 拉取整行数据、回传结果,形成“请求→取数→复议”闭环。
|
||||||
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||||
|
|
||||||
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成闭环。
|
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型只需按 schema 声明 `daily` / `daily_basic` 等表名与窗口,系统使用 `DataBroker` 一次性返回指定交易日的全部列,再带着查询结果进入下一轮提示,从而形成闭环。
|
||||||
|
|
||||||
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from app.agents.base import AgentAction
|
from app.agents.base import AgentAction
|
||||||
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
||||||
@ -17,9 +17,10 @@ LOG_EXTRA = {"stage": "department"}
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataRequest:
|
class TableRequest:
|
||||||
field: str
|
name: str
|
||||||
window: int = 1
|
window: int = 1
|
||||||
|
trade_date: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,6 +65,8 @@ class DepartmentDecision:
|
|||||||
class DepartmentAgent:
|
class DepartmentAgent:
|
||||||
"""Wraps LLM ensemble logic for a single analytical department."""
|
"""Wraps LLM ensemble logic for a single analytical department."""
|
||||||
|
|
||||||
|
ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: DepartmentSettings,
|
settings: DepartmentSettings,
|
||||||
@ -106,10 +109,7 @@ class DepartmentAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
transcript: List[str] = []
|
transcript: List[str] = []
|
||||||
delivered_requests: set[Tuple[str, int]] = {
|
delivered_requests: set[Tuple[str, int, str]] = set()
|
||||||
(field, 1)
|
|
||||||
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
|
||||||
}
|
|
||||||
|
|
||||||
primary_endpoint = llm_cfg.primary
|
primary_endpoint = llm_cfg.primary
|
||||||
final_message: Optional[Dict[str, Any]] = None
|
final_message: Optional[Dict[str, Any]] = None
|
||||||
@ -135,6 +135,14 @@ class DepartmentAgent:
|
|||||||
message = choice.get("message", {})
|
message = choice.get("message", {})
|
||||||
transcript.append(_message_to_text(message))
|
transcript.append(_message_to_text(message))
|
||||||
|
|
||||||
|
assistant_record: Dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": _extract_message_content(message),
|
||||||
|
}
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
assistant_record["tool_calls"] = message.get("tool_calls")
|
||||||
|
messages.append(assistant_record)
|
||||||
|
|
||||||
tool_calls = message.get("tool_calls") or []
|
tool_calls = message.get("tool_calls") or []
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
for call in tool_calls:
|
for call in tool_calls:
|
||||||
@ -151,7 +159,6 @@ class DepartmentAgent:
|
|||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": call.get("id"),
|
"tool_call_id": call.get("id"),
|
||||||
"name": call.get("function", {}).get("name"),
|
|
||||||
"content": json.dumps(tool_response, ensure_ascii=False),
|
"content": json.dumps(tool_response, ensure_ascii=False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -213,165 +220,50 @@ class DepartmentAgent:
|
|||||||
def _fulfill_data_requests(
|
def _fulfill_data_requests(
|
||||||
self,
|
self,
|
||||||
context: DepartmentContext,
|
context: DepartmentContext,
|
||||||
requests: Sequence[DataRequest],
|
requests: Sequence[TableRequest],
|
||||||
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]:
|
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]:
|
||||||
lines: List[str] = []
|
lines: List[str] = []
|
||||||
payload: List[Dict[str, Any]] = []
|
payload: List[Dict[str, Any]] = []
|
||||||
delivered: set[Tuple[str, int]] = set()
|
delivered: set[Tuple[str, int, str]] = set()
|
||||||
|
|
||||||
ts_code = context.ts_code
|
ts_code = context.ts_code
|
||||||
trade_date = self._normalize_trade_date(context.trade_date)
|
default_trade_date = self._normalize_trade_date(context.trade_date)
|
||||||
|
|
||||||
latest_groups: Dict[str, List[str]] = {}
|
|
||||||
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
|
|
||||||
values_map, db_alias_map, series_map = _build_context_lookup(context)
|
|
||||||
|
|
||||||
for req in requests:
|
for req in requests:
|
||||||
field = req.field.strip()
|
table = (req.name or "").strip().lower()
|
||||||
if not field:
|
if not table:
|
||||||
continue
|
continue
|
||||||
window = req.window
|
if table not in self.ALLOWED_TABLES:
|
||||||
resolved: Optional[Tuple[str, str]] = None
|
lines.append(f"- {table}: 不在允许的表列表中")
|
||||||
if "." in field:
|
continue
|
||||||
resolved = self._broker.resolve_field(field)
|
trade_date = self._normalize_trade_date(req.trade_date or default_trade_date)
|
||||||
elif field in db_alias_map:
|
window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||||
resolved = db_alias_map[field]
|
key = (table, window, trade_date)
|
||||||
|
if key in delivered:
|
||||||
if resolved:
|
lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求")
|
||||||
table, column = resolved
|
|
||||||
canonical = f"{table}.{column}"
|
|
||||||
if window <= 1:
|
|
||||||
latest_groups.setdefault(canonical, []).append(field)
|
|
||||||
delivered.add((field, 1))
|
|
||||||
delivered.add((canonical, 1))
|
|
||||||
else:
|
|
||||||
series_requests.append((req, resolved))
|
|
||||||
delivered.add((field, window))
|
|
||||||
delivered.add((canonical, window))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if field in values_map:
|
rows = self._broker.fetch_table_rows(table, ts_code, trade_date, window)
|
||||||
value = values_map[field]
|
if rows:
|
||||||
if window <= 1:
|
|
||||||
payload.append(
|
|
||||||
{
|
|
||||||
"field": field,
|
|
||||||
"window": 1,
|
|
||||||
"source": "context",
|
|
||||||
"values": [
|
|
||||||
{
|
|
||||||
"trade_date": context.trade_date,
|
|
||||||
"value": value,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
lines.append(f"- {field}: {value} (来自上下文)")
|
|
||||||
else:
|
|
||||||
series = series_map.get(field)
|
|
||||||
if series:
|
|
||||||
trimmed = series[: window]
|
|
||||||
payload.append(
|
|
||||||
{
|
|
||||||
"field": field,
|
|
||||||
"window": window,
|
|
||||||
"source": "context_series",
|
|
||||||
"values": [
|
|
||||||
{"trade_date": dt, "value": val}
|
|
||||||
for dt, val in trimmed
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
preview = ", ".join(
|
|
||||||
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
|
|
||||||
)
|
|
||||||
lines.append(
|
|
||||||
f"- {field} (window={window} 来自上下文序列): {preview}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
payload.append(
|
|
||||||
{
|
|
||||||
"field": field,
|
|
||||||
"window": window,
|
|
||||||
"source": "context",
|
|
||||||
"values": [
|
|
||||||
{
|
|
||||||
"trade_date": context.trade_date,
|
|
||||||
"value": value,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"warning": "仅提供当前值,缺少历史序列",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
lines.append(
|
|
||||||
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
|
|
||||||
)
|
|
||||||
delivered.add((field, window))
|
|
||||||
if field in db_alias_map:
|
|
||||||
resolved = db_alias_map[field]
|
|
||||||
canonical = f"{resolved[0]}.{resolved[1]}"
|
|
||||||
delivered.add((canonical, window))
|
|
||||||
continue
|
|
||||||
|
|
||||||
lines.append(f"- {field}: 字段不存在或不可用")
|
|
||||||
|
|
||||||
if latest_groups:
|
|
||||||
latest_values = self._broker.fetch_latest(
|
|
||||||
ts_code, trade_date, list(latest_groups.keys())
|
|
||||||
)
|
|
||||||
for canonical, aliases in latest_groups.items():
|
|
||||||
value = latest_values.get(canonical)
|
|
||||||
if value is None:
|
|
||||||
lines.append(f"- {canonical}: (数据缺失)")
|
|
||||||
else:
|
|
||||||
lines.append(f"- {canonical}: {value}")
|
|
||||||
for alias in aliases:
|
|
||||||
payload.append(
|
|
||||||
{
|
|
||||||
"field": alias,
|
|
||||||
"window": 1,
|
|
||||||
"source": "database",
|
|
||||||
"values": [
|
|
||||||
{
|
|
||||||
"trade_date": trade_date,
|
|
||||||
"value": value,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
for req, resolved in series_requests:
|
|
||||||
table, column = resolved
|
|
||||||
series = self._broker.fetch_series(
|
|
||||||
table,
|
|
||||||
column,
|
|
||||||
ts_code,
|
|
||||||
trade_date,
|
|
||||||
window=req.window,
|
|
||||||
)
|
|
||||||
if series:
|
|
||||||
preview = ", ".join(
|
preview = ", ".join(
|
||||||
f"{dt}:{val:.4f}"
|
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
|
||||||
for dt, val in series[: min(len(series), 5)]
|
|
||||||
)
|
)
|
||||||
lines.append(
|
lines.append(
|
||||||
f"- {req.field} (window={req.window}): {preview}"
|
f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)} 行 {preview}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lines.append(
|
lines.append(
|
||||||
f"- {req.field} (window={req.window}): (数据缺失)"
|
f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)"
|
||||||
)
|
)
|
||||||
payload.append(
|
payload.append(
|
||||||
{
|
{
|
||||||
"field": req.field,
|
"table": table,
|
||||||
"window": req.window,
|
"window": window,
|
||||||
"source": "database",
|
"trade_date": trade_date,
|
||||||
"values": [
|
"rows": rows,
|
||||||
{"trade_date": dt, "value": val}
|
|
||||||
for dt, val in series
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
delivered.add(key)
|
||||||
|
|
||||||
return lines, payload, delivered
|
return lines, payload, delivered
|
||||||
|
|
||||||
@ -380,9 +272,9 @@ class DepartmentAgent:
|
|||||||
self,
|
self,
|
||||||
context: DepartmentContext,
|
context: DepartmentContext,
|
||||||
call: Mapping[str, Any],
|
call: Mapping[str, Any],
|
||||||
delivered_requests: set[Tuple[str, int]],
|
delivered_requests: set[Tuple[str, int, str]],
|
||||||
round_idx: int,
|
round_idx: int,
|
||||||
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
|
) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]:
|
||||||
function_block = call.get("function") or {}
|
function_block = call.get("function") or {}
|
||||||
name = function_block.get("name") or ""
|
name = function_block.get("name") or ""
|
||||||
if name != "fetch_data":
|
if name != "fetch_data":
|
||||||
@ -398,23 +290,29 @@ class DepartmentAgent:
|
|||||||
}, set()
|
}, set()
|
||||||
|
|
||||||
args = _parse_tool_arguments(function_block.get("arguments"))
|
args = _parse_tool_arguments(function_block.get("arguments"))
|
||||||
raw_requests = args.get("requests") or []
|
base_trade_date = self._normalize_trade_date(
|
||||||
requests: List[DataRequest] = []
|
args.get("trade_date") or context.trade_date
|
||||||
|
)
|
||||||
|
raw_requests = args.get("tables") or []
|
||||||
|
requests: List[TableRequest] = []
|
||||||
skipped: List[str] = []
|
skipped: List[str] = []
|
||||||
for item in raw_requests:
|
for item in raw_requests:
|
||||||
field = str(item.get("field", "")).strip()
|
name = str(item.get("name", "")).strip().lower()
|
||||||
if not field:
|
if not name:
|
||||||
continue
|
continue
|
||||||
|
window_raw = item.get("window")
|
||||||
try:
|
try:
|
||||||
window = int(item.get("window", 1))
|
window = int(window_raw) if window_raw is not None else 1
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
window = 1
|
window = 1
|
||||||
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||||
key = (field, window)
|
override_date = item.get("trade_date")
|
||||||
|
req_date = self._normalize_trade_date(override_date or base_trade_date)
|
||||||
|
key = (name, window, req_date)
|
||||||
if key in delivered_requests:
|
if key in delivered_requests:
|
||||||
skipped.append(field)
|
skipped.append(name)
|
||||||
continue
|
continue
|
||||||
requests.append(DataRequest(field=field, window=window))
|
requests.append(TableRequest(name=name, window=window, trade_date=req_date))
|
||||||
|
|
||||||
if not requests:
|
if not requests:
|
||||||
return {
|
return {
|
||||||
@ -461,34 +359,44 @@ class DepartmentAgent:
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "fetch_data",
|
"name": "fetch_data",
|
||||||
"description": (
|
"description": (
|
||||||
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
|
"根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily' 与 'daily_basic' 表。"
|
||||||
"window 表示希望返回的最近数据点数量。"
|
|
||||||
),
|
),
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"requests": {
|
"tables": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"field": {
|
"name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "数据字段,格式为 table.column",
|
"enum": self.ALLOWED_TABLES,
|
||||||
|
"description": "表名,例如 daily 或 daily_basic",
|
||||||
},
|
},
|
||||||
"window": {
|
"window": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 1,
|
"minimum": 1,
|
||||||
"maximum": max_window,
|
"maximum": max_window,
|
||||||
"description": "返回最近多少个数据点,默认为 1",
|
"description": "向前回溯的记录条数,默认为 1",
|
||||||
|
},
|
||||||
|
"trade_date": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": r"^\\d{8}$",
|
||||||
|
"description": "覆盖默认交易日(格式 YYYYMMDD)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["field"],
|
"required": ["name"],
|
||||||
},
|
},
|
||||||
"minItems": 1,
|
"minItems": 1,
|
||||||
}
|
},
|
||||||
|
"trade_date": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": r"^\\d{8}$",
|
||||||
|
"description": "默认交易日(格式 YYYYMMDD)",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["requests"],
|
"required": ["tables"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -555,39 +463,6 @@ def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
|
|
||||||
raw_requests = payload.get("data_requests")
|
|
||||||
requests: List[DataRequest] = []
|
|
||||||
if not isinstance(raw_requests, list):
|
|
||||||
return requests
|
|
||||||
seen: set[Tuple[str, int]] = set()
|
|
||||||
for item in raw_requests:
|
|
||||||
field = ""
|
|
||||||
window = 1
|
|
||||||
if isinstance(item, str):
|
|
||||||
field = item.strip()
|
|
||||||
elif isinstance(item, Mapping):
|
|
||||||
candidate = item.get("field")
|
|
||||||
if candidate is None:
|
|
||||||
continue
|
|
||||||
field = str(candidate).strip()
|
|
||||||
try:
|
|
||||||
window = int(item.get("window", 1))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
window = 1
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
if not field:
|
|
||||||
continue
|
|
||||||
window = max(1, window)
|
|
||||||
key = (field, window)
|
|
||||||
if key in seen:
|
|
||||||
continue
|
|
||||||
seen.add(key)
|
|
||||||
requests.append(DataRequest(field=field, window=window))
|
|
||||||
return requests
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||||
if isinstance(payload, dict):
|
if isinstance(payload, dict):
|
||||||
return dict(payload)
|
return dict(payload)
|
||||||
@ -636,38 +511,6 @@ def _extract_message_content(message: Mapping[str, Any]) -> str:
|
|||||||
return json.dumps(message, ensure_ascii=False)
|
return json.dumps(message, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def _build_context_lookup(
|
|
||||||
context: DepartmentContext,
|
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
|
|
||||||
values: Dict[str, Any] = {}
|
|
||||||
db_alias: Dict[str, Tuple[str, str]] = {}
|
|
||||||
series_map: Dict[str, List[Tuple[str, float]]] = {}
|
|
||||||
|
|
||||||
for source in (context.features or {}, context.market_snapshot or {}):
|
|
||||||
for key, value in source.items():
|
|
||||||
values[str(key)] = value
|
|
||||||
|
|
||||||
scope_values = context.raw.get("scope_values") or {}
|
|
||||||
for key, value in scope_values.items():
|
|
||||||
key_str = str(key)
|
|
||||||
values[key_str] = value
|
|
||||||
if "." in key_str:
|
|
||||||
table, column = key_str.split(".", 1)
|
|
||||||
db_alias.setdefault(column, (table, column))
|
|
||||||
db_alias.setdefault(key_str, (table, column))
|
|
||||||
values.setdefault(column, value)
|
|
||||||
|
|
||||||
close_series = context.raw.get("close_series") or []
|
|
||||||
if isinstance(close_series, list) and close_series:
|
|
||||||
series_map["close"] = close_series
|
|
||||||
series_map["daily.close"] = close_series
|
|
||||||
|
|
||||||
turnover_series = context.raw.get("turnover_series") or []
|
|
||||||
if isinstance(turnover_series, list) and turnover_series:
|
|
||||||
series_map["turnover_rate"] = turnover_series
|
|
||||||
series_map["daily_basic.turnover_rate"] = turnover_series
|
|
||||||
|
|
||||||
return values, db_alias, series_map
|
|
||||||
|
|
||||||
|
|
||||||
class DepartmentManager:
|
class DepartmentManager:
|
||||||
|
|||||||
@ -5,11 +5,12 @@ import json
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from statistics import mean, pstdev
|
from statistics import mean, pstdev
|
||||||
from typing import Any, Dict, Iterable, List, Mapping
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from app.agents.base import AgentContext
|
from app.agents.base import AgentContext
|
||||||
from app.agents.departments import DepartmentManager
|
from app.agents.departments import DepartmentManager
|
||||||
from app.agents.game import Decision, decide
|
from app.agents.game import Decision, decide
|
||||||
|
from app.llm.metrics import record_decision as metrics_record_decision
|
||||||
from app.agents.registry import default_agents
|
from app.agents.registry import default_agents
|
||||||
from app.utils.data_access import DataBroker
|
from app.utils.data_access import DataBroker
|
||||||
from app.utils.config import get_config
|
from app.utils.config import get_config
|
||||||
@ -224,7 +225,12 @@ class BacktestEngine:
|
|||||||
|
|
||||||
return feature_map
|
return feature_map
|
||||||
|
|
||||||
def simulate_day(self, trade_date: date, state: PortfolioState) -> List[Decision]:
|
def simulate_day(
|
||||||
|
self,
|
||||||
|
trade_date: date,
|
||||||
|
state: PortfolioState,
|
||||||
|
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||||
|
) -> List[Decision]:
|
||||||
feature_map = self.load_market_data(trade_date)
|
feature_map = self.load_market_data(trade_date)
|
||||||
decisions: List[Decision] = []
|
decisions: List[Decision] = []
|
||||||
for ts_code, payload in feature_map.items():
|
for ts_code, payload in feature_map.items():
|
||||||
@ -247,6 +253,26 @@ class BacktestEngine:
|
|||||||
)
|
)
|
||||||
decisions.append(decision)
|
decisions.append(decision)
|
||||||
self.record_agent_state(context, decision)
|
self.record_agent_state(context, decision)
|
||||||
|
if decision_callback:
|
||||||
|
try:
|
||||||
|
decision_callback(ts_code, trade_date, context, decision)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
|
||||||
|
try:
|
||||||
|
metrics_record_decision(
|
||||||
|
ts_code=ts_code,
|
||||||
|
trade_date=context.trade_date,
|
||||||
|
action=decision.action.value,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
summary=decision.summary,
|
||||||
|
source="backtest",
|
||||||
|
departments={
|
||||||
|
code: dept.to_dict()
|
||||||
|
for code, dept in decision.department_decisions.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
|
||||||
# TODO: translate decisions into fills, holdings, and NAV updates.
|
# TODO: translate decisions into fills, holdings, and NAV updates.
|
||||||
_ = state
|
_ = state
|
||||||
return decisions
|
return decisions
|
||||||
@ -357,20 +383,27 @@ class BacktestEngine:
|
|||||||
_ = payload
|
_ = payload
|
||||||
# TODO: persist payload into bt_trades / audit tables when schema is ready.
|
# TODO: persist payload into bt_trades / audit tables when schema is ready.
|
||||||
|
|
||||||
def run(self) -> BacktestResult:
|
def run(
|
||||||
|
self,
|
||||||
|
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||||
|
) -> BacktestResult:
|
||||||
state = PortfolioState()
|
state = PortfolioState()
|
||||||
result = BacktestResult()
|
result = BacktestResult()
|
||||||
current = self.cfg.start_date
|
current = self.cfg.start_date
|
||||||
while current <= self.cfg.end_date:
|
while current <= self.cfg.end_date:
|
||||||
decisions = self.simulate_day(current, state)
|
decisions = self.simulate_day(current, state, decision_callback)
|
||||||
_ = decisions
|
_ = decisions
|
||||||
current = date.fromordinal(current.toordinal() + 1)
|
current = date.fromordinal(current.toordinal() + 1)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def run_backtest(cfg: BtConfig) -> BacktestResult:
|
def run_backtest(
|
||||||
|
cfg: BtConfig,
|
||||||
|
*,
|
||||||
|
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||||
|
) -> BacktestResult:
|
||||||
engine = BacktestEngine(cfg)
|
engine = BacktestEngine(cfg)
|
||||||
result = engine.run()
|
result = engine.run(decision_callback=decision_callback)
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
_ = conn
|
_ = conn
|
||||||
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from app.utils.config import (
|
|||||||
LLMEndpoint,
|
LLMEndpoint,
|
||||||
get_config,
|
get_config,
|
||||||
)
|
)
|
||||||
|
from app.llm.metrics import record_call
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -220,11 +221,12 @@ def call_endpoint_with_messages(
|
|||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||||||
|
record_call(provider_key, model)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||||
return _request_openai_chat(
|
data = _request_openai_chat(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=model,
|
model=model,
|
||||||
@ -234,6 +236,11 @@ def call_endpoint_with_messages(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
usage = data.get("usage", {}) if isinstance(data, dict) else {}
|
||||||
|
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
|
||||||
|
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
|
||||||
|
record_call(provider_key, model, prompt_tokens, completion_tokens)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _normalize_response(text: str) -> str:
|
def _normalize_response(text: str) -> str:
|
||||||
|
|||||||
114
app/llm/metrics.py
Normal file
114
app/llm/metrics.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
"""Simple runtime metrics collector for LLM calls."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Deque, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Metrics:
|
||||||
|
total_calls: int = 0
|
||||||
|
total_prompt_tokens: int = 0
|
||||||
|
total_completion_tokens: int = 0
|
||||||
|
provider_calls: Dict[str, int] = field(default_factory=dict)
|
||||||
|
model_calls: Dict[str, int] = field(default_factory=dict)
|
||||||
|
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
|
||||||
|
decision_action_counts: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
_METRICS = _Metrics()
|
||||||
|
_LOCK = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def record_call(
|
||||||
|
provider: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
prompt_tokens: Optional[int] = None,
|
||||||
|
completion_tokens: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record a single LLM API invocation."""
|
||||||
|
|
||||||
|
normalized_provider = (provider or "unknown").lower()
|
||||||
|
normalized_model = (model or "").strip()
|
||||||
|
with _LOCK:
|
||||||
|
_METRICS.total_calls += 1
|
||||||
|
_METRICS.provider_calls[normalized_provider] = (
|
||||||
|
_METRICS.provider_calls.get(normalized_provider, 0) + 1
|
||||||
|
)
|
||||||
|
if normalized_model:
|
||||||
|
_METRICS.model_calls[normalized_model] = (
|
||||||
|
_METRICS.model_calls.get(normalized_model, 0) + 1
|
||||||
|
)
|
||||||
|
if prompt_tokens:
|
||||||
|
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
||||||
|
if completion_tokens:
|
||||||
|
_METRICS.total_completion_tokens += int(completion_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||||
|
"""Return a snapshot of current metrics. Optionally reset counters."""
|
||||||
|
|
||||||
|
with _LOCK:
|
||||||
|
data = {
|
||||||
|
"total_calls": _METRICS.total_calls,
|
||||||
|
"total_prompt_tokens": _METRICS.total_prompt_tokens,
|
||||||
|
"total_completion_tokens": _METRICS.total_completion_tokens,
|
||||||
|
"provider_calls": dict(_METRICS.provider_calls),
|
||||||
|
"model_calls": dict(_METRICS.model_calls),
|
||||||
|
"decision_action_counts": dict(_METRICS.decision_action_counts),
|
||||||
|
"recent_decisions": list(_METRICS.decisions),
|
||||||
|
}
|
||||||
|
if reset:
|
||||||
|
_METRICS.total_calls = 0
|
||||||
|
_METRICS.total_prompt_tokens = 0
|
||||||
|
_METRICS.total_completion_tokens = 0
|
||||||
|
_METRICS.provider_calls.clear()
|
||||||
|
_METRICS.model_calls.clear()
|
||||||
|
_METRICS.decision_action_counts.clear()
|
||||||
|
_METRICS.decisions.clear()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def reset() -> None:
|
||||||
|
"""Reset all collected metrics."""
|
||||||
|
|
||||||
|
snapshot(reset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def record_decision(
|
||||||
|
*,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
action: str,
|
||||||
|
confidence: float,
|
||||||
|
summary: str,
|
||||||
|
source: str,
|
||||||
|
departments: Optional[Dict[str, object]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record a high-level decision for later inspection."""
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"ts_code": ts_code,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"action": action,
|
||||||
|
"confidence": confidence,
|
||||||
|
"summary": summary,
|
||||||
|
"source": source,
|
||||||
|
"departments": departments or {},
|
||||||
|
}
|
||||||
|
with _LOCK:
|
||||||
|
_METRICS.decisions.append(record)
|
||||||
|
_METRICS.decision_action_counts[action] = (
|
||||||
|
_METRICS.decision_action_counts.get(action, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
|
||||||
|
"""Return the most recent decisions up to limit."""
|
||||||
|
|
||||||
|
with _LOCK:
|
||||||
|
if limit <= 0:
|
||||||
|
return []
|
||||||
|
return list(_METRICS.decisions)[-limit:]
|
||||||
@ -62,7 +62,7 @@ def department_prompt(
|
|||||||
"risks": ["风险点", "..."]
|
"risks": ["风险点", "..."]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
如需额外数据,请调用可用工具 `fetch_data` 并在参数中提供 `requests` 数组(元素包含 `field` 以及可选的 `window`);`field` 必须符合【可用数据范围】,`window` 默认为 1。
|
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表;在参数中填写 `tables` 数组,元素包含 `name`(表名)与可选的 `window`(向前回溯的条数,默认 1)及 `trade_date`(YYYYMMDD,默认本次交易日)。
|
||||||
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
工具返回的数据会在后续消息中提供,请在获取所有必要信息后再给出最终 JSON 答复。
|
||||||
|
|
||||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||||
|
|||||||
@ -20,11 +20,18 @@ import requests
|
|||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from app.agents.base import AgentContext
|
||||||
|
from app.agents.game import Decision
|
||||||
from app.backtest.engine import BtConfig, run_backtest
|
from app.backtest.engine import BtConfig, run_backtest
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.ingest.checker import run_boot_check
|
from app.ingest.checker import run_boot_check
|
||||||
from app.ingest.tushare import FetchJob, run_ingestion
|
from app.ingest.tushare import FetchJob, run_ingestion
|
||||||
from app.llm.client import llm_config_snapshot, run_llm
|
from app.llm.client import llm_config_snapshot, run_llm
|
||||||
|
from app.llm.metrics import (
|
||||||
|
reset as reset_llm_metrics,
|
||||||
|
snapshot as snapshot_llm_metrics,
|
||||||
|
recent_decisions as llm_recent_decisions,
|
||||||
|
)
|
||||||
from app.utils.config import (
|
from app.utils.config import (
|
||||||
ALLOWED_LLM_STRATEGIES,
|
ALLOWED_LLM_STRATEGIES,
|
||||||
DEFAULT_LLM_BASE_URLS,
|
DEFAULT_LLM_BASE_URLS,
|
||||||
@ -152,6 +159,45 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
|
|||||||
def render_today_plan() -> None:
|
def render_today_plan() -> None:
|
||||||
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
|
||||||
st.header("今日计划")
|
st.header("今日计划")
|
||||||
|
st.caption("统计数据基于最近一次渲染,刷新页面即可获取最新结果。")
|
||||||
|
|
||||||
|
metrics_state = snapshot_llm_metrics()
|
||||||
|
st.subheader("LLM 调用统计 (实时)")
|
||||||
|
stats_col1, stats_col2, stats_col3 = st.columns(3)
|
||||||
|
stats_col1.metric("总调用次数", metrics_state.get("total_calls", 0))
|
||||||
|
stats_col2.metric("Prompt Tokens", metrics_state.get("total_prompt_tokens", 0))
|
||||||
|
stats_col3.metric("Completion Tokens", metrics_state.get("total_completion_tokens", 0))
|
||||||
|
provider_calls = metrics_state.get("provider_calls", {})
|
||||||
|
model_calls = metrics_state.get("model_calls", {})
|
||||||
|
if provider_calls or model_calls:
|
||||||
|
with st.expander("调用明细", expanded=False):
|
||||||
|
if provider_calls:
|
||||||
|
st.write("按 Provider:")
|
||||||
|
st.json(provider_calls)
|
||||||
|
if model_calls:
|
||||||
|
st.write("按模型:")
|
||||||
|
st.json(model_calls)
|
||||||
|
|
||||||
|
st.subheader("最近决策 (全局)")
|
||||||
|
decision_feed = metrics_state.get("recent_decisions", []) or llm_recent_decisions(20)
|
||||||
|
if decision_feed:
|
||||||
|
for record in reversed(decision_feed[-20:]):
|
||||||
|
ts_code = record.get("ts_code")
|
||||||
|
trade_date = record.get("trade_date")
|
||||||
|
action = record.get("action")
|
||||||
|
confidence = record.get("confidence")
|
||||||
|
summary = record.get("summary")
|
||||||
|
departments = record.get("departments", {})
|
||||||
|
st.markdown(
|
||||||
|
f"**{trade_date} {ts_code}** → {action} (信心 {confidence:.2f})"
|
||||||
|
)
|
||||||
|
if summary:
|
||||||
|
st.caption(f"摘要:{summary}")
|
||||||
|
if departments:
|
||||||
|
st.json(departments)
|
||||||
|
st.divider()
|
||||||
|
else:
|
||||||
|
st.caption("暂无决策记录,执行回测或实时评估后可在此查看。")
|
||||||
try:
|
try:
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
date_rows = conn.execute(
|
date_rows = conn.execute(
|
||||||
@ -323,7 +369,7 @@ def render_today_plan() -> None:
|
|||||||
st.subheader("部门意见")
|
st.subheader("部门意见")
|
||||||
if dept_records:
|
if dept_records:
|
||||||
dept_df = pd.DataFrame(dept_records)
|
dept_df = pd.DataFrame(dept_records)
|
||||||
st.dataframe(dept_df, use_container_width=True, hide_index=True)
|
st.dataframe(dept_df, width='stretch', hide_index=True)
|
||||||
for code, details in dept_details.items():
|
for code, details in dept_details.items():
|
||||||
with st.expander(f"{code} 补充详情", expanded=False):
|
with st.expander(f"{code} 补充详情", expanded=False):
|
||||||
supplements = details.get("supplements", [])
|
supplements = details.get("supplements", [])
|
||||||
@ -345,7 +391,7 @@ def render_today_plan() -> None:
|
|||||||
st.subheader("代理评分")
|
st.subheader("代理评分")
|
||||||
if agent_records:
|
if agent_records:
|
||||||
agent_df = pd.DataFrame(agent_records)
|
agent_df = pd.DataFrame(agent_records)
|
||||||
st.dataframe(agent_df, use_container_width=True, hide_index=True)
|
st.dataframe(agent_df, width='stretch', hide_index=True)
|
||||||
else:
|
else:
|
||||||
st.info("暂无基础代理评分。")
|
st.info("暂无基础代理评分。")
|
||||||
|
|
||||||
@ -390,6 +436,41 @@ def render_backtest() -> None:
|
|||||||
|
|
||||||
if st.button("运行回测"):
|
if st.button("运行回测"):
|
||||||
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||||||
|
decision_log_container = st.container()
|
||||||
|
status_placeholder = st.empty()
|
||||||
|
llm_stats_placeholder = st.empty()
|
||||||
|
decision_entries: List[str] = []
|
||||||
|
|
||||||
|
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
|
||||||
|
ts_label = trade_dt.isoformat()
|
||||||
|
summary = decision.summary
|
||||||
|
entry_lines = [
|
||||||
|
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
|
||||||
|
]
|
||||||
|
if summary:
|
||||||
|
entry_lines.append(f"摘要:{summary}")
|
||||||
|
dep_highlights = []
|
||||||
|
for dept_code, dept_decision in decision.department_decisions.items():
|
||||||
|
dep_highlights.append(
|
||||||
|
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
|
||||||
|
)
|
||||||
|
if dep_highlights:
|
||||||
|
entry_lines.append("部门意见:" + ";".join(dep_highlights))
|
||||||
|
decision_entries.append(" \n".join(entry_lines))
|
||||||
|
decision_log_container.markdown("\n\n".join(decision_entries[-50:]))
|
||||||
|
status_placeholder.info(
|
||||||
|
f"最新决策:{ts_code} -> {decision.action.value} ({decision.confidence:.2f})"
|
||||||
|
)
|
||||||
|
stats = snapshot_llm_metrics()
|
||||||
|
llm_stats_placeholder.json(
|
||||||
|
{
|
||||||
|
"LLM 调用次数": stats.get("total_calls", 0),
|
||||||
|
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
|
||||||
|
"Completion Tokens": stats.get("total_completion_tokens", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
reset_llm_metrics()
|
||||||
with st.spinner("正在执行回测..."):
|
with st.spinner("正在执行回测..."):
|
||||||
try:
|
try:
|
||||||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
@ -415,14 +496,24 @@ def render_backtest() -> None:
|
|||||||
"hold_days": int(hold_days),
|
"hold_days": int(hold_days),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
result = run_backtest(cfg)
|
result = run_backtest(cfg, decision_callback=_decision_callback)
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"回测完成:nav_records=%s trades=%s",
|
"回测完成:nav_records=%s trades=%s",
|
||||||
len(result.nav_series),
|
len(result.nav_series),
|
||||||
len(result.trades),
|
len(result.trades),
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
st.success("回测执行完成,详见回测结果摘要。")
|
st.success("回测执行完成,详见下方结果与统计。")
|
||||||
|
metrics = snapshot_llm_metrics()
|
||||||
|
llm_stats_placeholder.json(
|
||||||
|
{
|
||||||
|
"LLM 调用次数": metrics.get("total_calls", 0),
|
||||||
|
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
|
||||||
|
"Completion Tokens": metrics.get("total_completion_tokens", 0),
|
||||||
|
"按 Provider": metrics.get("provider_calls", {}),
|
||||||
|
"按模型": metrics.get("model_calls", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||||
@ -654,7 +745,7 @@ def render_settings() -> None:
|
|||||||
ensemble_rows,
|
ensemble_rows,
|
||||||
num_rows="dynamic",
|
num_rows="dynamic",
|
||||||
key="global_ensemble_editor",
|
key="global_ensemble_editor",
|
||||||
use_container_width=True,
|
width='stretch',
|
||||||
hide_index=True,
|
hide_index=True,
|
||||||
column_config={
|
column_config={
|
||||||
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
"provider": st.column_config.SelectboxColumn("Provider", options=provider_keys),
|
||||||
@ -735,7 +826,7 @@ def render_settings() -> None:
|
|||||||
dept_rows,
|
dept_rows,
|
||||||
num_rows="fixed",
|
num_rows="fixed",
|
||||||
key="department_editor",
|
key="department_editor",
|
||||||
use_container_width=True,
|
width='stretch',
|
||||||
hide_index=True,
|
hide_index=True,
|
||||||
column_config={
|
column_config={
|
||||||
"code": st.column_config.TextColumn("编码", disabled=True),
|
"code": st.column_config.TextColumn("编码", disabled=True),
|
||||||
@ -998,7 +1089,7 @@ def render_tests() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(candle_fig, use_container_width=True)
|
st.plotly_chart(candle_fig, width='stretch')
|
||||||
|
|
||||||
vol_fig = px.bar(
|
vol_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -1008,7 +1099,7 @@ def render_tests() -> None:
|
|||||||
title="成交量",
|
title="成交量",
|
||||||
)
|
)
|
||||||
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(vol_fig, use_container_width=True)
|
st.plotly_chart(vol_fig, width='stretch')
|
||||||
|
|
||||||
amt_fig = px.bar(
|
amt_fig = px.bar(
|
||||||
df_reset,
|
df_reset,
|
||||||
@ -1018,7 +1109,7 @@ def render_tests() -> None:
|
|||||||
title="成交额",
|
title="成交额",
|
||||||
)
|
)
|
||||||
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(amt_fig, use_container_width=True)
|
st.plotly_chart(amt_fig, width='stretch')
|
||||||
|
|
||||||
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
|
||||||
box_fig = px.box(
|
box_fig = px.box(
|
||||||
@ -1029,7 +1120,7 @@ def render_tests() -> None:
|
|||||||
title="月度收盘价分布",
|
title="月度收盘价分布",
|
||||||
)
|
)
|
||||||
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
|
||||||
st.plotly_chart(box_fig, use_container_width=True)
|
st.plotly_chart(box_fig, width='stretch')
|
||||||
|
|
||||||
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
|
||||||
st.dataframe(df_reset.tail(20), width='stretch')
|
st.dataframe(df_reset.tail(20), width='stretch')
|
||||||
|
|||||||
@ -199,6 +199,55 @@ class DataBroker:
|
|||||||
return False
|
return False
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
|
def fetch_table_rows(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
ts_code: str,
|
||||||
|
trade_date: str,
|
||||||
|
window: int,
|
||||||
|
) -> List[Dict[str, object]]:
|
||||||
|
if window <= 0:
|
||||||
|
return []
|
||||||
|
window = min(window, self.MAX_WINDOW)
|
||||||
|
columns = self._get_table_columns(table)
|
||||||
|
if not columns:
|
||||||
|
LOGGER.debug("表不存在或无字段 table=%s", table, extra=LOG_EXTRA)
|
||||||
|
return []
|
||||||
|
|
||||||
|
column_list = ", ".join(columns)
|
||||||
|
has_trade_date = "trade_date" in columns
|
||||||
|
if has_trade_date:
|
||||||
|
query = (
|
||||||
|
f"SELECT {column_list} FROM {table} "
|
||||||
|
"WHERE ts_code = ? AND trade_date <= ? "
|
||||||
|
"ORDER BY trade_date DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
params: Tuple[object, ...] = (ts_code, trade_date, window)
|
||||||
|
else:
|
||||||
|
query = (
|
||||||
|
f"SELECT {column_list} FROM {table} "
|
||||||
|
"WHERE ts_code = ? ORDER BY rowid DESC LIMIT ?"
|
||||||
|
)
|
||||||
|
params = (ts_code, window)
|
||||||
|
|
||||||
|
results: List[Dict[str, object]] = []
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
try:
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug(
|
||||||
|
"表查询失败 table=%s err=%s",
|
||||||
|
table,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
record = {col: row[col] for col in columns}
|
||||||
|
results.append(record)
|
||||||
|
return results
|
||||||
|
|
||||||
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
||||||
normalized = _safe_split(field)
|
normalized = _safe_split(field)
|
||||||
if not normalized:
|
if not normalized:
|
||||||
@ -215,7 +264,7 @@ class DataBroker:
|
|||||||
return None
|
return None
|
||||||
return table, resolved
|
return table, resolved
|
||||||
|
|
||||||
def _get_table_columns(self, table: str) -> Optional[set[str]]:
|
def _get_table_columns(self, table: str) -> Optional[List[str]]:
|
||||||
if not _is_safe_identifier(table):
|
if not _is_safe_identifier(table):
|
||||||
return None
|
return None
|
||||||
cache = getattr(self, "_column_cache", None)
|
cache = getattr(self, "_column_cache", None)
|
||||||
@ -234,7 +283,7 @@ class DataBroker:
|
|||||||
if not rows:
|
if not rows:
|
||||||
cache[table] = None
|
cache[table] = None
|
||||||
return None
|
return None
|
||||||
columns = {row["name"] for row in rows if row["name"]}
|
columns = [row["name"] for row in rows if row["name"]]
|
||||||
cache[table] = columns
|
cache[table] = columns
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user