diff --git a/tensorflow_addons/metrics/cohens_kappa.py b/tensorflow_addons/metrics/cohens_kappa.py index a29d13efd1..a60c0b0553 100644 --- a/tensorflow_addons/metrics/cohens_kappa.py +++ b/tensorflow_addons/metrics/cohens_kappa.py @@ -67,18 +67,26 @@ def __init__(self, num_classes: FloatTensorLike, name: str = 'cohen_kappa', weightage: Optional[str] = None, + sparse_labels: bool = False, + regression: bool = False, dtype: AcceptableDTypes = None, **kwargs): """Creates a `CohenKappa` instance. Args: num_classes: Number of unique classes in your dataset. - name: (Optional) String name of the metric instance. - weightage: (Optional) Weighting to be considered for calculating + weightage: (optional) Weighting to be considered for calculating kappa statistics. A valid value is one of - [None, 'linear', 'quadratic']. Defaults to `None`. - dtype: (Optional) Data type of the metric result. - Defaults to `None`. + [None, 'linear', 'quadratic']. Defaults to `None` + sparse_lables: (bool) Valid only for multi-class scenario. + If True, ground truth labels are expected tp be integers + and not one-hot encoded + regression: (bool) If set, that means the problem is being treated + as a regression problem where you are regressing the predictions. + **Note:** If you are regressing for the values, the the output layer + should contain a single unit. + name: (optional) String name of the metric instance + dtype: (optional) Data type of the metric result. Defaults to `None` Raises: ValueError: If the value passed for `weightage` is invalid @@ -89,8 +97,18 @@ def __init__(self, if weightage not in (None, 'linear', 'quadratic'): raise ValueError("Unknown kappa weighting type.") + if num_classes == 2: + self._update = self._update_binary_class_model + elif num_classes > 2: + self._update = self._update_multi_class_model + else: + raise ValueError("""Number of classes must be + greater than or euqal to two""") + self.weightage = weightage self.num_classes = num_classes + self.regression = regression + self.sparse_labels = sparse_labels self.conf_mtx = self.add_weight( 'conf_mtx', shape=(self.num_classes, self.num_classes), @@ -114,14 +132,35 @@ def update_state(self, y_true, y_pred, sample_weight=None): Returns: Update op. """ + return self._update(y_true, y_pred, sample_weight) + + def _update_binary_class_model(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, dtype=tf.int64) - y_pred = tf.cast(y_pred, dtype=tf.int64) + y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = tf.cast(y_pred > 0.5, dtype=tf.int64) + return self._update_confusion_matrix(y_true, y_pred, sample_weight) + + def _update_multi_class_model(self, y_true, y_pred, sample_weight=None): + if not self.sparse_labels: + y_true = tf.cast(tf.argmax(y_true, axis=-1), dtype=tf.int64) + else: + y_true = tf.cast(y_true, dtype=tf.int64) + + if tf.rank(y_pred) > 1: + if not self.regression: + y_pred = tf.cast(tf.argmax(y_pred, axis=-1), dtype=tf.int64) + else: + y_pred = tf.math.round(tf.math.abs(y_pred)) + y_pred = tf.cast(y_pred, dtype=tf.int64) + else: + y_pred = tf.cast(y_pred, dtype=tf.int64) + + return self._update_confusion_matrix(y_true, y_pred, sample_weight) - if y_true.shape != y_pred.shape: - raise ValueError( - "Number of samples in `y_true` and `y_pred` are different") + def _update_confusion_matrix(self, y_true, y_pred, sample_weight): + y_true = tf.squeeze(y_true) + y_pred = tf.squeeze(y_pred) - # compute the new values of the confusion matrix new_conf_mtx = tf.math.confusion_matrix( labels=y_true, predictions=y_pred, @@ -129,7 +168,6 @@ def update_state(self, y_true, y_pred, sample_weight=None): weights=sample_weight, dtype=tf.float32) - # update the values in the original confusion matrix return self.conf_mtx.assign_add(new_conf_mtx) def result(self): @@ -179,6 +217,8 @@ def get_config(self): config = { "num_classes": self.num_classes, "weightage": self.weightage, + "sparse_labels": self.sparse_labels, + "regression": self.regression } base_config = super().get_config() return {**base_config, **config} diff --git a/tensorflow_addons/metrics/cohens_kappa_test.py b/tensorflow_addons/metrics/cohens_kappa_test.py index e3c7227f46..2c6b4b2281 100644 --- a/tensorflow_addons/metrics/cohens_kappa_test.py +++ b/tensorflow_addons/metrics/cohens_kappa_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for Cohen's Kappa Metric.""" +import numpy as np import tensorflow as tf from tensorflow_addons.metrics import CohenKappa from tensorflow_addons.utils import test_utils @@ -34,9 +35,9 @@ def test_config(self): self.assertEqual(kp_obj.num_classes, 5) def initialize_vars(self): - kp_obj1 = CohenKappa(num_classes=5) - kp_obj2 = CohenKappa(num_classes=5, weightage="linear") - kp_obj3 = CohenKappa(num_classes=5, weightage="quadratic") + kp_obj1 = CohenKappa(num_classes=5, sparse_labels=True) + kp_obj2 = CohenKappa(num_classes=5, sparse_labels=True, weightage="linear") + kp_obj3 = CohenKappa(num_classes=5, sparse_labels=True, weightage="quadratic") self.evaluate(tf.compat.v1.variables_initializer(kp_obj1.variables)) self.evaluate(tf.compat.v1.variables_initializer(kp_obj2.variables)) @@ -147,12 +148,85 @@ def test_large_values(self): y_true = [1] * 10000 + [0] * 20000 + [1] * 20000 y_pred = [0] * 20000 + [1] * 30000 + y_true = tf.convert_to_tensor(y_true) + y_pred = tf.convert_to_tensor(y_pred) + obj = CohenKappa(num_classes=2) self.evaluate(tf.compat.v1.variables_initializer(obj.variables)) self.evaluate(obj.update_state(y_true, y_pred)) self.assertAllClose(0.166666666, obj.result()) + def test_with_sparse_labels(self): + y_true = np.array([4, 4, 3, 4], dtype=np.int32) + y_pred = np.array([4, 4, 1, 2], dtype=np.int32) + + obj = CohenKappa(num_classes=5, sparse_labels=True) + self.evaluate(tf.compat.v1.variables_initializer(obj.variables)) + + self.evaluate(obj.update_state(y_true, y_pred)) + self.assertAllClose(0.19999999, obj.result()) + + def test_with_ohe_labels(self): + y_true = np.array([4, 4, 3, 4], dtype=np.int32) + y_true = tf.keras.utils.to_categorical(y_true, num_classes=5) + y_pred = np.array([4, 4, 1, 2], dtype=np.int32) + + obj = CohenKappa(num_classes=5, sparse_labels=False) + self.evaluate(tf.compat.v1.variables_initializer(obj.variables)) + + self.evaluate(obj.update_state(y_true, y_pred)) + self.assertAllClose(0.19999999, obj.result()) + + def test_keras_binary_reg_model(self): + kp = CohenKappa(num_classes=2) + inputs = tf.keras.layers.Input(shape=(10,)) + outputs = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile(optimizer="sgd", loss="mse", metrics=[kp]) + + x = np.random.rand(1000, 10).astype(np.float32) + y = np.random.randint(2, size=(1000, 1)).astype(np.float32) + + model.fit(x, y, epochs=1, verbose=0, batch_size=32) + + def test_keras_multiclass_reg_model(self): + kp = CohenKappa(num_classes=5, regression=True, sparse_labels=True) + inputs = tf.keras.layers.Input(shape=(10,)) + outputs = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile(optimizer="sgd", loss="mse", metrics=[kp]) + + x = np.random.rand(1000, 10).astype(np.float32) + y = np.random.randint(5, size=(1000,)).astype(np.float32) + + model.fit(x, y, epochs=1, verbose=0, batch_size=32) + + def test_keras_binary_clasasification_model(self): + kp = CohenKappa(num_classes=2) + inputs = tf.keras.layers.Input(shape=(10,)) + outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile(optimizer="sgd", loss="binary_crossentropy", metrics=[kp]) + + x = np.random.rand(1000, 10).astype(np.float32) + y = np.random.randint(2, size=(1000, 1)).astype(np.float32) + + model.fit(x, y, epochs=1, verbose=0, batch_size=32) + + def test_keras_multiclass_classification_model(self): + kp = CohenKappa(num_classes=5) + inputs = tf.keras.layers.Input(shape=(10,)) + outputs = tf.keras.layers.Dense(5, activation="softmax")(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=[kp]) + + x = np.random.rand(1000, 10).astype(np.float32) + y = np.random.randint(5, size=(1000,)).astype(np.float32) + y = tf.keras.utils.to_categorical(y, num_classes=5) + + model.fit(x, y, epochs=1, verbose=0, batch_size=32) + if __name__ == "__main__": tf.test.main()