diff --git a/tensorflow_addons/metrics/cohens_kappa.py b/tensorflow_addons/metrics/cohens_kappa.py index 40a7680922..9ea4e8f6bc 100644 --- a/tensorflow_addons/metrics/cohens_kappa.py +++ b/tensorflow_addons/metrics/cohens_kappa.py @@ -38,117 +38,117 @@ class CohenKappa(Metric): while calculating the Cohen's Kappa score. Usage: + ```python actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32) preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32) weights = np.array([1, 1, 2, 5, 10, 2, 3, 3], dtype=np.int32) - + m = tfa.metrics.CohenKappa(num_classes=5) m.update_state(actuals, preds) print('Final result: ', m.result().numpy()) # Result: 0.61904764 - + # To use this with weights, sample_weight argument can be used. m = tfa.metrics.CohenKappa(num_classes=5) m.update_state(actuals, preds, sample_weight=weights) print('Final result: ', m.result().numpy()) # Result: 0.37209308 ``` + Usage with tf.keras API: + ```python - model = keras.models.Model(inputs, outputs) + model = tf.keras.models.Model(inputs, outputs) model.add_metric(tfa.metrics.CohenKappa(num_classes=5)(outputs)) model.compile('sgd', loss='mse') ``` - - Args: - num_classes : Number of unique classes in your dataset - weightage : Weighting to be considered for calculating - kappa statistics. A valid value is one of - [None, 'linear', 'quadratic']. Defaults to None. - - Returns: - kappa_score : float - The kappa statistic, which is a number between -1 and 1. The maximum - value means complete agreement; zero or lower means chance agreement. - - Raises: - ValueError: If the value passed for `weightage` is invalid - i.e. not any one of [None, 'linear', 'quadratic'] """ def __init__(self, num_classes, name='cohen_kappa', weightage=None, - dtype=tf.float32): + dtype=None): + """Creates a `CohenKappa` instance. + + Args: + num_classes: Number of unique classes in your dataset. + name: (Optional) String name of the metric instance. + weightage: (Optional) Weighting to be considered for calculating + kappa statistics. A valid value is one of + [None, 'linear', 'quadratic']. Defaults to `None`. + dtype: (Optional) Data type of the metric result. + Defaults to `None`. + + Raises: + ValueError: If the value passed for `weightage` is invalid + i.e. not any one of [None, 'linear', 'quadratic'] + """ super(CohenKappa, self).__init__(name=name, dtype=dtype) if weightage not in (None, 'linear', 'quadratic'): raise ValueError("Unknown kappa weighting type.") - else: - self.weightage = weightage + self.weightage = weightage self.num_classes = num_classes self.conf_mtx = self.add_weight( 'conf_mtx', shape=(self.num_classes, self.num_classes), initializer=tf.keras.initializers.zeros, - dtype=tf.int32) + dtype=tf.int64) def update_state(self, y_true, y_pred, sample_weight=None): """Accumulates the confusion matrix condition statistics. Args: - y_true : array, shape = [n_samples] - Labels assigned by the first annotator. - y_pred : array, shape = [n_samples] - Labels assigned by the second annotator. The kappa statistic - is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't - change the value. - sample_weight(optional) : for weighting labels in confusion matrix - Default is None. The dtype for weights should be the same - as the dtype for confusion matrix. For more details, - please check tf.math.confusion_matrix. - + y_true: Labels assigned by the first annotator with shape + `[num_samples,]`. + y_pred: Labels assigned by the second annotator with shape + `[num_samples,]`. The kappa statistic is symmetric, + so swapping `y_true` and `y_pred` doesn't change the value. + sample_weight (optional): for weighting labels in confusion matrix + Defaults to `None`. The dtype for weights should be the same + as the dtype for confusion matrix. For more details, + please check `tf.math.confusion_matrix`. Returns: Update op. """ - y_true = tf.cast(y_true, dtype=tf.int32) - y_pred = tf.cast(y_pred, dtype=tf.int32) + y_true = tf.cast(y_true, dtype=tf.int64) + y_pred = tf.cast(y_pred, dtype=tf.int64) if y_true.shape != y_pred.shape: raise ValueError( - "Number of samples in y_true and y_pred are different") + "Number of samples in `y_true` and `y_pred` are different") # compute the new values of the confusion matrix new_conf_mtx = tf.math.confusion_matrix( labels=y_true, predictions=y_pred, num_classes=self.num_classes, - weights=sample_weight) + weights=sample_weight, + dtype=tf.int64) # update the values in the original confusion matrix return self.conf_mtx.assign_add(new_conf_mtx) def result(self): nb_ratings = tf.shape(self.conf_mtx)[0] - weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32) + weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int64) # 2. Create a weight matrix if self.weightage is None: - diagonal = tf.zeros([nb_ratings], dtype=tf.int32) + diagonal = tf.zeros([nb_ratings], dtype=tf.int64) weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal) - weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) - else: - weight_mtx += tf.range(nb_ratings, dtype=tf.int32) - weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) + weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.int64) + weight_mtx = tf.cast(weight_mtx, dtype=self.dtype) if self.weightage == 'linear': weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx)) else: weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2) - weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) + + weight_mtx = tf.cast(weight_mtx, dtype=self.dtype) # 3. Get counts actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1) @@ -162,8 +162,8 @@ def result(self): conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx) out_prod = out_prod / tf.reduce_sum(out_prod) - conf_mtx = tf.cast(conf_mtx, dtype=tf.float32) - out_prod = tf.cast(out_prod, dtype=tf.float32) + conf_mtx = tf.cast(conf_mtx, dtype=self.dtype) + out_prod = tf.cast(out_prod, dtype=self.dtype) # 6. Calculate Kappa score numerator = tf.reduce_sum(conf_mtx * weight_mtx) @@ -186,4 +186,6 @@ def reset_states(self): for v in self.variables: K.set_value( - v, np.zeros((self.num_classes, self.num_classes), np.int32)) + v, + np.zeros((self.num_classes, self.num_classes), + v.dtype.as_numpy_dtype)) diff --git a/tensorflow_addons/metrics/cohens_kappa_test.py b/tensorflow_addons/metrics/cohens_kappa_test.py index 4793bd5c7a..efb33e4fc0 100644 --- a/tensorflow_addons/metrics/cohens_kappa_test.py +++ b/tensorflow_addons/metrics/cohens_kappa_test.py @@ -130,6 +130,16 @@ def test_kappa_with_sample_weights(self): self.check_results([kp_obj1, kp_obj2, kp_obj3], [-0.25473321, -0.38992332, -0.60695344]) + def test_large_values(self): + y_true = [1] * 10000 + [0] * 20000 + [1] * 20000 + y_pred = [0] * 20000 + [1] * 30000 + + obj = CohenKappa(num_classes=2) + self.evaluate(tf.compat.v1.variables_initializer(obj.variables)) + + self.evaluate(obj.update_state(y_true, y_pred)) + self.assertAllClose(0.166666666, obj.result()) + if __name__ == '__main__': tf.test.main()