-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Closed
Labels
TF 2.19comp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug
Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
source
TensorFlow version
2.19
Custom code
Yes
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
P100
Current behavior?
I'm creating a transformer encoder layer, and I was trying to add positional encoding, but I always encounter a graph execution error when using tfm.nlp.layers.MultiHeadRelativeAttention. Every time I use it, I get the posted error, and it may be an issue with how batches are being processed in the layer. The layer is really experimental, but I have tried many ways of getting around the error, but the error seems to persist.
Standalone code to reproduce the issue
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.key_dim = d_model // num_heads
# Attention layer
self.att = tfm.nlp.layers.MultiHeadRelativeAttention(
num_heads=num_heads,
key_dim=self.key_dim
)
# Trainable bias parameters with correct shape
self.content_bias = self.add_weight(
name='content_bias',
shape=[1, 1, num_heads, self.key_dim], # [1, 1, H, Dk]
initializer='zeros',
trainable=True
)
self.position_bias = self.add_weight(
name='position_bias',
shape=[1, 1, num_heads, self.key_dim], # [1, 1, H, Dk]
initializer='zeros',
trainable=True
)
# Rest of the network
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model)
])
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.rel_pos_encode = tfm.nlp.layers.RelativePositionEmbedding(d_model)
def call(self, x, training=False, mask=None):
batch_size = tf.shape(x)[0]
seq_len = tf.shape(x)[1]
H = self.num_heads
Dk = self.key_dim
# 1) Prepare biases with correct shape [B, L, H, Dk]
content_attention_bias = tf.tile(self.content_bias, [batch_size, seq_len, 1, 1])
positional_attention_bias = tf.tile(self.position_bias, [batch_size, seq_len, 1, 1])
# 2) Generate relative position encoding [B*H, 2*L-1, Dk]
rel_len = 2 * seq_len - 1
rel_embedding = self.rel_pos_encode(inputs=None, length=rel_len)
rel_embedding = tf.reshape(rel_embedding, [rel_len, H, Dk])
rel_embedding = tf.transpose(rel_embedding, [1, 0, 2]) # [H, 2*L-1, Dk]
rel_embedding = tf.tile(rel_embedding, [batch_size, 1, 1]) # [B*H, 2*L-1, Dk]
# 3) Call attention with properly shaped biases
attn_output = self.att(
query=x,
value=x,
content_attention_bias=content_attention_bias,
positional_attention_bias=positional_attention_bias,
relative_position_encoding=rel_embedding,
attention_mask=mask
)
# 4) Standard transformer operations
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
Relevant log output
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/tmp/ipykernel_35/1955170888.py in <cell line: 0>()
----> 1 history = model.fit(
2 train,
3 validation_data=val,
4 epochs=50,
5 )
/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 e.message += " name: " + name
58 raise core._status_to_exception(e) from None
---> 59 except TypeError as e:
60 keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
61 if keras_symbolic_tensors:
InvalidArgumentError: Graph execution error:
Detected at node gradient_tape/improved_transformer_23_1/transformer_block_45_1/multi_head_relative_attention_44/add_2/BroadcastGradientArgs defined at (most recent call last):
<stack traces unavailable>
Incompatible shapes: [8,4,512,512] vs. [32,4,512,512]
Stack trace for op definition:
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 712, in start
File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 205, in start
File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 499, in process_one
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request
File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute
File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "/tmp/ipykernel_35/1955170888.py", line 1, in <cell line: 0>
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 371, in fit
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 219, in function
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 132, in multi_step_on_iterator
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 113, in one_step_on_data
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 77, in train_step
[[{{node gradient_tape/improved_transformer_23_1/transformer_block_45_1/multi_head_relative_attention_44/add_2/BroadcastGradientArgs}}]]
tf2xla conversion failed while converting __inference_one_step_on_data_322878[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
[[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_323180]
Metadata
Metadata
Assignees
Labels
TF 2.19comp:kerasKeras related issuesKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug