From 0ef5a9ea11df605c1663e4dd25182a67f7fe2e92 Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Thu, 13 Jun 2019 18:08:12 +0530 Subject: [PATCH 01/11] Port CRF from tf.contrib to tfa.text --- tensorflow_addons/text/BUILD | 14 + tensorflow_addons/text/README.md | 1 + tensorflow_addons/text/__init__.py | 12 + tensorflow_addons/text/crf_ops.py | 464 +++++++++++++++++++++++++ tensorflow_addons/text/crf_ops_test.py | 358 +++++++++++++++++++ 5 files changed, 849 insertions(+) create mode 100644 tensorflow_addons/text/crf_ops.py create mode 100644 tensorflow_addons/text/crf_ops_test.py diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index 4787cd8c0c..d96bdd582b 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -6,6 +6,7 @@ py_library( name = "text", srcs = ([ "__init__.py", + "crf_ops.py", "skip_gram_ops.py", ]), data = [ @@ -15,6 +16,19 @@ py_library( srcs_version = "PY2AND3", ) +py_test( + name = "crf_ops_test", + size = "small", + srcs = [ + "crf_ops_test.py", + ], + main = "crf_ops_test.py", + srcs_version = "PY2AND3", + deps = [ + ":text", + ], +) + py_test( name = "skip_gram_ops_test", size = "small", diff --git a/tensorflow_addons/text/README.md b/tensorflow_addons/text/README.md index 4b4d948363..d6e60a07b9 100644 --- a/tensorflow_addons/text/README.md +++ b/tensorflow_addons/text/README.md @@ -4,6 +4,7 @@ | Submodule | Maintainers | Contact Info | |:---------- |:----------- |:------------- | | skip_gram_ops | | | +| crf | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | ## Components | Submodule | Text Processing Function | Reference | diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py index 05c758e26d..6c67afa387 100644 --- a/tensorflow_addons/text/__init__.py +++ b/tensorflow_addons/text/__init__.py @@ -20,3 +20,15 @@ # Skip Gram Sampling from tensorflow_addons.text.skip_gram_ops import skip_gram_sample from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab + +from tensorflow_addons.text.crf_ops import crf_binary_score +from tensorflow_addons.text.crf_ops import crf_decode +from tensorflow_addons.text.crf_ops import crf_log_likelihood +from tensorflow_addons.text.crf_ops import crf_log_norm +from tensorflow_addons.text.crf_ops import crf_multitag_sequence_score +from tensorflow_addons.text.crf_ops import crf_sequence_score +from tensorflow_addons.text.crf_ops import crf_unary_score +from tensorflow_addons.text.crf_ops import CrfDecodeBackwardRnnCell +from tensorflow_addons.text.crf_ops import CrfDecodeForwardRnnCell +from tensorflow_addons.text.crf_ops import CrfForwardRnnCell +from tensorflow_addons.text.crf_ops import viterbi_decode diff --git a/tensorflow_addons/text/crf_ops.py b/tensorflow_addons/text/crf_ops.py new file mode 100644 index 0000000000..7acd10924a --- /dev/null +++ b/tensorflow_addons/text/crf_ops.py @@ -0,0 +1,464 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +def crf_sequence_score(inputs, tag_indices, sequence_lengths, + transition_params): + """Computes the unnormalized score for a tag sequence. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we + compute the unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] + + example_inds = tf.reshape( + tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + sequence_scores = tf.gather_nd( + tf.squeeze(inputs, [1]), + tf.concat([example_inds, tag_indices], axis=1)) + sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), + tf.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, + transition_params): + """Computes the unnormalized score of all tag sequences matching tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + filtered_inputs = tf.where( + tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return tf.reduce_logsumexp( + filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences matching the given tags. + filtered_inputs = tf.where( + tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params) + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_log_norm(inputs, sequence_lengths, transition_params): + """Computes the normalization for a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + log_norm: A [batch_size] vector of normalizers for a CRF. + """ + # Split up the first and rest of the inputs in preparation for the forward + # algorithm. + first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) + first_input = tf.squeeze(first_input, [1]) + + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + log_norm = tf.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where(tf.less_equal(sequence_lengths, 0), + tf.zeros_like(log_norm), + log_norm) + return log_norm + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + # Sequence length is not allowed to be less than zero. + sequence_lengths_less_one = tf.maximum( + tf.constant(0, dtype=sequence_lengths.dtype), + sequence_lengths - 1) + + forward_layer = tf.keras.layers.RNN( + forward_cell, + return_sequences=True, + return_state=True) + + _, alphas = forward_layer(rest_of_input, first_input) + + log_norm = tf.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where(tf.less_equal(sequence_lengths, 0), + tf.zeros_like(log_norm), + log_norm) + return log_norm + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_log_likelihood(inputs, + tag_indices, + sequence_lengths, + transition_params=None): + """Computes the log-likelihood of tag sequences in a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we + compute the log-likelihood. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix, if available. + Returns: + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. + transition_params: A [num_tags, num_tags] transition matrix. This is either + provided by the caller or created in this function. + """ + # Get shape information. + num_tags = inputs.shape[2] + + # Get the transition matrix if not provided. + if transition_params is None: + transition_params = tf.get_variable("transitions", [num_tags, num_tags]) + + sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, + transition_params) + log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) + + # Normalize the scores to get the log-likelihood per example. + log_likelihood = sequence_scores - log_norm + return log_likelihood, transition_params + + +def crf_unary_score(tag_indices, sequence_lengths, inputs): + """Computes the unary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. + Returns: + unary_scores: A [batch_size] vector of unary scores. + """ + batch_size = tf.shape(inputs)[0] + max_seq_len = tf.shape(inputs)[1] + num_tags = tf.shape(inputs)[2] + + flattened_inputs = tf.reshape(inputs, [-1]) + + offsets = tf.expand_dims( + tf.range(batch_size) * max_seq_len * num_tags, 1) + offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == tf.int64: + offsets = tf.cast(offsets, tf.int64) + flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) + + unary_scores = tf.reshape( + tf.gather(flattened_inputs, flattened_tag_indices), + [batch_size, max_seq_len]) + + masks = tf.sequence_mask(sequence_lengths, + maxlen=tf.shape(tag_indices)[1], + dtype=tf.float32) + + unary_scores = tf.reduce_sum(unary_scores * masks, 1) + return unary_scores + + +def crf_binary_score(tag_indices, sequence_lengths, transition_params): + """Computes the binary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + Returns: + binary_scores: A [batch_size] vector of binary scores. + """ + # Get shape information. + num_tags = tf.shape(transition_params)[0] + num_transitions = tf.shape(tag_indices)[1] - 1 + + # Truncate by one on each side of the sequence to get the start and end + # indices of each transition. + start_tag_indices = tf.slice(tag_indices, [0, 0], + [-1, num_transitions]) + end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) + + # Encode the indices in a flattened representation. + flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices + flattened_transition_params = tf.reshape(transition_params, [-1]) + + # Get the binary scores based on the flattened representation. + binary_scores = tf.gather(flattened_transition_params, + flattened_transition_indices) + + masks = tf.sequence_mask(sequence_lengths, + maxlen=tf.shape(tag_indices)[1], + dtype=tf.float32) + truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) + binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) + return binary_scores + + +class CrfForwardRnnCell(tf.keras.layers.Layer): + def __init__(self, transition_params, **kwargs): + super(CrfForwardRnnCell, self).__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + self.state_size = self._num_tags + self.output_size = self._num_tags + + def build(self, input_shape): + super(CrfForwardRnnCell, self).build(input_shape) + + def call(self, inputs, state, training=None): + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) + return new_alphas, new_alphas + + +def viterbi_decode(score, transition_params): + """Decode the highest scoring sequence of tags outside of TensorFlow. + + This should only be used at test time. + + Args: + score: A [seq_len, num_tags] matrix of unary potentials. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + + Returns: + viterbi: A [seq_len] list of integers containing the highest scoring tag + indices. + viterbi_score: A float containing the score for the Viterbi sequence. + """ + trellis = np.zeros_like(score) + backpointers = np.zeros_like(score, dtype=np.int32) + trellis[0] = score[0] + + for t in range(1, score.shape[0]): + v = np.expand_dims(trellis[t - 1], 1) + transition_params + trellis[t] = score[t] + np.max(v, 0) + backpointers[t] = np.argmax(v, 0) + + viterbi = [np.argmax(trellis[-1])] + for bp in reversed(backpointers[1:]): + viterbi.append(bp[viterbi[-1]]) + viterbi.reverse() + + viterbi_score = np.max(trellis[-1]) + return viterbi, viterbi_score + + +class CrfDecodeForwardRnnCell(tf.keras.layers.Layer): + """Computes the forward decoding in a linear-chain CRF. + """ + + def __init__(self, transition_params, **kwargs): + """Initialize the CrfDecodeForwardRnnCell. + + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the broadcast + summation occurring within the cell. + """ + super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + self.state_size = self._num_tags + self.output_size = self._num_tags + + def build(self, input_shape): + super(CrfDecodeForwardRnnCell, self).build(input_shape) + + def call(self, inputs, state, training=None): + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) + backpointers = tf.cast(backpointers, dtype=tf.int32) + return backpointers, new_state + + +class CrfDecodeBackwardRnnCell(tf.keras.layers.Layer): + """Computes backward decoding in a linear-chain CRF. + """ + + def __init__(self, num_tags, **kwargs): + """Initialize the CrfDecodeBackwardRnnCell. + + Args: + num_tags: An integer. The number of tags. + """ + super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) + self._num_tags = num_tags + + self.state_size = 1 + self.output_size = 1 + + def build(self, input_shape): + super(CrfDecodeBackwardRnnCell, self).build(input_shape) + + def call(self, inputs, state, training=None): + state = tf.squeeze(state[0], axis=[1]) + batch_size = tf.shape(inputs)[0] + b_indices = tf.range(batch_size) + indices = tf.stack([b_indices, state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, indices), axis=-1) + + return new_tags, new_tags + + +def crf_decode(potentials, transition_params, sequence_length): + """Decode the highest scoring sequence of tags in TensorFlow. + + This is a function for tensor. + + Args: + potentials: A [batch_size, max_seq_len, num_tags] tensor of + unary potentials. + transition_params: A [num_tags, num_tags] matrix of + binary potentials. + sequence_length: A [batch_size] vector of true sequence lengths. + + Returns: + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + best_score: A [batch_size] vector, containing the score of `decode_tags`. + """ + + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = tf.squeeze(potentials, [1]) + decode_tags = tf.expand_dims( + tf.argmax(squeezed_potentials, axis=1), 1) + best_score = tf.reduce_max(squeezed_potentials, axis=1) + return tf.cast(decode_tags, dtype=tf.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.shape[2] + + # Computes forward decoding. Get last score and backpointers. + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = tf.squeeze(initial_state, axis=[1]) # [B, O] + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + # Sequence length is not allowed to be less than zero. + + sequence_length_less_one = tf.maximum( + tf.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) + + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_layer = tf.keras.layers.RNN(crf_fwd_cell, + return_sequences=True, + return_state=True, + time_major=False) + backpointers, last_score = crf_fwd_layer(inputs, initial_state) + backpointers = tf.reverse_sequence(backpointers, sequence_length_less_one, seq_axis=1) + + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) + initial_state = tf.expand_dims(initial_state, axis=-1) + crf_bwd_layer = tf.keras.layers.RNN(crf_bwd_cell, + return_sequences=True, + return_state=True, + time_major=False) + decode_tags, _ = crf_bwd_layer(backpointers, initial_state) + + decode_tags = tf.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = tf.concat([initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = tf.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_axis=1) + + best_score = tf.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + if potentials.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() diff --git a/tensorflow_addons/text/crf_ops_test.py b/tensorflow_addons/text/crf_ops_test.py new file mode 100644 index 0000000000..d706992c32 --- /dev/null +++ b/tensorflow_addons/text/crf_ops_test.py @@ -0,0 +1,358 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CRF.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np +import tensorflow as tf + +from tensorflow_addons import text +from tensorflow_addons.utils import test_utils + + +class CrfTest(tf.test.TestCase): + + def calculateSequenceScore(self, inputs, transition_params, tag_indices, + sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + return expected_unary_score + expected_binary_score + + def testCrfSequenceScore(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + + tf_sequence_score = self.evaluate(sequence_score) + + expected_sequence_score = self.calculateSequenceScore( + inputs, transition_params, tag_indices, sequence_lengths) + self.assertAllClose(tf_sequence_score, expected_sequence_score) + + def testCrfMultiTagSequenceScore(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_bitmap_list = [ + np.array( + [[True, True, False], [True, False, True], [False, True, True], + [True, False, True]], + dtype=np.bool), + np.array([[True, True, False]], dtype=np.bool) + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list): + sequence_score = text.crf_multitag_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_bitmap=tf.expand_dims(tag_bitmap, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + tf_sum_sequence_score = self.evaluate(sequence_score) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + self.calculateSequenceScore(inputs, transition_params, indices, + sequence_lengths) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores) + self.assertAllClose(tf_sum_sequence_score, + expected_log_sum_exp_sequence_scores) + + def testCrfUnaryScore(self): + inputs = np.array( + [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + unary_score = text.crf_unary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + inputs=tf.expand_dims(inputs, 0)) + unary_score = tf.squeeze(unary_score, [0]) + tf_unary_score = self.evaluate(unary_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) + + def testCrfBinaryScore(self): + tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + binary_score = text.crf_binary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + binary_score = tf.squeeze(binary_score, [0]) + tf_binary_score = self.evaluate(binary_score) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + self.assertAllClose(tf_binary_score, expected_binary_score) + + def testCrfLogNorm(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params))) + + brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) + log_norm = text.crf_log_norm( + inputs=tf.expand_dims(inputs, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + log_norm = tf.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = self.evaluate( + [brute_force_log_norm, log_norm]) + + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + + def testCrfLogNormZeroSeqLength(self): + """ + Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. + """ + inputs = tf.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], + dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = text.crf_log_norm(inputs, sequence_lengths, transition_params) + tf_log_norm = self.evaluate(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + + def testCrfLogLikelihood(self): + inputs = np.array( + [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_log_likelihoods = [] + + # Make sure all probabilities sum to 1. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + sequence_log_likelihood, _ = text.crf_log_likelihood( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + all_sequence_log_likelihoods.append(sequence_log_likelihood) + total_log_likelihood = tf.reduce_logsumexp( + all_sequence_log_likelihoods) + tf_total_log_likelihood = self.evaluate(total_log_likelihood) + self.assertAllClose(tf_total_log_likelihood, 0.0) + + def testViterbiDecode(self): + inputs = np.array( + [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.viterbi_decode( + inputs[:sequence_lengths], transition_params) + + self.assertAllClose(actual_max_score, expected_max_score) + self.assertEqual(actual_max_sequence, + expected_max_sequence[:sequence_lengths]) + + def testCrfDecode(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.crf_decode( + tf.expand_dims(inputs, 0), + tf.constant(transition_params), + tf.expand_dims(sequence_lengths, 0)) + actual_max_sequence = tf.squeeze(actual_max_sequence, [0]) + actual_max_score = tf.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = self.evaluate( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) + + def testCrfDecodeZeroSeqLength(self): + """ + Test that crf_decode works when sequence_length contains one or more zeros. + """ + inputs = tf.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], + dtype=np.int32)) + tags, scores = text.crf_decode(inputs, transition_params, sequence_lengths) + tf_tags, tf_scores = self.evaluate([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) + + +if __name__ == "__main__": + tf.test.main() From 51bebe87cfa252b470f62c8a6dcb81e98498a723 Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Thu, 20 Jun 2019 00:15:09 +0530 Subject: [PATCH 02/11] Format using make code-format --- tensorflow_addons/text/crf_ops.py | 761 ++++++++++++------------- tensorflow_addons/text/crf_ops_test.py | 633 ++++++++++---------- 2 files changed, 687 insertions(+), 707 deletions(-) diff --git a/tensorflow_addons/text/crf_ops.py b/tensorflow_addons/text/crf_ops.py index 7acd10924a..9e5fd02051 100644 --- a/tensorflow_addons/text/crf_ops.py +++ b/tensorflow_addons/text/crf_ops.py @@ -23,442 +23,435 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): - """Computes the unnormalized score for a tag sequence. - - Args: - inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials - to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the unnormalized score. - sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix. - Returns: - sequence_scores: A [batch_size] vector of unnormalized sequence scores. - """ - - # If max_seq_len is 1, we skip the score calculation and simply gather the - # unary potentials of the single tag. - def _single_seq_fn(): - batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] - - example_inds = tf.reshape( - tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) - sequence_scores = tf.gather_nd( - tf.squeeze(inputs, [1]), - tf.concat([example_inds, tag_indices], axis=1)) - sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0), - tf.zeros_like(sequence_scores), - sequence_scores) - return sequence_scores - - def _multi_seq_fn(): - # Compute the scores of the given tag sequence. - unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) - binary_scores = crf_binary_score(tag_indices, sequence_lengths, - transition_params) - sequence_scores = unary_scores + binary_scores - return sequence_scores - - if inputs.shape[1] == 1: - return _single_seq_fn() - else: - return _multi_seq_fn() + """Computes the unnormalized score for a tag sequence. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we + compute the unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] + + example_inds = tf.reshape( + tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + sequence_scores = tf.gather_nd( + tf.squeeze(inputs, [1]), + tf.concat([example_inds, tag_indices], axis=1)) + sequence_scores = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params): - """Computes the unnormalized score of all tag sequences matching tag_bitmap. - - tag_bitmap enables more than one tag to be considered correct at each time - step. This is useful when an observed output at a given time step is - consistent with more than one tag, and thus the log likelihood of that - observation must take into account all possible consistent tags. - - Using one-hot vectors in tag_bitmap gives results identical to - crf_sequence_score. - - Args: - inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials - to use as input to the CRF layer. - tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor - representing all active tags at each index for which to calculate the - unnormalized score. - sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix. - Returns: - sequence_scores: A [batch_size] vector of unnormalized sequence scores. - """ - - # If max_seq_len is 1, we skip the score calculation and simply gather the - # unary potentials of all active tags. - def _single_seq_fn(): - filtered_inputs = tf.where( - tag_bitmap, inputs, - tf.fill(tf.shape(inputs), float("-inf"))) - return tf.reduce_logsumexp( - filtered_inputs, axis=[1, 2], keepdims=False) - - def _multi_seq_fn(): - # Compute the logsumexp of all scores of sequences matching the given tags. - filtered_inputs = tf.where( - tag_bitmap, inputs, - tf.fill(tf.shape(inputs), float("-inf"))) - return crf_log_norm( - inputs=filtered_inputs, - sequence_lengths=sequence_lengths, - transition_params=transition_params) - - if inputs.shape[1] == 1: - return _single_seq_fn() - else: - return _multi_seq_fn() + """Computes the unnormalized score of all tag sequences matching + tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + filtered_inputs = tf.where(tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return tf.reduce_logsumexp( + filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences matching the given tags. + filtered_inputs = tf.where(tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params) + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() def crf_log_norm(inputs, sequence_lengths, transition_params): - """Computes the normalization for a CRF. - - Args: - inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials - to use as input to the CRF layer. - sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix. - Returns: - log_norm: A [batch_size] vector of normalizers for a CRF. - """ - # Split up the first and rest of the inputs in preparation for the forward - # algorithm. - first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) - first_input = tf.squeeze(first_input, [1]) - - # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over - # the "initial state" (the unary potentials). - def _single_seq_fn(): - log_norm = tf.reduce_logsumexp(first_input, [1]) - # Mask `log_norm` of the sequences with length <= zero. - log_norm = tf.where(tf.less_equal(sequence_lengths, 0), - tf.zeros_like(log_norm), - log_norm) - return log_norm - - def _multi_seq_fn(): - """Forward computation of alpha values.""" - rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) - # Compute the alpha values in the forward algorithm in order to get the - # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - # Sequence length is not allowed to be less than zero. - sequence_lengths_less_one = tf.maximum( - tf.constant(0, dtype=sequence_lengths.dtype), - sequence_lengths - 1) - - forward_layer = tf.keras.layers.RNN( - forward_cell, - return_sequences=True, - return_state=True) - - _, alphas = forward_layer(rest_of_input, first_input) - - log_norm = tf.reduce_logsumexp(alphas, [1]) - # Mask `log_norm` of the sequences with length <= zero. - log_norm = tf.where(tf.less_equal(sequence_lengths, 0), - tf.zeros_like(log_norm), - log_norm) - return log_norm - - if inputs.shape[1] == 1: - return _single_seq_fn() - else: - return _multi_seq_fn() + """Computes the normalization for a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + log_norm: A [batch_size] vector of normalizers for a CRF. + """ + # Split up the first and rest of the inputs in preparation for the forward + # algorithm. + first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) + first_input = tf.squeeze(first_input, [1]) + + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + log_norm = tf.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), + log_norm) + return log_norm + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + # Sequence length is not allowed to be less than zero. + sequence_lengths_less_one = tf.maximum( + tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) + + forward_layer = tf.keras.layers.RNN( + forward_cell, return_sequences=True, return_state=True) + + _, alphas = forward_layer(rest_of_input, first_input) + + log_norm = tf.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), + log_norm) + return log_norm + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() def crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params=None): - """Computes the log-likelihood of tag sequences in a CRF. - - Args: - inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials - to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the log-likelihood. - sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix, if available. - Returns: - log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of - each example, given the sequence of tag indices. - transition_params: A [num_tags, num_tags] transition matrix. This is either - provided by the caller or created in this function. - """ - # Get shape information. - num_tags = inputs.shape[2] - - # Get the transition matrix if not provided. - if transition_params is None: - transition_params = tf.get_variable("transitions", [num_tags, num_tags]) - - sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, - transition_params) - log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) - - # Normalize the scores to get the log-likelihood per example. - log_likelihood = sequence_scores - log_norm - return log_likelihood, transition_params + """Computes the log-likelihood of tag sequences in a CRF. + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we + compute the log-likelihood. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix, if available. + Returns: + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. + transition_params: A [num_tags, num_tags] transition matrix. This is either + provided by the caller or created in this function. + """ + # Get shape information. + num_tags = inputs.shape[2] -def crf_unary_score(tag_indices, sequence_lengths, inputs): - """Computes the unary scores of tag sequences. + # Get the transition matrix if not provided. + if transition_params is None: + transition_params = tf.get_variable("transitions", + [num_tags, num_tags]) - Args: - tag_indices: A [batch_size, max_seq_len] matrix of tag indices. - sequence_lengths: A [batch_size] vector of true sequence lengths. - inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. - Returns: - unary_scores: A [batch_size] vector of unary scores. - """ - batch_size = tf.shape(inputs)[0] - max_seq_len = tf.shape(inputs)[1] - num_tags = tf.shape(inputs)[2] + sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, + transition_params) + log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) - flattened_inputs = tf.reshape(inputs, [-1]) + # Normalize the scores to get the log-likelihood per example. + log_likelihood = sequence_scores - log_norm + return log_likelihood, transition_params - offsets = tf.expand_dims( - tf.range(batch_size) * max_seq_len * num_tags, 1) - offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) - # Use int32 or int64 based on tag_indices' dtype. - if tag_indices.dtype == tf.int64: - offsets = tf.cast(offsets, tf.int64) - flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) - unary_scores = tf.reshape( - tf.gather(flattened_inputs, flattened_tag_indices), - [batch_size, max_seq_len]) +def crf_unary_score(tag_indices, sequence_lengths, inputs): + """Computes the unary scores of tag sequences. - masks = tf.sequence_mask(sequence_lengths, - maxlen=tf.shape(tag_indices)[1], - dtype=tf.float32) + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. + Returns: + unary_scores: A [batch_size] vector of unary scores. + """ + batch_size = tf.shape(inputs)[0] + max_seq_len = tf.shape(inputs)[1] + num_tags = tf.shape(inputs)[2] - unary_scores = tf.reduce_sum(unary_scores * masks, 1) - return unary_scores + flattened_inputs = tf.reshape(inputs, [-1]) + offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) + offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == tf.int64: + offsets = tf.cast(offsets, tf.int64) + flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) -def crf_binary_score(tag_indices, sequence_lengths, transition_params): - """Computes the binary scores of tag sequences. - - Args: - tag_indices: A [batch_size, max_seq_len] matrix of tag indices. - sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] matrix of binary potentials. - Returns: - binary_scores: A [batch_size] vector of binary scores. - """ - # Get shape information. - num_tags = tf.shape(transition_params)[0] - num_transitions = tf.shape(tag_indices)[1] - 1 - - # Truncate by one on each side of the sequence to get the start and end - # indices of each transition. - start_tag_indices = tf.slice(tag_indices, [0, 0], - [-1, num_transitions]) - end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) - - # Encode the indices in a flattened representation. - flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices - flattened_transition_params = tf.reshape(transition_params, [-1]) - - # Get the binary scores based on the flattened representation. - binary_scores = tf.gather(flattened_transition_params, - flattened_transition_indices) - - masks = tf.sequence_mask(sequence_lengths, - maxlen=tf.shape(tag_indices)[1], - dtype=tf.float32) - truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) - binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) - return binary_scores + unary_scores = tf.reshape( + tf.gather(flattened_inputs, flattened_tag_indices), + [batch_size, max_seq_len]) + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) -class CrfForwardRnnCell(tf.keras.layers.Layer): - def __init__(self, transition_params, **kwargs): - super(CrfForwardRnnCell, self).__init__(**kwargs) - self._transition_params = tf.expand_dims(transition_params, 0) - self._num_tags = transition_params.shape[0] - self.state_size = self._num_tags - self.output_size = self._num_tags + unary_scores = tf.reduce_sum(unary_scores * masks, 1) + return unary_scores - def build(self, input_shape): - super(CrfForwardRnnCell, self).build(input_shape) - def call(self, inputs, state, training=None): - state = tf.expand_dims(state[0], 2) - transition_scores = state + self._transition_params - new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) - return new_alphas, new_alphas +def crf_binary_score(tag_indices, sequence_lengths, transition_params): + """Computes the binary scores of tag sequences. + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + Returns: + binary_scores: A [batch_size] vector of binary scores. + """ + # Get shape information. + num_tags = tf.shape(transition_params)[0] + num_transitions = tf.shape(tag_indices)[1] - 1 -def viterbi_decode(score, transition_params): - """Decode the highest scoring sequence of tags outside of TensorFlow. + # Truncate by one on each side of the sequence to get the start and end + # indices of each transition. + start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) + end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) - This should only be used at test time. + # Encode the indices in a flattened representation. + flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices + flattened_transition_params = tf.reshape(transition_params, [-1]) - Args: - score: A [seq_len, num_tags] matrix of unary potentials. - transition_params: A [num_tags, num_tags] matrix of binary potentials. + # Get the binary scores based on the flattened representation. + binary_scores = tf.gather(flattened_transition_params, + flattened_transition_indices) - Returns: - viterbi: A [seq_len] list of integers containing the highest scoring tag - indices. - viterbi_score: A float containing the score for the Viterbi sequence. - """ - trellis = np.zeros_like(score) - backpointers = np.zeros_like(score, dtype=np.int32) - trellis[0] = score[0] + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) + truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) + binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) + return binary_scores - for t in range(1, score.shape[0]): - v = np.expand_dims(trellis[t - 1], 1) + transition_params - trellis[t] = score[t] + np.max(v, 0) - backpointers[t] = np.argmax(v, 0) - viterbi = [np.argmax(trellis[-1])] - for bp in reversed(backpointers[1:]): - viterbi.append(bp[viterbi[-1]]) - viterbi.reverse() +class CrfForwardRnnCell(tf.keras.layers.Layer): + def __init__(self, transition_params, **kwargs): + super(CrfForwardRnnCell, self).__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + self.state_size = self._num_tags + self.output_size = self._num_tags - viterbi_score = np.max(trellis[-1]) - return viterbi, viterbi_score + def build(self, input_shape): + super(CrfForwardRnnCell, self).build(input_shape) + def call(self, inputs, state, training=None): + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) + return new_alphas, new_alphas -class CrfDecodeForwardRnnCell(tf.keras.layers.Layer): - """Computes the forward decoding in a linear-chain CRF. - """ - def __init__(self, transition_params, **kwargs): - """Initialize the CrfDecodeForwardRnnCell. +def viterbi_decode(score, transition_params): + """Decode the highest scoring sequence of tags outside of TensorFlow. + + This should only be used at test time. Args: - transition_params: A [num_tags, num_tags] matrix of binary - potentials. This matrix is expanded into a - [1, num_tags, num_tags] in preparation for the broadcast - summation occurring within the cell. + score: A [seq_len, num_tags] matrix of unary potentials. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + + Returns: + viterbi: A [seq_len] list of integers containing the highest scoring tag + indices. + viterbi_score: A float containing the score for the Viterbi sequence. """ - super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) - self._transition_params = tf.expand_dims(transition_params, 0) - self._num_tags = transition_params.shape[0] - self.state_size = self._num_tags - self.output_size = self._num_tags + trellis = np.zeros_like(score) + backpointers = np.zeros_like(score, dtype=np.int32) + trellis[0] = score[0] - def build(self, input_shape): - super(CrfDecodeForwardRnnCell, self).build(input_shape) + for t in range(1, score.shape[0]): + v = np.expand_dims(trellis[t - 1], 1) + transition_params + trellis[t] = score[t] + np.max(v, 0) + backpointers[t] = np.argmax(v, 0) - def call(self, inputs, state, training=None): - state = tf.expand_dims(state[0], 2) - transition_scores = state + self._transition_params - new_state = inputs + tf.reduce_max(transition_scores, [1]) - backpointers = tf.argmax(transition_scores, 1) - backpointers = tf.cast(backpointers, dtype=tf.int32) - return backpointers, new_state + viterbi = [np.argmax(trellis[-1])] + for bp in reversed(backpointers[1:]): + viterbi.append(bp[viterbi[-1]]) + viterbi.reverse() + + viterbi_score = np.max(trellis[-1]) + return viterbi, viterbi_score + + +class CrfDecodeForwardRnnCell(tf.keras.layers.Layer): + """Computes the forward decoding in a linear-chain CRF.""" + + def __init__(self, transition_params, **kwargs): + """Initialize the CrfDecodeForwardRnnCell. + + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the broadcast + summation occurring within the cell. + """ + super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + self.state_size = self._num_tags + self.output_size = self._num_tags + + def build(self, input_shape): + super(CrfDecodeForwardRnnCell, self).build(input_shape) + + def call(self, inputs, state, training=None): + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) + backpointers = tf.cast(backpointers, dtype=tf.int32) + return backpointers, new_state class CrfDecodeBackwardRnnCell(tf.keras.layers.Layer): - """Computes backward decoding in a linear-chain CRF. - """ + """Computes backward decoding in a linear-chain CRF.""" - def __init__(self, num_tags, **kwargs): - """Initialize the CrfDecodeBackwardRnnCell. + def __init__(self, num_tags, **kwargs): + """Initialize the CrfDecodeBackwardRnnCell. - Args: - num_tags: An integer. The number of tags. - """ - super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) - self._num_tags = num_tags + Args: + num_tags: An integer. The number of tags. + """ + super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) + self._num_tags = num_tags - self.state_size = 1 - self.output_size = 1 + self.state_size = 1 + self.output_size = 1 - def build(self, input_shape): - super(CrfDecodeBackwardRnnCell, self).build(input_shape) + def build(self, input_shape): + super(CrfDecodeBackwardRnnCell, self).build(input_shape) - def call(self, inputs, state, training=None): - state = tf.squeeze(state[0], axis=[1]) - batch_size = tf.shape(inputs)[0] - b_indices = tf.range(batch_size) - indices = tf.stack([b_indices, state], axis=1) - new_tags = tf.expand_dims(tf.gather_nd(inputs, indices), axis=-1) + def call(self, inputs, state, training=None): + state = tf.squeeze(state[0], axis=[1]) + batch_size = tf.shape(inputs)[0] + b_indices = tf.range(batch_size) + indices = tf.stack([b_indices, state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, indices), axis=-1) - return new_tags, new_tags + return new_tags, new_tags def crf_decode(potentials, transition_params, sequence_length): - """Decode the highest scoring sequence of tags in TensorFlow. - - This is a function for tensor. - - Args: - potentials: A [batch_size, max_seq_len, num_tags] tensor of - unary potentials. - transition_params: A [num_tags, num_tags] matrix of - binary potentials. - sequence_length: A [batch_size] vector of true sequence lengths. - - Returns: - decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. - Contains the highest scoring tag indices. - best_score: A [batch_size] vector, containing the score of `decode_tags`. - """ - - # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag - # and the max activation. - def _single_seq_fn(): - squeezed_potentials = tf.squeeze(potentials, [1]) - decode_tags = tf.expand_dims( - tf.argmax(squeezed_potentials, axis=1), 1) - best_score = tf.reduce_max(squeezed_potentials, axis=1) - return tf.cast(decode_tags, dtype=tf.int32), best_score - - def _multi_seq_fn(): - """Decoding of highest scoring sequence.""" - - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.shape[2] - - # Computes forward decoding. Get last score and backpointers. - initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = tf.squeeze(initial_state, axis=[1]) # [B, O] - inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - # Sequence length is not allowed to be less than zero. - - sequence_length_less_one = tf.maximum( - tf.constant(0, dtype=sequence_length.dtype), - sequence_length - 1) - - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - crf_fwd_layer = tf.keras.layers.RNN(crf_fwd_cell, - return_sequences=True, - return_state=True, - time_major=False) - backpointers, last_score = crf_fwd_layer(inputs, initial_state) - backpointers = tf.reverse_sequence(backpointers, sequence_length_less_one, seq_axis=1) - - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) - initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) - initial_state = tf.expand_dims(initial_state, axis=-1) - crf_bwd_layer = tf.keras.layers.RNN(crf_bwd_cell, - return_sequences=True, - return_state=True, - time_major=False) - decode_tags, _ = crf_bwd_layer(backpointers, initial_state) - - decode_tags = tf.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = tf.concat([initial_state, decode_tags], # [B, T] - axis=1) - decode_tags = tf.reverse_sequence( # [B, T] - decode_tags, sequence_length, seq_axis=1) - - best_score = tf.reduce_max(last_score, axis=1) # [B] - return decode_tags, best_score - - if potentials.shape[1] == 1: - return _single_seq_fn() - else: - return _multi_seq_fn() + """Decode the highest scoring sequence of tags in TensorFlow. + + This is a function for tensor. + + Args: + potentials: A [batch_size, max_seq_len, num_tags] tensor of + unary potentials. + transition_params: A [num_tags, num_tags] matrix of + binary potentials. + sequence_length: A [batch_size] vector of true sequence lengths. + + Returns: + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + best_score: A [batch_size] vector, containing the score of `decode_tags`. + """ + + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = tf.squeeze(potentials, [1]) + decode_tags = tf.expand_dims(tf.argmax(squeezed_potentials, axis=1), 1) + best_score = tf.reduce_max(squeezed_potentials, axis=1) + return tf.cast(decode_tags, dtype=tf.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.shape[2] + + # Computes forward decoding. Get last score and backpointers. + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = tf.squeeze(initial_state, axis=[1]) # [B, O] + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + # Sequence length is not allowed to be less than zero. + + sequence_length_less_one = tf.maximum( + tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1) + + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_layer = tf.keras.layers.RNN( + crf_fwd_cell, + return_sequences=True, + return_state=True, + time_major=False) + backpointers, last_score = crf_fwd_layer(inputs, initial_state) + backpointers = tf.reverse_sequence( + backpointers, sequence_length_less_one, seq_axis=1) + + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) + initial_state = tf.expand_dims(initial_state, axis=-1) + crf_bwd_layer = tf.keras.layers.RNN( + crf_bwd_cell, + return_sequences=True, + return_state=True, + time_major=False) + decode_tags, _ = crf_bwd_layer(backpointers, initial_state) + + decode_tags = tf.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = tf.concat( + [initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = tf.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_axis=1) + + best_score = tf.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + if potentials.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() diff --git a/tensorflow_addons/text/crf_ops_test.py b/tensorflow_addons/text/crf_ops_test.py index d706992c32..ad22d95a2f 100644 --- a/tensorflow_addons/text/crf_ops_test.py +++ b/tensorflow_addons/text/crf_ops_test.py @@ -28,331 +28,318 @@ class CrfTest(tf.test.TestCase): - - def calculateSequenceScore(self, inputs, transition_params, tag_indices, - sequence_lengths): - expected_unary_score = sum( - inputs[i][tag_indices[i]] for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - return expected_unary_score + expected_binary_score - - def testCrfSequenceScore(self): - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - # Test both the length-1 and regular cases. - sequence_lengths_list = [ - np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) - ] - inputs_list = [ - np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], - dtype=np.float32), - np.array([[4, 5, -3]], - dtype=np.float32), - ] - tag_indices_list = [ - np.array([1, 2, 1, 0], dtype=np.int32), - np.array([1], dtype=np.int32) - ] - for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, - inputs_list, - tag_indices_list): - sequence_score = text.crf_sequence_score( - inputs=tf.expand_dims(inputs, 0), - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - sequence_score = tf.squeeze(sequence_score, [0]) - - tf_sequence_score = self.evaluate(sequence_score) - - expected_sequence_score = self.calculateSequenceScore( - inputs, transition_params, tag_indices, sequence_lengths) - self.assertAllClose(tf_sequence_score, expected_sequence_score) - - def testCrfMultiTagSequenceScore(self): - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - # Test both the length-1 and regular cases. - sequence_lengths_list = [ - np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) - ] - inputs_list = [ - np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], - dtype=np.float32), - np.array([[4, 5, -3]], - dtype=np.float32), - ] - tag_bitmap_list = [ - np.array( - [[True, True, False], [True, False, True], [False, True, True], - [True, False, True]], - dtype=np.bool), - np.array([[True, True, False]], dtype=np.bool) - ] - for sequence_lengths, inputs, tag_bitmap in zip( - sequence_lengths_list, inputs_list, tag_bitmap_list): - sequence_score = text.crf_multitag_sequence_score( - inputs=tf.expand_dims(inputs, 0), - tag_bitmap=tf.expand_dims(tag_bitmap, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - sequence_score = tf.squeeze(sequence_score, [0]) - tf_sum_sequence_score = self.evaluate(sequence_score) - all_indices_list = [ - single_index_bitmap.nonzero()[0] - for single_index_bitmap in tag_bitmap[:sequence_lengths] - ] - expected_sequence_scores = [ - self.calculateSequenceScore(inputs, transition_params, indices, - sequence_lengths) - for indices in itertools.product(*all_indices_list) - ] - expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( - expected_sequence_scores) - self.assertAllClose(tf_sum_sequence_score, - expected_log_sum_exp_sequence_scores) - - def testCrfUnaryScore(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - for dtype in (np.int32, np.int64): - tag_indices = np.array([1, 2, 1, 0], dtype=dtype) - sequence_lengths = np.array(3, dtype=np.int32) - unary_score = text.crf_unary_score( - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - inputs=tf.expand_dims(inputs, 0)) - unary_score = tf.squeeze(unary_score, [0]) - tf_unary_score = self.evaluate(unary_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - self.assertAllClose(tf_unary_score, expected_unary_score) - - def testCrfBinaryScore(self): - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - binary_score = text.crf_binary_score( - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - binary_score = tf.squeeze(binary_score, [0]) - tf_binary_score = self.evaluate(binary_score) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - self.assertAllClose(tf_binary_score, expected_binary_score) - - def testCrfLogNorm(self): - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - # Test both the length-1 and regular cases. - sequence_lengths_list = [ - np.array(3, dtype=np.int32), - np.array(1, dtype=np.int64) - ] - inputs_list = [ - np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], - dtype=np.float32), - np.array([[3, -1, 3]], - dtype=np.float32), - ] - tag_indices_list = [ - np.array([1, 2, 1, 0], dtype=np.int32), - np.array([2], dtype=np.int32) - ] - - for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, - inputs_list, - tag_indices_list): - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - all_sequence_scores = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequence_scores.append( - text.crf_sequence_score( - inputs=tf.expand_dims(inputs, 0), + def calculateSequenceScore(self, inputs, transition_params, tag_indices, + sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + return expected_unary_score + expected_binary_score + + def testCrfSequenceScore(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + + tf_sequence_score = self.evaluate(sequence_score) + + expected_sequence_score = self.calculateSequenceScore( + inputs, transition_params, tag_indices, sequence_lengths) + self.assertAllClose(tf_sequence_score, expected_sequence_score) + + def testCrfMultiTagSequenceScore(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], dtype=np.float32), + ] + tag_bitmap_list = [ + np.array([[True, True, False], [True, False, True], + [False, True, True], [True, False, True]], + dtype=np.bool), + np.array([[True, True, False]], dtype=np.bool) + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list): + sequence_score = text.crf_multitag_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_bitmap=tf.expand_dims(tag_bitmap, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + tf_sum_sequence_score = self.evaluate(sequence_score) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + self.calculateSequenceScore(inputs, transition_params, indices, + sequence_lengths) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores) + self.assertAllClose(tf_sum_sequence_score, + expected_log_sum_exp_sequence_scores) + + def testCrfUnaryScore(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + unary_score = text.crf_unary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + inputs=tf.expand_dims(inputs, 0)) + unary_score = tf.squeeze(unary_score, [0]) + tf_unary_score = self.evaluate(unary_score) + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) + + def testCrfBinaryScore(self): + tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + binary_score = text.crf_binary_score( tag_indices=tf.expand_dims(tag_indices, 0), sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params))) - - brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) - log_norm = text.crf_log_norm( - inputs=tf.expand_dims(inputs, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - log_norm = tf.squeeze(log_norm, [0]) - tf_brute_force_log_norm, tf_log_norm = self.evaluate( - [brute_force_log_norm, log_norm]) - - self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) - - def testCrfLogNormZeroSeqLength(self): - """ - Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. - """ - inputs = tf.constant(np.ones([2, 10, 5], - dtype=np.float32)) - transition_params = tf.constant(np.ones([5, 5], - dtype=np.float32)) - sequence_lengths = tf.constant(np.zeros([2], - dtype=np.int32)) - expected_log_norm = np.zeros([2], dtype=np.float32) - log_norm = text.crf_log_norm(inputs, sequence_lengths, transition_params) - tf_log_norm = self.evaluate(log_norm) - self.assertAllClose(tf_log_norm, expected_log_norm) - - def testCrfLogLikelihood(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - all_sequence_log_likelihoods = [] - - # Make sure all probabilities sum to 1. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - sequence_log_likelihood, _ = text.crf_log_likelihood( - inputs=tf.expand_dims(inputs, 0), - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - all_sequence_log_likelihoods.append(sequence_log_likelihood) - total_log_likelihood = tf.reduce_logsumexp( - all_sequence_log_likelihoods) - tf_total_log_likelihood = self.evaluate(total_log_likelihood) - self.assertAllClose(tf_total_log_likelihood, 0.0) - - def testViterbiDecode(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = text.crf_sequence_score( - inputs=tf.expand_dims(inputs, 0), - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - sequence_score = tf.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = self.evaluate(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = text.viterbi_decode( - inputs[:sequence_lengths], transition_params) - - self.assertAllClose(actual_max_score, expected_max_score) - self.assertEqual(actual_max_sequence, - expected_max_sequence[:sequence_lengths]) - - def testCrfDecode(self): - transition_params = np.array( - [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - # Test both the length-1 and regular cases. - sequence_lengths_list = [ - np.array(3, dtype=np.int32), - np.array(1, dtype=np.int64) - ] - inputs_list = [ - np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], - dtype=np.float32), - np.array([[-1, 2, 1]], - dtype=np.float32), - ] - tag_indices_list = [ - np.array([1, 2, 1, 0], dtype=np.int32), - np.array([2], dtype=np.int32) - ] - - for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, - inputs_list, - tag_indices_list): - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = text.crf_sequence_score( - inputs=tf.expand_dims(inputs, 0), - tag_indices=tf.expand_dims(tag_indices, 0), - sequence_lengths=tf.expand_dims(sequence_lengths, 0), - transition_params=tf.constant(transition_params)) - sequence_score = tf.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = self.evaluate(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = text.crf_decode( - tf.expand_dims(inputs, 0), - tf.constant(transition_params), - tf.expand_dims(sequence_lengths, 0)) - actual_max_sequence = tf.squeeze(actual_max_sequence, [0]) - actual_max_score = tf.squeeze(actual_max_score, [0]) - tf_actual_max_sequence, tf_actual_max_score = self.evaluate( - [actual_max_sequence, actual_max_score]) - - self.assertAllClose(tf_actual_max_score, expected_max_score) - self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), - expected_max_sequence[:sequence_lengths]) - - def testCrfDecodeZeroSeqLength(self): - """ - Test that crf_decode works when sequence_length contains one or more zeros. - """ - inputs = tf.constant(np.ones([2, 10, 5], - dtype=np.float32)) - transition_params = tf.constant(np.ones([5, 5], - dtype=np.float32)) - sequence_lengths = tf.constant(np.zeros([2], - dtype=np.int32)) - tags, scores = text.crf_decode(inputs, transition_params, sequence_lengths) - tf_tags, tf_scores = self.evaluate([tags, scores]) - self.assertEqual(len(tf_tags.shape), 2) - self.assertEqual(len(tf_scores.shape), 1) + transition_params=tf.constant(transition_params)) + binary_score = tf.squeeze(binary_score, [0]) + tf_binary_score = self.evaluate(binary_score) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + self.assertAllClose(tf_binary_score, expected_binary_score) + + def testCrfLogNorm(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params))) + + brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) + log_norm = text.crf_log_norm( + inputs=tf.expand_dims(inputs, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + log_norm = tf.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = self.evaluate( + [brute_force_log_norm, log_norm]) + + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + + def testCrfLogNormZeroSeqLength(self): + """Test `crf_log_norm` when `sequence_lengths` contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = text.crf_log_norm(inputs, sequence_lengths, + transition_params) + tf_log_norm = self.evaluate(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + + def testCrfLogLikelihood(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_log_likelihoods = [] + + # Make sure all probabilities sum to 1. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + sequence_log_likelihood, _ = text.crf_log_likelihood( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + all_sequence_log_likelihoods.append(sequence_log_likelihood) + total_log_likelihood = tf.reduce_logsumexp( + all_sequence_log_likelihoods) + tf_total_log_likelihood = self.evaluate(total_log_likelihood) + self.assertAllClose(tf_total_log_likelihood, 0.0) + + def testViterbiDecode(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[ + expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.viterbi_decode( + inputs[:sequence_lengths], transition_params) + + self.assertAllClose(actual_max_score, expected_max_score) + self.assertEqual(actual_max_sequence, + expected_max_sequence[:sequence_lengths]) + + def testCrfDecode(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[ + expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.crf_decode( + tf.expand_dims(inputs, 0), tf.constant(transition_params), + tf.expand_dims(sequence_lengths, 0)) + actual_max_sequence = tf.squeeze(actual_max_sequence, [0]) + actual_max_score = tf.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = self.evaluate( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual( + list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) + + def testCrfDecodeZeroSeqLength(self): + """Test that crf_decode works when sequence_length contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + tags, scores = text.crf_decode(inputs, transition_params, + sequence_lengths) + tf_tags, tf_scores = self.evaluate([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) if __name__ == "__main__": - tf.test.main() + tf.test.main() From 69093696329f7b0b12e03420ef83dd99abe5c692 Mon Sep 17 00:00:00 2001 From: Dheeraj Rajaram Reddy Date: Fri, 21 Jun 2019 13:34:43 +0530 Subject: [PATCH 03/11] Add tf.function to all the CRF functions --- tensorflow_addons/text/crf_ops.py | 7 +++++++ tensorflow_addons/text/crf_ops_test.py | 1 + 2 files changed, 8 insertions(+) diff --git a/tensorflow_addons/text/crf_ops.py b/tensorflow_addons/text/crf_ops.py index 9e5fd02051..0d77f94207 100644 --- a/tensorflow_addons/text/crf_ops.py +++ b/tensorflow_addons/text/crf_ops.py @@ -21,6 +21,7 @@ import tensorflow as tf +@tf.function def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -65,6 +66,7 @@ def _multi_seq_fn(): return _multi_seq_fn() +@tf.function def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params): """Computes the unnormalized score of all tag sequences matching @@ -113,6 +115,7 @@ def _multi_seq_fn(): return _multi_seq_fn() +@tf.function def crf_log_norm(inputs, sequence_lengths, transition_params): """Computes the normalization for a CRF. @@ -167,6 +170,7 @@ def _multi_seq_fn(): return _multi_seq_fn() +@tf.function def crf_log_likelihood(inputs, tag_indices, sequence_lengths, @@ -203,6 +207,7 @@ def crf_log_likelihood(inputs, return log_likelihood, transition_params +@tf.function def crf_unary_score(tag_indices, sequence_lengths, inputs): """Computes the unary scores of tag sequences. @@ -237,6 +242,7 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): return unary_scores +@tf.function def crf_binary_score(tag_indices, sequence_lengths, transition_params): """Computes the binary scores of tag sequences. @@ -379,6 +385,7 @@ def call(self, inputs, state, training=None): return new_tags, new_tags +@tf.function def crf_decode(potentials, transition_params, sequence_length): """Decode the highest scoring sequence of tags in TensorFlow. diff --git a/tensorflow_addons/text/crf_ops_test.py b/tensorflow_addons/text/crf_ops_test.py index ad22d95a2f..84c09b539b 100644 --- a/tensorflow_addons/text/crf_ops_test.py +++ b/tensorflow_addons/text/crf_ops_test.py @@ -27,6 +27,7 @@ from tensorflow_addons.utils import test_utils +@test_utils.run_all_in_graph_and_eager_modes class CrfTest(tf.test.TestCase): def calculateSequenceScore(self, inputs, transition_params, tag_indices, sequence_lengths): From 426e129c0d325c5e7c2e5ae1649d06d35473b2ff Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Mon, 24 Jun 2019 19:03:14 +0530 Subject: [PATCH 04/11] RNN call masks computation based on seq len --- tensorflow_addons/text/crf_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/text/crf_ops.py b/tensorflow_addons/text/crf_ops.py index 0d77f94207..6e04b8bf2b 100644 --- a/tensorflow_addons/text/crf_ops.py +++ b/tensorflow_addons/text/crf_ops.py @@ -155,8 +155,9 @@ def _multi_seq_fn(): forward_layer = tf.keras.layers.RNN( forward_cell, return_sequences=True, return_state=True) - _, alphas = forward_layer(rest_of_input, first_input) - + mask = tf.sequence_mask(sequence_lengths_less_one, + tf.shape(inputs)[1] - 1) + _, alphas = forward_layer(rest_of_input, first_input, mask=mask) log_norm = tf.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where( @@ -170,7 +171,6 @@ def _multi_seq_fn(): return _multi_seq_fn() -@tf.function def crf_log_likelihood(inputs, tag_indices, sequence_lengths, @@ -428,13 +428,15 @@ def _multi_seq_fn(): sequence_length_less_one = tf.maximum( tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1) + mask = tf.sequence_mask(sequence_length_less_one, tf.shape(inputs)[1]) crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) crf_fwd_layer = tf.keras.layers.RNN( crf_fwd_cell, return_sequences=True, return_state=True, time_major=False) - backpointers, last_score = crf_fwd_layer(inputs, initial_state) + backpointers, last_score = crf_fwd_layer( + inputs, initial_state, mask=mask) backpointers = tf.reverse_sequence( backpointers, sequence_length_less_one, seq_axis=1) From d43d1acb29eef927cf5b832f3b0f096c17dc4ba6 Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Tue, 25 Jun 2019 11:44:45 +0530 Subject: [PATCH 05/11] Rename files and minor fixes * Rename crf_ops* -> crf* * The RNN cells inherit `AbstractRNNCell` instead of `Layer` * Remove used `training` variable * Add docstring for RNN Cells --- tensorflow_addons/text/BUILD | 8 +-- tensorflow_addons/text/__init__.py | 27 +++++---- tensorflow_addons/text/{crf_ops.py => crf.py} | 60 +++++++++++++++---- .../text/{crf_ops_test.py => crf_test.py} | 0 4 files changed, 67 insertions(+), 28 deletions(-) rename tensorflow_addons/text/{crf_ops.py => crf.py} (92%) rename tensorflow_addons/text/{crf_ops_test.py => crf_test.py} (100%) diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index d96bdd582b..21306ef3f9 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -6,7 +6,7 @@ py_library( name = "text", srcs = ([ "__init__.py", - "crf_ops.py", + "crf.py", "skip_gram_ops.py", ]), data = [ @@ -17,12 +17,12 @@ py_library( ) py_test( - name = "crf_ops_test", + name = "crf_test", size = "small", srcs = [ - "crf_ops_test.py", + "crf_test.py", ], - main = "crf_ops_test.py", + main = "crf_test.py", srcs_version = "PY2AND3", deps = [ ":text", diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py index 6c67afa387..865b72725a 100644 --- a/tensorflow_addons/text/__init__.py +++ b/tensorflow_addons/text/__init__.py @@ -17,18 +17,19 @@ from __future__ import division from __future__ import print_function +# Conditional Random Field +from tensorflow_addons.text.crf import crf_binary_score +from tensorflow_addons.text.crf import crf_decode +from tensorflow_addons.text.crf import crf_log_likelihood +from tensorflow_addons.text.crf import crf_log_norm +from tensorflow_addons.text.crf import crf_multitag_sequence_score +from tensorflow_addons.text.crf import crf_sequence_score +from tensorflow_addons.text.crf import crf_unary_score +from tensorflow_addons.text.crf import CrfDecodeBackwardRnnCell +from tensorflow_addons.text.crf import CrfDecodeForwardRnnCell +from tensorflow_addons.text.crf import CrfForwardRnnCell +from tensorflow_addons.text.crf import viterbi_decode + # Skip Gram Sampling from tensorflow_addons.text.skip_gram_ops import skip_gram_sample -from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab - -from tensorflow_addons.text.crf_ops import crf_binary_score -from tensorflow_addons.text.crf_ops import crf_decode -from tensorflow_addons.text.crf_ops import crf_log_likelihood -from tensorflow_addons.text.crf_ops import crf_log_norm -from tensorflow_addons.text.crf_ops import crf_multitag_sequence_score -from tensorflow_addons.text.crf_ops import crf_sequence_score -from tensorflow_addons.text.crf_ops import crf_unary_score -from tensorflow_addons.text.crf_ops import CrfDecodeBackwardRnnCell -from tensorflow_addons.text.crf_ops import CrfDecodeForwardRnnCell -from tensorflow_addons.text.crf_ops import CrfForwardRnnCell -from tensorflow_addons.text.crf_ops import viterbi_decode +from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab \ No newline at end of file diff --git a/tensorflow_addons/text/crf_ops.py b/tensorflow_addons/text/crf.py similarity index 92% rename from tensorflow_addons/text/crf_ops.py rename to tensorflow_addons/text/crf.py index 6e04b8bf2b..aba57c7431 100644 --- a/tensorflow_addons/text/crf_ops.py +++ b/tensorflow_addons/text/crf.py @@ -277,18 +277,45 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): return binary_scores -class CrfForwardRnnCell(tf.keras.layers.Layer): +class CrfForwardRnnCell(tf.keras.layers.AbstractRNNCell): + """Computes the alpha values in a linear-chain CRF. + + See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. + """ def __init__(self, transition_params, **kwargs): + """Initialize the CrfForwardRnnCell. + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the + broadcast summation occurring within the cell. + """ super(CrfForwardRnnCell, self).__init__(**kwargs) self._transition_params = tf.expand_dims(transition_params, 0) self._num_tags = transition_params.shape[0] - self.state_size = self._num_tags - self.output_size = self._num_tags + + @property + def state_size(self): + return self._num_tags + + @property + def output_size(self): + return self._num_tags def build(self, input_shape): super(CrfForwardRnnCell, self).build(input_shape) - def call(self, inputs, state, training=None): + def call(self, inputs, state): + """Build the CrfForwardRnnCell. + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the + previous alpha values. + scope: Unused variable scope of this cell. + Returns: + new_alphas, new_alphas: A pair of [batch_size, num_tags] + matrices values containing the new alpha values. + """ state = tf.expand_dims(state[0], 2) transition_scores = state + self._transition_params new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) @@ -327,7 +354,7 @@ def viterbi_decode(score, transition_params): return viterbi, viterbi_score -class CrfDecodeForwardRnnCell(tf.keras.layers.Layer): +class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): """Computes the forward decoding in a linear-chain CRF.""" def __init__(self, transition_params, **kwargs): @@ -342,13 +369,19 @@ def __init__(self, transition_params, **kwargs): super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) self._transition_params = tf.expand_dims(transition_params, 0) self._num_tags = transition_params.shape[0] - self.state_size = self._num_tags - self.output_size = self._num_tags + + @property + def state_size(self): + return self._num_tags + + @property + def output_size(self): + return self._num_tags def build(self, input_shape): super(CrfDecodeForwardRnnCell, self).build(input_shape) - def call(self, inputs, state, training=None): + def call(self, inputs, state): state = tf.expand_dims(state[0], 2) transition_scores = state + self._transition_params new_state = inputs + tf.reduce_max(transition_scores, [1]) @@ -369,13 +402,18 @@ def __init__(self, num_tags, **kwargs): super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) self._num_tags = num_tags - self.state_size = 1 - self.output_size = 1 + @property + def state_size(self): + return 1 + + @property + def output_size(self): + return 1 def build(self, input_shape): super(CrfDecodeBackwardRnnCell, self).build(input_shape) - def call(self, inputs, state, training=None): + def call(self, inputs, state): state = tf.squeeze(state[0], axis=[1]) batch_size = tf.shape(inputs)[0] b_indices = tf.range(batch_size) diff --git a/tensorflow_addons/text/crf_ops_test.py b/tensorflow_addons/text/crf_test.py similarity index 100% rename from tensorflow_addons/text/crf_ops_test.py rename to tensorflow_addons/text/crf_test.py From 69cc6caae085854f908d358bd3b444baa07340fc Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Tue, 25 Jun 2019 11:49:56 +0530 Subject: [PATCH 06/11] code format --- tensorflow_addons/text/crf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index aba57c7431..4e2837f5e0 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -282,6 +282,7 @@ class CrfForwardRnnCell(tf.keras.layers.AbstractRNNCell): See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. """ + def __init__(self, transition_params, **kwargs): """Initialize the CrfForwardRnnCell. Args: @@ -307,13 +308,14 @@ def build(self, input_shape): def call(self, inputs, state): """Build the CrfForwardRnnCell. + Args: inputs: A [batch_size, num_tags] matrix of unary potentials. - state: A [batch_size, num_tags] matrix containing the + state: A [batch_size, num_tags] matrix containing the previous alpha values. scope: Unused variable scope of this cell. Returns: - new_alphas, new_alphas: A pair of [batch_size, num_tags] + new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices values containing the new alpha values. """ state = tf.expand_dims(state[0], 2) @@ -369,7 +371,7 @@ def __init__(self, transition_params, **kwargs): super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) self._transition_params = tf.expand_dims(transition_params, 0) self._num_tags = transition_params.shape[0] - + @property def state_size(self): return self._num_tags From aaddd73450a5db5ecb819e25372b3ff027c7b535 Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Wed, 26 Jun 2019 14:42:19 +0530 Subject: [PATCH 07/11] Remove unnecessary code and params --- tensorflow_addons/text/crf.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 4e2837f5e0..398bd58ac1 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -152,12 +152,11 @@ def _multi_seq_fn(): sequence_lengths_less_one = tf.maximum( tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) - forward_layer = tf.keras.layers.RNN( - forward_cell, return_sequences=True, return_state=True) + forward_layer = tf.keras.layers.RNN(forward_cell) mask = tf.sequence_mask(sequence_lengths_less_one, tf.shape(inputs)[1] - 1) - _, alphas = forward_layer(rest_of_input, first_input, mask=mask) + alphas = forward_layer(rest_of_input, first_input, mask=mask) log_norm = tf.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where( @@ -402,7 +401,6 @@ def __init__(self, num_tags, **kwargs): num_tags: An integer. The number of tags. """ super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) - self._num_tags = num_tags @property def state_size(self): @@ -471,10 +469,7 @@ def _multi_seq_fn(): mask = tf.sequence_mask(sequence_length_less_one, tf.shape(inputs)[1]) crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) crf_fwd_layer = tf.keras.layers.RNN( - crf_fwd_cell, - return_sequences=True, - return_state=True, - time_major=False) + crf_fwd_cell, return_sequences=True, return_state=True) backpointers, last_score = crf_fwd_layer( inputs, initial_state, mask=mask) backpointers = tf.reverse_sequence( @@ -484,11 +479,8 @@ def _multi_seq_fn(): initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) initial_state = tf.expand_dims(initial_state, axis=-1) crf_bwd_layer = tf.keras.layers.RNN( - crf_bwd_cell, - return_sequences=True, - return_state=True, - time_major=False) - decode_tags, _ = crf_bwd_layer(backpointers, initial_state) + crf_bwd_cell, return_sequences=True) + decode_tags = crf_bwd_layer(backpointers, initial_state) decode_tags = tf.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = tf.concat( From 0b781c984b608b9cc4c4f4721de543e07c3f4490 Mon Sep 17 00:00:00 2001 From: "Dheeraj R. Reddy" Date: Fri, 5 Jul 2019 16:03:45 +0530 Subject: [PATCH 08/11] Replace RNN Cells with tf.scan --- tensorflow_addons/text/__init__.py | 6 +- tensorflow_addons/text/crf.py | 143 +++++++++-------------------- 2 files changed, 48 insertions(+), 101 deletions(-) diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py index 865b72725a..11f8f9fecb 100644 --- a/tensorflow_addons/text/__init__.py +++ b/tensorflow_addons/text/__init__.py @@ -20,14 +20,14 @@ # Conditional Random Field from tensorflow_addons.text.crf import crf_binary_score from tensorflow_addons.text.crf import crf_decode +from tensorflow_addons.text.crf import crf_decode_backward +from tensorflow_addons.text.crf import crf_decode_forward +from tensorflow_addons.text.crf import crf_forward from tensorflow_addons.text.crf import crf_log_likelihood from tensorflow_addons.text.crf import crf_log_norm from tensorflow_addons.text.crf import crf_multitag_sequence_score from tensorflow_addons.text.crf import crf_sequence_score from tensorflow_addons.text.crf import crf_unary_score -from tensorflow_addons.text.crf import CrfDecodeBackwardRnnCell -from tensorflow_addons.text.crf import CrfDecodeForwardRnnCell -from tensorflow_addons.text.crf import CrfForwardRnnCell from tensorflow_addons.text.crf import viterbi_decode # Skip Gram Sampling diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 398bd58ac1..e9c58d2b0b 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -147,16 +147,9 @@ def _multi_seq_fn(): rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm in order to get the # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - # Sequence length is not allowed to be less than zero. - sequence_lengths_less_one = tf.maximum( - tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1) - forward_layer = tf.keras.layers.RNN(forward_cell) - - mask = tf.sequence_mask(sequence_lengths_less_one, - tf.shape(inputs)[1] - 1) - alphas = forward_layer(rest_of_input, first_input, mask=mask) + alphas = crf_forward(rest_of_input, first_input, transition_params, + sequence_lengths) log_norm = tf.reduce_logsumexp(alphas, [1]) # Mask `log_norm` of the sequences with length <= zero. log_norm = tf.where( @@ -276,51 +269,28 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): return binary_scores -class CrfForwardRnnCell(tf.keras.layers.AbstractRNNCell): +@tf.function +def crf_forward(inputs, state, transition_params, sequence_lengths): """Computes the alpha values in a linear-chain CRF. See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. """ - def __init__(self, transition_params, **kwargs): - """Initialize the CrfForwardRnnCell. - Args: - transition_params: A [num_tags, num_tags] matrix of binary - potentials. This matrix is expanded into a - [1, num_tags, num_tags] in preparation for the - broadcast summation occurring within the cell. - """ - super(CrfForwardRnnCell, self).__init__(**kwargs) - self._transition_params = tf.expand_dims(transition_params, 0) - self._num_tags = transition_params.shape[0] - - @property - def state_size(self): - return self._num_tags - - @property - def output_size(self): - return self._num_tags - - def build(self, input_shape): - super(CrfForwardRnnCell, self).build(input_shape) + sequence_lengths = tf.maximum( + tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2) + inputs = tf.transpose(inputs, [1, 0, 2]) + transition_params = tf.expand_dims(transition_params, 0) - def call(self, inputs, state): - """Build the CrfForwardRnnCell. - - Args: - inputs: A [batch_size, num_tags] matrix of unary potentials. - state: A [batch_size, num_tags] matrix containing the - previous alpha values. - scope: Unused variable scope of this cell. - Returns: - new_alphas, new_alphas: A pair of [batch_size, num_tags] - matrices values containing the new alpha values. - """ - state = tf.expand_dims(state[0], 2) - transition_scores = state + self._transition_params + def _scan_fn(state, inputs): + state = tf.expand_dims(state, 2) + transition_scores = state + transition_params new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) - return new_alphas, new_alphas + return new_alphas + + all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + idxs = tf.stack( + [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1) + return tf.gather_nd(all_alphas, idxs) def viterbi_decode(score, transition_params): @@ -391,36 +361,27 @@ def call(self, inputs, state): return backpointers, new_state -class CrfDecodeBackwardRnnCell(tf.keras.layers.Layer): - """Computes backward decoding in a linear-chain CRF.""" - - def __init__(self, num_tags, **kwargs): - """Initialize the CrfDecodeBackwardRnnCell. - - Args: - num_tags: An integer. The number of tags. - """ - super(CrfDecodeBackwardRnnCell, self).__init__(**kwargs) +@tf.function +def crf_decode_forward(inputs, state, transition_params, sequence_lengths): + mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_layer = tf.keras.layers.RNN( + crf_fwd_cell, return_sequences=True, return_state=True) + return crf_fwd_layer(inputs, state, mask=mask) - @property - def state_size(self): - return 1 - @property - def output_size(self): - return 1 - - def build(self, input_shape): - super(CrfDecodeBackwardRnnCell, self).build(input_shape) +@tf.function +def crf_decode_backward(inputs, state): + """Computes backward decoding in a linear-chain CRF.""" + inputs = tf.transpose(inputs, [1, 0, 2]) - def call(self, inputs, state): - state = tf.squeeze(state[0], axis=[1]) - batch_size = tf.shape(inputs)[0] - b_indices = tf.range(batch_size) - indices = tf.stack([b_indices, state], axis=1) - new_tags = tf.expand_dims(tf.gather_nd(inputs, indices), axis=-1) + def _scan_fn(state, inputs): + state = tf.squeeze(state, axis=[1]) + idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) + return new_tags - return new_tags, new_tags + return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) @tf.function @@ -452,44 +413,30 @@ def _single_seq_fn(): def _multi_seq_fn(): """Decoding of highest scoring sequence.""" - - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.shape[2] - # Computes forward decoding. Get last score and backpointers. initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = tf.squeeze(initial_state, axis=[1]) # [B, O] - inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - # Sequence length is not allowed to be less than zero. + initial_state = tf.squeeze(initial_state, axis=[1]) + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) sequence_length_less_one = tf.maximum( tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1) - mask = tf.sequence_mask(sequence_length_less_one, tf.shape(inputs)[1]) - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - crf_fwd_layer = tf.keras.layers.RNN( - crf_fwd_cell, return_sequences=True, return_state=True) - backpointers, last_score = crf_fwd_layer( - inputs, initial_state, mask=mask) + backpointers, last_score = crf_decode_forward( + inputs, initial_state, transition_params, sequence_length_less_one) + backpointers = tf.reverse_sequence( backpointers, sequence_length_less_one, seq_axis=1) - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) initial_state = tf.expand_dims(initial_state, axis=-1) - crf_bwd_layer = tf.keras.layers.RNN( - crf_bwd_cell, return_sequences=True) - decode_tags = crf_bwd_layer(backpointers, initial_state) - - decode_tags = tf.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = tf.concat( - [initial_state, decode_tags], # [B, T] - axis=1) - decode_tags = tf.reverse_sequence( # [B, T] + + decode_tags = crf_decode_backward(backpointers, initial_state) + decode_tags = tf.squeeze(decode_tags, axis=[2]) + decode_tags = tf.concat([initial_state, decode_tags], axis=1) + decode_tags = tf.reverse_sequence( decode_tags, sequence_length, seq_axis=1) - best_score = tf.reduce_max(last_score, axis=1) # [B] + best_score = tf.reduce_max(last_score, axis=1) return decode_tags, best_score if potentials.shape[1] == 1: From ee052ceb8d4b6ff416d8bbd798f1197c677ef3f8 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Thu, 1 Aug 2019 01:42:08 +0530 Subject: [PATCH 09/11] Remove @tf.function wrappers --- tensorflow_addons/text/crf.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index e9c58d2b0b..53447bf30c 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -20,8 +20,9 @@ import numpy as np import tensorflow as tf +# TODO: Wrap functions in @tf.function once +# https://github.com/tensorflow/tensorflow/issues/29075 is resolved -@tf.function def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -66,7 +67,6 @@ def _multi_seq_fn(): return _multi_seq_fn() -@tf.function def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params): """Computes the unnormalized score of all tag sequences matching @@ -115,7 +115,6 @@ def _multi_seq_fn(): return _multi_seq_fn() -@tf.function def crf_log_norm(inputs, sequence_lengths, transition_params): """Computes the normalization for a CRF. @@ -199,7 +198,6 @@ def crf_log_likelihood(inputs, return log_likelihood, transition_params -@tf.function def crf_unary_score(tag_indices, sequence_lengths, inputs): """Computes the unary scores of tag sequences. @@ -234,7 +232,6 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): return unary_scores -@tf.function def crf_binary_score(tag_indices, sequence_lengths, transition_params): """Computes the binary scores of tag sequences. @@ -269,7 +266,6 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): return binary_scores -@tf.function def crf_forward(inputs, state, transition_params, sequence_lengths): """Computes the alpha values in a linear-chain CRF. @@ -361,7 +357,6 @@ def call(self, inputs, state): return backpointers, new_state -@tf.function def crf_decode_forward(inputs, state, transition_params, sequence_lengths): mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) @@ -370,7 +365,6 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths): return crf_fwd_layer(inputs, state, mask=mask) -@tf.function def crf_decode_backward(inputs, state): """Computes backward decoding in a linear-chain CRF.""" inputs = tf.transpose(inputs, [1, 0, 2]) @@ -384,7 +378,6 @@ def _scan_fn(state, inputs): return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) -@tf.function def crf_decode(potentials, transition_params, sequence_length): """Decode the highest scoring sequence of tags in TensorFlow. From 7a0435bf845a0ac6ae5301dd5a249704cd776df6 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Thu, 1 Aug 2019 01:54:02 +0530 Subject: [PATCH 10/11] Add missing docstrings --- tensorflow_addons/text/crf.py | 49 ++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 53447bf30c..a3d5928960 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -270,6 +270,19 @@ def crf_forward(inputs, state, transition_params, sequence_lengths): """Computes the alpha values in a linear-chain CRF. See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous alpha + values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + This matrix is expanded into a [1, num_tags, num_tags] in preparation + for the broadcast summation occurring within the cell. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + new_alphas: A [batch_size, num_tags] matrix containing the + new alpha values. """ sequence_lengths = tf.maximum( @@ -349,6 +362,17 @@ def build(self, input_shape): super(CrfDecodeForwardRnnCell, self).build(input_shape) def call(self, inputs, state): + """Build the CrfDecodeForwardRnnCell. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ state = tf.expand_dims(state[0], 2) transition_scores = state + self._transition_params new_state = inputs + tf.reduce_max(transition_scores, [1]) @@ -358,6 +382,19 @@ def call(self, inputs, state): def crf_decode_forward(inputs, state, transition_params, sequence_lengths): + """Computes forward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) crf_fwd_layer = tf.keras.layers.RNN( @@ -366,7 +403,17 @@ def crf_decode_forward(inputs, state, transition_params, sequence_lengths): def crf_decode_backward(inputs, state): - """Computes backward decoding in a linear-chain CRF.""" + """Computes backward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. + + Returns: + new_tags: A [batch_size, num_tags] + tensor containing the new tag indices. + """ inputs = tf.transpose(inputs, [1, 0, 2]) def _scan_fn(state, inputs): From 205ebd11db9d7c019483119fc372f4032bf82963 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Thu, 1 Aug 2019 02:32:12 +0530 Subject: [PATCH 11/11] reformat --- tensorflow_addons/text/crf.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index a3d5928960..d8d97bf216 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -23,6 +23,7 @@ # TODO: Wrap functions in @tf.function once # https://github.com/tensorflow/tensorflow/issues/29075 is resolved + def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -30,8 +31,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the unnormalized score. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the unnormalized score. sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix. Returns: @@ -171,15 +172,16 @@ def crf_log_likelihood(inputs, Args: inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials to use as input to the CRF layer. - tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we - compute the log-likelihood. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the log-likelihood. sequence_lengths: A [batch_size] vector of true sequence lengths. - transition_params: A [num_tags, num_tags] transition matrix, if available. + transition_params: A [num_tags, num_tags] transition matrix, + if available. Returns: log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of each example, given the sequence of tag indices. - transition_params: A [num_tags, num_tags] transition matrix. This is either - provided by the caller or created in this function. + transition_params: A [num_tags, num_tags] transition matrix. This is + either provided by the caller or created in this function. """ # Get shape information. num_tags = inputs.shape[2] @@ -252,7 +254,8 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) # Encode the indices in a flattened representation. - flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices + flattened_transition_indices = start_tag_indices * \ + num_tags + end_tag_indices flattened_transition_params = tf.reshape(transition_params, [-1]) # Get the binary scores based on the flattened representation. @@ -281,7 +284,7 @@ def crf_forward(inputs, state, transition_params, sequence_lengths): sequence_lengths: A [batch_size] vector of true sequence lengths. Returns: - new_alphas: A [batch_size, num_tags] matrix containing the + new_alphas: A [batch_size, num_tags] matrix containing the new alpha values. """