From ccbd2c7eb1b121733ee04c80ce17672ea1b080cb Mon Sep 17 00:00:00 2001 From: 0xSG <100090997+sxsmg@users.noreply.github.com> Date: Fri, 4 Jul 2025 14:19:50 +0530 Subject: [PATCH] implement loop splitting in cat --- tinygrad/tensor.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5d37ec4ea4b98..afb223cef0e94 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1309,11 +1309,30 @@ def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: ``` """ dim = self._resolve_dim(dim) - for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) + for arg in args: assert arg.ndim == self.ndim and all(ti == ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i != dim) tensors = [self, *args] - dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) - for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) - return functools.reduce(Tensor.add, tensors) + + def _cat_impl(ts: list[Tensor]) -> Tensor: + dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in ts], initial=0)) + for i,t in enumerate(ts): + ts[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j == dim else None for j in range(t.ndim)]) + return functools.reduce(Tensor.add, ts) + + if all_int([t.shape[dim] for t in tensors]): + SPLIT = 256 + if sum(t.shape[dim] for t in tensors) > SPLIT: + out, cur, sz = [], [], 0 + for t in tensors: + if sz and sz + t.shape[dim] > SPLIT: + out.append(_cat_impl(cur)) + cur, sz = [], 0 + cur.append(t) + sz += t.shape[dim] + if cur: + out.append(_cat_impl(cur)) + return Tensor.cat(*out, dim=dim) if len(out) > 1 else out[0] + + return _cat_impl(tensors) def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """