-
Notifications
You must be signed in to change notification settings - Fork 292
Description
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Ubuntu 20.04
TensorFlow version and how it was installed (source or binary):
tensorflow 2.16
TensorFlow-Recommenders-Addons version and how it was installed (source or binary):
tfrs v0.7.3
Python version:
3.10
Is GPU used? (yes/no):
yes
Describe the bug
I am using tfrs with tf 2.16. I have faced this exception when model.fit(...) is called.
Exception encountered when calling Retrieval.call().
Can not convert a NoneType into a Tensor or Operation.
I have managed to localise the problem - the exception is raised after I added batch metric to the Retrieval task.
All my code was working fine on TF 2.14 but I need to remove batch metric from the task on TF 2.16.
class LogitsAccuracy(tf.keras.metrics.Accuracy):
""" Custom loss for diagonal y_true and the matrix of query-candidates scores.
def update_state(self, y_true, y_pred, sample_weight=None):
new_preds = tf.argmax(y_pred, axis=-1)
# new_trues = tf.range(tf.linalg.trace(y_true))
batch_size = tf.shape(y_pred)[0]
new_trues = tf.range(batch_size)
# Explicit casts to ensure correct dtypes
new_preds = tf.cast(new_preds, tf.int32)
new_trues = tf.cast(new_trues, tf.int32)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, tf.float32)
return super().update_state(new_trues, new_preds, sample_weight=sample_weight)
self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
# batch_metrics=[LogitsAccuracy(name='accuracy')]
)
How can I fix this problem? I have tested other standard metrics, but results were the same - raising exception.