这是indexloc提供的服务,不要输入任何密码
Skip to content

Introduce AllGatherDynamicSliceShuffledOffsetSimplifier, a new HLO pass that collapse dynamic-slice(all-gather) with shuffled offset into collective-permute. #97287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions third_party/xla/xla/hlo/transforms/simplifiers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,44 @@ cc_library(
],
)

cc_library(
name = "all_gather_permuted_ds_simplifier",
srcs = ["all_gather_permuted_ds_simplifier.cc"],
hdrs = ["all_gather_permuted_ds_simplifier.h"],
deps = [
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:collective_opt_utils",
"//xla/service:hlo_module_config",
"//xla/service:pattern_matcher",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "all_gather_permuted_ds_simplifier_test",
srcs = ["all_gather_permuted_ds_simplifier_test.cc"],
deps = [
":all_gather_permuted_ds_simplifier",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/hlo/utils:hlo_matchers",
"//xla/service:hlo_module_config",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "tree_reduction_rewriter",
srcs = ["tree_reduction_rewriter.cc"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the apecific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier.h"

#include <algorithm>
#include <cstdint>
#include <optional>
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/collective_opt_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"

namespace xla {

absl::Status
AllGatherDynamicSlicePermutedOffsetSimplifierVisitor::HandleDynamicSlice(
HloInstruction* dynamic_slice_hlo) {
HloDynamicSliceInstruction* dynamic_slice =
Cast<HloDynamicSliceInstruction>(dynamic_slice_hlo);
HloInstruction* operand = dynamic_slice->mutable_operand(0);

// Check if the operand is a reshape or all-gather instruction
namespace m = match;
HloInstruction* all_gather_hlo;
if (!Match(operand, m::AllGather(&all_gather_hlo)) &&
!Match(operand, m::Reshape(m::AllGather(&all_gather_hlo)))) {
return absl::OkStatus();
}

HloAllGatherInstruction* all_gather =
Cast<HloAllGatherInstruction>(all_gather_hlo);

// Shape check: dynamic-slice shape should match the all-gather operand shape.
if (!ShapeUtil::Compatible(dynamic_slice->shape(),
all_gather->operand(0)->shape())) {
return absl::OkStatus();
}

const HloModuleConfig& config = dynamic_slice->GetModule()->config();
std::optional<AllGatherDynamicSliceMatchSpec> offset_spec =
MatchPermutedSliceAndPartitionOffset(
all_gather, config.num_partitions(), config.replica_count(),
HloPredicateIsOp<HloOpcode::kPartitionId>,
/*allow_multiple_users=*/false);

if (offset_spec.has_value() && !offset_spec->permutation_pairs.empty()) {
// Remove the duplicated pairs as no collective permute is needed.
offset_spec->permutation_pairs.erase(
std::remove_if(offset_spec->permutation_pairs.begin(),
offset_spec->permutation_pairs.end(),
[](const std::pair<int64_t, int64_t>& pair) {
return pair.first == pair.second;
}),
offset_spec->permutation_pairs.end());
// Replace the pattern with a collective permute.
HloInstruction* cp =
dynamic_slice->AddInstruction(HloInstruction::CreateCollectivePermute(
dynamic_slice->shape(), all_gather->mutable_operand(0),
offset_spec->permutation_pairs, all_gather->channel_id()));
return ReplaceInstruction(dynamic_slice, cp);
}

return absl::OkStatus();
}

absl::StatusOr<bool> AllGatherDynamicSlicePermutedOffsetSimplifier::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
AllGatherDynamicSlicePermutedOffsetSimplifierVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
changed |= visitor.changed();
}
return changed;
}

} // namespace xla
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_GATHER_PERMUTED_DS_SIMPLIFIER_H_
#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_GATHER_PERMUTED_DS_SIMPLIFIER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "tsl/platform/statusor.h"

namespace xla {

// Visitor for AllGatherDynamicSlicePermutedOffsetSimplifier.
class AllGatherDynamicSlicePermutedOffsetSimplifierVisitor
: public DfsHloRewriteVisitor {
public:
absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
};

// A pass that simplifies a pattern of `all-gather` followed by a permuted
// `dynamic-slice` into a single `collective-permute`.
//
// For example:
//
// Before:
//
// ENTRY entry {
// p = f32[32,8,128] parameter(0)
// ag = f32[256,8,128] all-gather(p), replica_groups={{0,1,2,3,4,5,6,7}},
// dimensions={0}
// pid = u32[] partition-id()
// permuted_idx_list = s32[8]{0} constant({224,192,160,128,96,64,32,0})
// offset = s32[1] dynamic-slice(permuted_idx_list, pid),
// dynamic_slice_sizes={1}
// offset_reshape = s32[] reshape(offset)
// ...
// ROOT ds = f32[32,8,128] dynamic-slice(ag, offset_reshape, ...),
// dynamic_slice_sizes={32,8,128}
// }
//
// After:
//
// ENTRY entry {
// p = f32[32,8,128] parameter(0)
// ROOT cp = f32[32,8,128] collective-permute(p),
// source_target_pairs={{0,7},{1,6},{2,5},{3,4},{4,3},{5,2},{6,1},{7,0}}
// }
class AllGatherDynamicSlicePermutedOffsetSimplifier : public HloModulePass {
public:
absl::string_view name() const override {
return "all-gather-to-collective-permute-simplifier";
}

using HloModulePass::Run;
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace xla

#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_GATHER_PERMUTED_DS_SIMPLIFIER_H_
Loading