Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self,
'conf_mtx',
shape=(self.num_classes, self.num_classes),
initializer=tf.keras.initializers.zeros,
dtype=tf.int64)
dtype=tf.float32)

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix condition statistics.
Expand Down Expand Up @@ -125,21 +125,21 @@ def update_state(self, y_true, y_pred, sample_weight=None):
predictions=y_pred,
num_classes=self.num_classes,
weights=sample_weight,
dtype=tf.int64)
dtype=tf.float32)

# update the values in the original confusion matrix
return self.conf_mtx.assign_add(new_conf_mtx)

def result(self):
nb_ratings = tf.shape(self.conf_mtx)[0]
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int64)
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.float32)

# 2. Create a weight matrix
if self.weightage is None:
diagonal = tf.zeros([nb_ratings], dtype=tf.int64)
diagonal = tf.zeros([nb_ratings], dtype=tf.float32)
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
else:
weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.int64)
weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.float32)
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)

if self.weightage == 'linear':
Expand Down Expand Up @@ -167,8 +167,10 @@ def result(self):
# 6. Calculate Kappa score
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
denominator = tf.reduce_sum(out_prod * weight_mtx)
kp = 1 - (numerator / denominator)
return kp
return tf.cond(
tf.math.is_nan(denominator),
true_fn=lambda: 0.0,
false_fn=lambda: 1 - (numerator / denominator))

def get_config(self):
"""Returns the serializable config of the metric."""
Expand Down
15 changes: 15 additions & 0 deletions tensorflow_addons/metrics/cohens_kappa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def update_obj_states(self, obj1, obj2, obj3, actuals, preds, weights):
self.evaluate(update_op2)
self.evaluate(update_op3)

def reset_obj_states(self, obj1, obj2, obj3):
obj1.reset_states()
obj2.reset_states()
obj3.reset_states()

def check_results(self, objs, values):
obj1, obj2, obj3 = objs
val1, val2, val3 = values
Expand Down Expand Up @@ -130,6 +135,16 @@ def test_kappa_with_sample_weights(self):
self.check_results([kp_obj1, kp_obj2, kp_obj3],
[-0.25473321, -0.38992332, -0.60695344])

def test_kappa_reset_states(self):
# Initialize
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()

# reset states
self.reset_obj_states(kp_obj1, kp_obj2, kp_obj3)

# check results
self.check_results([kp_obj1, kp_obj2, kp_obj3], [0.0, 0.0, 0.0])

def test_large_values(self):
y_true = [1] * 10000 + [0] * 20000 + [1] * 20000
y_pred = [0] * 20000 + [1] * 30000
Expand Down