这是indexloc提供的服务,不要输入任何密码
Skip to content

Annotate some XLA:GPU flags as stable i.e. they should provide 6 month deprecation notice. #97134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,8 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_autotune_level",
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
debug_options->xla_gpu_autotune_level(),
"Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
"[Stable] Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; "
"2 = on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
"5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also "
"the related flag xla_gpu_autotune_gemm_rtol. Remark that, setting the "
"level to 5 only makes sense if you are sure that the reference (first "
Expand Down Expand Up @@ -1425,7 +1425,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int64_setter_for(
&DebugOptions::set_xla_gpu_all_reduce_combine_threshold_bytes),
debug_options->xla_gpu_all_reduce_combine_threshold_bytes(),
"Size threshold (in bytes) for the GPU all-reduce combiner."));
"[Stable] Size threshold (in bytes) for the GPU all-reduce combiner."));
flag_list->push_back(tsl::Flag(
"xla_gpu_all_gather_combine_threshold_bytes",
int64_setter_for(
Expand All @@ -1437,7 +1437,8 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int64_setter_for(
&DebugOptions::set_xla_gpu_reduce_scatter_combine_threshold_bytes),
debug_options->xla_gpu_reduce_scatter_combine_threshold_bytes(),
"Size threshold (in bytes) for the GPU reduce-scatter combiner."));
"[Stable] Size threshold (in bytes) for the GPU reduce-scatter "
"combiner."));
flag_list->push_back(tsl::Flag(
"xla_gpu_collective_permute_combine_threshold_bytes",
int64_setter_for(
Expand Down Expand Up @@ -1628,7 +1629,8 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_enable_dynamic_slice_fusion",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_dynamic_slice_fusion),
debug_options->xla_gpu_enable_dynamic_slice_fusion(),
"Whether to enable XLA address computation fusion"));
"[Stable] Whether to enable address computation fusion to optimize "
"dynamic-slice and dynamic-update-slice operations."));
flag_list->push_back(tsl::Flag(
"xla_gpu_nccl_termination_timeout_seconds",
int64_setter_for(
Expand Down Expand Up @@ -1720,7 +1722,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_latency_hiding_scheduler),
debug_options->xla_gpu_enable_latency_hiding_scheduler(),
"Enable latency-hiding scheduler for XLA:GPU"));
"[Stable] Enable latency-hiding scheduler for XLA:GPU"));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_analytical_latency_estimator",
bool_setter_for(
Expand Down Expand Up @@ -1787,18 +1789,18 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_enable_pipelined_all_reduce",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_reduce),
debug_options->xla_gpu_enable_pipelined_all_reduce(),
"Enable pipelinling of all-reduce instructions."));
"[Stable] Enable pipelinling of all-reduce instructions."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_pipelined_all_gather",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_gather),
debug_options->xla_gpu_enable_pipelined_all_gather(),
"Enable pipelinling of all-gather instructions."));
"[Stable] Enable pipelinling of all-gather instructions."));
flag_list->push_back(
tsl::Flag("xla_gpu_enable_pipelined_reduce_scatter",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_pipelined_reduce_scatter),
debug_options->xla_gpu_enable_pipelined_reduce_scatter(),
"Enable pipelinling of reduce-scatter instructions."));
"[Stable] Enable pipelinling of reduce-scatter instructions."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_pipelined_p2p",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_p2p),
Expand All @@ -1809,7 +1811,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int64_setter_for(
&DebugOptions::set_xla_gpu_collective_permute_decomposer_threshold),
debug_options->xla_gpu_collective_permute_decomposer_threshold(),
"Collective permute decomposer threshold."));
"[Stable] Collective permute decomposer threshold."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_pipeline_parallelism_opt_level",
setter_for_xla_gpu_experimental_pipeline_parallelism_opt_level,
Expand All @@ -1826,7 +1828,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
tsl::Flag("xla_gpu_enable_triton_gemm",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm),
debug_options->xla_gpu_enable_triton_gemm(),
"Use Triton-based matrix multiplication."));
"[Stable] Whether to use Triton-based matrix multiplication."));
flag_list->push_back(tsl::Flag(
"xla_gpu_unsupported_generic_triton_emitter_features",
SetterForRepeatedEnum<DebugOptions::GenericTritonEmitterFeature>(
Expand Down Expand Up @@ -1871,7 +1873,8 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_exhaustive_tiling_search",
bool_setter_for(&DebugOptions::set_xla_gpu_exhaustive_tiling_search),
debug_options->xla_gpu_exhaustive_tiling_search(),
"Enable (slow) search for the Triton GEMM fusion tilings."));
"[Stable] Search for Triton GEMM tilings exhaustively during autotuning. "
"This increases the compile time."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_subchannel_dequantisation_fusion",
bool_setter_for(
Expand Down Expand Up @@ -1978,8 +1981,8 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
tsl::Flag("xla_gpu_cublas_fallback",
bool_setter_for(&DebugOptions::set_xla_gpu_cublas_fallback),
debug_options->xla_gpu_cublas_fallback(),
"Allow GEMM fusion autotuning to fall back to cuBLAS when that "
"is faster."));
"[Stable] Whether to allow GEMM fusion autotuning to fall back "
"to cuBLAS when it is faster than Triton."));
flag_list->push_back(tsl::Flag(
"xla_gpu_cudnn_gemm_fusion_level",
int32_setter_for(&DebugOptions::set_xla_gpu_cudnn_gemm_fusion_level),
Expand All @@ -1996,7 +1999,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_while_loop_double_buffering),
debug_options->xla_gpu_enable_while_loop_double_buffering(),
"Enable double buffering for while loop"));
"[Stable] Enable double buffering for while loop"));
flag_list->push_back(tsl::Flag(
"xla_gpu_filter_kernels_spilling_registers_on_autotuning",
bool_setter_for(
Expand Down Expand Up @@ -2287,7 +2290,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"xla_gpu_dot_merger_threshold_mb",
int32_setter_for(&DebugOptions::set_xla_gpu_dot_merger_threshold_mb),
debug_options->xla_gpu_dot_merger_threshold_mb(),
"Dot merger pass threshold to be set in MB."));
"[Stable] Dot merger pass threshold to be set in MB."));
flag_list->push_back(
tsl::Flag("xla_enable_fast_math",
bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
Expand Down Expand Up @@ -2548,6 +2551,19 @@ FlagStatus GetFlagStatus(absl::string_view flag_name) {
static const absl::NoDestructor<absl::flat_hash_set<std::string>>
kStableFlags({
// go/keep-sorted start
"xla_gpu_all_reduce_combine_threshold_bytes",
"xla_gpu_autotune_level",
"xla_gpu_collective_permute_decomposer_threshold",
"xla_gpu_cublas_fallback", "xla_gpu_dot_merger_threshold_mb",
"xla_gpu_enable_dynamic_slice_fusion",
"xla_gpu_enable_latency_hiding_scheduler",
"xla_gpu_enable_pipelined_all_gather",
"xla_gpu_enable_pipelined_all_reduce",
"xla_gpu_enable_pipelined_reduce_scatter",
"xla_gpu_enable_triton_gemm",
"xla_gpu_enable_while_loop_double_buffering",
"xla_gpu_exhaustive_tiling_search",
"xla_gpu_reduce_scatter_combine_threshold_bytes",
// go/keep-sorted end
});
static const absl::NoDestructor<absl::flat_hash_set<std::string>>
Expand Down