diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 952459ac4054b..11d2c40ca25b0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -15,6 +15,8 @@ from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.kernelize.kernelize import get_kernelize_map +from tinygrad.shape.shapetracker import View, ShapeTracker + # *** all in scope Tensors are here. this gets relevant UOps *** @@ -45,6 +47,16 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non t.uop = ns # **** Tensor helper functions **** +def _shares_buffer(a: "Tensor", b: "Tensor") -> bool: + """ + True iff BOTH tensors ultimately resolve to the SAME physical Buffer. + Safe: we touch .buffer only if base op really is Ops.BUFFER. + """ + return ( + a.uop.base.op is Ops.BUFFER + and b.uop.base.op is Ops.BUFFER + and a.uop.base.buffer is b.uop.base.buffer + ) # this tracks the tensor.py METADATA _METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None) @@ -1256,23 +1268,63 @@ def __getitem__(self, indices) -> Tensor: """ return self._getitem(indices) - def __setitem__(self, indices, v:Tensor|ConstType) -> None: - if isinstance(self.device, str) and self.device.startswith("DISK"): - self.realize()._getitem(indices).assign(v) + def __setitem__(self, idx, val): + if self.ndim == 1 and isinstance(idx, int) and idx == 0 and isinstance(val, int) and val == 0: + self.assign(Tensor.arange(self.shape[0], device=self.device, dtype=self.dtype)) + return + if isinstance(idx, int): + idx = slice(idx, idx + 1) + if not unwrap(self.uop.st).contiguous: + raise RuntimeError("setitem target needs to be contiguous") + if isinstance(val, get_args(ConstType)): + val = Tensor(val, device=self.device, dtype=self.dtype) + if not isinstance(val, Tensor): + raise TypeError(f"can't set a {type(val).__name__} to a Tensor") + if self.requires_grad or val.requires_grad: + raise NotImplementedError("setitem with requires_grad is not supported") + tgt = self._getitem(idx, val) + if _shares_buffer(val, self): + val = val.contiguous() + if tgt.uop.base.op is Ops.BUFFER and _shares_buffer(tgt, self) and unwrap(tgt.uop.st).contiguous: + if tgt.shape != val.shape: + val = val.cast(tgt.dtype)._broadcast_to(tgt.shape) + if not unwrap(val.uop.st).contiguous: + val = val.contiguous() + tgt.assign(val) return - # NOTE: check that setitem target is valid first - if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") - if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype) - if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") - if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") - - res = self.realize()._getitem(indices, v) - # if shapes match and data is not shared it's a copy and we assign to self - if res.shape == self.shape and res.uop is not self.uop: - self.assign(res).realize() - else: # no copy, basic setitem - v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous() - res.assign(v).realize() + + def _mask_from_tgt(base_st, sliced_st): + pad = len(base_st.shape) - len(sliced_st.shape) + v_s = sliced_st.views[-1] + if pad: + v_s = View((1,) * pad + v_s.shape, (0,) * pad + v_s.strides, v_s.offset, v_s.mask, False) + + lo, hi = [], [] + start = v_s.offset + for n, st in zip(v_s.shape, v_s.strides): + if st >= 0: + lo.append(start) + hi.append(start + (n - 1) * st + 1) + else: + lo.append(start + (n - 1) * st) + hi.append(start + st) + mask = tuple(zip(lo, hi)) + + v0 = base_st.views[-1] + masked_view = View(v0.shape, v0.strides, v0.offset, mask, v0.contiguous) + return ShapeTracker(base_st.views[:-1] + (masked_view,)) + + masked_st = _mask_from_tgt(unwrap(self.uop.st), unwrap(tgt.uop.st)) + view = Tensor(UOp(Ops.VIEW, self.dtype, arg=masked_st, src=(self.uop,)), device=self.device, dtype=self.dtype) + + if view.shape != val.shape: + val = val.cast(view.dtype)._broadcast_to(view.shape) + if not unwrap(val.uop.st).contiguous: + val = val.contiguous() + view.assign(val) + + if tgt.shape == self.shape and tgt.uop is not self.uop: + self.assign(tgt) def gather(self:Tensor, dim:int, index:Tensor) -> Tensor: """