Introduces a new utility function, MatchPermutedSliceAndPartitionOffset
, to detect a pattern where a DynamicSlice
consumes the output of an AllGather
with a permuted set of offsets. This pattern is equivalent to a CollectivePermute
and can be optimized accordingly.
#97189
+547
−52
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Introduces a new utility function,
MatchPermutedSliceAndPartitionOffset
, to detect a pattern where aDynamicSlice
consumes the output of anAllGather
with a permuted set of offsets. This pattern is equivalent to aCollectivePermute
and can be optimized accordingly.The logic is divided into four main sections:
AllGather
is a suitable candidate for this optimization. It ensures the operation is performed across multiple partitions with a single replica and uses flattened IDs (i.e.,use_global_device_ids
is enabled).DynamicSlice
user of theAllGather
, correctly traversing through any interveningReshape
orBitcast
operations that do not alter the data.AllGather
(the source layout) and theDynamicSlice
(the destination/permuted layout).AllGather
with the destination partition ID from theDynamicSlice
.This CL also introduces a few key data structures to support this logic:
PartitionOffsetSpec
: Represents the mapping from a memory offset to a partition ID for each replica group. This is used to model the data layout produced by theAllGather
and the permuted access pattern of theDynamicSlice
.PermutationPairs
: A type alias for a list of(source_id, destination_id)
pairs, representing a permute operation for CP.