-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Which component requires the feature?
CuTe DSL
Feature Request
Hi! Recently I have been using CuTe DSL to implement a fused kernel, which involves NVLink-Domain communication by accessing peer memory mapped by torch symmetric memory. I need to pass a list of cute.Tensors into the kernel and use thread_idx to index into the list. However, it seems that cute.arch.thread_idx() cannot be used as an integer to index a list, so the following error popped out:
cutlass.base_dsl.common.DSLRuntimeError: DSLRuntimeError: '<class 'cutlass.base_dsl._mlir_helpers.arith.ArithValue'>' object cannot be interpreted as an integer
Now I implement my kernel in the following way:
@cute.kernel
def a2a(
mS: cute.Tensor,
peer_bufs: List[cute.Tensor], # List of peer mem tensors
rank: cutlass.Constexpr[int],
):
tidx, _, _ = cute.arch.thread_idx()
# Cannot use tidx to index List[cute.Tensor]
# peer_bufs[tidx][rank] = mS[tidx]
if tidx == 0:
peer_bufs[0][rank] = mS[tidx]
if tidx == 1:
peer_bufs[1][rank] = mS[tidx]
if tidx == 2:
peer_bufs[2][rank] = mS[tidx]
if tidx == 3:
peer_bufs[3][rank] = mS[tidx]
if tidx == 4:
peer_bufs[4][rank] = mS[tidx]
if tidx == 5:
peer_bufs[5][rank] = mS[tidx]
if tidx == 6:
peer_bufs[6][rank] = mS[tidx]
if tidx == 7:
peer_bufs[7][rank] = mS[tidx]
It seems really inelegant and verbose. Most importantly, list with different lengths is not supported since I cannot dynamically adjust my kernel at runtime, but theoretically as long as the length of list passed into kernel is the same as the one passed into cute.compile, it should be feasible. However, for example, with the above kernel, when I passed List[cute.Tensor] with only 4 cute.Tensors in it, the following error popped out:
[rank3]: peer_bufs[4][rank] = mS[tidx]
[rank3]: IndexError: list index out of range
I hope to know whether this usage is not feasible at all, or it is an incoming feature. Or there might be another way in CuTe DSL to maneuver this situation. Thanks!!!!