这是indexloc提供的服务,不要输入任何密码
Skip to content

[xla:cpu] Add Thunk::async_resume() property to decide if thunk execution should be resumed using the TaskRunner or the caller thread #97473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/xla/xla/backends/cpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ cc_library(
hdrs = ["thread_pool_task_runner.h"],
deps = [
":thunk",
"@com_google_absl//absl/base:core_headers",
"@eigen_archive//:eigen3",
],
)
Expand Down
14 changes: 2 additions & 12 deletions third_party/xla/xla/backends/cpu/runtime/thread_pool_task_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ limitations under the License.

#define EIGEN_USE_THREADS

#include <cstdint>
#include <optional>
#include <utility>

#include "absl/base/optimization.h"
#include "unsupported/Eigen/CXX11/ThreadPool"
#include "xla/backends/cpu/runtime/thunk.h"

Expand All @@ -36,22 +35,13 @@ class ThreadPoolTaskRunner : public Thunk::TaskRunner {
: thread_pool_(thread_pool) {}

void operator()(Thunk::Task task) final {
if (thread_pool_ == nullptr) {
if (ABSL_PREDICT_FALSE(thread_pool_ == nullptr)) {
task();
} else {
thread_pool_->Schedule(std::move(task));
}
}

std::optional<int64_t> current_worker_id() const final {
if (thread_pool_ == nullptr) {
return {0};
} else {
int64_t thread_id = thread_pool_->CurrentThreadId();
return thread_id == -1 ? std::nullopt : std::make_optional(thread_id);
}
}

private:
Eigen::ThreadPoolInterface* thread_pool_;
};
Expand Down
27 changes: 13 additions & 14 deletions third_party/xla/xla/backends/cpu/runtime/thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,7 @@ class Thunk {
class TaskRunner {
public:
virtual ~TaskRunner() = default;

virtual void operator()(Task task) = 0;

// Returns the current worker id if the caller happens to run on a thread
// managed by the task runner. Otherwise returns empty optional. Thunk
// executor relies on this information to do a best-effort resource
// isolation by making sure that all thunks are executed inside a task
// runner, and do not "leak" into arbitrary thread pools in the process,
// because by default we resume execution on a thread that completed thunk
// execute event AsyncValue, and it can be an external thread pool.
virtual std::optional<int64_t> current_worker_id() const = 0;
};

Thunk(Kind kind, Info info);
Expand All @@ -147,10 +137,19 @@ class Thunk {
using ResourceUses = absl::InlinedVector<ResourceUse, 4>;
virtual ResourceUses resource_uses() const { return {}; }

virtual std::vector<std::pair<std::string, const ThunkSequence*>>
nested_thunks() const {
return {};
}
// Returns the list of nested thunk sequences together with their names (i.e.
// for ConditionalThunk it returns thunk sequences for all branches).
using NamedThunkSequence = std::pair<std::string, const ThunkSequence*>;
virtual std::vector<NamedThunkSequence> nested_thunks() const { return {}; }

// Returns `true` if thunk execution uses thread pool(s) not owned by the
// XLA:CPU runtime, i.e. thunk execution happens asynchronously on the IO
// event manager thread pool. Thunk executor takes extra care to resume
// execution using the TaskRunner passed via the ExecuteParams, otherwise we
// can accidentally take over the thread pool that we do not own. By default
// thunk execution is resumed on a thread that sets the ExecuteEvent async
// value concrete.
virtual bool async_resume() const { return false; }

//===--------------------------------------------------------------------===//
// CollectiveExecuteParams
Expand Down
70 changes: 40 additions & 30 deletions third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,29 +350,34 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
auto event = tsl::MakeConstructedAsyncValueRef<ExecuteEvent>();
execute_event.AndThen([this, &params, it, event](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already running
// on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
execute_event.AndThen(
[this, &params, &thunk, it, event](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (ABSL_PREDICT_FALSE(thunk.async_resume() && runner)) {
// Resume execution using the task runner to avoid executing
// remaining thunks on a thread pool that we don't own.
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
} else {
// Resume execution on a thread that completed thunk execution.
ResumeExecuteSequential(it + 1, params, std::move(event));
}
});
}
});
return event;
}

// Abort execution if any of the thunks failed.
if (ABSL_PREDICT_FALSE(execute_event.IsError())) {
return execute_event;
}

// At this point execute_event must be concrete (completed successfully),
// and we can move on to the next thunk.
DCHECK(execute_event.IsConcrete());
}

