Skip to content

Commit 92b99f1

Browse files
committed
reverting code formatting changes
1 parent ad89d27 commit 92b99f1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tensorflow_addons/losses/kappa_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def __init__(
111111
self.weight_mat = (col_mat - row_mat) ** 2
112112

113113
def call(self, y_true, y_pred):
114-
y_true = tf.cast(y_true, self.col_label_vec.dtype)
115-
y_pred = tf.cast(y_pred, self.weight_mat.dtype)
114+
y_true = tf.cast(y_true, dtype=self.col_label_vec.dtype)
115+
y_pred = tf.cast(y_pred, dtype=self.weight_mat.dtype)
116116
batch_size = tf.shape(y_true)[0]
117117
cat_labels = tf.matmul(y_true, self.col_label_vec)
118118
cat_label_mat = tf.tile(cat_labels, [1, self.num_classes])
@@ -126,7 +126,7 @@ def call(self, y_true, y_pred):
126126
pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True)
127127
w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True)
128128
denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist))
129-
denominator /= tf.cast(batch_size, denominator.dtype)
129+
denominator /= tf.cast(batch_size, dtype=denominator.dtype)
130130
loss = tf.math.divide_no_nan(numerator, denominator)
131131
return tf.math.log(loss + self.epsilon)
132132

0 commit comments

Comments
 (0)