这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions hwy/ops/arm_neon-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7662,22 +7662,18 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
#endif

template <class DI32, HWY_IF_I32_D(DI32)>
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 8)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
DI32 di32, VFromD<Repartition<uint8_t, DI32>> a_u,
DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u,
VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) {
// TODO: use vusdot[q]_s32 on NEON targets that require support for NEON I8MM

const RebindToUnsigned<decltype(di32)> du32;
const Repartition<uint8_t, decltype(di32)> du8;

const auto b_u = BitCast(du8, b_i);
const auto result_sum0 =
SumOfMulQuadAccumulate(du32, a_u, b_u, BitCast(du32, sum));
const auto result_sum1 = ShiftLeft<8>(
SumOfMulQuadAccumulate(du32, a_u, ShiftRight<7>(b_u), Zero(du32)));
return VFromD<DI32>(vusdot_s32(sum.raw, a_u.raw, b_i.raw));
}

return BitCast(di32, Sub(result_sum0, result_sum1));
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 16)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u,
VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) {
return VFromD<DI32>(vusdotq_s32(sum.raw, a_u.raw, b_i.raw));
}

#endif // HWY_TARGET == HWY_NEON_BF16
Expand Down
8 changes: 5 additions & 3 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6555,9 +6555,10 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a,
template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u,
svint8_t b_i, svint32_t sum) {
// TODO: use svusdot_u32 on SVE targets that require support for both SVE2
// and SVE I8MM.

#if HWY_SVE_HAVE_2
(void)di32;
return svusdot_s32(sum, a_u, b_i);
#else
const RebindToUnsigned<decltype(di32)> du32;
const Repartition<uint8_t, decltype(di32)> du8;

Expand All @@ -6567,6 +6568,7 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u,
ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u)));

return BitCast(di32, Sub(result_sum0, result_sum1));
#endif
}

#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE
Expand Down
4 changes: 3 additions & 1 deletion hwy/ops/set_macros-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@
#define HWY_HAVE_SCALABLE 1
#endif

#define HWY_TARGET_STR_I8MM "+i8mm"

// Can use pragmas instead of -march compiler flag
#if HWY_HAVE_RUNTIME_DISPATCH
#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128
Expand All @@ -628,7 +630,7 @@
#define HWY_TARGET_STR "+sve2,+sve" HWY_TARGET_STR_I8MM
#endif
#else // not SVE2 target
#define HWY_TARGET_STR "+sve" HWY_TARGET_STR_I8MM
#define HWY_TARGET_STR "+sve"
#endif
#else // !HWY_HAVE_RUNTIME_DISPATCH
// HWY_TARGET_STR remains undefined
Expand Down
23 changes: 21 additions & 2 deletions hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10025,12 +10025,21 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
#else
#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
#endif

#if HWY_X86_HAVE_AVX10_2_OPS
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 16)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
VFromD<Repartition<int8_t, DI32>> a,
VFromD<Repartition<int8_t, DI32>> b,
VFromD<DI32> sum) {
return VFromD<DI32>{_mm_dpbssd_epi32(sum.raw, a.raw, b.raw)};
}
#else // !HWY_X86_HAVE_AVX10_2_OPS
template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32,
VFromD<Repartition<int8_t, DI32>> a,
VFromD<Repartition<int8_t, DI32>> b,
VFromD<DI32> sum) {
// TODO(janwas): AVX-VNNI-INT8 has dpbssd.
const Repartition<uint8_t, decltype(di32)> du8;

const auto a_u = BitCast(du8, a);
Expand All @@ -10039,17 +10048,26 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32,
SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32)));
return result_sum_0 - result_sum_1;
}
#endif // HWY_X86_HAVE_AVX10_2_OPS

#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#else
#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#endif

#if HWY_X86_HAVE_AVX10_2_OPS
template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_LE_D(DU32, 16)>
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
return VFromD<DU32>{_mm_dpbuud_epi32(sum.raw, a.raw, b.raw)};
}
#else // !HWY_X86_HAVE_AVX10_2_OPS
template <class DU32, HWY_IF_U32_D(DU32)>
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
DU32 du32, VFromD<Repartition<uint8_t, DU32>> a,
VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
// TODO(janwas): AVX-VNNI-INT8 has dpbuud.
const Repartition<uint8_t, decltype(du32)> du8;
const RebindToSigned<decltype(du8)> di8;
const RebindToSigned<decltype(du32)> di32;
Expand All @@ -10062,6 +10080,7 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(

return BitCast(du32, result_sum_0 - result_sum_1);
}
#endif // HWY_X86_HAVE_AVX10_2_OPS

