diff --git a/hwy/bit_set.h b/hwy/bit_set.h index f8f921becf..b3b18e565f 100644 --- a/hwy/bit_set.h +++ b/hwy/bit_set.h @@ -16,7 +16,7 @@ #ifndef HIGHWAY_HWY_BIT_SET_H_ #define HIGHWAY_HWY_BIT_SET_H_ -// BitSet with fast Foreach for up to 64 and 4096 members. +// Various BitSet for 64, up to 4096, or any number of bits. #include @@ -24,9 +24,11 @@ namespace hwy { -// 64-bit specialization of std::bitset, which lacks Foreach. +// 64-bit specialization of `std::bitset`, which lacks `Foreach`. class BitSet64 { public: + constexpr size_t MaxSize() const { return 64; } + // No harm if `i` is already set. void Set(size_t i) { HWY_DASSERT(i < 64); @@ -48,15 +50,24 @@ class BitSet64 { return (bits_ & (1ULL << i)) != 0; } - // Returns true if any Get(i) would return true for i in [0, 64). + // Returns true if Get(i) would return true for any i in [0, 64). bool Any() const { return bits_ != 0; } - // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + // Returns true if Get(i) would return true for all i in [0, 64). + bool All() const { return bits_ == ~uint64_t{0}; } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! size_t First() const { HWY_DASSERT(Any()); return Num0BitsBelowLS1Bit_Nonzero64(bits_); } + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + return Num0BitsBelowLS1Bit_Nonzero64(~bits_); + } + // Returns uint64_t(Get(i)) << i for i in [0, 64). uint64_t Get64() const { return bits_; } @@ -78,10 +89,226 @@ class BitSet64 { uint64_t bits_ = 0; }; -// Two-level bitset for up to kMaxSize <= 4096 values. +// Any number of bits, flat array. +template +class BitSet { + static_assert(kMaxSize != 0, "BitSet requires non-zero size"); + + public: + constexpr size_t MaxSize() const { return kMaxSize; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Set(mod); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Clear(mod); + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return bits_[idx].Get(mod); + } + + // Returns true if Get(i) would return true for any i in [0, kMaxSize). + bool Any() const { + for (const BitSet64& bits : bits_) { + if (bits.Any()) return true; + } + return false; + } + + // Returns true if Get(i) would return true for all i in [0, kMaxSize). + bool All() const { + for (size_t idx = 0; idx < kNum64 - 1; ++idx) { + if (!bits_[idx].All()) return false; + } + + constexpr size_t kRemainder = kMaxSize % 64; + if (kRemainder == 0) { + return bits_[kNum64 - 1].All(); + } + return bits_[kNum64 - 1].Count() == kRemainder; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (bits_[idx].Any()) return idx * 64 + bits_[idx].First(); + } + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`! + size_t First0() const { + HWY_DASSERT(!All()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (!bits_[idx].All()) { + const size_t first0 = idx * 64 + bits_[idx].First0(); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited BitSet64. + template + void Foreach(const Func& func) const { + for (size_t idx = 0; idx < kNum64; ++idx) { + bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); + } + } + + size_t Count() const { + size_t total = 0; + for (const BitSet64& bits : bits_) { + total += bits.Count(); + } + return total; + } + + private: + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); + BitSet64 bits_[kNum64]; +}; + +// Any number of bits, flat array, atomic updates to the u64. +template +class AtomicBitSet { + static_assert(kMaxSize != 0, "AtomicBitSet requires non-zero size"); + + // Bits may signal something to other threads, hence relaxed is insufficient. + // Acq/Rel ensures a happens-before relationship. + static constexpr auto kAcq = std::memory_order_acquire; + static constexpr auto kRel = std::memory_order_release; + + public: + constexpr size_t MaxSize() const { return kMaxSize; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].fetch_or(1ULL << mod, kRel); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].fetch_and(~(1ULL << mod), kRel); + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return ((bits_[idx].load(kAcq) & (1ULL << mod))) != 0; + } + + // Returns true if Get(i) would return true for any i in [0, kMaxSize). + bool Any() const { + for (const std::atomic& bits : bits_) { + if (bits.load(kAcq)) return true; + } + return false; + } + + // Returns true if Get(i) would return true for all i in [0, kMaxSize). + bool All() const { + for (size_t idx = 0; idx < kNum64 - 1; ++idx) { + if (bits_[idx].load(kAcq) != ~uint64_t{0}) return false; + } + + constexpr size_t kRemainder = kMaxSize % 64; + const uint64_t last_bits = bits_[kNum64 - 1].load(kAcq); + if (kRemainder == 0) { + return last_bits == ~uint64_t{0}; + } + return PopCount(last_bits) == kRemainder; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + const uint64_t bits = bits_[idx].load(kAcq); + if (bits != 0) { + return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits); + } + } + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + const uint64_t inv_bits = ~bits_[idx].load(kAcq); + if (inv_bits != 0) { + const size_t first0 = + idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(inv_bits); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited uint64_t. + template + void Foreach(const Func& func) const { + for (size_t idx = 0; idx < kNum64; ++idx) { + uint64_t remaining_bits = bits_[idx].load(kAcq); + while (remaining_bits != 0) { + const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); + remaining_bits &= remaining_bits - 1; // clear LSB + func(idx * 64 + i); + } + } + } + + size_t Count() const { + size_t total = 0; + for (const std::atomic& bits : bits_) { + total += PopCount(bits.load(kAcq)); + } + return total; + } + + private: + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); + std::atomic bits_[kNum64] = {}; +}; + +// Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators +// (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets. +// This comes at the cost of slightly slower mutators (`Set/Clear`). template class BitSet4096 { + static_assert(kMaxSize != 0, "BitSet4096 requires non-zero size"); + public: + constexpr size_t MaxSize() const { return kMaxSize; } + // No harm if `i` is already set. void Set(size_t i) { HWY_DASSERT(i < kMaxSize); @@ -117,16 +344,38 @@ class BitSet4096 { return bits_[idx].Get(mod); } - // Returns true if any Get(i) would return true for i in [0, 64). + // Returns true if `Get(i)` would return true for any i in [0, kMaxSize). bool Any() const { return nonzero_.Any(); } - // Returns lowest i such that Get(i). Caller must ensure Any() beforehand! + // Returns true if `Get(i)` would return true for all i in [0, kMaxSize). + bool All() const { + // Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096. + if (nonzero_.Count() != kNum64) return false; + return Count() == kMaxSize; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! size_t First() const { HWY_DASSERT(Any()); const size_t idx = nonzero_.First(); return idx * 64 + bits_[idx].First(); } + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + // It is likely not worthwhile to have a separate `BitSet64` for `not_all_`, + // hence iterate over all u64. + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (!bits_[idx].All()) { + const size_t first0 = idx * 64 + bits_[idx].First0(); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + // Returns uint64_t(Get(i)) << i for i in [0, 64). uint64_t Get64() const { return bits_[0].Get64(); } @@ -149,8 +398,9 @@ class BitSet4096 { private: static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient"); + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); BitSet64 nonzero_; - BitSet64 bits_[kMaxSize / 64]; + BitSet64 bits_[kNum64]; }; } // namespace hwy diff --git a/hwy/bit_set_test.cc b/hwy/bit_set_test.cc index daaa6ec5d5..4091da54ed 100644 --- a/hwy/bit_set_test.cc +++ b/hwy/bit_set_test.cc @@ -32,50 +32,79 @@ namespace hwy { namespace { -// Template arg for kMin avoids compiler behavior mismatch for lambda capture. -template -void TestSet() { +template +void SmokeTest() { + constexpr size_t kMax = Set().MaxSize() - 1; + Set set; // Defaults to empty. HWY_ASSERT(!set.Any()); - HWY_ASSERT(set.Count() == 0); - set.Foreach( - [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); }); + HWY_ASSERT(!set.All()); HWY_ASSERT(!set.Get(0)); HWY_ASSERT(!set.Get(kMax)); + HWY_ASSERT(set.First0() == 0); + set.Foreach( + [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); }); + HWY_ASSERT(set.Count() == 0); // After setting, we can retrieve it. set.Set(kMax); HWY_ASSERT(set.Get(kMax)); HWY_ASSERT(set.Any()); + HWY_ASSERT(!set.All()); HWY_ASSERT(set.First() == kMax); - HWY_ASSERT(set.Count() == 1); + HWY_ASSERT(set.First0() == 0); set.Foreach([](size_t i) { HWY_ASSERT(i == kMax); }); - - // SetNonzeroBitsFrom64 does not clear old bits. - set.SetNonzeroBitsFrom64(1ull << kMin); - HWY_ASSERT(set.Any()); - HWY_ASSERT(set.First() == kMin); - HWY_ASSERT(set.Get(kMin)); - HWY_ASSERT(set.Get(kMax)); - HWY_ASSERT(set.Count() == 2); - set.Foreach([](size_t i) { HWY_ASSERT(i == kMin || i == kMax); }); + HWY_ASSERT(set.Count() == 1); // After clearing, it is empty again. - set.Clear(kMin); set.Clear(kMax); + set.Clear(0); // was not set + HWY_ASSERT(!set.Get(0)); + HWY_ASSERT(!set.Get(kMax)); HWY_ASSERT(!set.Any()); - HWY_ASSERT(set.Count() == 0); + HWY_ASSERT(!set.All()); + HWY_ASSERT(set.First0() == 0); set.Foreach( [](size_t i) { HWY_ABORT("Set should be empty but got %zu\n", i); }); - HWY_ASSERT(!set.Get(0)); - HWY_ASSERT(!set.Get(kMax)); + HWY_ASSERT(set.Count() == 0); +} + +TEST(BitSetTest, SmokeTestSet64) { SmokeTest(); } +TEST(BitSetTest, SmokeTestSet) { SmokeTest>(); } +TEST(BitSetTest, SmokeTestAtomicSet) { SmokeTest>(); } +TEST(BitSetTest, SmokeTestSet4096) { SmokeTest>(); } + +template +void TestSetNonzeroBitsFrom64() { + constexpr size_t kMin = 0; + Set set; + set.SetNonzeroBitsFrom64(1ull << kMin); + HWY_ASSERT(set.Any()); + HWY_ASSERT(!set.All()); + HWY_ASSERT(set.Get(kMin)); + HWY_ASSERT(set.First() == kMin); + HWY_ASSERT(set.First0() == kMin + 1); + set.Foreach([](size_t i) { HWY_ASSERT(i == kMin); }); + HWY_ASSERT(set.Count() == 1); + + set.SetNonzeroBitsFrom64(0x70ULL); + HWY_ASSERT(set.Get(kMin) && set.Get(4) && set.Get(5) && set.Get(6)); + HWY_ASSERT(set.Any()); + HWY_ASSERT(!set.All()); + HWY_ASSERT(set.First() == kMin); // does not clear existing bits + HWY_ASSERT(set.First0() == kMin + 1); + set.Foreach([](size_t i) { HWY_ASSERT(i == kMin || (4 <= i && i <= 6)); }); + HWY_ASSERT(set.Count() == 4); } -TEST(BitSetTest, TestSet64) { TestSet(); } -TEST(BitSetTest, TestSet4096) { TestSet, 4095>(); } +TEST(BitSetTest, TestSetNonzeroBits64) { TestSetNonzeroBitsFrom64(); } +TEST(BitSetTest, TestSetNonzeroBits4096) { + TestSetNonzeroBitsFrom64>(); +} -// Supports membership and random choice, for testing BitSet4096. +// Reference implementation using map (for sparse `BitSet4096`) and vector for +// random choice of elements. class SlowSet { public: // Inserting multiple times is a no-op. @@ -136,6 +165,7 @@ class SlowSet { template void CheckSame(const Set& set) { HWY_ASSERT(set.Any() == (set.Count() != 0)); + HWY_ASSERT(set.All() == (set.Count() == set.MaxSize())); HWY_ASSERT(Count() == set.Count()); // Everything set has, we also have. set.Foreach([this](size_t i) { HWY_ASSERT(Get(i)); }); @@ -146,6 +176,12 @@ class SlowSet { if (set.Any()) { HWY_ASSERT(set.First() == idx_for_i_.begin()->first); } + if (!set.All()) { + const size_t idx0 = set.First0(); + HWY_ASSERT(idx0 < set.MaxSize()); + HWY_ASSERT(!set.Get(idx0)); + HWY_ASSERT(!Get(idx0)); + } } private: @@ -153,16 +189,17 @@ class SlowSet { std::map idx_for_i_; }; -void TestSetRandom(uint64_t grow_prob) { - const uint32_t mod = 4096; +template +void TestSetWithGrowProb(uint64_t grow_prob) { + constexpr uint32_t max_size = static_cast(Set().MaxSize()); RandomState rng; // Multiple independent random tests: for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { - BitSet4096<> set; + Set set; SlowSet slow_set; // Mutate sets via random walk and ensure they are the same afterwards. - for (size_t iter = 0; iter < 200; ++iter) { + for (size_t iter = 0; iter < AdjustedReps(1000); ++iter) { const uint64_t bits = (Random64(&rng) >> 10) & 0x3FF; if (bits > 980 && slow_set.Count() != 0) { // Small chance of reinsertion: already present, unchanged after. @@ -175,7 +212,7 @@ void TestSetRandom(uint64_t grow_prob) { HWY_ASSERT(count == set.Count()); } else if (bits < grow_prob) { // Set random value; no harm if already set. - const size_t i = static_cast(Random32(&rng) % mod); + const size_t i = static_cast(Random32(&rng) % max_size); slow_set.Set(i); set.Set(i); HWY_ASSERT(set.Get(i)); @@ -194,9 +231,23 @@ void TestSetRandom(uint64_t grow_prob) { } } -// Lower probability of growth so that the set is often nearly empty. -TEST(BitSetTest, TestSetRandomShrink) { TestSetRandom(400); } -TEST(BitSetTest, TestSetRandomGrow) { TestSetRandom(600); } +template +void TestSetRandom() { + // Lower probability of growth so that the set is often nearly empty. + TestSetWithGrowProb(400); + + TestSetWithGrowProb(600); +} + +TEST(BitSetTest, TestSet64) { TestSetRandom(); } +TEST(BitSetTest, TestSet41) { TestSetRandom>(); } +TEST(BitSetTest, TestSet) { TestSetRandom>(); } +// One partial u64 +TEST(BitSetTest, TestAtomicSet32) { TestSetRandom>(); } +// 3 whole u64 +TEST(BitSetTest, TestAtomicSet192) { TestSetRandom>(); } +TEST(BitSetTest, TestSet3000) { TestSetRandom>(); } +TEST(BitSetTest, TestSet4096) { TestSetRandom>(); } } // namespace } // namespace hwy diff --git a/hwy/contrib/thread_pool/topology.cc b/hwy/contrib/thread_pool/topology.cc index d524120344..8da62a09d6 100644 --- a/hwy/contrib/thread_pool/topology.cc +++ b/hwy/contrib/thread_pool/topology.cc @@ -114,7 +114,9 @@ bool ForEachSLPI(LOGICAL_PROCESSOR_RELATIONSHIP rel, Func&& func) { } HWY_ASSERT(GetLastError() == ERROR_INSUFFICIENT_BUFFER); // Note: `buf_bytes` may be less than `sizeof(SLPI)`, which has padding. - uint8_t* buf = static_cast(malloc(buf_bytes)); + // `calloc` zero-initializes the `Reserved` field, part of which has been + // repurposed into `GroupCount` in SDKs, 10.0.22000.0 or possibly earlier. + uint8_t* buf = static_cast(calloc(1, buf_bytes)); HWY_ASSERT(buf); // Fill the buffer. @@ -658,6 +660,27 @@ void SetClusterCacheSizes(std::vector& packages) { #elif HWY_OS_WIN +// See #2734. GroupCount was added around Windows 10, but SDK docs do not +// mention the actual version required. It is known to be absent in 8.1 and +// MinGW 5.0.1, and present in the 10.0.22000.0 SDK. However, the OS must also +// know about the field. Thus we zero-initialize the reserved field, assume it +// remains zero, and return 1 if zero (old style single GroupMask), otherwise +// the number of groups. There are two such structures, but note that +// `PROCESSOR_RELATIONSHIP` already had this field. +static size_t GroupCount(const CACHE_RELATIONSHIP& cr) { + // Added as the last u16 in the reserved area before GroupMask. We only read + // one byte because 256*64 processor bits are plenty. + const uint8_t* pcount = + reinterpret_cast(&cr.GroupMask) - sizeof(uint16_t); + return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1); +} + +static size_t GroupCount(const NUMA_NODE_RELATIONSHIP& nn) { + const uint8_t* pcount = + reinterpret_cast(&nn.GroupMask) - sizeof(uint16_t); + return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1); +} + // Also sets LP.core and LP.smt. size_t MaxLpsPerCore(std::vector& lps) { size_t max_lps_per_core = 0; @@ -711,7 +734,7 @@ size_t MaxCoresPerCluster(const size_t max_lps_per_core, const CACHE_RELATIONSHIP& cr = info.Cache; if (cr.Type != CacheUnified && cr.Type != CacheData) return; if (cr.Level != 3) return; - foreach_cluster(cr.GroupCount, cr.GroupMasks); + foreach_cluster(GroupCount(cr), cr.GroupMasks); }; if (!ForEachSLPI(RelationProcessorDie, foreach_die)) { @@ -768,7 +791,7 @@ void SetNodes(std::vector& lps) { if (info.Relationship != RelationNumaNode) return; const NUMA_NODE_RELATIONSHIP& nn = info.NumaNode; // This field was previously reserved/zero. There is at least one group. - const size_t num_groups = HWY_MAX(1, nn.GroupCount); + const size_t num_groups = HWY_MAX(1, GroupCount(nn)); const uint8_t node = static_cast(nn.NodeNumber); ForeachBit(num_groups, nn.GroupMasks, lps, __LINE__, [node](size_t lp, std::vector& lps) { @@ -1027,7 +1050,7 @@ bool InitCachesWin(Caches& caches) { : cr.Associativity; // How many cores share this cache? - size_t shared_with = NumBits(cr.GroupCount, cr.GroupMasks); + size_t shared_with = NumBits(GroupCount(cr), cr.GroupMasks); // Divide out hyperthreads. This core may have fewer than // `max_lps_per_core`, hence round up. shared_with = DivCeil(shared_with, max_lps_per_core);