-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Open
Labels
Description
CUTLASS 3 and CUTLASS 4 have different layouts for the same type of Tensor, as shown in the following code and output. CUTLASS 3's layout is simpler; why was it necessary to change to CUTLASS 4's layout, which has more complex sub-dimensions?
constexpr static int BBits =3;
constexpr static int MBase =2;
constexpr static int SShift=5;
constexpr static auto swizzle = Swizzle<BBits, MBase, SShift>{};
using ShmLayoutAtomA = decltype(
composition(swizzle,
make_layout(make_shape(Int<kTileM>{}, Int<kTileK>{}),
make_stride(Int<kTileK>{}, Int<1>{}) )));
using ShmLayoutA = decltype( //(64, 32, 2)
tile_to_shape(ShmLayoutAtomA{},
make_shape(Int<kTileM>{}, Int<kTileK>{}, Int<kStage>{})
));
using ShmLayoutAtomB = decltype(
composition(swizzle,
make_layout(make_shape(Int<16>{}, Int<kTileN>{}),
make_stride(Int<kTileN>{}, Int<1>{})))
);
using ShmLayoutB = decltype( //(32, 128, 2)
tile_to_shape(ShmLayoutAtomB{},
make_shape(Int<kTileK>{}, Int<kTileN>{}, Int<kStage>{}))
);
// shared memory
auto sA = make_tensor(make_smem_ptr(Ashm),
typename Config::ShmLayoutA{}); // (kTileM, kTileK, kStage)
auto sB = make_tensor(make_smem_ptr(Bshm),
typename Config::ShmLayoutB{}); // (kTileN, kTileK, kStage)
PRINT(sA);
PRINT(sB);
cutlass 3:
sA smem_ptr[16b](0x7f23dd000000) o Sw<3,2,5> o _0 o (_64,_32,_2):(_32,_1,_2048)
sB smem_ptr[16b](0x7f23dd006000) o Sw<3,2,5> o _0 o (_32,_128,_2):(_128,_1,_4096)
cutlass 4:
sA smem_ptr[32b](0x7f18a5000000) o Sw<3,2,5> o _0 o ((_64,_1),(_32,_1),(_1,_2)):((_32,_0),(_1,_0),(_0,_2048))
sB smem_ptr[32b](0x7f18a5004000) o Sw<3,2,5> o _0 o ((_16,_2),(_128,_1),(_1,_2)):((_128,_2048),(_1,_0),(_0,_4096))