#endif // HWY_TARGET <= HWY_AVX3_DL

Expand Down
23 changes: 20 additions & 3 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2113,8 +2113,8 @@ HWY_INLINE Vec256<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/,
// ------------------------------ SumsOfAdjQuadAbsDiff

template <int kAOffset, int kBOffset>
static Vec256<uint16_t> SumsOfAdjQuadAbsDiff(Vec256<uint8_t> a,
Vec256<uint8_t> b) {
HWY_API Vec256<uint16_t> SumsOfAdjQuadAbsDiff(Vec256<uint8_t> a,
Vec256<uint8_t> b) {
static_assert(0 <= kAOffset && kAOffset <= 1,
"kAOffset must be between 0 and 1");
static_assert(0 <= kBOffset && kBOffset <= 3,
Expand Down Expand Up @@ -6424,7 +6424,24 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
return VFromD<DI32>{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)};
}

#endif
#if HWY_X86_HAVE_AVX10_2_OPS
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
VFromD<Repartition<int8_t, DI32>> a,
VFromD<Repartition<int8_t, DI32>> b,
VFromD<DI32> sum) {
return VFromD<DI32>{_mm256_dpbssd_epi32(sum.raw, a.raw, b.raw)};
}

template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 32)>
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
return VFromD<DU32>{_mm256_dpbuud_epi32(sum.raw, a.raw, b.raw)};
}
#endif // HWY_X86_HAVE_AVX10_2_OPS

#endif // HWY_TARGET <= HWY_AVX3_DL

// ================================================== CONVERT

Expand Down
35 changes: 30 additions & 5 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6744,22 +6744,30 @@ HWY_API Vec512<uint64_t> CLMulUpper(Vec512<uint64_t> va, Vec512<uint64_t> vb) {
// SumsOfAdjShufQuadAbsDiff)

template <int kAOffset, int kBOffset>
static Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a,
Vec512<uint8_t> b) {
HWY_API Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a,
Vec512<uint8_t> b) {
static_assert(0 <= kAOffset && kAOffset <= 1,
"kAOffset must be between 0 and 1");
static_assert(0 <= kBOffset && kBOffset <= 3,
"kBOffset must be between 0 and 3");

#if HWY_X86_HAVE_AVX10_2_OPS
// AVX10.2 now has the _mm512_mpsadbw_epu8 intrinsic available
return Vec512<uint16_t>{_mm512_mpsadbw_epu8(
a.raw, b.raw,
(kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)};
#else
const DFromV<decltype(a)> d;
const RepartitionToWideX2<decltype(d)> du32;

// While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the
// SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on
// AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast.
// The _mm512_mpsadbw_epu8 intrinsic is not available prior to AVX10.2.
// The SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on
// pre-AVX10.2 targets that support AVX3 using SumsOfShuffledQuadAbsDiff and
// U32 Broadcast.
return SumsOfShuffledQuadAbsDiff<kAOffset + 2, kAOffset + 1, kAOffset + 1,
kAOffset>(
a, BitCast(d, Broadcast<kBOffset>(BitCast(du32, b))));
#endif
}

#if !HWY_IS_MSAN
Expand Down Expand Up @@ -7636,6 +7644,23 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
return VFromD<DI32>{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)};
}

#if HWY_X86_HAVE_AVX10_2_OPS
template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
VFromD<Repartition<int8_t, DI32>> a,
VFromD<Repartition<int8_t, DI32>> b,
VFromD<DI32> sum) {
return VFromD<DI32>{_mm512_dpbssd_epi32(sum.raw, a.raw, b.raw)};
}

template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 64)>
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
return VFromD<DU32>{_mm512_dpbuud_epi32(sum.raw, a.raw, b.raw)};
}
#endif // HWY_X86_HAVE_AVX10_2_OPS

#endif

// ------------------------------ Reductions
Expand Down
3 changes: 2 additions & 1 deletion hwy/targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ static int64_t DetectTargets() {
if ((HasCpuFeature("hw.optional.AdvSIMD_HPFPCvt") ||
HasCpuFeature("hw.optional.arm.AdvSIMD_HPFPCvt")) &&
HasCpuFeature("hw.optional.arm.FEAT_DotProd") &&
HasCpuFeature("hw.optional.arm.FEAT_BF16")) {
HasCpuFeature("hw.optional.arm.FEAT_BF16") &&
HasCpuFeature("hw.optional.arm.FEAT_I8MM")) {
bits |= HWY_NEON_BF16;
}
}
Expand Down