这是indexloc提供的服务,不要输入任何密码
Skip to content

[FEA] Support indexing List[cute.Tensor] with thread_idx inside kernel #2773

@Gin-Sin

Description

@Gin-Sin

Which component requires the feature?

CuTe DSL

Feature Request

Hi! Recently I have been using CuTe DSL to implement a fused kernel, which involves NVLink-Domain communication by accessing peer memory mapped by torch symmetric memory. I need to pass a list of cute.Tensors into the kernel and use thread_idx to index into the list. However, it seems that cute.arch.thread_idx() cannot be used as an integer to index a list, so the following error popped out:

cutlass.base_dsl.common.DSLRuntimeError: DSLRuntimeError: '<class 'cutlass.base_dsl._mlir_helpers.arith.ArithValue'>' object cannot be interpreted as an integer

Now I implement my kernel in the following way:

@cute.kernel
def a2a(
    mS: cute.Tensor,
    peer_bufs: List[cute.Tensor],  # List of peer mem tensors
    rank: cutlass.Constexpr[int],
):
    tidx, _, _ = cute.arch.thread_idx()

    # Cannot use tidx to index List[cute.Tensor]
    # peer_bufs[tidx][rank] = mS[tidx]

    if tidx == 0:
        peer_bufs[0][rank] = mS[tidx]
    if tidx == 1:
        peer_bufs[1][rank] = mS[tidx]
    if tidx == 2:
        peer_bufs[2][rank] = mS[tidx]
    if tidx == 3:
        peer_bufs[3][rank] = mS[tidx]
    if tidx == 4:
        peer_bufs[4][rank] = mS[tidx]
    if tidx == 5:
        peer_bufs[5][rank] = mS[tidx]
    if tidx == 6:
        peer_bufs[6][rank] = mS[tidx]
    if tidx == 7:
        peer_bufs[7][rank] = mS[tidx]

It seems really inelegant and verbose. Most importantly, list with different lengths is not supported since I cannot dynamically adjust my kernel at runtime, but theoretically as long as the length of list passed into kernel is the same as the one passed into cute.compile, it should be feasible. However, for example, with the above kernel, when I passed List[cute.Tensor] with only 4 cute.Tensors in it, the following error popped out:

[rank3]:     peer_bufs[4][rank] = mS[tidx]
[rank3]: IndexError: list index out of range

I hope to know whether this usage is not feasible at all, or it is an incoming feature. Or there might be another way in CuTe DSL to maneuver this situation. Thanks!!!!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions