diff --git a/app/agents/departments.py b/app/agents/departments.py index c866e3d..ec0091c 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -28,6 +28,7 @@ class TableRequest: name: str window: int = 1 trade_date: Optional[str] = None + columns: Optional[Sequence[str]] = None @dataclass @@ -83,6 +84,8 @@ class DepartmentAgent: "news", "index_daily", ] + MAX_TOOL_ROWS: ClassVar[int] = 60 + MAX_TOOL_COLUMNS: ClassVar[int] = 12 def __init__( self, @@ -449,6 +452,18 @@ class DepartmentAgent: window, auto_refresh=False # 避免在回测过程中触发自动补数 ) + selected_columns: List[str] = [] + if rows: + selected_columns = self._select_columns(list(rows[0].keys()), req.columns) + rows = [ + self._format_row(row, selected_columns) + for row in rows + ] + if len(rows) > self.MAX_TOOL_ROWS: + rows = rows[: self.MAX_TOOL_ROWS] + elif req.columns: + selected_columns = list(req.columns)[: self.MAX_TOOL_COLUMNS] + summary = self._summarize_rows(rows) if rows: preview = ", ".join( f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)] @@ -465,13 +480,100 @@ class DepartmentAgent: "table": table, "window": window, "trade_date": trade_date, + "columns": selected_columns, "rows": rows, + "summary": summary, + "row_limit": self.MAX_TOOL_ROWS, } ) delivered.add(key) return lines, payload, delivered + def _select_columns( + self, + available: Sequence[str], + requested: Optional[Sequence[str]] = None, + ) -> List[str]: + available_list = [str(col) for col in available] + selected: List[str] = [] + if requested: + for col in requested: + name = str(col) + if name in available_list and name not in selected: + selected.append(name) + if not selected: + preferred = { + "trade_date", + "ts_code", + "close", + "open", + "high", + "low", + "vol", + "volume", + "amount", + "turnover", + "turnover_rate", + "turnover_rate_f", + "pct_chg", + "nav", + "cash", + "market_value", + "net_flow", + "exposure", + "sentiment", + "heat", + } + selected = [col for col in available_list if col in preferred] + if not selected: + selected = available_list[: self.MAX_TOOL_COLUMNS] + else: + selected = selected[: self.MAX_TOOL_COLUMNS] + if "trade_date" in available_list and "trade_date" not in selected: + selected = ["trade_date"] + [col for col in selected if col != "trade_date"] + return selected + + @staticmethod + def _format_row(row: Mapping[str, Any], columns: Sequence[str]) -> Dict[str, Any]: + formatted: Dict[str, Any] = {} + for col in columns: + value = row.get(col) + if isinstance(value, float): + formatted[col] = round(value, 6) + elif isinstance(value, (int, str)): + formatted[col] = value + elif hasattr(value, "isoformat"): + try: + formatted[col] = value.isoformat() # type: ignore[attr-defined] + except Exception: # noqa: BLE001 + formatted[col] = str(value) + else: + formatted[col] = str(value) if value is not None else None + return formatted + + @staticmethod + def _summarize_rows(rows: Sequence[Mapping[str, Any]]) -> Dict[str, Any]: + summary: Dict[str, Any] = {} + if not rows: + return summary + numeric_columns: Dict[str, List[float]] = {} + for row in rows: + for key, value in row.items(): + if isinstance(value, (int, float)): + numeric_columns.setdefault(key, []).append(float(value)) + for key, values in numeric_columns.items(): + if not values: + continue + summary[key] = { + "min": round(min(values), 6), + "max": round(max(values), 6), + "avg": round(sum(values) / len(values), 6), + "last": round(values[-1], 6), + } + summary["row_count"] = len(rows) + return summary + def _handle_tool_call( self, @@ -513,11 +615,19 @@ class DepartmentAgent: window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120))) override_date = item.get("trade_date") req_date = self._normalize_trade_date(override_date or base_trade_date) + columns_raw = item.get("columns") or item.get("fields") + columns: Optional[List[str]] = None + if isinstance(columns_raw, str): + columns = [col.strip() for col in columns_raw.split(",") if col and col.strip()] + elif isinstance(columns_raw, Sequence): + columns = [str(col).strip() for col in columns_raw if str(col).strip()] + if columns: + columns = columns[: self.MAX_TOOL_COLUMNS] key = (name, window, req_date) if key in delivered_requests: skipped.append(name) continue - requests.append(TableRequest(name=name, window=window, trade_date=req_date)) + requests.append(TableRequest(name=name, window=window, trade_date=req_date, columns=columns)) if not requests: return { @@ -591,6 +701,11 @@ class DepartmentAgent: "pattern": r"^\\d{8}$", "description": "覆盖默认交易日(格式 YYYYMMDD)", }, + "columns": { + "type": "array", + "items": {"type": "string"}, + "description": "可选字段列表,未指定时自动选择常用列", + }, }, "required": ["name"], },