Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow_addons/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

# Conditional Random Field
from tensorflow_addons.text.crf import crf_binary_score
from tensorflow_addons.text.crf import crf_constrained_decode
from tensorflow_addons.text.crf import crf_decode
from tensorflow_addons.text.crf import crf_decode_backward
from tensorflow_addons.text.crf import crf_decode_forward
from tensorflow_addons.text.crf import crf_filtered_inputs
from tensorflow_addons.text.crf import crf_forward
from tensorflow_addons.text.crf import crf_log_likelihood
from tensorflow_addons.text.crf import crf_log_norm
Expand Down
61 changes: 56 additions & 5 deletions tensorflow_addons/text/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@
# https://github.com/tensorflow/tensorflow/issues/29075 is resolved


def crf_filtered_inputs(inputs: TensorLike, tag_bitmap: TensorLike) -> tf.Tensor:
"""Constrains the inputs to filter out certain tags at each time step.

tag_bitmap limits the allowed tags at each input time step.
This is useful when an observed output at a given time step needs to be
constrained to a selected set of tags.

Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
representing all active tags at each index for which to calculate the
unnormalized score.
Returns:
filtered_inputs: A [batch_size] vector of unnormalized sequence scores.
"""

# set scores of filtered out inputs to be -inf.
filtered_inputs = tf.where(
tag_bitmap,
inputs,
tf.fill(tf.shape(inputs), tf.cast(float("-inf"), inputs.dtype)),
)
return filtered_inputs


def crf_sequence_score(
inputs: TensorLike,
tag_indices: TensorLike,
Expand Down Expand Up @@ -107,11 +133,7 @@ def crf_multitag_sequence_score(
"""
tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
filtered_inputs = tf.where(
tag_bitmap,
inputs,
tf.fill(tf.shape(inputs), tf.cast(float("-inf"), inputs.dtype)),
)
filtered_inputs = crf_filtered_inputs(inputs, tag_bitmap)

# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of all active tags.
Expand Down Expand Up @@ -559,3 +581,32 @@ def _multi_seq_fn():
return tf.cond(
tf.equal(tf.shape(potentials)[1], 1), _single_seq_fn, _multi_seq_fn
)


def crf_constrained_decode(
potentials: TensorLike,
tag_bitmap: TensorLike,
transition_params: TensorLike,
sequence_length: TensorLike,
) -> tf.Tensor:
"""Decode the highest scoring sequence of tags under constraints.

This is a function for tensor.

Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor of
unary potentials.
tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
representing all active tags at each index for which to calculate the
unnormalized score.
transition_params: A [num_tags, num_tags] matrix of
binary potentials.
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indices.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""

filtered_potentials = crf_filtered_inputs(potentials, tag_bitmap)
return crf_decode(filtered_potentials, transition_params, sequence_length)
157 changes: 134 additions & 23 deletions tensorflow_addons/text/tests/crf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,82 @@ def calculate_sequence_score(inputs, transition_params, tag_indices, sequence_le
return expected_unary_score + expected_binary_score


def brute_force_decode(sequence_lengths, inputs, transition_params):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]

all_sequence_scores = []
all_sequences = []

tag_indices_iterator = itertools.product(range(num_tags), repeat=sequence_lengths)
inputs = tf.expand_dims(inputs, 0)
sequence_lengths = tf.expand_dims(sequence_lengths, 0)
transition_params = tf.constant(transition_params)

# Compare the dynamic program with brute force computation.
for tag_indices in tag_indices_iterator:
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = text.crf_sequence_score(
inputs=inputs,
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=sequence_lengths,
transition_params=transition_params,
)
sequence_score = tf.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)

expected_max_sequence_index = np.argmax(all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = all_sequence_scores[expected_max_sequence_index]
return expected_max_sequence, expected_max_score


@pytest.mark.parametrize("dtype", [np.float16, np.float32])
def test_crf_filtered_inputs(dtype):
# Test both the length-1 and regular cases.
sequence_lengths_list = [np.array(3, dtype=np.int32), np.array(1, dtype=np.int32)]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype),
np.array([[4, 5, -3]], dtype=dtype),
]
tag_bitmap_list = [
np.array(
[
[True, False, False],
[False, True, True],
[False, True, True],
[False, True, True],
],
dtype=np.bool,
),
np.array([[False, True, True]], dtype=np.bool),
]
neg_inf = float("-inf")
expected_filtered_inputs_list = [
np.array(
[[4, neg_inf, neg_inf], [neg_inf, -1, 3], [neg_inf, 2, 1], [neg_inf, 0, 0]],
dtype=dtype,
),
np.array([[neg_inf, 5, -3]], dtype=dtype),
]
for sequence_lengths, inputs, tag_bitmap, expected_filtered_inputs in zip(
sequence_lengths_list,
inputs_list,
tag_bitmap_list,
expected_filtered_inputs_list,
):
filtered_inputs = text.crf_filtered_inputs(
inputs=tf.expand_dims(inputs, 0), tag_bitmap=tf.expand_dims(tag_bitmap, 0)
)
filtered_inputs = tf.squeeze(filtered_inputs, [0])

