diff --git a/hwy/contrib/thread_pool/thread_pool.h b/hwy/contrib/thread_pool/thread_pool.h index bc483bd479..5d46622dab 100644 --- a/hwy/contrib/thread_pool/thread_pool.h +++ b/hwy/contrib/thread_pool/thread_pool.h @@ -77,6 +77,8 @@ static inline void SetThreadName(const char* format, int thread) { // Whether workers should block or spin. enum class PoolWaitMode : uint8_t { kBlock = 1, kSpin }; +enum class Exit : uint32_t { kNone, kLoop, kThread }; + // Upper bound on non-empty `ThreadPool` (single-worker pools do not count). // Turin has 16 clusters. Add one for the across-cluster pool. HWY_INLINE_VAR constexpr size_t kMaxClusters = 32 + 1; @@ -89,29 +91,33 @@ HWY_INLINE_VAR constexpr size_t kAllClusters = kMaxClusters - 1; class PoolWorkerMapping { public: // Backward-compatible mode: returns local worker index. - PoolWorkerMapping() : cluster_idx_(0), workers_per_cluster_(0) {} - PoolWorkerMapping(size_t cluster_idx, size_t workers_per_cluster) - : cluster_idx_(cluster_idx), workers_per_cluster_(workers_per_cluster) { + PoolWorkerMapping() : cluster_idx_(0), max_cluster_workers_(0) {} + PoolWorkerMapping(size_t cluster_idx, size_t max_cluster_workers) + : cluster_idx_(cluster_idx), max_cluster_workers_(max_cluster_workers) { HWY_DASSERT(cluster_idx <= kAllClusters); // Only use this ctor for the new global worker index mode. If this were // zero, we would still return local indices. - HWY_DASSERT(workers_per_cluster != 0); + HWY_DASSERT(max_cluster_workers != 0); } size_t ClusterIdx() const { return cluster_idx_; } + size_t MaxClusterWorkers() const { return max_cluster_workers_; } // Returns global_idx, or unchanged local worker_idx if default-constructed. size_t operator()(size_t worker_idx) const { if (cluster_idx_ == kAllClusters) { - // Main thread, plus the first core of each subsequent cluster. - return worker_idx * workers_per_cluster_ + 0; + const size_t cluster_idx = worker_idx; + HWY_DASSERT(cluster_idx < kAllClusters); + // First index within the N-th cluster. The main thread is the first. + return cluster_idx * max_cluster_workers_; } - return cluster_idx_ * workers_per_cluster_ + worker_idx; + HWY_DASSERT(max_cluster_workers_ == 0 || worker_idx < max_cluster_workers_); + return cluster_idx_ * max_cluster_workers_ + worker_idx; } private: size_t cluster_idx_; - size_t workers_per_cluster_; + size_t max_cluster_workers_; }; namespace pool { @@ -414,10 +420,10 @@ class Stats { Avg(sum_tasks_stolen_, num_run_dynamic_ * num_workers); printf( - "%3zu: static %5d, %.2f tasks; dyn %5d, %.1f tasks, %.2f steals; " + "%3zu: static %5d, %.2f tasks; dyn %5d, %4.1f tasks, %.2f steals; " "wake %7.3f ns, latency %6.3f < %7.3f us, barrier %7.3f us; " "func: static %6.3f + dyn %7.3f = %.1f%% of total run %7.3f s, " - "%.1f%% of thread time %6.3f s; main run share %5.1f%%\n", + "%.1f%% of thread time %7.3f s; main run share %5.1f%%\n", num_threads, static_cast(num_run_static_), avg_tasks_static, static_cast(num_run_dynamic_), avg_tasks_dynamic, avg_steals, ns(per_run(Seconds(sum_d_wake_))), @@ -608,8 +614,8 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes // avoids a separate `ThreadPool` member which risks going out of sync. void SetNextConfig(Config copy) { next_config_ = copy; } - uint32_t GetExit() const { return exit_; } - void SetExit(uint32_t exit) { exit_ = exit; } + Exit GetExit() const { return exit_; } + void SetExit(Exit exit) { exit_ = exit; } uint32_t WorkerEpoch() const { return worker_epoch_; } uint32_t AdvanceWorkerEpoch() { return ++worker_epoch_; } @@ -678,7 +684,7 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes // Written and read by the same thread, hence not atomic. Config next_config_; - uint32_t exit_ = 0; + Exit exit_ = Exit::kNone; // thread_pool_test requires nonzero epoch. uint32_t worker_epoch_ = 1; @@ -1099,7 +1105,7 @@ class alignas(HWY_ALIGNMENT) ThreadPool { (void)RunWithoutAutotune( 0, NumWorkers(), [this](HWY_MAYBE_UNUSED uint64_t task, size_t worker) { HWY_DASSERT(task == worker); - workers_[worker].SetExit(1); + workers_[worker].SetExit(Exit::kThread); }); for (std::thread& thread : threads_) { @@ -1245,14 +1251,28 @@ class alignas(HWY_ALIGNMENT) ThreadPool { // Called by `std::thread`. Could also be a lambda, but annotating with // `HWY_POOL_PROFILE` makes it easier to inspect the generated code. class ThreadFunc { - // Functor called by `CallWithConfig`. - // TODO: loop until config changes. - struct WorkerWait { + // Functor called by `CallWithConfig`. Loops until `SendConfig` changes the + // Spin or Wait policy or the pool is destroyed. + struct WorkerLoop { template - void operator()(const Spin& spin, const Wait& wait, - pool::Worker& worker) const { - // TODO: log number of spin-wait iterations. - (void)wait.UntilWoken(worker, spin); + void operator()(const Spin& spin, const Wait& wait, pool::Worker& worker, + pool::Tasks& tasks, pool::Shared& shared) const { + do { + // Main worker also calls this, so their epochs match. + const uint32_t epoch = worker.AdvanceWorkerEpoch(); + + // TODO: log number of spin-wait iterations. + (void)wait.UntilWoken(worker, spin); + + Stopwatch stopwatch = worker.MakeStopwatch(); + tasks.WorkerRun(&worker); + shared.stats.NotifyThreadRun(worker.Index(), stopwatch); + + // Notify barrier after `WorkerRun`. Note that we cannot send an + // after-barrier timestamp, see above. + pool::Barrier().WorkerReached(worker, epoch); + // Check after `WorkerReached`, otherwise the main thread deadlocks. + } while (worker.GetExit() == Exit::kNone); } }; @@ -1275,23 +1295,15 @@ class alignas(HWY_ALIGNMENT) ThreadPool { // be counted. Instead, `ProfilerFunc` records the elapsed time. // Loop termination via `GetExit` is triggered by `~ThreadPool`. - do { - // Main worker also calls this, so their epochs match. - const uint32_t epoch = worker_.AdvanceWorkerEpoch(); + for (;;) { // Uses the initial config, or the last one set during WorkerRun. - CallWithConfig(worker_.NextConfig(), WorkerWait(), worker_); - - Stopwatch stopwatch = worker_.MakeStopwatch(); - tasks_.WorkerRun(&worker_); - shared_.stats.NotifyThreadRun(worker_.Index(), stopwatch); + CallWithConfig(worker_.NextConfig(), WorkerLoop(), worker_, tasks_, + shared_); - // Notify barrier after `WorkerRun`. Note that we cannot send an - // after-barrier timestamp, see above. - pool::Barrier().WorkerReached(worker_, epoch); - - // Check after notifying the barrier, otherwise the main thread - // deadlocks. - } while (!worker_.GetExit()); + // Exit or reset the flag and return to WorkerLoop with a new config. + if (worker_.GetExit() == Exit::kThread) break; + worker_.SetExit(Exit::kNone); + } worker_.GetProfiler().SetGlobalIdx(~size_t{0}); @@ -1376,6 +1388,7 @@ class alignas(HWY_ALIGNMENT) ThreadPool { [this, next_config](HWY_MAYBE_UNUSED uint64_t task, size_t worker) { HWY_DASSERT(task == worker); // one task per worker workers_[worker].SetNextConfig(next_config); + workers_[worker].SetExit(Exit::kLoop); }); // All have woken and are, or will be, waiting per `next_config`. Now we