这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hwy/bit_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <stddef.h>

#include <atomic>

#include "hwy/base.h"

namespace hwy {
Expand Down
2 changes: 1 addition & 1 deletion hwy/contrib/math/math_tan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,4 @@ HWY_AFTER_TEST();
HWY_TEST_MAIN();
#endif // HWY_ONCE

#endif // HWY_ARCH_RVV
#endif // HWY_ARCH_RVV
5 changes: 2 additions & 3 deletions hwy/contrib/random/random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ namespace hwy {
namespace HWY_NAMESPACE { // required: unique per target
namespace internal {

namespace {
#if HWY_HAVE_FLOAT64
// C++ < 17 does not support hexfloat
#if __cpp_hex_float > 201603L
Expand All @@ -52,7 +51,6 @@ constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c,

constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3,
0x77710069854ee241, 0x39109bb02acbe635};
} // namespace

class SplitMix64 {
public:
Expand Down Expand Up @@ -177,6 +175,7 @@ class VectorXoshiro {
#if HWY_HAVE_FLOAT64
using VF64 = Vec<ScalableTag<double>>;
#endif

public:
explicit VectorXoshiro(const std::uint64_t seed,
const std::uint64_t threadNumber = 0)
Expand Down Expand Up @@ -381,4 +380,4 @@ class CachedXoshiro {

HWY_AFTER_NAMESPACE();

#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_
#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_
34 changes: 8 additions & 26 deletions hwy/contrib/sort/algo-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,36 +208,18 @@ enum class Algo {
};

static inline bool IsVQ(Algo algo) {
switch (algo) {
case Algo::kVQSort:
case Algo::kVQPartialSort:
case Algo::kVQSelect:
return true;
default:
return false;
}
return algo == Algo::kVQSort || algo == Algo::kVQPartialSort ||
algo == Algo::kVQSelect;
}

static inline bool IsSelect(Algo algo) {
switch (algo) {
case Algo::kStdSelect:
case Algo::kVQSelect:
case Algo::kHeapSelect:
return true;
default:
return false;
}
return algo == Algo::kStdSelect || algo == Algo::kVQSelect ||
algo == Algo::kHeapSelect;
}

static inline bool IsPartialSort(Algo algo) {
switch (algo) {
case Algo::kStdPartialSort:
case Algo::kVQPartialSort:
case Algo::kHeapPartialSort:
return true;
default:
return false;
}
return algo == Algo::kStdPartialSort || algo == Algo::kVQPartialSort ||
algo == Algo::kHeapPartialSort;
}

static inline Algo ReferenceAlgoFor(Algo algo) {
Expand Down Expand Up @@ -452,8 +434,8 @@ InputStats<T> GenerateInput(const Dist dist, T* v, size_t num_lanes) {
}

InputStats<T> input_stats;
for (size_t i = 0; i < num_lanes; ++i) {
input_stats.Notify(v[i]);
for (size_t j = 0; j < num_lanes; ++j) {
input_stats.Notify(v[j]);
}
return input_stats;
}
Expand Down
2 changes: 1 addition & 1 deletion hwy/contrib/sort/bench_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ std::vector<size_t> SizesToBenchmark(BenchmarkModes mode) {
HWY_NOINLINE void BenchAllSort() {
// Not interested in benchmark results for these targets. Note that SSE4 is
// numerically less than SSE2, hence it is the lower bound.
if (HWY_SSE4 <= HWY_TARGET && HWY_TARGET <= HWY_SSE2) {
if (HWY_SSE4 <= HWY_TARGET && HWY_TARGET <= HWY_SSE2 && Unpredictable1()) {
return;
}
#if HAVE_INTEL
Expand Down
2 changes: 1 addition & 1 deletion hwy/contrib/sort/print_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ static void PrintMergeNetwork(int rows, int cols) {
printf("\n");
}

int main(int argc, char** argv) {
int main(int /*argc*/, char** /*argv*/) {
PrintMergeNetwork(8, 2);
PrintMergeNetwork(8, 4);
PrintMergeNetwork(16, 4);
Expand Down
12 changes: 7 additions & 5 deletions hwy/contrib/sort/sort_unit_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,13 @@ static HWY_NOINLINE void TestRandomGenerator() {
sum_lo += bits & 0xFFFFFFFF;
sum_hi += bits >> 32;
}
const double expected = 1000 * (1ULL << 31);
HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_lo) &&
static_cast<double>(sum_lo) <= 1.1 * expected);
HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_hi) &&
static_cast<double>(sum_hi) <= 1.1 * expected);
{
const double expected = 1000 * (1ULL << 31);
HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_lo) &&
static_cast<double>(sum_lo) <= 1.1 * expected);
HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_hi) &&
static_cast<double>(sum_hi) <= 1.1 * expected);
}

const size_t lanes_per_block = HWY_MAX(64 / sizeof(TU), N); // power of two

Expand Down
12 changes: 10 additions & 2 deletions hwy/contrib/thread_pool/spin.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,27 @@ struct SpinResult {
// `HWY_TARGET` and its runtime dispatch mechanism. Returned by `Type()`, also
// used by callers to set the `disabled` argument for `DetectSpin`.
enum class SpinType : uint8_t {
#if HWY_ENABLE_MONITORX
kMonitorX = 1, // AMD
kUMonitor, // Intel
kPause,
#endif
#if HWY_ENABLE_UMONITOR
kUMonitor = 2, // Intel
#endif
kPause = 3,
kSentinel // for iterating over all enumerators. Must be last.
};

// For printing which is in use.
static inline const char* ToString(SpinType type) {
switch (type) {
#if HWY_ENABLE_MONITORX
case SpinType::kMonitorX:
return "MonitorX_C1";
#endif
#if HWY_ENABLE_UMONITOR
case SpinType::kUMonitor:
return "UMonitor_C0.2";
#endif
case SpinType::kPause:
return "Pause";
case SpinType::kSentinel:
Expand Down
2 changes: 1 addition & 1 deletion hwy/contrib/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ HWY_INLINE void CallWithConfig(const Config& config, Func&& func,
case WaitType::kSpinSeparate:
return CallWithSpin(config.spin_type, func, WaitSpinSeparate(),
std::forward<Args>(args)...);
default:
case WaitType::kSentinel:
HWY_UNREACHABLE;
}
}
Expand Down
2 changes: 1 addition & 1 deletion hwy/detect_targets.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
#define HWY_SVE2 (1LL << 23)
#define HWY_SVE (1LL << 24)
// Bit 25 reserved for NEON
#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2/N3)
#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16/i8mm (e.g. Neoverse V2/N2)
// Bit 27 reserved for NEON
#define HWY_NEON (1LL << 28) // Implies support for AES
#define HWY_NEON_WITHOUT_AES (1LL << 29)
Expand Down
6 changes: 3 additions & 3 deletions hwy/nanobenchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ HWY_DLLEXPORT size_t Measure(Func func, const uint8_t* arg,

// Calls operator() of the given closure (lambda function).
template <class Closure>
static FuncOutput CallClosure(const Closure* f, const FuncInput input) {
return (*f)(input);
static FuncOutput CallClosure(const void* f, const FuncInput input) {
return (*reinterpret_cast<const Closure*>(f))(input);
}

// Same as Measure, except "closure" is typically a lambda function of
Expand All @@ -143,7 +143,7 @@ static inline size_t MeasureClosure(const Closure& closure,
const FuncInput* inputs,
const size_t num_inputs, Result* results,
const Params& p = Params()) {
return Measure(reinterpret_cast<Func>(&CallClosure<Closure>),
return Measure(static_cast<Func>(&CallClosure<Closure>),
reinterpret_cast<const uint8_t*>(&closure), inputs, num_inputs,
results, p);
}
Expand Down
11 changes: 7 additions & 4 deletions hwy/ops/set_macros-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@
#define HWY_TARGET_STR_FP16 "+fp16"
#endif

#define HWY_TARGET_STR_I8MM "+i8mm"

#if HWY_TARGET == HWY_NEON_WITHOUT_AES
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400
// Prevents inadvertent use of SVE by GCC 13.4 and earlier, see #2689.
Expand All @@ -562,7 +564,8 @@
#elif HWY_TARGET == HWY_NEON
#define HWY_TARGET_STR HWY_TARGET_STR_NEON
#elif HWY_TARGET == HWY_NEON_BF16
#define HWY_TARGET_STR HWY_TARGET_STR_FP16 "+bf16+dotprod" HWY_TARGET_STR_NEON
#define HWY_TARGET_STR \
HWY_TARGET_STR_FP16 HWY_TARGET_STR_I8MM "+bf16+dotprod" HWY_TARGET_STR_NEON
#else
#error "Logic error, missing case"
#endif // HWY_TARGET
Expand Down Expand Up @@ -620,12 +623,12 @@
// Static dispatch with -march=armv8-a+sve2+aes, or no baseline, hence dynamic
// dispatch, which checks for AES support at runtime.
#if defined(__ARM_FEATURE_SVE2_AES) || (HWY_BASELINE_SVE2 == 0)
#define HWY_TARGET_STR "+sve2+sve2-aes,+sve"
#define HWY_TARGET_STR "+sve2+sve2-aes,+sve" HWY_TARGET_STR_I8MM
#else // SVE2 without AES
#define HWY_TARGET_STR "+sve2,+sve"
#define HWY_TARGET_STR "+sve2,+sve" HWY_TARGET_STR_I8MM
#endif
#else // not SVE2 target
#define HWY_TARGET_STR "+sve"
#define HWY_TARGET_STR "+sve" HWY_TARGET_STR_I8MM
#endif
#else // !HWY_HAVE_RUNTIME_DISPATCH
// HWY_TARGET_STR remains undefined
Expand Down
17 changes: 14 additions & 3 deletions hwy/targets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ static int64_t DetectTargets() {
#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH
namespace arm {

#ifndef HWCAP2_I8MM
#define HWCAP2_I8MM (1 << 13)
#endif

#if HWY_ARCH_ARM_A64 && !HWY_OS_APPLE && \
(HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \
((HWY_TARGETS & HWY_ALL_SVE) != 0)
Expand Down Expand Up @@ -502,8 +506,10 @@ static int64_t DetectTargets() {

#if defined(HWCAP_ASIMDHP) && defined(HWCAP_ASIMDDP) && defined(HWCAP2_BF16)
const CapBits hw2 = getauxval(AT_HWCAP2);
const int64_t kGroupF16Dot = HWCAP_ASIMDHP | HWCAP_ASIMDDP;
if ((hw & kGroupF16Dot) == kGroupF16Dot && (hw2 & HWCAP2_BF16)) {
constexpr CapBits kGroupF16Dot = HWCAP_ASIMDHP | HWCAP_ASIMDDP;
constexpr CapBits kGroupBF16 = HWCAP2_BF16 | HWCAP2_I8MM;
if ((hw & kGroupF16Dot) == kGroupF16Dot &&
(hw2 & kGroupBF16) == kGroupBF16) {
bits |= HWY_NEON_BF16;
}
#endif // HWCAP_ASIMDHP && HWCAP_ASIMDDP && HWCAP2_BF16
Expand All @@ -522,8 +528,13 @@ static int64_t DetectTargets() {
#ifndef HWCAP2_SVEAES
#define HWCAP2_SVEAES (1 << 2)
#endif
#ifndef HWCAP2_SVEI8MM
#define HWCAP2_SVEI8MM (1 << 9)
#endif
constexpr CapBits kGroupSVE2 =
HWCAP2_SVE2 | HWCAP2_SVEAES | HWCAP2_SVEI8MM | HWCAP2_I8MM;
const CapBits hw2 = getauxval(AT_HWCAP2);
if ((hw2 & HWCAP2_SVE2) && (hw2 & HWCAP2_SVEAES)) {
if ((hw2 & kGroupSVE2) == kGroupSVE2) {
bits |= HWY_SVE2;
}

Expand Down