Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def _monotonic_probability_fn(score,
test-time, and when hard attention is not desired.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tf.contrib.seq2seq.monotonic_attention` for more information.
`tfa.seq2seq.monotonic_attention` for more information.
seed: (optional) Random seed for pre-sigmoid noise.

Returns:
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def __init__(self,
of the memory is large.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tf.contrib.seq2seq.monotonic_attention` for more information.
`tfa.seq2seq.monotonic_attention` for more information.
kernel_initializer: (optional), the name of the initializer for the
attention kernel.
dtype: The data type for the query and memory layers of the attention
Expand Down Expand Up @@ -1214,7 +1214,7 @@ def __init__(self,
of the memory is large.
mode: How to compute the attention distribution. Must be one of
'recursive', 'parallel', or 'hard'. See the docstring for
`tf.contrib.seq2seq.monotonic_attention` for more information.
`tfa.seq2seq.monotonic_attention` for more information.
dtype: The data type for the query and memory layers of the attention
mechanism.
name: Name to use when creating ops.
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def __init__(self,
in `AttentionWrapper`, then you must ensure that:

- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
`tfa.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `get_initial_state` method of
this wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `get_initial_state` above contains a
Expand All @@ -1557,11 +1557,11 @@ def __init__(self,
An example:

```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
tiled_encoder_outputs = tfa.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
tiled_encoder_final_state = tfa.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
tiled_sequence_length = tfa.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
Expand Down Expand Up @@ -1725,7 +1725,7 @@ def __init__(self,
"Non-matching batch sizes between the memory "
"(encoder output) and initial_cell_state. Are you using "
"the BeamSearchDecoder? You may need to tile your "
"initial state via the tf.contrib.seq2seq.tile_batch "
"initial state via the tfa.seq2seq.tile_batch "
"function with argument multiple=beam_width.")
with tf.control_dependencies(
self._batch_size_checks( # pylint: disable=bad-continuation
Expand Down Expand Up @@ -1827,7 +1827,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
"(encoder output) and the requested batch size. Are you using "
"the BeamSearchDecoder? If so, make sure your encoder output "
"has been tiled to beam_width via "
"tf.contrib.seq2seq.tile_batch, and the batch_size= argument "
"tfa.seq2seq.tile_batch, and the batch_size= argument "
"passed to get_initial_state is batch_size * beam_width.")
with tf.control_dependencies(
self._batch_size_checks(batch_size, error_message)): # pylint: disable=bad-continuation
Expand Down Expand Up @@ -1908,7 +1908,7 @@ def call(self, inputs, state, **kwargs):
"Non-matching batch sizes between the memory "
"(encoder output) and the query (decoder output). Are you using "
"the BeamSearchDecoder? You may need to tile your memory input "
"via the tf.contrib.seq2seq.tile_batch function with argument "
"via the tfa.seq2seq.tile_batch function with argument "
"multiple=beam_width.")
with tf.control_dependencies(
self._batch_size_checks(cell_batch_size, error_message)): # pylint: disable=bad-continuation
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
"""Tests for tfa.seq2seq.attention_wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_addons/seq2seq/basic_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,19 @@ def output_dtype(self):
tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
self.sampler.sample_ids_dtype)

def step(self, time, inputs, state):
def step(self, time, inputs, state, training=None):
"""Perform a decoding step.

Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
training: Python boolean.

Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
cell_outputs, cell_state = self.cell(inputs, state)
cell_outputs, cell_state = self.cell(inputs, state, training=training)
if self.output_layer is not None:
cell_outputs = self.output_layer(cell_outputs)
sample_ids = self.sampler.sample(
Expand Down
24 changes: 17 additions & 7 deletions tensorflow_addons/seq2seq/beam_search_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,16 @@ def _maybe_sort_array_beams(self, t, parent_ids, sequence_length):
[_check_batch_beam(t, self._batch_size, self._beam_width)]):
return gather_tree_from_array(t, parent_ids, sequence_length)

def step(self, time, inputs, state, name=None):
def step(self, time, inputs, state, training=None, name=None):
"""Perform a decoding step.

Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
training: Python boolean. Indicates whether the layer should
behave in training mode or in inference mode. Only relevant
when `dropout` or `recurrent_dropout` is used.
name: Name scope for any created operations.

Returns:
Expand All @@ -544,7 +547,8 @@ def step(self, time, inputs, state, name=None):
cell_state = tf.nest.map_structure(self._maybe_merge_batch_beams,
cell_state,
self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs, next_cell_state = self._cell(
inputs, cell_state, training=training)
cell_outputs = tf.nest.map_structure(
lambda out: self._split_batch_beams(out, out.shape[1:]),
cell_outputs)
Expand Down Expand Up @@ -586,7 +590,7 @@ class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder):
`AttentionWrapper`, then you must ensure that:

- The encoder output has been tiled to `beam_width` via
`tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
`tfa.seq2seq.tile_batch` (NOT `tf.tile`).
- The `batch_size` argument passed to the `get_initial_state` method of
this wrapper is equal to `true_batch_size * beam_width`.
- The initial state created with `get_initial_state` above contains a
Expand All @@ -596,11 +600,11 @@ class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder):
An example:

```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
tiled_encoder_outputs = tfa.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
tiled_encoder_final_state = tfa.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
tiled_sequence_length = tfa.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
Expand Down Expand Up @@ -752,7 +756,12 @@ def output_dtype(self):
predicted_ids=tf.int32,
parent_ids=tf.int32)

def call(self, embeddning, start_tokens, end_token, initial_state,
def call(self,
embeddning,
start_tokens,
end_token,
initial_state,
training=None,
**kwargs):
init_kwargs = kwargs
init_kwargs["start_tokens"] = start_tokens
Expand All @@ -765,6 +774,7 @@ def call(self, embeddning, start_tokens, end_token, initial_state,
maximum_iterations=self.maximum_iterations,
parallel_iterations=self.parallel_iterations,
swap_memory=self.swap_memory,
training=training,
decoder_init_input=embeddning,
decoder_init_kwargs=init_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/seq2seq/beam_search_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder."""
"""Tests for tfa.seq2seq.seq2seq.beam_search_decoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
22 changes: 18 additions & 4 deletions tensorflow_addons/seq2seq/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Decoder(object):
RNNCell instance as the state.
- `finished`: boolean tensor telling whether each sequence in the batch is
finished.
- `training`: boolean whether it should behave in training mode or in
inference mode.
- `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at
each time step.
"""
Expand Down Expand Up @@ -79,7 +81,7 @@ def initialize(self, name=None):
raise NotImplementedError

@abc.abstractmethod
def step(self, time, inputs, state, name=None):
def step(self, time, inputs, state, training=None, name=None):
"""Called per step of decoding (but only once for dynamic decoding).

Args:
Expand All @@ -88,6 +90,9 @@ def step(self, time, inputs, state, name=None):
time step.
state: RNNCell state (possibly nested tuple of) tensor[s] from
previous time step.
training: Python boolean. Indicates whether the layer should behave
in training mode or in inference mode. Only relevant
when `dropout` or `recurrent_dropout` is used.
name: Name scope for any created operations.

Returns:
Expand Down Expand Up @@ -136,6 +141,8 @@ class BaseDecoder(tf.keras.layers.Layer):
encoder, which will be used for the attention wrapper for the RNNCell.
- `finished`: boolean tensor telling whether each sequence in the batch is
finished.
- `training`: boolean whether it should behave in training mode or in
inference mode.
- `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at
each time step.
"""
Expand All @@ -154,7 +161,7 @@ def __init__(self,
self.swap_memory = swap_memory
super(BaseDecoder, self).__init__(**kwargs)

def call(self, inputs, initial_state=None, **kwargs):
def call(self, inputs, initial_state=None, training=None, **kwargs):
init_kwargs = kwargs
init_kwargs["initial_state"] = initial_state
return dynamic_decode(
Expand All @@ -164,6 +171,7 @@ def call(self, inputs, initial_state=None, **kwargs):
maximum_iterations=self.maximum_iterations,
parallel_iterations=self.parallel_iterations,
swap_memory=self.swap_memory,
training=training,
decoder_init_input=inputs,
decoder_init_kwargs=init_kwargs)

Expand Down Expand Up @@ -204,7 +212,7 @@ def initialize(self, inputs, initial_state=None, **kwargs):
"""
raise NotImplementedError

def step(self, time, inputs, state):
def step(self, time, inputs, state, training):
"""Called per step of decoding (but only once for dynamic decoding).

Args:
Expand All @@ -213,6 +221,8 @@ def step(self, time, inputs, state):
time step.
state: RNNCell state (possibly nested tuple of) tensor[s] from
previous time step.
training: Python boolean. Indicates whether the layer should
behave in training mode or in inference mode.

Returns:
`(outputs, next_state, next_inputs, finished)`: `outputs` is an
Expand Down Expand Up @@ -265,6 +275,7 @@ def dynamic_decode(decoder,
maximum_iterations=None,
parallel_iterations=32,
swap_memory=False,
training=None,
scope=None,
**kwargs):
"""Perform dynamic decoding with `decoder`.
Expand All @@ -287,6 +298,9 @@ def dynamic_decode(decoder,
steps. Default is `None` (decode until the decoder is fully done).
parallel_iterations: Argument passed to `tf.while_loop`.
swap_memory: Argument passed to `tf.while_loop`.
training: Python boolean. Indicates whether the layer should behave
in training mode or in inference mode. Only relevant
when `dropout` or `recurrent_dropout` is used.
scope: Optional variable scope to use.
**kwargs: dict, other keyword arguments for dynamic_decode. It might
contain arguments for `BaseDecoder` to initialize, which takes all
Expand Down Expand Up @@ -389,7 +403,7 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
```
"""
(next_outputs, decoder_state, next_inputs,
decoder_finished) = decoder.step(time, inputs, state)
decoder_finished) = decoder.step(time, inputs, state, training)
if decoder.tracks_own_finished:
next_finished = decoder_finished
else:
Expand Down