这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extra/backends/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, f
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
elif uop == Ops.DEFINE_REG: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args)
elif uop == Ops.ASSIGN:
elif uop == Ops.STORE:
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
r[u] = r[vin[0]]
elif uop == Ops.STORE:
Expand Down
4 changes: 2 additions & 2 deletions extra/sched/fuzz_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def fuzz_schedule(outs:List[UOp]):
for lsi in ts:
for out in lsi.outputs:
# freeze assign state before exec
if out.op is Ops.ASSIGN:
if out.op is Ops.STORE:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in lsi.inputs:
Expand All @@ -51,7 +51,7 @@ def fuzz_schedule(outs:List[UOp]):
for out in lsi.outputs:
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
if out.op is Ops.STORE: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in lsi.inputs:
if x not in rawbufs:
# override the assign_target after ASSIGN
Expand Down
14 changes: 7 additions & 7 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def test_tensor_cores_unroll_phi(self):
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.op is Ops.WMMA:
assert u.src[-1].src[0].op != Ops.ASSIGN
assert u.src[-1].src[0].op != Ops.STORE

@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
Expand All @@ -461,7 +461,7 @@ def test_tensor_cores_unroll_casted_phi(self):
for u in k.uops:
if u.op is Ops.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != Ops.ASSIGN
assert u.src[-1].src[0].op != Ops.STORE

@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
Expand All @@ -474,7 +474,7 @@ def test_tensor_cores_unroll_casted_phi_with_children(self):
for u in k.uops:
if u.op is Ops.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != Ops.ASSIGN
assert u.src[-1].src[0].op != Ops.STORE

@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
Expand All @@ -483,10 +483,10 @@ def test_simple_unroll_no_between_phi_dependencies(self):
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE
for u in k.uops:
if u.op is Ops.ASSIGN:
if u.op is Ops.STORE:
assert u.src[1].op in GroupOp.ALU
# children of ASSIGN are placed after ENDRANGE
if any(x.op is Ops.ASSIGN for x in u.src):
# children of STORE are placed after ENDRANGE
if any(x.op is Ops.STORE for x in u.src):
end_range = [i for i, x in enumerate(k.uops) if x.op is Ops.ENDRANGE][0]
assert end_range < k.uops.index(u)

Expand Down Expand Up @@ -610,7 +610,7 @@ def helper(t, max_ops=0):
if if_op:=next((u for u in uops if u.op is Ops.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified"
assert len([u for u in uops if u.op is Ops.STORE]) == 0, "ASSIGN should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"

Expand Down
2 changes: 1 addition & 1 deletion test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,7 +2006,7 @@ def test_no_reshape_reduceop(self):
self.assertEqual(a.uop.shape, (32,))

def swizzle_cnt(u:UOp) -> int:
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.STORE}])

class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
Expand Down
10 changes: 5 additions & 5 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,12 +573,12 @@ def test_scalar_var(self):
t = Tensor(vv).uop
self.assertEqual(t.st, ShapeTracker.from_shape(()))

# ** ASSIGN is ASSIGN(VIEW(BUFFER), new_val)
# ** ASSIGN is STORE(VIEW(BUFFER), new_val)

def test_assign_flat(self):
buffer = Tensor.arange(4).realize()
a = buffer.assign(Tensor.zeros((4,), dtype=dtypes.int))
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat()))
assign_pattern = UPat(Ops.STORE, src=(UPat(Ops.BUFFER), UPat()))
assert assign_pattern.match(a.uop, {})
a.realize()
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
Expand All @@ -592,7 +592,7 @@ def test_assign_permuted(self):
def test_assign_reshaped(self):
buffer = Tensor.ones((4,)).contiguous().realize()
a = buffer.reshape((2, 2)).assign(Tensor.zeros((2, 2)))
assign_pattern = UPat(Ops.ASSIGN, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat()))
assign_pattern = UPat(Ops.STORE, src=(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER))), UPat()))
assert assign_pattern.match(a.uop, {})
a.realize()
self.assertEqual(buffer.tolist(), [0, 0, 0, 0])
Expand All @@ -601,9 +601,9 @@ def test_assign_reshaped(self):
def test_setitem(self):
a = Tensor.ones((4,)).contiguous().realize()
assign = a.shrink(((1, 2),)).assign(Tensor.zeros((1,)))
# the ASSIGN UOp has size=1
# the STORE UOp has size=1
self.assertEqual(assign.uop.size, 1)
# the ASSIGN views the buffer with a shrunk st
# the STORE views the buffer with a shrunk st
self.assertEqual(assign.uop.src[0].st, ShapeTracker.from_shape((4,)).shrink(((1, 2),)))
# the underlying BUFFER has a size=4
self.assertEqual(assign.uop.buf_uop.size, 4)
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_kernelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_two_reduce(self):
a1 = a.sum(axis=1)
a0 = a1.sum(axis=0)
a0.kernelize()
self.assertIs(a1.uop.base.op, Ops.ASSIGN)
self.assertIs(a1.uop.base.op, Ops.STORE)

