这是indexloc提供的服务,不要输入任何密码
Skip to content

Also create reduce-scatters when SPMD-partitioning is disabled. #97211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions third_party/xla/xla/service/collective_opt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ std::optional<ReduceScatterSpec> MatchReduceScatter(
ar, num_partitions, num_replicas, min_rank, ar->constrain_layout(),
ar->use_global_device_ids(), ar->channel_id().has_value());
}
bool is_cross_module =
ar->channel_id() && ar->opcode() == HloOpcode::kAllReduce;
auto spec = MatchWithDynamicSlice(
ar, num_partitions, num_replicas, allow_multiple_split_dims,
allow_intervening_reshape, min_rank, match_partition_id, match_replica_id,
ar->constrain_layout(), ar->use_global_device_ids(),
ar->channel_id() && ar->opcode() == HloOpcode::kAllReduce,
ar->constrain_layout(), ar->use_global_device_ids(), is_cross_module,
allow_intervening_bitcast);
return spec;
}
Expand All @@ -353,11 +354,12 @@ std::optional<ReduceScatterSpec> AllGatherDynamicSliceCancellation(
bool allow_intervening_reshape, int64_t min_rank,
HloPredicate match_partition_id, HloPredicate match_replica_id,
bool allow_intervening_bitcast, bool allow_multiple_users) {
bool is_cross_module =
ag->channel_id() && ag->opcode() == HloOpcode::kAllGather;
auto spec = MatchWithDynamicSlice(
ag, num_partitions, num_replicas, allow_multiple_split_dims,
allow_intervening_reshape, min_rank, match_partition_id, match_replica_id,
ag->constrain_layout(), ag->use_global_device_ids(),
ag->channel_id() && ag->opcode() == HloOpcode::kAllGather,
ag->constrain_layout(), ag->use_global_device_ids(), is_cross_module,
allow_intervening_bitcast, allow_multiple_users);

if (!spec.has_value()) {
Expand Down Expand Up @@ -693,9 +695,7 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
HloPredicate match_partition_id, HloPredicate match_replica_id,
bool is_constrain_layout, bool use_global_device_ids, bool is_cross_module,
bool allow_intervening_bitcast, bool allow_multiple_users) {
if (!instruction->shape().IsArray() || is_constrain_layout ||
(is_cross_module &&
!instruction->GetModule()->config().use_spmd_partitioning())) {
if (!instruction->shape().IsArray() || is_constrain_layout) {
VLOG(2) << "Unsupported collective: " << instruction->ToString();
return std::nullopt;
}
Expand Down Expand Up @@ -763,7 +763,8 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
}
}
}
map_id = [&, orthogonal_replicas](const HloInstruction* hlo, int64_t id) {
map_id = [&, orthogonal_replicas](const HloInstruction* hlo,
int64_t id) -> int64_t {
if (match_replica_id(hlo)) {
return num_partitions == 1 ? id : -1;
}
Expand All @@ -789,7 +790,7 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
is_replica_mul_num_partitions(hlo->operand(0))))) {
return id;
}
return int64_t{-1};
return -1;
};
} else {
// Right now all cross-partition all-reduces' subgroups refer to replicas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ class GpuReduceScatterCreatorTest : public HloHardwareIndependentTestBase {
public:
absl::StatusOr<std::unique_ptr<HloModule>> RunPass(
absl::string_view hlo_module, int64_t num_replicas,
int64_t num_partitions, bool expect_change) {
int64_t num_partitions, bool use_spmd_partitioning, bool expect_change) {
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/num_replicas,
/*num_partitions=*/num_partitions);
config.set_use_spmd_partitioning(num_partitions > 1);
/*replica_count=*/num_replicas, /*num_partitions=*/num_partitions);
config.set_use_spmd_partitioning(use_spmd_partitioning);
TF_ASSIGN_OR_RETURN(auto module,
ParseAndReturnVerifiedModule(hlo_module, config));
auto changed = ReduceScatterCreator().Run(module.get());
Expand Down Expand Up @@ -109,6 +108,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -147,6 +147,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -186,6 +187,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
Expand Down Expand Up @@ -219,6 +221,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Reshape(m::ReduceScatter(m::Parameter(0)))));
Expand Down Expand Up @@ -253,6 +256,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -290,6 +294,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/false));
}

Expand Down Expand Up @@ -321,6 +326,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/2,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -356,6 +362,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/2,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -390,6 +397,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/8,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -426,6 +434,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/8,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_EQ(AllReduceCount(module), 1);
EXPECT_EQ(ReduceScatterCount(module), 1);
Expand Down Expand Up @@ -464,6 +473,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/4,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -498,6 +508,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/4,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
Expand Down Expand Up @@ -532,6 +543,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/2,
/*num_partitions=*/4,
/*use_spmd_partitioning=*/true,
/*expect_change=*/false));
}

Expand Down Expand Up @@ -563,6 +575,7 @@ ENTRY %AllReduce {
TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/1,
/*num_partitions=*/8,
/*use_spmd_partitioning=*/true,
/*expect_change=*/true));
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Slice(m::Parameter(0)))));
Expand All @@ -573,6 +586,47 @@ TEST_F(GpuReduceScatterCreatorTest,
absl::string_view hlo_string = R"(
HloModule AllReduce

%sum {
%a = f32[] parameter(0)
%b = f32[] parameter(1)
ROOT %add = f32[] add(%a, %b)
}

ENTRY %AllReduce {
%param = f32[32,8,128]{2,1,0} parameter(0)
%all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
replica_groups={}, to_apply=%sum, backend_config={"collective_backend_config":{"is_pipelined":true}}
%table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
%rid = u32[] replica-id()
%id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
%reshape = s32[] reshape(%id)
%slice_size = s32[] constant(4)
%offset = s32[] multiply(%reshape, %slice_size)
%zero = s32[] constant(0)
ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
dynamic_slice_sizes={4,8,128}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
const auto *rs = Cast<HloReduceScatterInstruction>(
module->entry_computation()->root_instruction());
EXPECT_TRUE(rs->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
}

TEST_F(GpuReduceScatterCreatorTest,
ReduceScatterCreatorWithSPMDPartitioningDisabled) {
absl::string_view hlo_string = R"(
HloModule test

%sum {
%a = f32[] parameter(0)
%b = f32[] parameter(1)
Expand All @@ -582,30 +636,26 @@ HloModule AllReduce
ENTRY %AllReduce {
%param = f32[32,8,128]{2,1,0} parameter(0)
%all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
replica_groups={}, to_apply=%sum, backend_config={"collective_backend_config":{"is_pipelined":true}}
replica_groups={{0,1,2,3,4,5,6,7}}, channel_id=1,
use_global_device_ids=true, to_apply=%sum
%table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
%rid = u32[] replica-id()
%id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
%pid = u32[] partition-id()
%id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
%reshape = s32[] reshape(%id)
%slice_size = s32[] constant(4)
%offset = s32[] multiply(%reshape, %slice_size)
%zero = s32[] constant(0)
ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
dynamic_slice_sizes={4,8,128}
ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero,
%zero), dynamic_slice_sizes={4,8,128}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
/*num_replicas=*/8,
/*num_partitions=*/1,
/*num_replicas=*/1,
/*num_partitions=*/8,
/*use_spmd_partitioning=*/false,
/*expect_change=*/true));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::ReduceScatter(m::Parameter(0))));
const auto *rs = Cast<HloReduceScatterInstruction>(
module->entry_computation()->root_instruction());
EXPECT_TRUE(rs->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
}

} // namespace
Expand Down