refactor table data handling with column selection and summary statistics
This commit is contained in:
parent
07535d1c19
commit
a42f065332
@ -28,6 +28,7 @@ class TableRequest:
|
|||||||
name: str
|
name: str
|
||||||
window: int = 1
|
window: int = 1
|
||||||
trade_date: Optional[str] = None
|
trade_date: Optional[str] = None
|
||||||
|
columns: Optional[Sequence[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -83,6 +84,8 @@ class DepartmentAgent:
|
|||||||
"news",
|
"news",
|
||||||
"index_daily",
|
"index_daily",
|
||||||
]
|
]
|
||||||
|
MAX_TOOL_ROWS: ClassVar[int] = 60
|
||||||
|
MAX_TOOL_COLUMNS: ClassVar[int] = 12
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -449,6 +452,18 @@ class DepartmentAgent:
|
|||||||
window,
|
window,
|
||||||
auto_refresh=False # 避免在回测过程中触发自动补数
|
auto_refresh=False # 避免在回测过程中触发自动补数
|
||||||
)
|
)
|
||||||
|
selected_columns: List[str] = []
|
||||||
|
if rows:
|
||||||
|
selected_columns = self._select_columns(list(rows[0].keys()), req.columns)
|
||||||
|
rows = [
|
||||||
|
self._format_row(row, selected_columns)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
if len(rows) > self.MAX_TOOL_ROWS:
|
||||||
|
rows = rows[: self.MAX_TOOL_ROWS]
|
||||||
|
elif req.columns:
|
||||||
|
selected_columns = list(req.columns)[: self.MAX_TOOL_COLUMNS]
|
||||||
|
summary = self._summarize_rows(rows)
|
||||||
if rows:
|
if rows:
|
||||||
preview = ", ".join(
|
preview = ", ".join(
|
||||||
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
|
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
|
||||||
@ -465,13 +480,100 @@ class DepartmentAgent:
|
|||||||
"table": table,
|
"table": table,
|
||||||
"window": window,
|
"window": window,
|
||||||
"trade_date": trade_date,
|
"trade_date": trade_date,
|
||||||
|
"columns": selected_columns,
|
||||||
"rows": rows,
|
"rows": rows,
|
||||||
|
"summary": summary,
|
||||||
|
"row_limit": self.MAX_TOOL_ROWS,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
delivered.add(key)
|
delivered.add(key)
|
||||||
|
|
||||||
return lines, payload, delivered
|
return lines, payload, delivered
|
||||||
|
|
||||||
|
def _select_columns(
|
||||||
|
self,
|
||||||
|
available: Sequence[str],
|
||||||
|
requested: Optional[Sequence[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
available_list = [str(col) for col in available]
|
||||||
|
selected: List[str] = []
|
||||||
|
if requested:
|
||||||
|
for col in requested:
|
||||||
|
name = str(col)
|
||||||
|
if name in available_list and name not in selected:
|
||||||
|
selected.append(name)
|
||||||
|
if not selected:
|
||||||
|
preferred = {
|
||||||
|
"trade_date",
|
||||||
|
"ts_code",
|
||||||
|
"close",
|
||||||
|
"open",
|
||||||
|
"high",
|
||||||
|
"low",
|
||||||
|
"vol",
|
||||||
|
"volume",
|
||||||
|
"amount",
|
||||||
|
"turnover",
|
||||||
|
"turnover_rate",
|
||||||
|
"turnover_rate_f",
|
||||||
|
"pct_chg",
|
||||||
|
"nav",
|
||||||
|
"cash",
|
||||||
|
"market_value",
|
||||||
|
"net_flow",
|
||||||
|
"exposure",
|
||||||
|
"sentiment",
|
||||||
|
"heat",
|
||||||
|
}
|
||||||
|
selected = [col for col in available_list if col in preferred]
|
||||||
|
if not selected:
|
||||||
|
selected = available_list[: self.MAX_TOOL_COLUMNS]
|
||||||
|
else:
|
||||||
|
selected = selected[: self.MAX_TOOL_COLUMNS]
|
||||||
|
if "trade_date" in available_list and "trade_date" not in selected:
|
||||||
|
selected = ["trade_date"] + [col for col in selected if col != "trade_date"]
|
||||||
|
return selected
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_row(row: Mapping[str, Any], columns: Sequence[str]) -> Dict[str, Any]:
|
||||||
|
formatted: Dict[str, Any] = {}
|
||||||
|
for col in columns:
|
||||||
|
value = row.get(col)
|
||||||
|
if isinstance(value, float):
|
||||||
|
formatted[col] = round(value, 6)
|
||||||
|
elif isinstance(value, (int, str)):
|
||||||
|
formatted[col] = value
|
||||||
|
elif hasattr(value, "isoformat"):
|
||||||
|
try:
|
||||||
|
formatted[col] = value.isoformat() # type: ignore[attr-defined]
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
formatted[col] = str(value)
|
||||||
|
else:
|
||||||
|
formatted[col] = str(value) if value is not None else None
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _summarize_rows(rows: Sequence[Mapping[str, Any]]) -> Dict[str, Any]:
|
||||||
|
summary: Dict[str, Any] = {}
|
||||||
|
if not rows:
|
||||||
|
return summary
|
||||||
|
numeric_columns: Dict[str, List[float]] = {}
|
||||||
|
for row in rows:
|
||||||
|
for key, value in row.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
numeric_columns.setdefault(key, []).append(float(value))
|
||||||
|
for key, values in numeric_columns.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
summary[key] = {
|
||||||
|
"min": round(min(values), 6),
|
||||||
|
"max": round(max(values), 6),
|
||||||
|
"avg": round(sum(values) / len(values), 6),
|
||||||
|
"last": round(values[-1], 6),
|
||||||
|
}
|
||||||
|
summary["row_count"] = len(rows)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
def _handle_tool_call(
|
def _handle_tool_call(
|
||||||
self,
|
self,
|
||||||
@ -513,11 +615,19 @@ class DepartmentAgent:
|
|||||||
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||||
override_date = item.get("trade_date")
|
override_date = item.get("trade_date")
|
||||||
req_date = self._normalize_trade_date(override_date or base_trade_date)
|
req_date = self._normalize_trade_date(override_date or base_trade_date)
|
||||||
|
columns_raw = item.get("columns") or item.get("fields")
|
||||||
|
columns: Optional[List[str]] = None
|
||||||
|
if isinstance(columns_raw, str):
|
||||||
|
columns = [col.strip() for col in columns_raw.split(",") if col and col.strip()]
|
||||||
|
elif isinstance(columns_raw, Sequence):
|
||||||
|
columns = [str(col).strip() for col in columns_raw if str(col).strip()]
|
||||||
|
if columns:
|
||||||
|
columns = columns[: self.MAX_TOOL_COLUMNS]
|
||||||
key = (name, window, req_date)
|
key = (name, window, req_date)
|
||||||
if key in delivered_requests:
|
if key in delivered_requests:
|
||||||
skipped.append(name)
|
skipped.append(name)
|
||||||
continue
|
continue
|
||||||
requests.append(TableRequest(name=name, window=window, trade_date=req_date))
|
requests.append(TableRequest(name=name, window=window, trade_date=req_date, columns=columns))
|
||||||
|
|
||||||
if not requests:
|
if not requests:
|
||||||
return {
|
return {
|
||||||
@ -591,6 +701,11 @@ class DepartmentAgent:
|
|||||||
"pattern": r"^\\d{8}$",
|
"pattern": r"^\\d{8}$",
|
||||||
"description": "覆盖默认交易日(格式 YYYYMMDD)",
|
"description": "覆盖默认交易日(格式 YYYYMMDD)",
|
||||||
},
|
},
|
||||||
|
"columns": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "可选字段列表,未指定时自动选择常用列",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["name"],
|
"required": ["name"],
|
||||||
},
|
},
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user