Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of the single tag.
Expand Down Expand Up @@ -92,19 +94,19 @@ def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
filtered_inputs = tf.where(tag_bitmap, inputs,
tf.fill(tf.shape(inputs), float("-inf")))

# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of all active tags.
def _single_seq_fn():
filtered_inputs = tf.where(tag_bitmap, inputs,
tf.fill(tf.shape(inputs), float("-inf")))
return tf.reduce_logsumexp(
filtered_inputs, axis=[1, 2], keepdims=False)

def _multi_seq_fn():
# Compute the logsumexp of all scores of sequences matching the given tags.
filtered_inputs = tf.where(tag_bitmap, inputs,
tf.fill(tf.shape(inputs), float("-inf")))
return crf_log_norm(
inputs=filtered_inputs,
sequence_lengths=sequence_lengths,
Expand All @@ -127,6 +129,7 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
Returns:
log_norm: A [batch_size] vector of normalizers for a CRF.
"""
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
# Split up the first and rest of the inputs in preparation for the forward
# algorithm.
first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
Expand Down Expand Up @@ -183,10 +186,12 @@ def crf_log_likelihood(inputs,
transition_params: A [num_tags, num_tags] transition matrix. This is
either provided by the caller or created in this function.
"""
# Get shape information.
num_tags = inputs.shape[2]

# Get the transition matrix if not provided.
# cast type to handle different types
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

if transition_params is None:
initializer = tf.keras.initializers.GlorotUniform()
transition_params = tf.Variable(
Expand All @@ -211,6 +216,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
Returns:
unary_scores: A [batch_size] vector of unary scores.
"""
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

batch_size = tf.shape(inputs)[0]
max_seq_len = tf.shape(inputs)[1]
num_tags = tf.shape(inputs)[2]
Expand Down Expand Up @@ -245,7 +253,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
Returns:
binary_scores: A [batch_size] vector of binary scores.
"""
# Get shape information.
tag_indices = tf.cast(tag_indices, dtype=tf.int32)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

num_tags = tf.shape(transition_params)[0]
num_transitions = tf.shape(tag_indices)[1] - 1

Expand Down Expand Up @@ -288,6 +298,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths):
new_alphas: A [batch_size, num_tags] matrix containing the
new alpha values.
"""
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

sequence_lengths = tf.maximum(
tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2)
Expand Down Expand Up @@ -399,6 +410,7 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
backpointers: A [batch_size, num_tags] matrix of backpointers.
new_state: A [batch_size, num_tags] matrix of new score values.
"""
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
crf_fwd_layer = tf.keras.layers.RNN(
Expand Down Expand Up @@ -446,6 +458,7 @@ def crf_decode(potentials, transition_params, sequence_length):
Contains the highest scoring tag indices.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""
sequence_length = tf.cast(sequence_length, dtype=tf.int32)

# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
# and the max activation.
Expand Down
16 changes: 16 additions & 0 deletions tensorflow_addons/text/crf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ def testCrfLogLikelihood(self):
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
# TODO: https://github.com/PyCQA/pylint/issues/3139
# pylint: disable=E1136
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
# pylint: enable=E1136
all_sequence_log_likelihoods = []

# Make sure all probabilities sum to 1.
Expand Down Expand Up @@ -241,8 +244,11 @@ def testViterbiDecode(self):
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
# TODO: https://github.com/PyCQA/pylint/issues/3139
# pylint: disable=E1136
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
# pylint: enable=E1136

all_sequence_scores = []
all_sequences = []
Expand Down Expand Up @@ -347,6 +353,16 @@ def testCrfDecodeZeroSeqLength(self):
self.assertEqual(len(tf_tags.shape), 2)
self.assertEqual(len(tf_scores.shape), 1)

def testDifferentDtype(self):
inputs = np.ones([16, 20, 5], dtype=np.float32)
tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64))
seq_lens = np.ones([
16,
], dtype=np.int64) * 20

loss, _ = text.crf_log_likelihood(
inputs=inputs, tag_indices=tags, sequence_lengths=seq_lens)


if __name__ == "__main__":
tf.test.main()