diff --git a/tensorflow_addons/metrics/cohens_kappa.py b/tensorflow_addons/metrics/cohens_kappa.py index 3f9340e5f7..2b14c20310 100644 --- a/tensorflow_addons/metrics/cohens_kappa.py +++ b/tensorflow_addons/metrics/cohens_kappa.py @@ -93,7 +93,7 @@ def __init__(self, 'conf_mtx', shape=(self.num_classes, self.num_classes), initializer=tf.keras.initializers.zeros, - dtype=tf.int64) + dtype=tf.float32) def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates the confusion matrix condition statistics. @@ -125,21 +125,21 @@ def update_state(self, y_true, y_pred, sample_weight=None): predictions=y_pred, num_classes=self.num_classes, weights=sample_weight, - dtype=tf.int64) + dtype=tf.float32) # update the values in the original confusion matrix return self.conf_mtx.assign_add(new_conf_mtx) def result(self): nb_ratings = tf.shape(self.conf_mtx)[0] - weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int64) + weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.float32) # 2. Create a weight matrix if self.weightage is None: - diagonal = tf.zeros([nb_ratings], dtype=tf.int64) + diagonal = tf.zeros([nb_ratings], dtype=tf.float32) weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal) else: - weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.int64) + weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.float32) weight_mtx = tf.cast(weight_mtx, dtype=self.dtype) if self.weightage == 'linear': @@ -167,8 +167,10 @@ def result(self): # 6. Calculate Kappa score numerator = tf.reduce_sum(conf_mtx * weight_mtx) denominator = tf.reduce_sum(out_prod * weight_mtx) - kp = 1 - (numerator / denominator) - return kp + return tf.cond( + tf.math.is_nan(denominator), + true_fn=lambda: 0.0, + false_fn=lambda: 1 - (numerator / denominator)) def get_config(self): """Returns the serializable config of the metric.""" diff --git a/tensorflow_addons/metrics/cohens_kappa_test.py b/tensorflow_addons/metrics/cohens_kappa_test.py index efb33e4fc0..40b0c9a0a1 100644 --- a/tensorflow_addons/metrics/cohens_kappa_test.py +++ b/tensorflow_addons/metrics/cohens_kappa_test.py @@ -56,6 +56,11 @@ def update_obj_states(self, obj1, obj2, obj3, actuals, preds, weights): self.evaluate(update_op2) self.evaluate(update_op3) + def reset_obj_states(self, obj1, obj2, obj3): + obj1.reset_states() + obj2.reset_states() + obj3.reset_states() + def check_results(self, objs, values): obj1, obj2, obj3 = objs val1, val2, val3 = values @@ -130,6 +135,16 @@ def test_kappa_with_sample_weights(self): self.check_results([kp_obj1, kp_obj2, kp_obj3], [-0.25473321, -0.38992332, -0.60695344]) + def test_kappa_reset_states(self): + # Initialize + kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars() + + # reset states + self.reset_obj_states(kp_obj1, kp_obj2, kp_obj3) + + # check results + self.check_results([kp_obj1, kp_obj2, kp_obj3], [0.0, 0.0, 0.0]) + def test_large_values(self): y_true = [1] * 10000 + [0] * 20000 + [1] * 20000 y_pred = [0] * 20000 + [1] * 30000