Skip to content

Commit 1e86eec

Browse files
Squadrickseanpmorgan
authored andcommitted
FIX: Internally cast to required DType (#659)
* FIX: Internally cast to required DType * Add pylint disable for unscriptable-object bug
1 parent c3aba08 commit 1e86eec

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

tensorflow_addons/text/crf.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
3838
Returns:
3939
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
4040
"""
41+
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
42+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
4143

4244
# If max_seq_len is 1, we skip the score calculation and simply gather the
4345
# unary potentials of the single tag.
@@ -92,19 +94,19 @@ def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
9294
Returns:
9395
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
9496
"""
97+
tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
98+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
99+
filtered_inputs = tf.where(tag_bitmap, inputs,
100+
tf.fill(tf.shape(inputs), float("-inf")))
95101

96102
# If max_seq_len is 1, we skip the score calculation and simply gather the
97103
# unary potentials of all active tags.
98104
def _single_seq_fn():
99-
filtered_inputs = tf.where(tag_bitmap, inputs,
100-
tf.fill(tf.shape(inputs), float("-inf")))
101105
return tf.reduce_logsumexp(
102106
filtered_inputs, axis=[1, 2], keepdims=False)
103107

104108
def _multi_seq_fn():
105109
# Compute the logsumexp of all scores of sequences matching the given tags.
106-
filtered_inputs = tf.where(tag_bitmap, inputs,
107-
tf.fill(tf.shape(inputs), float("-inf")))
108110
return crf_log_norm(
109111
inputs=filtered_inputs,
110112
sequence_lengths=sequence_lengths,
@@ -127,6 +129,7 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
127129
Returns:
128130
log_norm: A [batch_size] vector of normalizers for a CRF.
129131
"""
132+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
130133
# Split up the first and rest of the inputs in preparation for the forward
131134
# algorithm.
132135
first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
@@ -183,10 +186,12 @@ def crf_log_likelihood(inputs,
183186
transition_params: A [num_tags, num_tags] transition matrix. This is
184187
either provided by the caller or created in this function.
185188
"""
186-
# Get shape information.
187189
num_tags = inputs.shape[2]
188190

189-
# Get the transition matrix if not provided.
191+
# cast type to handle different types
192+
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
193+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
194+
190195
if transition_params is None:
191196
initializer = tf.keras.initializers.GlorotUniform()
192197
transition_params = tf.Variable(
@@ -211,6 +216,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
211216
Returns:
212217
unary_scores: A [batch_size] vector of unary scores.
213218
"""
219+
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
220+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
221+
214222
batch_size = tf.shape(inputs)[0]
215223
max_seq_len = tf.shape(inputs)[1]
216224
num_tags = tf.shape(inputs)[2]
@@ -245,7 +253,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
245253
Returns:
246254
binary_scores: A [batch_size] vector of binary scores.
247255
"""
248-
# Get shape information.
256+
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
257+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
258+
249259
num_tags = tf.shape(transition_params)[0]
250260
num_transitions = tf.shape(tag_indices)[1] - 1
251261

@@ -288,6 +298,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
288298
new_alphas: A [batch_size, num_tags] matrix containing the
289299
new alpha values.
290300
"""
301+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
291302

292303
sequence_lengths = tf.maximum(
293304
tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2)
@@ -399,6 +410,7 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
399410
backpointers: A [batch_size, num_tags] matrix of backpointers.
400411
new_state: A [batch_size, num_tags] matrix of new score values.
401412
"""
413+
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
402414
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
403415
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
404416
crf_fwd_layer = tf.keras.layers.RNN(
@@ -446,6 +458,7 @@ def crf_decode(potentials, transition_params, sequence_length):
446458
Contains the highest scoring tag indices.
447459
best_score: A [batch_size] vector, containing the score of `decode_tags`.
448460
"""
461+
sequence_length = tf.cast(sequence_length, dtype=tf.int32)
449462

450463
# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
451464
# and the max activation.

tensorflow_addons/text/crf_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,11 @@ def testCrfLogLikelihood(self):
209209
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
210210
dtype=np.float32)
211211
sequence_lengths = np.array(3, dtype=np.int32)
212+
# TODO: https://github.com/PyCQA/pylint/issues/3139
213+
# pylint: disable=E1136
212214
num_words = inputs.shape[0]
213215
num_tags = inputs.shape[1]
216+
# pylint: enable=E1136
214217
all_sequence_log_likelihoods = []
215218

216219
# Make sure all probabilities sum to 1.
@@ -241,8 +244,11 @@ def testViterbiDecode(self):
241244
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
242245
dtype=np.float32)
243246
sequence_lengths = np.array(3, dtype=np.int32)
247+
# TODO: https://github.com/PyCQA/pylint/issues/3139
248+
# pylint: disable=E1136
244249
num_words = inputs.shape[0]
245250
num_tags = inputs.shape[1]
251+
# pylint: enable=E1136
246252

247253
all_sequence_scores = []
248254
all_sequences = []
@@ -347,6 +353,16 @@ def testCrfDecodeZeroSeqLength(self):
347353
self.assertEqual(len(tf_tags.shape), 2)
348354
self.assertEqual(len(tf_scores.shape), 1)
349355

356+
def testDifferentDtype(self):
357+
inputs = np.ones([16, 20, 5], dtype=np.float32)
358+
tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64))
359+
seq_lens = np.ones([
360+
16,
361+
], dtype=np.int64) * 20
362+
363+
loss, _ = text.crf_log_likelihood(
364+
inputs=inputs, tag_indices=tags, sequence_lengths=seq_lens)
365+
350366

351367
if __name__ == "__main__":
352368
tf.test.main()

0 commit comments

Comments
 (0)