// If we got to the end of the sequence it means that all thunks have
Expand All @@ -395,21 +400,21 @@ void ThunkExecutor::ResumeExecuteSequential(
// If thunk execution is not completed yet, attach a continuation to
// resume sequential execution starting from the next thunk.
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
execute_event.AndThen([this, &params, it,
execute_event.AndThen([this, &params, &thunk, it,
event = std::move(event)](absl::Status status) {
Thunk::TaskRunner* runner = params.task_runner;

if (ABSL_PREDICT_FALSE(!status.ok())) {
event.SetError(std::move(status));
} else if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
ResumeExecuteSequential(it + 1, params, std::move(event));
} else {
// Resume execution in the task runner to avoid thread "leaks".
} else if (ABSL_PREDICT_FALSE(thunk.async_resume() && runner)) {
// Resume execution using the task runner to avoid executing
// remaining thunks on a thread pool that we don't own.
(*runner)([this, &params, it, event = std::move(event)] {
ResumeExecuteSequential(it + 1, params, std::move(event));
});
} else {
// Resume execution on a thread that completed thunk execution.
ResumeExecuteSequential(it + 1, params, std::move(event));
}
});
return;
Expand All @@ -420,6 +425,10 @@ void ThunkExecutor::ResumeExecuteSequential(
event.SetError(execute_event.GetError());
return;
}

// At this point execute_event must be concrete (completed successfully),
// and we can move on to the next thunk.
DCHECK(execute_event.IsConcrete());
}

// If we got to the end of the sequence it means that all thunks have
Expand Down Expand Up @@ -498,7 +507,8 @@ void ThunkExecutor::Execute(ExecuteState* state,
// execute session. If we happen to process the last thunk in the ready
// queue, we will forward the lock that we already hold (note that the
// lock might be empty, if `Execute` was called by the main thread).
execute_event.AndThen([&params, &node, state, is_sink, inc_pending_nodes,
execute_event.AndThen([&params, &node, &thunk, state, is_sink,
inc_pending_nodes,
execute_event = execute_event.AsPtr(),
ready_queue = ready_queue.CreateEmptyReadyQueue(),
lock = ready_queue.Empty()
Expand All @@ -520,18 +530,18 @@ void ThunkExecutor::Execute(ExecuteState* state,
}

Thunk::TaskRunner* runner = state->runner;
if (ABSL_PREDICT_TRUE(!runner || runner->current_worker_id())) {
// Resume execution in the current thread if we are already
// running on a thread managed by the task runner.
state->executor->Execute(state, params, std::move(ready_queue),
std::move(lock));
} else {
// Resume execution in the task runner to avoid thread "leaks".
if (ABSL_PREDICT_FALSE(thunk.async_resume() && runner)) {
// Resume execution using the task runner to avoid executing
// remaining thunks on a thread pool that we don't own.
(*runner)([state, &params, ready_queue = std::move(ready_queue),
lock = std::move(lock)] {
state->executor->Execute(state, params, std::move(ready_queue),
std::move(lock));
});
} else {
// Resume execution on a thread that completed thunk execution.
state->executor->Execute(state, params, std::move(ready_queue),
std::move(lock));
}
});
}
Expand Down
33 changes: 9 additions & 24 deletions third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,20 @@ using ::testing::ElementsAre;
static int64_t shared_resource;

// An adaptor from a lambda that runs tasks and a TaskRunner API.
template <typename Runner, typename WorkerId>
template <typename Runner>
class TaskRunnerAdaptor : public Thunk::TaskRunner {
public:
TaskRunnerAdaptor(Runner runner, WorkerId worker_id)
: runner_(std::move(runner)), worker_id_(std::move(worker_id)) {}
explicit TaskRunnerAdaptor(Runner runner) : runner_(std::move(runner)) {}

void operator()(Thunk::Task task) final { runner_(std::move(task)); }

std::optional<int64_t> current_worker_id() const final {
return worker_id_();
}

private:
Runner runner_;
WorkerId worker_id_;
};

template <typename Runner>
auto MakeTaskRunnerFrom(Runner&& runner) {
auto no_id = []() { return std::nullopt; };
return TaskRunnerAdaptor<Runner, decltype(no_id)>(
std::forward<Runner>(runner), no_id);
}

template <typename Runner, typename WorkerId>
auto MakeTaskRunnerFrom(Runner&& runner, WorkerId&& worker_id) {
return TaskRunnerAdaptor<Runner, WorkerId>(std::forward<Runner>(runner),
std::forward<WorkerId>(worker_id));
return TaskRunnerAdaptor<Runner>(std::forward<Runner>(runner));
}

// A test-only thunk for verifying thunk executor implementation:
Expand Down Expand Up @@ -475,13 +461,10 @@ TEST(ThunkExecutorTest, Execute) {
auto data = LiteralUtil::CreateFull({20}, int32_t{1});
BufferAllocations allocations = CreateBufferAllocations(data);

auto task_runner = MakeTaskRunnerFrom(
[&](Thunk::Task task) {
trace.push_back("<TaskRunner>");
task();
},
// Always return current worker id as 0.
[] { return 0; });
auto task_runner = MakeTaskRunnerFrom([&](Thunk::Task task) {
trace.push_back("<TaskRunner>");
task();
});

Thunk::ExecuteParams params = {nullptr, &allocations};
params.task_runner = &task_runner;
Expand Down Expand Up @@ -532,6 +515,8 @@ class NoOpAsyncThunk : public Thunk {
return BufferUses{BufferUse::Write(slice_)};
}

bool async_resume() const override { return true; }

private:
static tsl::thread::ThreadPool* ThreadPool() {
static auto* thread_pool =
Expand Down