@@ -38,6 +38,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
3838 Returns:
3939 sequence_scores: A [batch_size] vector of unnormalized sequence scores.
4040 """
41+ tag_indices = tf .cast (tag_indices , dtype = tf .int32 )
42+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
4143
4244 # If max_seq_len is 1, we skip the score calculation and simply gather the
4345 # unary potentials of the single tag.
@@ -92,19 +94,19 @@ def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
9294 Returns:
9395 sequence_scores: A [batch_size] vector of unnormalized sequence scores.
9496 """
97+ tag_bitmap = tf .cast (tag_bitmap , dtype = tf .bool )
98+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
99+ filtered_inputs = tf .where (tag_bitmap , inputs ,
100+ tf .fill (tf .shape (inputs ), float ("-inf" )))
95101
96102 # If max_seq_len is 1, we skip the score calculation and simply gather the
97103 # unary potentials of all active tags.
98104 def _single_seq_fn ():
99- filtered_inputs = tf .where (tag_bitmap , inputs ,
100- tf .fill (tf .shape (inputs ), float ("-inf" )))
101105 return tf .reduce_logsumexp (
102106 filtered_inputs , axis = [1 , 2 ], keepdims = False )
103107
104108 def _multi_seq_fn ():
105109 # Compute the logsumexp of all scores of sequences matching the given tags.
106- filtered_inputs = tf .where (tag_bitmap , inputs ,
107- tf .fill (tf .shape (inputs ), float ("-inf" )))
108110 return crf_log_norm (
109111 inputs = filtered_inputs ,
110112 sequence_lengths = sequence_lengths ,
@@ -127,6 +129,7 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
127129 Returns:
128130 log_norm: A [batch_size] vector of normalizers for a CRF.
129131 """
132+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
130133 # Split up the first and rest of the inputs in preparation for the forward
131134 # algorithm.
132135 first_input = tf .slice (inputs , [0 , 0 , 0 ], [- 1 , 1 , - 1 ])
@@ -183,10 +186,12 @@ def crf_log_likelihood(inputs,
183186 transition_params: A [num_tags, num_tags] transition matrix. This is
184187 either provided by the caller or created in this function.
185188 """
186- # Get shape information.
187189 num_tags = inputs .shape [2 ]
188190
189- # Get the transition matrix if not provided.
191+ # cast type to handle different types
192+ tag_indices = tf .cast (tag_indices , dtype = tf .int32 )
193+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
194+
190195 if transition_params is None :
191196 initializer = tf .keras .initializers .GlorotUniform ()
192197 transition_params = tf .Variable (
@@ -211,6 +216,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
211216 Returns:
212217 unary_scores: A [batch_size] vector of unary scores.
213218 """
219+ tag_indices = tf .cast (tag_indices , dtype = tf .int32 )
220+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
221+
214222 batch_size = tf .shape (inputs )[0 ]
215223 max_seq_len = tf .shape (inputs )[1 ]
216224 num_tags = tf .shape (inputs )[2 ]
@@ -245,7 +253,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
245253 Returns:
246254 binary_scores: A [batch_size] vector of binary scores.
247255 """
248- # Get shape information.
256+ tag_indices = tf .cast (tag_indices , dtype = tf .int32 )
257+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
258+
249259 num_tags = tf .shape (transition_params )[0 ]
250260 num_transitions = tf .shape (tag_indices )[1 ] - 1
251261
@@ -288,6 +298,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
288298 new_alphas: A [batch_size, num_tags] matrix containing the
289299 new alpha values.
290300 """
301+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
291302
292303 sequence_lengths = tf .maximum (
293304 tf .constant (0 , dtype = sequence_lengths .dtype ), sequence_lengths - 2 )
@@ -399,6 +410,7 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
399410 backpointers: A [batch_size, num_tags] matrix of backpointers.
400411 new_state: A [batch_size, num_tags] matrix of new score values.
401412 """
413+ sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
402414 mask = tf .sequence_mask (sequence_lengths , tf .shape (inputs )[1 ])
403415 crf_fwd_cell = CrfDecodeForwardRnnCell (transition_params )
404416 crf_fwd_layer = tf .keras .layers .RNN (
@@ -446,6 +458,7 @@ def crf_decode(potentials, transition_params, sequence_length):
446458 Contains the highest scoring tag indices.
447459 best_score: A [batch_size] vector, containing the score of `decode_tags`.
448460 """
461+ sequence_length = tf .cast (sequence_length , dtype = tf .int32 )
449462
450463 # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
451464 # and the max activation.
0 commit comments