这是indexloc提供的服务,不要输入任何密码
Skip to content

[BUG] The result of tiled_divide on a Tensor causes layout issues when used for Tensor slicing. #2499

@happyflathead

Description

@happyflathead

Which component has the problem?

CUTLASS C++

Bug Report

Describe the bug
The result of tiled_divide on a Tensor causes layout issues when used for Tensor slicing.

Steps/Code to reproduce bug

struct Config{
    constexpr static int TileM=64;
    constexpr static int TileN=128;
    constexpr static int TileK=32;
    constexpr static int Stage=2;

    using g2s_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
    using g2s_traits = Copy_Traits<g2s_op>;
    using g2s_atom = Copy_Atom<g2s_traits, float>;
    using G2SCopyA = decltype( // A:(64, 32)
        make_tiled_copy(g2s_atom{}, 
            make_layout(
                make_shape(Int<64>{}, Int<8>{}),
                make_stride(Int<8>{}, Int<1>{})),
            make_layout(make_shape(Int<1>{}, Int<4>{})))
        );

    using G2SCopyB = decltype( // B:(32, 128)
        make_tiled_copy(g2s_atom{}, 
            make_layout(
                make_shape(Int<16>{},   Int<32>{}),
                make_stride(Int<32>{}, Int<1>{})),
            make_layout(make_shape(Int<1>{}, Int<4>{})))
        );
    constexpr static int ThreadNum = 512;
    
};

template<typename Config, typename TensorA, typename TensorB, typename TensorC>
__global__ void gemm_float(TensorA A, TensorB B, TensorC C){
    using G2SCopyA = typename Config::G2SCopyA;
    using G2SCopyB = typename Config::G2SCopyB;

    
    extern __shared__ float smem_ptr[];
    auto [bx, by, bz] = blockIdx; (void)bz;
    // auto [bx, by, _] = blockIdx;
    auto gA = A(_, by, _);
    auto gB = B(_, _, bx);
    auto gC = C(_, by, bx);  
    
    int idx = threadIdx.x;
    G2SCopyA g2s_a;
    auto thr_g2s_a = g2s_a.get_slice(idx);
    auto gA_cp_s = thr_g2s_a.partition_S(gA);

    G2SCopyB g2s_b;
    auto thr_g2s_b = g2s_b.get_slice(idx);
    auto gB_cp_s = thr_g2s_b.partition_S(gB);


    PRINT(gA_cp_s);
    PRINT(gB_cp_s);
}

template<typename Config, int M, int N, int K>
void launch_gemm(const float* Aptr, 
                 const float* Bptr, 
                 float* Cptr){    
    constexpr int TileM = Config::TileM;
    constexpr int TileN = Config::TileN;
    constexpr int TileK = Config::TileK;
    
    Tensor A_total = make_tensor(make_gmem_ptr(Aptr), 
               make_layout(make_shape(Int<M>{}, Int<K>{}), 
                           LayoutRight{}));
    Tensor A = tiled_divide(A_total, make_shape(Int<TileM>{}, Int<TileK>{}));
    
    Tensor B_total = make_tensor(make_gmem_ptr(Bptr),
                           make_layout(make_shape(Int<K>{}, Int<N>{}),
                           LayoutRight{}));
    Tensor B = tiled_divide(B_total, make_shape(Int<TileK>{}, Int<TileN>{}));
    

    Tensor C_total = make_tensor(make_gmem_ptr(Cptr),
                           make_layout(make_shape(Int<M>{}, Int<N>{}),
                           LayoutRight{}));
    Tensor C = tiled_divide(C_total, make_shape(Int<TileM>{}, Int<TileN>{}));

    dim3 block{Config::ThreadNum, 1, 1};
    dim3 grid{CEILDIV(N, TileN), CEILDIV(M, TileM)};


    gemm_float<Config, decltype(A), decltype(B), decltype(C)>
        <<<grid, block, Config::ShmSize>>>(A, B, C);
    cudaCheckError(cudaGetLastError());
    cudaCheckError(cudaStreamSynchronize(0));
 
}
  1. the 1st dimension layout of gA_cp_s: (_4,_1): (_32,_0) and gB_cp_s:(_4,_1):(_131072,_0).
  2. In G2SCopyA and G2SCopyB, 1st dimension layout of parameters should be 4-element contiguous
    output:
gA_cp_s gmem_ptr[32b](0x73811ce00000) o ((_4,_1),_32,_1):((_32,_0),_1,_0)
gB_cp_s gmem_ptr[32b](0x738110000000) o ((_4,_1),(_2,_128),_1):((_131072,_0),(_65536,_1),_0)

Expected behavior

gA_cp_s gmem_ptr[32b](0x7a0760e00000) o ((_4,_1),_1,_1):((_1,_0),_0,_0)
gB_cp_s gmem_ptr[32b](0x7a0754000000) o ((_4,_1),_8,_1):((_1,_0),_65536,_0)

P.S.
I think the result of auto sub_tensor = tiled_divide(tensor, shape); auto block_tensor = sub_tensor(_,by,_) should be similar to local_tile; however, it's not.

template<typename Config, typename TensorA, typename TensorB, typename TensorC>
__global__ void gemm_float(TensorA A, TensorB B, TensorC C){
    constexpr int TileM = Config::TileM;
    constexpr int TileN = Config::TileN;
    constexpr int TileK = Config::TileK;
    constexpr int Stage = Config::Stage;
    using G2SCopyA = typename Config::G2SCopyA;
    using G2SCopyB = typename Config::G2SCopyB;
    using ShmLayoutA = typename Config::ShmLayoutA;
    using ShmLayoutB = typename Config::ShmLayoutB;
    
    extern __shared__ float smem_ptr[];
    auto [bx, by, _] = blockIdx;
    Tensor gA = local_tile(A, make_tile(Int<TileM>{}, Int<TileK>{}),
                            make_coord(by, _));  // (TileM, TileK, k)
    Tensor gB = local_tile(B, make_tile(Int<TileN>{}, Int<TileK>{}),
                            make_coord(bx, _));  // (TileN, TileK, k)
    Tensor gC = local_tile(C, make_tile(Int<TileM>{}, Int<TileN>{}),
                            make_coord(by, bx));  // (TileM, TileN)
    

    auto sA = make_tensor(make_smem_ptr(smem_ptr), ShmLayoutA{});
    auto sB = make_tensor(make_smem_ptr(smem_ptr+size(ShmLayoutA{})), 
                                        ShmLayoutB{});
    
    int idx = threadIdx.x;
    G2SCopyA g2s_a;
    auto thr_g2s_a = g2s_a.get_slice(idx);
    auto gA_cp_s = thr_g2s_a.partition_S(gA);
    auto gA_cp_d = thr_g2s_a.partition_D(sA);
    G2SCopyB g2s_b;
    auto thr_g2s_b = g2s_b.get_slice(idx);
    auto gB_cp_s = thr_g2s_b.partition_S(gB);
    auto gB_cp_d = thr_g2s_b.partition_D(sB);
    PRINT(gA_cp_s);
    PRINT(gB_cp_s);
}

output:

gA_cp_s gmem_ptr[32b](0x7a0760e00000) o ((_4,_1),_1,_1):((_1,_0),_0,_0)
gB_cp_s gmem_ptr[32b](0x7a0754000000) o ((_4,_1),_8,_1):((_1,_0),_65536,_0)

Environment details (please complete the following information):

  • cutlass 4
  • cuda 12.9

Additional context
replacing tiled_divide with zipped_divide has the similar problems

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions