diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index b20b0476ae..f94160b78d 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -38,6 +38,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of the single tag. @@ -92,19 +94,19 @@ def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ + tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + filtered_inputs = tf.where(tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) # If max_seq_len is 1, we skip the score calculation and simply gather the # unary potentials of all active tags. def _single_seq_fn(): - filtered_inputs = tf.where(tag_bitmap, inputs, - tf.fill(tf.shape(inputs), float("-inf"))) return tf.reduce_logsumexp( filtered_inputs, axis=[1, 2], keepdims=False) def _multi_seq_fn(): # Compute the logsumexp of all scores of sequences matching the given tags. - filtered_inputs = tf.where(tag_bitmap, inputs, - tf.fill(tf.shape(inputs), float("-inf"))) return crf_log_norm( inputs=filtered_inputs, sequence_lengths=sequence_lengths, @@ -127,6 +129,7 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): Returns: log_norm: A [batch_size] vector of normalizers for a CRF. """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) # Split up the first and rest of the inputs in preparation for the forward # algorithm. first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) @@ -183,10 +186,12 @@ def crf_log_likelihood(inputs, transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function. """ - # Get shape information. num_tags = inputs.shape[2] - # Get the transition matrix if not provided. + # cast type to handle different types + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + if transition_params is None: initializer = tf.keras.initializers.GlorotUniform() transition_params = tf.Variable( @@ -211,6 +216,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): Returns: unary_scores: A [batch_size] vector of unary scores. """ + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + batch_size = tf.shape(inputs)[0] max_seq_len = tf.shape(inputs)[1] num_tags = tf.shape(inputs)[2] @@ -245,7 +253,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): Returns: binary_scores: A [batch_size] vector of binary scores. """ - # Get shape information. + tag_indices = tf.cast(tag_indices, dtype=tf.int32) + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) + num_tags = tf.shape(transition_params)[0] num_transitions = tf.shape(tag_indices)[1] - 1 @@ -288,6 +298,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths): new_alphas: A [batch_size, num_tags] matrix containing the new alpha values. """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) sequence_lengths = tf.maximum( tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2) @@ -399,6 +410,7 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths): backpointers: A [batch_size, num_tags] matrix of backpointers. new_state: A [batch_size, num_tags] matrix of new score values. """ + sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) crf_fwd_layer = tf.keras.layers.RNN( @@ -446,6 +458,7 @@ def crf_decode(potentials, transition_params, sequence_length): Contains the highest scoring tag indices. best_score: A [batch_size] vector, containing the score of `decode_tags`. """ + sequence_length = tf.cast(sequence_length, dtype=tf.int32) # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag # and the max activation. diff --git a/tensorflow_addons/text/crf_test.py b/tensorflow_addons/text/crf_test.py index 1c76d0b0ec..0848042022 100644 --- a/tensorflow_addons/text/crf_test.py +++ b/tensorflow_addons/text/crf_test.py @@ -209,8 +209,11 @@ def testCrfLogLikelihood(self): transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) + # TODO: https://github.com/PyCQA/pylint/issues/3139 + # pylint: disable=E1136 num_words = inputs.shape[0] num_tags = inputs.shape[1] + # pylint: enable=E1136 all_sequence_log_likelihoods = [] # Make sure all probabilities sum to 1. @@ -241,8 +244,11 @@ def testViterbiDecode(self): transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) + # TODO: https://github.com/PyCQA/pylint/issues/3139 + # pylint: disable=E1136 num_words = inputs.shape[0] num_tags = inputs.shape[1] + # pylint: enable=E1136 all_sequence_scores = [] all_sequences = [] @@ -347,6 +353,16 @@ def testCrfDecodeZeroSeqLength(self): self.assertEqual(len(tf_tags.shape), 2) self.assertEqual(len(tf_scores.shape), 1) + def testDifferentDtype(self): + inputs = np.ones([16, 20, 5], dtype=np.float32) + tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64)) + seq_lens = np.ones([ + 16, + ], dtype=np.int64) * 20 + + loss, _ = text.crf_log_likelihood( + inputs=inputs, tag_indices=tags, sequence_lengths=seq_lens) + if __name__ == "__main__": tf.test.main()