-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Open
Labels
TF 2.19comp:datatf.data related issuestf.data related issuescomp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:performancePerformance IssuePerformance Issue
Description
Issue type
Performance
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
binary
TensorFlow version
tf 2.19.0
Custom code
No
OS platform and distribution
RHEL 9.4
Python version
3.11
CUDA/cuDNN version
12.5
Current behavior?
MemcpyH2D does not overlap with model computation when using tf.data.experimental.prefetch_to_device inside tf.distribute.MirroredStrategy.distribute_datasets_from_function. I would expect this operations to overlap.
Standalone code to reproduce the issue
import tensorflow as tf
class Model(tf.keras.Model):
def call(self, x):
y = x / 1000
for i in range(3):
y = tf.matmul(y, x / 1000)
return tf.reduce_sum(y, axis=[1, 2])
def get_dataset(ictx):
ds = tf.data.Dataset.range(1, 1001, output_type=tf.float32)
ds = ds.map(lambda i: (tf.ones((1024 * 5, 1024 * 5)) / i, 0.0))
ds = ds.batch(8)
ds = ds.apply(tf.data.experimental.prefetch_to_device('gpu'))
return ds
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
ds = strategy.distribute_datasets_from_function(get_dataset)
model = Model()
model.compile(loss='mse')
model.fit(
ds,
epochs=1,
steps_per_epoch=30,
callbacks=tf.keras.callbacks.TensorBoard(profile_batch=(15, 25)))
Metadata
Metadata
Assignees
Labels
TF 2.19comp:datatf.data related issuestf.data related issuescomp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:performancePerformance IssuePerformance Issue