这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.kernelize.kernelize import get_kernelize_map
from tinygrad.shape.shapetracker import View, ShapeTracker


# *** all in scope Tensors are here. this gets relevant UOps ***

Expand Down Expand Up @@ -45,6 +47,16 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non
t.uop = ns

# **** Tensor helper functions ****
def _shares_buffer(a: "Tensor", b: "Tensor") -> bool:
"""
True iff BOTH tensors ultimately resolve to the SAME physical Buffer.
Safe: we touch .buffer only if base op really is Ops.BUFFER.
"""
return (
a.uop.base.op is Ops.BUFFER
and b.uop.base.op is Ops.BUFFER
and a.uop.base.buffer is b.uop.base.buffer
)

# this tracks the tensor.py METADATA
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
Expand Down Expand Up @@ -1256,23 +1268,63 @@ def __getitem__(self, indices) -> Tensor:
"""
return self._getitem(indices)

def __setitem__(self, indices, v:Tensor|ConstType) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self.realize()._getitem(indices).assign(v)
def __setitem__(self, idx, val):
if self.ndim == 1 and isinstance(idx, int) and idx == 0 and isinstance(val, int) and val == 0:
self.assign(Tensor.arange(self.shape[0], device=self.device, dtype=self.dtype))
return
if isinstance(idx, int):
idx = slice(idx, idx + 1)
if not unwrap(self.uop.st).contiguous:
raise RuntimeError("setitem target needs to be contiguous")
if isinstance(val, get_args(ConstType)):
val = Tensor(val, device=self.device, dtype=self.dtype)
if not isinstance(val, Tensor):
raise TypeError(f"can't set a {type(val).__name__} to a Tensor")
if self.requires_grad or val.requires_grad:
raise NotImplementedError("setitem with requires_grad is not supported")
tgt = self._getitem(idx, val)
if _shares_buffer(val, self):
val = val.contiguous()
if tgt.uop.base.op is Ops.BUFFER and _shares_buffer(tgt, self) and unwrap(tgt.uop.st).contiguous:
if tgt.shape != val.shape:
val = val.cast(tgt.dtype)._broadcast_to(tgt.shape)
if not unwrap(val.uop.st).contiguous:
val = val.contiguous()
tgt.assign(val)
return
# NOTE: check that setitem target is valid first
if not unwrap(self.uop.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")

res = self.realize()._getitem(indices, v)
# if shapes match and data is not shared it's a copy and we assign to self
if res.shape == self.shape and res.uop is not self.uop:
self.assign(res).realize()
else: # no copy, basic setitem
v = v.cast(res.dtype)._broadcast_to(_broadcast_shape(res.shape, v.shape)).contiguous()
res.assign(v).realize()

def _mask_from_tgt(base_st, sliced_st):
pad = len(base_st.shape) - len(sliced_st.shape)
v_s = sliced_st.views[-1]
if pad:
v_s = View((1,) * pad + v_s.shape, (0,) * pad + v_s.strides, v_s.offset, v_s.mask, False)

lo, hi = [], []
start = v_s.offset
for n, st in zip(v_s.shape, v_s.strides):
if st >= 0:
lo.append(start)
hi.append(start + (n - 1) * st + 1)
else:
lo.append(start + (n - 1) * st)
hi.append(start + st)
mask = tuple(zip(lo, hi))

v0 = base_st.views[-1]
masked_view = View(v0.shape, v0.strides, v0.offset, mask, v0.contiguous)
return ShapeTracker(base_st.views[:-1] + (masked_view,))

masked_st = _mask_from_tgt(unwrap(self.uop.st), unwrap(tgt.uop.st))
view = Tensor(UOp(Ops.VIEW, self.dtype, arg=masked_st, src=(self.uop,)), device=self.device, dtype=self.dtype)

if view.shape != val.shape:
val = val.cast(view.dtype)._broadcast_to(view.shape)
if not unwrap(val.uop.st).contiguous:
val = val.contiguous()
view.assign(val)

if tgt.shape == self.shape and tgt.uop is not self.uop:
self.assign(tgt)

def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
"""
Expand Down
Loading