def test_two_reduce_w_add(self):
a = Tensor.ones(16,16).contiguous()
Expand All @@ -27,7 +27,7 @@ def test_two_reduce_w_add(self):
# NOTE: the +1 is fused with a1, so a1 is not kernelized
self.assertIs(a1.uop.base.op, Ops.REDUCE_AXIS)
# the input to the REDUCE_AXIS is an ASSIGN though
self.assertIs(a1.uop.base.src[0].base.op, Ops.ASSIGN)
self.assertIs(a1.uop.base.src[0].base.op, Ops.STORE)

if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tinygrad/codegen/devectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def no_vectorized_acc(acc:UOp):

devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN), name="alu"), no_vectorized_alu),
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.STORE), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_REG, name="acc"), no_vectorized_acc),
])
Expand All @@ -309,7 +309,7 @@ def no_vectorized_acc(acc:UOp):
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
])

# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.STORE ***

@dataclass
class ReduceContext:
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def do_contract(con:UOp):
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def from_sink(sink:UOp) -> BlockContext:
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
else: ctx.child_ctxs[u] = ()
elif u.op is Ops.ASSIGN:
elif u.op is Ops.STORE:
assert u.src[0].op is Ops.DEFINE_REG
ctx.child_ctxs[u] = tuple([y for y in ctx.last_ctx(u.src[1]) if y not in u.src[0].src[1:]])
return ctx
Expand Down
10 changes: 5 additions & 5 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,28 @@ def unbind_bind(ctx:list[dict[Variable, int]], x:UOp):
# **** schedule linearizer

def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int]]:
# construct the KERNEL children graph based on assigns
# construct the KERNEL children graph based on stores
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.ASSIGN: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
if u.op is not Ops.STORE: continue # anything that's not a STORE doesn't write a kernel, so we can skip
k = u.src[1]
in_degree.setdefault(k, 0)
for s in k.src:
if s.op is Ops.ASSIGN:
if s.op is Ops.STORE:
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN
assert ss.op is Ops.STORE
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
else:
raise RuntimeError(f"input to kernel must be ASSIGN or BUFFER, not {s.op}")
raise RuntimeError(f"input to kernel must be STORE or BUFFER, not {s.op}")

# linearize KERNEL UOps into ScheduleItems in BFS order
queue = deque(k for k,v in in_degree.items() if v == 0)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
# compute the target path (top down)
in_target_path: dict[UOp, bool] = {}
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
# don't flow through DETACH/ASSIGN or anything not in target path
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
# don't flow through DETACH/STORE or anything not in target path
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.STORE} and in_target_path[node]))

def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
Expand Down
10 changes: 5 additions & 5 deletions tinygrad/kernelize/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
from tinygrad.shape.shapetracker import ShapeTracker

ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.STORE, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.GBARRIER}

# **** Grouper decides which of the UOps realize
Expand All @@ -25,8 +25,8 @@ def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# always realize STORE/CONTIGUOUS/GroupOp.Meta
(UPat({Ops.STORE, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize parents of COPY, MSELECT, MSTACK
Expand Down Expand Up @@ -61,7 +61,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
assigns: dict[UOp, None] = {}
for u in (toposort:=sink.toposort()):
if u.op in {Ops.VIEW, Ops.SINK}: continue
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
if u.op is Ops.STORE: assigns[u.buf_uop] = None
for s in u.src: children.setdefault(s.base, {})[u] = None

# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
Expand All @@ -85,7 +85,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.STORE}):
parents = [r, *group]
while parents and not forced_realize:
p = parents.pop().base
Expand Down
Loading
Loading