This repository was archived by the owner on May 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 481
This repository was archived by the owner on May 21, 2025. It is now read-only.
Clarifying global floating point policy #357
Copy link
Copy link
Open
Description
The issue of float precision affects many computations in tensorflow_ranking
, such as
ranking/tensorflow_ranking/python/metrics_impl.py
Lines 603 to 628 in a928e2b
def _compute_impl(self, labels, predictions, weights, mask): | |
"""See `_RankingMetric`.""" | |
topn = tf.shape(predictions)[1] if self._topn is None else self._topn | |
# Relevance = 1.0 when labels >= 1.0. | |
relevance = tf.cast(tf.greater_equal(labels, 1.0), dtype=tf.float32) | |
sorted_relevance, sorted_weights = utils.sort_by_scores( | |
predictions, [relevance, weights], topn=topn, mask=mask) | |
per_list_relevant_counts = tf.cumsum(sorted_relevance, axis=1) | |
per_list_cutoffs = tf.cumsum(tf.ones_like(sorted_relevance), axis=1) | |
per_list_precisions = tf.math.divide_no_nan(per_list_relevant_counts, | |
per_list_cutoffs) | |
total_precision = tf.reduce_sum( | |
input_tensor=per_list_precisions * sorted_weights * sorted_relevance, | |
axis=1, | |
keepdims=True) | |
# Compute the total relevance regardless of self._topn. | |
total_relevance = tf.reduce_sum( | |
input_tensor=weights * relevance, axis=1, keepdims=True) | |
per_list_map = tf.math.divide_no_nan(total_precision, total_relevance) | |
# per_list_weights are computed from the whole list to avoid the problem of | |
# 0 when there is no relevant example in topn. | |
per_list_weights = _per_example_weights_to_per_list_weights( | |
weights, relevance) | |
return per_list_map, per_list_weights |
This has been mentioned before in #254, but I want to elaborate on our difficulties.
This type of hardcoded dtypes makes it extremely hard to move our programs to use float64
.
For example, if we use tf.keras.backend.set_floatx('float64')
anywhere, we get errors within tensorflow_ranking
due to conflicting dtypes.
Will the global floating point policy (tf.keras.mixed_precision.set_global_policy
and tf.keras.backend.floatx
) be supported?
If the official stance on the global policy is to ignore it, can it be documented?
Metadata
Metadata
Assignees
Labels
No labels