refactor table data handling with column selection and summary statistics

This commit is contained in:
Your Name 2025-10-11 20:18:20 +08:00
parent 07535d1c19
commit a42f065332

View File

@ -28,6 +28,7 @@ class TableRequest:
name: str
window: int = 1
trade_date: Optional[str] = None
columns: Optional[Sequence[str]] = None
@dataclass
@ -83,6 +84,8 @@ class DepartmentAgent:
"news",
"index_daily",
]
MAX_TOOL_ROWS: ClassVar[int] = 60
MAX_TOOL_COLUMNS: ClassVar[int] = 12
def __init__(
self,
@ -449,6 +452,18 @@ class DepartmentAgent:
window,
auto_refresh=False # 避免在回测过程中触发自动补数
)
selected_columns: List[str] = []
if rows:
selected_columns = self._select_columns(list(rows[0].keys()), req.columns)
rows = [
self._format_row(row, selected_columns)
for row in rows
]
if len(rows) > self.MAX_TOOL_ROWS:
rows = rows[: self.MAX_TOOL_ROWS]
elif req.columns:
selected_columns = list(req.columns)[: self.MAX_TOOL_COLUMNS]
summary = self._summarize_rows(rows)
if rows:
preview = ", ".join(
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
@ -465,13 +480,100 @@ class DepartmentAgent:
"table": table,
"window": window,
"trade_date": trade_date,
"columns": selected_columns,
"rows": rows,
"summary": summary,
"row_limit": self.MAX_TOOL_ROWS,
}
)
delivered.add(key)
return lines, payload, delivered
def _select_columns(
self,
available: Sequence[str],
requested: Optional[Sequence[str]] = None,
) -> List[str]:
available_list = [str(col) for col in available]
selected: List[str] = []
if requested:
for col in requested:
name = str(col)
if name in available_list and name not in selected:
selected.append(name)
if not selected:
preferred = {
"trade_date",
"ts_code",
"close",
"open",
"high",
"low",
"vol",
"volume",
"amount",
"turnover",
"turnover_rate",
"turnover_rate_f",
"pct_chg",
"nav",
"cash",
"market_value",
"net_flow",
"exposure",
"sentiment",
"heat",
}
selected = [col for col in available_list if col in preferred]
if not selected:
selected = available_list[: self.MAX_TOOL_COLUMNS]
else:
selected = selected[: self.MAX_TOOL_COLUMNS]
if "trade_date" in available_list and "trade_date" not in selected:
selected = ["trade_date"] + [col for col in selected if col != "trade_date"]
return selected
@staticmethod
def _format_row(row: Mapping[str, Any], columns: Sequence[str]) -> Dict[str, Any]:
formatted: Dict[str, Any] = {}
for col in columns:
value = row.get(col)
if isinstance(value, float):
formatted[col] = round(value, 6)
elif isinstance(value, (int, str)):
formatted[col] = value
elif hasattr(value, "isoformat"):
try:
formatted[col] = value.isoformat() # type: ignore[attr-defined]
except Exception: # noqa: BLE001
formatted[col] = str(value)
else:
formatted[col] = str(value) if value is not None else None
return formatted
@staticmethod
def _summarize_rows(rows: Sequence[Mapping[str, Any]]) -> Dict[str, Any]:
summary: Dict[str, Any] = {}
if not rows:
return summary
numeric_columns: Dict[str, List[float]] = {}
for row in rows:
for key, value in row.items():
if isinstance(value, (int, float)):
numeric_columns.setdefault(key, []).append(float(value))
for key, values in numeric_columns.items():
if not values:
continue
summary[key] = {
"min": round(min(values), 6),
"max": round(max(values), 6),
"avg": round(sum(values) / len(values), 6),
"last": round(values[-1], 6),
}
summary["row_count"] = len(rows)
return summary
def _handle_tool_call(
self,
@ -513,11 +615,19 @@ class DepartmentAgent:
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
override_date = item.get("trade_date")
req_date = self._normalize_trade_date(override_date or base_trade_date)
columns_raw = item.get("columns") or item.get("fields")
columns: Optional[List[str]] = None
if isinstance(columns_raw, str):
columns = [col.strip() for col in columns_raw.split(",") if col and col.strip()]
elif isinstance(columns_raw, Sequence):
columns = [str(col).strip() for col in columns_raw if str(col).strip()]
if columns:
columns = columns[: self.MAX_TOOL_COLUMNS]
key = (name, window, req_date)
if key in delivered_requests:
skipped.append(name)
continue
requests.append(TableRequest(name=name, window=window, trade_date=req_date))
requests.append(TableRequest(name=name, window=window, trade_date=req_date, columns=columns))
if not requests:
return {
@ -591,6 +701,11 @@ class DepartmentAgent:
"pattern": r"^\\d{8}$",
"description": "覆盖默认交易日(格式 YYYYMMDD",
},
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "可选字段列表,未指定时自动选择常用列",
},
},
"required": ["name"],
},