这是indexloc提供的服务,不要输入任何密码
Skip to content

graph execution error bug with tfm.nlp.layers.MultiHeadRelativeAttention #94599

@Thorballer

Description

@Thorballer

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

2.19

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

P100

Current behavior?

I'm creating a transformer encoder layer, and I was trying to add positional encoding, but I always encounter a graph execution error when using tfm.nlp.layers.MultiHeadRelativeAttention. Every time I use it, I get the posted error, and it may be an issue with how batches are being processed in the layer. The layer is really experimental, but I have tried many ways of getting around the error, but the error seems to persist.

Standalone code to reproduce the issue

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.key_dim = d_model // num_heads

        # Attention layer
        self.att = tfm.nlp.layers.MultiHeadRelativeAttention(
            num_heads=num_heads,
            key_dim=self.key_dim
        )

        # Trainable bias parameters with correct shape
        self.content_bias = self.add_weight(
            name='content_bias',
            shape=[1, 1, num_heads, self.key_dim],  # [1, 1, H, Dk]
            initializer='zeros',
            trainable=True
        )
        self.position_bias = self.add_weight(
            name='position_bias',
            shape=[1, 1, num_heads, self.key_dim],  # [1, 1, H, Dk]
            initializer='zeros',
            trainable=True
        )

        # Rest of the network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model)
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.rel_pos_encode = tfm.nlp.layers.RelativePositionEmbedding(d_model)

    def call(self, x, training=False, mask=None):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[1]
        H = self.num_heads
        Dk = self.key_dim

        # 1) Prepare biases with correct shape [B, L, H, Dk]
        content_attention_bias = tf.tile(self.content_bias, [batch_size, seq_len, 1, 1])
        positional_attention_bias = tf.tile(self.position_bias, [batch_size, seq_len, 1, 1])

        # 2) Generate relative position encoding [B*H, 2*L-1, Dk]
        rel_len = 2 * seq_len - 1
        rel_embedding = self.rel_pos_encode(inputs=None, length=rel_len)
        rel_embedding = tf.reshape(rel_embedding, [rel_len, H, Dk])
        rel_embedding = tf.transpose(rel_embedding, [1, 0, 2])  # [H, 2*L-1, Dk]
        rel_embedding = tf.tile(rel_embedding, [batch_size, 1, 1])  # [B*H, 2*L-1, Dk]

        # 3) Call attention with properly shaped biases
        attn_output = self.att(
            query=x,
            value=x,
            content_attention_bias=content_attention_bias,
            positional_attention_bias=positional_attention_bias,
            relative_position_encoding=rel_embedding,
            attention_mask=mask
        )

        # 4) Standard transformer operations
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

Relevant log output

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/tmp/ipykernel_35/1955170888.py in <cell line: 0>()
----> 1 history = model.fit(
      2     train,
      3     validation_data=val,
      4     epochs=50,
      5 )

/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57       e.message += " name: " + name
     58     raise core._status_to_exception(e) from None
---> 59   except TypeError as e:
     60     keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
     61     if keras_symbolic_tensors:

InvalidArgumentError: Graph execution error:

Detected at node gradient_tape/improved_transformer_23_1/transformer_block_45_1/multi_head_relative_attention_44/add_2/BroadcastGradientArgs defined at (most recent call last):
<stack traces unavailable>
Incompatible shapes: [8,4,512,512] vs. [32,4,512,512]

Stack trace for op definition: 
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 712, in start
File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 205, in start
File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever
File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once
File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 499, in process_one
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request
File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute
File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "/tmp/ipykernel_35/1955170888.py", line 1, in <cell line: 0>
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 371, in fit
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 219, in function
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 132, in multi_step_on_iterator
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 113, in one_step_on_data
File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 77, in train_step

	 [[{{node gradient_tape/improved_transformer_23_1/transformer_block_45_1/multi_head_relative_attention_44/add_2/BroadcastGradientArgs}}]]
	tf2xla conversion failed while converting __inference_one_step_on_data_322878[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
	 [[StatefulPartitionedCall]] [Op:__inference_multi_step_on_iterator_323180]

Metadata

Metadata

Assignees

Labels

TF 2.19comp:kerasKeras related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions