-
Notifications
You must be signed in to change notification settings - Fork 74.8k
[determinism] Add softmax/cross-entropy op exceptions for GPU determinism #47925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[determinism] Add softmax/cross-entropy op exceptions for GPU determinism #47925
Conversation
@sanjoy: It would be so cool for this one to get into TensorFlow version 2.5 as well. :-) |
tensorflow/core/kernels/xent_op.cc
Outdated
(!RequireDeterminism() || | ||
DisableSoftmaxXentWithLogitsOpDeterminismExceptions()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary parens.
@@ -58,6 +86,17 @@ class SoftmaxXentWithLogitsOp : public OpKernel { | |||
"2-dimensional, or broadcasted to be " | |||
"2-dimensional")); | |||
|
|||
if (std::is_same<Device, GPUDevice>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the CPU implementation deterministic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just confirmed that the CPU implementation is deterministic (thanks @wenscarl). I'm considering adding tests to prove/confirm/ensure that in a future PR.
@sanjoy, please will it be possible to get this merged before the version 2.5 branch is cut, which I believe will be on March 25 (tomorrow)? |
I just approved it, but I can't guarantee that it will make it through the merge process before the branch cut. |
My spam filter was acting up and I missed your comment, @sanjoy, sorry. I also missed the internal checks failure. This PR didn't make it into 2.5. Please will you let me know what the |
Looks like some internal tooling failure, I'll merge it manually. Unfortunately this won't make 2.5, as you noted. |
High-Level Summary
This current PR adds and tests the following functionality:
When the environment variable
TF_DETERMINISTIC_OPS
is set to"true"
or"1"
, an attempt to run the following ops on a GPU will throwtf.errors.UnimplementedError
(with an understandable message).tf.nn.softmax_cross_entropy_with_logits
tf.nn.sparse_softmax_cross_entropy_with_logits
Please see RFC: Enhancing determinism in TF (being added via tensorflow/community PR 346).
Additional Notes
Data Types
The exceptions will be thrown for all currently GPU-supported data types for the
logits
input:tf.float16
andtf.float32
for both ops, and, additionally,tf.float64
fortf.nn.softmax_cross_entropy_with_logits
.Exception-throwing for all combinations of relevant data types for
logits
andlabels
(tf.int32
andtf.int64
) are tested in both eager and graph mode when the op is used in the forward direction.Forward vs Backward
It is currently suspected that the introduction of random noise into the gradients passed backwards from this op actually originate in the forward path algorithm, but the backward path algorithm might add additional noise. However, the backprop path for this op is not, and cannot, be used without the forward path algorithm also being used (due to this being a loss function). Therefore, the presence of exception-throwing on the backward path specifically is not necessary and is not implemented or tested by this current PR.
When these ops have a fully deterministic mode of operation, the bit-exact reproducibility of the outputs of both the forward and backward paths of the ops should be verified.
XLA
The tests will not be run with XLA auto-jit enabled because any XLA implementation of these ops will not throw these exceptions.
When a fully deterministic mode for these ops is implemented, the bit-exact reproducibility of the outputs of both the forward and backward paths of the ops should be verified both with and without XLA auto-jit enabled.