test_utils.assert_allclose_according_to_type(
filtered_inputs, expected_filtered_inputs
)


@pytest.mark.parametrize("dtype", [np.float16, np.float32])
def test_crf_sequence_score(dtype):
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype)
Expand Down Expand Up @@ -309,29 +385,9 @@ def test_crf_decode(dtype):
for sequence_lengths, inputs, tag_indices in zip(
sequence_lengths_list, inputs_list, tag_indices_list
):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]

all_sequence_scores = []
all_sequences = []

# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = text.crf_sequence_score(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params),
)
sequence_score = tf.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)

expected_max_sequence_index = np.argmax(all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = all_sequence_scores[expected_max_sequence_index]
expected_max_sequence, expected_max_score = brute_force_decode(
sequence_lengths, inputs, transition_params
)

actual_max_sequence, actual_max_score = text.crf_decode(
tf.expand_dims(inputs, 0),
Expand All @@ -350,6 +406,61 @@ def test_crf_decode(dtype):
)


@pytest.mark.parametrize("dtype", [np.float16, np.float32])
def test_crf_constrained_decode(dtype):
transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype)
# Test both the length-1 and regular cases.
sequence_lengths_list = [np.array(3, dtype=np.int32), np.array(1, dtype=np.int32)]
inputs_list = [
np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype),
np.array([[4, 5, -3]], dtype=dtype),
]
tag_bitmap_list = [
np.array(
[
[True, False, False],
[False, True, True],
[False, True, True],
[False, True, True],
],
dtype=np.bool,
),
np.array([[False, True, True]], dtype=np.bool),
]
for sequence_lengths, inputs, tag_bitmap in zip(
sequence_lengths_list, inputs_list, tag_bitmap_list
):
filtered_inputs = text.crf_filtered_inputs(
inputs=tf.expand_dims(inputs, 0), tag_bitmap=tf.expand_dims(tag_bitmap, 0)
)

expected_max_sequence, expected_max_score = text.crf_decode(
filtered_inputs,
tf.constant(transition_params),
tf.expand_dims(sequence_lengths, 0),
)

expected_max_sequence = tf.squeeze(expected_max_sequence, [0])
expected_max_score = tf.squeeze(expected_max_score, [0])

actual_max_sequence, actual_max_score = text.crf_constrained_decode(
tf.expand_dims(inputs, 0),
tf.expand_dims(tag_bitmap, 0),
tf.constant(transition_params),
tf.expand_dims(sequence_lengths, 0),
)

actual_max_sequence = tf.squeeze(actual_max_sequence, [0])
actual_max_score = tf.squeeze(actual_max_score, [0])

test_utils.assert_allclose_according_to_type(
actual_max_score, expected_max_score, 1e-6, 1e-6
)
assert list(actual_max_sequence[:sequence_lengths]) == list(
expected_max_sequence[:sequence_lengths]
)


def test_crf_decode_zero_seq_length():
"""Test that crf_decode works when sequence_length contains one or more
zeros."""
Expand Down