From 1fe1df0583ace85231480f30e469add41c6cc2fc Mon Sep 17 00:00:00 2001 From: John Platts Date: Thu, 16 Oct 2025 14:04:55 -0500 Subject: [PATCH 1/3] Updated SumOfMulQuadAccumulate to use [s]vusdot[q]_s32 on NEON_BF16/SVE2 plus I8MM fixes --- hwy/ops/arm_neon-inl.h | 22 +++++++++------------- hwy/ops/arm_sve-inl.h | 8 +++++--- hwy/ops/set_macros-inl.h | 4 +++- hwy/targets.cc | 3 ++- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/hwy/ops/arm_neon-inl.h b/hwy/ops/arm_neon-inl.h index e8fd98e00f..44fc21eff6 100644 --- a/hwy/ops/arm_neon-inl.h +++ b/hwy/ops/arm_neon-inl.h @@ -7662,22 +7662,18 @@ HWY_API VFromD SumOfMulQuadAccumulate( #define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE #endif -template +template HWY_API VFromD SumOfMulQuadAccumulate( - DI32 di32, VFromD> a_u, + DI32 /*di32*/, VFromD> a_u, VFromD> b_i, VFromD sum) { - // TODO: use vusdot[q]_s32 on NEON targets that require support for NEON I8MM - - const RebindToUnsigned du32; - const Repartition du8; - - const auto b_u = BitCast(du8, b_i); - const auto result_sum0 = - SumOfMulQuadAccumulate(du32, a_u, b_u, BitCast(du32, sum)); - const auto result_sum1 = ShiftLeft<8>( - SumOfMulQuadAccumulate(du32, a_u, ShiftRight<7>(b_u), Zero(du32))); + return VFromD(vusdot_s32(sum.raw, a_u.raw, b_i.raw)); +} - return BitCast(di32, Sub(result_sum0, result_sum1)); +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD(vusdotq_s32(sum.raw, a_u.raw, b_i.raw)); } #endif // HWY_TARGET == HWY_NEON_BF16 diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 8acc66d416..982735b22a 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -6545,9 +6545,10 @@ HWY_API VFromD SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a, template HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, svint8_t b_i, svint32_t sum) { - // TODO: use svusdot_u32 on SVE targets that require support for both SVE2 - // and SVE I8MM. - +#if HWY_SVE_HAVE_2 + (void)di32; + return svusdot_s32(sum, a_u, b_i); +#else const RebindToUnsigned du32; const Repartition du8; @@ -6557,6 +6558,7 @@ HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u))); return BitCast(di32, Sub(result_sum0, result_sum1)); +#endif } #ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE diff --git a/hwy/ops/set_macros-inl.h b/hwy/ops/set_macros-inl.h index 9871ff1efa..1dfb197ab9 100644 --- a/hwy/ops/set_macros-inl.h +++ b/hwy/ops/set_macros-inl.h @@ -617,6 +617,8 @@ #define HWY_HAVE_SCALABLE 1 #endif +#define HWY_TARGET_STR_I8MM "+i8mm" + // Can use pragmas instead of -march compiler flag #if HWY_HAVE_RUNTIME_DISPATCH #if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 @@ -628,7 +630,7 @@ #define HWY_TARGET_STR "+sve2,+sve" HWY_TARGET_STR_I8MM #endif #else // not SVE2 target -#define HWY_TARGET_STR "+sve" HWY_TARGET_STR_I8MM +#define HWY_TARGET_STR "+sve" #endif #else // !HWY_HAVE_RUNTIME_DISPATCH // HWY_TARGET_STR remains undefined diff --git a/hwy/targets.cc b/hwy/targets.cc index 70c3dda43e..7b2290d4a5 100644 --- a/hwy/targets.cc +++ b/hwy/targets.cc @@ -494,7 +494,8 @@ static int64_t DetectTargets() { if ((HasCpuFeature("hw.optional.AdvSIMD_HPFPCvt") || HasCpuFeature("hw.optional.arm.AdvSIMD_HPFPCvt")) && HasCpuFeature("hw.optional.arm.FEAT_DotProd") && - HasCpuFeature("hw.optional.arm.FEAT_BF16")) { + HasCpuFeature("hw.optional.arm.FEAT_BF16") && + HasCpuFeature("hw.optional.arm.FEAT_I8MM")) { bits |= HWY_NEON_BF16; } } From 732a3e227e74fc5526f1485beff2749434f1d82f Mon Sep 17 00:00:00 2001 From: John Platts Date: Thu, 16 Oct 2025 20:51:22 -0500 Subject: [PATCH 2/3] Added AVX10_2-specific implementations of I8xI8 and U8xU8 SumOfMulQuadAccumulate --- hwy/ops/x86_128-inl.h | 23 +++++++++++++++++++++-- hwy/ops/x86_256-inl.h | 19 ++++++++++++++++++- hwy/ops/x86_512-inl.h | 17 +++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/hwy/ops/x86_128-inl.h b/hwy/ops/x86_128-inl.h index 3a2820102e..e496e3e9e2 100644 --- a/hwy/ops/x86_128-inl.h +++ b/hwy/ops/x86_128-inl.h @@ -10025,12 +10025,21 @@ HWY_API VFromD SumOfMulQuadAccumulate( #else #define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE #endif + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} +#else // !HWY_X86_HAVE_AVX10_2_OPS template HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, VFromD> a, VFromD> b, VFromD sum) { - // TODO(janwas): AVX-VNNI-INT8 has dpbssd. const Repartition du8; const auto a_u = BitCast(du8, a); @@ -10039,17 +10048,26 @@ HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32))); return result_sum_0 - result_sum_1; } +#endif // HWY_X86_HAVE_AVX10_2_OPS #ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #else #define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE #endif + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#else // !HWY_X86_HAVE_AVX10_2_OPS template HWY_API VFromD SumOfMulQuadAccumulate( DU32 du32, VFromD> a, VFromD> b, VFromD sum) { - // TODO(janwas): AVX-VNNI-INT8 has dpbuud. const Repartition du8; const RebindToSigned di8; const RebindToSigned di32; @@ -10062,6 +10080,7 @@ HWY_API VFromD SumOfMulQuadAccumulate( return BitCast(du32, result_sum_0 - result_sum_1); } +#endif // HWY_X86_HAVE_AVX10_2_OPS #endif // HWY_TARGET <= HWY_AVX3_DL diff --git a/hwy/ops/x86_256-inl.h b/hwy/ops/x86_256-inl.h index 4f92c7e79f..9bdbb43e4e 100644 --- a/hwy/ops/x86_256-inl.h +++ b/hwy/ops/x86_256-inl.h @@ -6424,7 +6424,24 @@ HWY_API VFromD SumOfMulQuadAccumulate( return VFromD{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; } -#endif +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm256_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm256_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + +#endif // HWY_TARGET <= HWY_AVX3_DL // ================================================== CONVERT diff --git a/hwy/ops/x86_512-inl.h b/hwy/ops/x86_512-inl.h index fc1f61106b..04acdaf3dd 100644 --- a/hwy/ops/x86_512-inl.h +++ b/hwy/ops/x86_512-inl.h @@ -7636,6 +7636,23 @@ HWY_API VFromD SumOfMulQuadAccumulate( return VFromD{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; } +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm512_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm512_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + #endif // ------------------------------ Reductions From 54045f12891375a7632a3de1cca5be3ccdc61ae7 Mon Sep 17 00:00:00 2001 From: John Platts Date: Thu, 16 Oct 2025 21:55:01 -0500 Subject: [PATCH 3/3] Updated AVX10.2 implementation of SumsOfAdjQuadAbsDiff --- hwy/ops/x86_256-inl.h | 4 ++-- hwy/ops/x86_512-inl.h | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/hwy/ops/x86_256-inl.h b/hwy/ops/x86_256-inl.h index 4f92c7e79f..60ee7e8d85 100644 --- a/hwy/ops/x86_256-inl.h +++ b/hwy/ops/x86_256-inl.h @@ -2113,8 +2113,8 @@ HWY_INLINE Vec256 SumsOf4(hwy::UnsignedTag /*type_tag*/, // ------------------------------ SumsOfAdjQuadAbsDiff template -static Vec256 SumsOfAdjQuadAbsDiff(Vec256 a, - Vec256 b) { +HWY_API Vec256 SumsOfAdjQuadAbsDiff(Vec256 a, + Vec256 b) { static_assert(0 <= kAOffset && kAOffset <= 1, "kAOffset must be between 0 and 1"); static_assert(0 <= kBOffset && kBOffset <= 3, diff --git a/hwy/ops/x86_512-inl.h b/hwy/ops/x86_512-inl.h index fc1f61106b..d541922222 100644 --- a/hwy/ops/x86_512-inl.h +++ b/hwy/ops/x86_512-inl.h @@ -6744,22 +6744,30 @@ HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { // SumsOfAdjShufQuadAbsDiff) template -static Vec512 SumsOfAdjQuadAbsDiff(Vec512 a, - Vec512 b) { +HWY_API Vec512 SumsOfAdjQuadAbsDiff(Vec512 a, + Vec512 b) { static_assert(0 <= kAOffset && kAOffset <= 1, "kAOffset must be between 0 and 1"); static_assert(0 <= kBOffset && kBOffset <= 3, "kBOffset must be between 0 and 3"); +#if HWY_X86_HAVE_AVX10_2_OPS + // AVX10.2 now has the _mm512_mpsadbw_epu8 intrinsic available + return Vec512{_mm512_mpsadbw_epu8( + a.raw, b.raw, + (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; +#else const DFromV d; const RepartitionToWideX2 du32; - // While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the - // SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on - // AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast. + // The _mm512_mpsadbw_epu8 intrinsic is not available prior to AVX10.2. + // The SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on + // pre-AVX10.2 targets that support AVX3 using SumsOfShuffledQuadAbsDiff and + // U32 Broadcast. return SumsOfShuffledQuadAbsDiff( a, BitCast(d, Broadcast(BitCast(du32, b)))); +#endif } #if !HWY_IS_MSAN