-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
We're implementing our own tensor-parallel mode with a "compile-first" approach. We'd thus like to be able to write "inefficient" code, such as each of wq
, wk
and wv
issuing its own all-gather of the activations (instead of using PrepareModuleInput
to do it ahead of time), and we'd like torch.compile to optimize the resulting graph for us.
What we're observing is that indeed the three redundant all-gathers get merged into a single one. Concretely, it seems that one of these all-gathers is chosen, and the usages of the other all-gathers are modified to ingest the output of the chosen all-gather. The issue is that these other all-gathers are not removed from the graph! They stick around as dead code (no usages) and are moved to the end of the graph. This is inefficient, as they are still executed, thus wasting time.
It appears that this choice is deliberate (see #131023 and #132341), motivated by supporting non-SPMD scenarios, as detailed in #130918. In that case, only some ranks were using the output of a collective, and if the other ranks removed it that'd cause a deadlock.
Our application is perfectly SPMD hence this design choice is problematic for us. We'd appreciate a way to change the compiler's behavior if we can inform it of such an assumption.
Error logs
No response
Versions
Nightly
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @chauhang @penguinwu @tianyu-l @XilunWu