@@ -111,8 +111,8 @@ def __init__(
111111 self .weight_mat = (col_mat - row_mat ) ** 2
112112
113113 def call (self , y_true , y_pred ):
114- y_true = tf .cast (y_true , self .col_label_vec .dtype )
115- y_pred = tf .cast (y_pred , self .weight_mat .dtype )
114+ y_true = tf .cast (y_true , dtype = self .col_label_vec .dtype )
115+ y_pred = tf .cast (y_pred , dtype = self .weight_mat .dtype )
116116 batch_size = tf .shape (y_true )[0 ]
117117 cat_labels = tf .matmul (y_true , self .col_label_vec )
118118 cat_label_mat = tf .tile (cat_labels , [1 , self .num_classes ])
@@ -126,7 +126,7 @@ def call(self, y_true, y_pred):
126126 pred_dist = tf .reduce_sum (y_pred , axis = 0 , keepdims = True )
127127 w_pred_dist = tf .matmul (self .weight_mat , pred_dist , transpose_b = True )
128128 denominator = tf .reduce_sum (tf .matmul (label_dist , w_pred_dist ))
129- denominator /= tf .cast (batch_size , denominator .dtype )
129+ denominator /= tf .cast (batch_size , dtype = denominator .dtype )
130130 loss = tf .math .divide_no_nan (numerator , denominator )
131131 return tf .math .log (loss + self .epsilon )
132132
0 commit comments