diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 6345d4d43c..4c73f86554 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -946,7 +946,7 @@ def _monotonic_probability_fn(score, test-time, and when hard attention is not desired. mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for - `tf.contrib.seq2seq.monotonic_attention` for more information. + `tfa.seq2seq.monotonic_attention` for more information. seed: (optional) Random seed for pre-sigmoid noise. Returns: @@ -1042,7 +1042,7 @@ def __init__(self, of the memory is large. mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for - `tf.contrib.seq2seq.monotonic_attention` for more information. + `tfa.seq2seq.monotonic_attention` for more information. kernel_initializer: (optional), the name of the initializer for the attention kernel. dtype: The data type for the query and memory layers of the attention @@ -1214,7 +1214,7 @@ def __init__(self, of the memory is large. mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for - `tf.contrib.seq2seq.monotonic_attention` for more information. + `tfa.seq2seq.monotonic_attention` for more information. dtype: The data type for the query and memory layers of the attention mechanism. name: Name to use when creating ops. @@ -1547,7 +1547,7 @@ def __init__(self, in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + `tfa.seq2seq.tile_batch` (NOT `tf.tile`). - The `batch_size` argument passed to the `get_initial_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `get_initial_state` above contains a @@ -1557,11 +1557,11 @@ def __init__(self, An example: ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + tiled_encoder_outputs = tfa.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + tiled_encoder_final_state = tfa.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + tiled_sequence_length = tfa.seq2seq.tile_batch( sequence_length, multiplier=beam_width) attention_mechanism = MyFavoriteAttentionMechanism( num_units=attention_depth, @@ -1725,7 +1725,7 @@ def __init__(self, "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state. Are you using " "the BeamSearchDecoder? You may need to tile your " - "initial state via the tf.contrib.seq2seq.tile_batch " + "initial state via the tfa.seq2seq.tile_batch " "function with argument multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks( # pylint: disable=bad-continuation @@ -1827,7 +1827,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None): "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output " "has been tiled to beam_width via " - "tf.contrib.seq2seq.tile_batch, and the batch_size= argument " + "tfa.seq2seq.tile_batch, and the batch_size= argument " "passed to get_initial_state is batch_size * beam_width.") with tf.control_dependencies( self._batch_size_checks(batch_size, error_message)): # pylint: disable=bad-continuation @@ -1908,7 +1908,7 @@ def call(self, inputs, state, **kwargs): "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input " - "via the tf.contrib.seq2seq.tile_batch function with argument " + "via the tfa.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): # pylint: disable=bad-continuation diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index d9538436aa..95d95d4ce3 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" +"""Tests for tfa.seq2seq.attention_wrapper.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow_addons/seq2seq/basic_decoder.py b/tensorflow_addons/seq2seq/basic_decoder.py index 930393a27c..779bfaae39 100644 --- a/tensorflow_addons/seq2seq/basic_decoder.py +++ b/tensorflow_addons/seq2seq/basic_decoder.py @@ -111,18 +111,19 @@ def output_dtype(self): tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()), self.sampler.sample_ids_dtype) - def step(self, time, inputs, state): + def step(self, time, inputs, state, training=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. + training: Python boolean. Returns: `(outputs, next_state, next_inputs, finished)`. """ - cell_outputs, cell_state = self.cell(inputs, state) + cell_outputs, cell_state = self.cell(inputs, state, training=training) if self.output_layer is not None: cell_outputs = self.output_layer(cell_outputs) sample_ids = self.sampler.sample( diff --git a/tensorflow_addons/seq2seq/beam_search_decoder.py b/tensorflow_addons/seq2seq/beam_search_decoder.py index 09607af5ae..0bc9074c72 100644 --- a/tensorflow_addons/seq2seq/beam_search_decoder.py +++ b/tensorflow_addons/seq2seq/beam_search_decoder.py @@ -518,13 +518,16 @@ def _maybe_sort_array_beams(self, t, parent_ids, sequence_length): [_check_batch_beam(t, self._batch_size, self._beam_width)]): return gather_tree_from_array(t, parent_ids, sequence_length) - def step(self, time, inputs, state, name=None): + def step(self, time, inputs, state, training=None, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. + training: Python boolean. Indicates whether the layer should + behave in training mode or in inference mode. Only relevant + when `dropout` or `recurrent_dropout` is used. name: Name scope for any created operations. Returns: @@ -544,7 +547,8 @@ def step(self, time, inputs, state, name=None): cell_state = tf.nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) - cell_outputs, next_cell_state = self._cell(inputs, cell_state) + cell_outputs, next_cell_state = self._cell( + inputs, cell_state, training=training) cell_outputs = tf.nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) @@ -586,7 +590,7 @@ class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder): `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + `tfa.seq2seq.tile_batch` (NOT `tf.tile`). - The `batch_size` argument passed to the `get_initial_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `get_initial_state` above contains a @@ -596,11 +600,11 @@ class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.BaseDecoder): An example: ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + tiled_encoder_outputs = tfa.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + tiled_encoder_final_state = tfa.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + tiled_sequence_length = tfa.seq2seq.tile_batch( sequence_length, multiplier=beam_width) attention_mechanism = MyFavoriteAttentionMechanism( num_units=attention_depth, @@ -752,7 +756,12 @@ def output_dtype(self): predicted_ids=tf.int32, parent_ids=tf.int32) - def call(self, embeddning, start_tokens, end_token, initial_state, + def call(self, + embeddning, + start_tokens, + end_token, + initial_state, + training=None, **kwargs): init_kwargs = kwargs init_kwargs["start_tokens"] = start_tokens @@ -765,6 +774,7 @@ def call(self, embeddning, start_tokens, end_token, initial_state, maximum_iterations=self.maximum_iterations, parallel_iterations=self.parallel_iterations, swap_memory=self.swap_memory, + training=training, decoder_init_input=embeddning, decoder_init_kwargs=init_kwargs) diff --git a/tensorflow_addons/seq2seq/beam_search_decoder_test.py b/tensorflow_addons/seq2seq/beam_search_decoder_test.py index 986b57e351..2e92c75844 100644 --- a/tensorflow_addons/seq2seq/beam_search_decoder_test.py +++ b/tensorflow_addons/seq2seq/beam_search_decoder_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder.""" +"""Tests for tfa.seq2seq.seq2seq.beam_search_decoder.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow_addons/seq2seq/decoder.py b/tensorflow_addons/seq2seq/decoder.py index 214e5ed2ce..cd5a25c2bc 100644 --- a/tensorflow_addons/seq2seq/decoder.py +++ b/tensorflow_addons/seq2seq/decoder.py @@ -43,6 +43,8 @@ class Decoder(object): RNNCell instance as the state. - `finished`: boolean tensor telling whether each sequence in the batch is finished. + - `training`: boolean whether it should behave in training mode or in + inference mode. - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each time step. """ @@ -79,7 +81,7 @@ def initialize(self, name=None): raise NotImplementedError @abc.abstractmethod - def step(self, time, inputs, state, name=None): + def step(self, time, inputs, state, training=None, name=None): """Called per step of decoding (but only once for dynamic decoding). Args: @@ -88,6 +90,9 @@ def step(self, time, inputs, state, name=None): time step. state: RNNCell state (possibly nested tuple of) tensor[s] from previous time step. + training: Python boolean. Indicates whether the layer should behave + in training mode or in inference mode. Only relevant + when `dropout` or `recurrent_dropout` is used. name: Name scope for any created operations. Returns: @@ -136,6 +141,8 @@ class BaseDecoder(tf.keras.layers.Layer): encoder, which will be used for the attention wrapper for the RNNCell. - `finished`: boolean tensor telling whether each sequence in the batch is finished. + - `training`: boolean whether it should behave in training mode or in + inference mode. - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each time step. """ @@ -154,7 +161,7 @@ def __init__(self, self.swap_memory = swap_memory super(BaseDecoder, self).__init__(**kwargs) - def call(self, inputs, initial_state=None, **kwargs): + def call(self, inputs, initial_state=None, training=None, **kwargs): init_kwargs = kwargs init_kwargs["initial_state"] = initial_state return dynamic_decode( @@ -164,6 +171,7 @@ def call(self, inputs, initial_state=None, **kwargs): maximum_iterations=self.maximum_iterations, parallel_iterations=self.parallel_iterations, swap_memory=self.swap_memory, + training=training, decoder_init_input=inputs, decoder_init_kwargs=init_kwargs) @@ -204,7 +212,7 @@ def initialize(self, inputs, initial_state=None, **kwargs): """ raise NotImplementedError - def step(self, time, inputs, state): + def step(self, time, inputs, state, training): """Called per step of decoding (but only once for dynamic decoding). Args: @@ -213,6 +221,8 @@ def step(self, time, inputs, state): time step. state: RNNCell state (possibly nested tuple of) tensor[s] from previous time step. + training: Python boolean. Indicates whether the layer should + behave in training mode or in inference mode. Returns: `(outputs, next_state, next_inputs, finished)`: `outputs` is an @@ -265,6 +275,7 @@ def dynamic_decode(decoder, maximum_iterations=None, parallel_iterations=32, swap_memory=False, + training=None, scope=None, **kwargs): """Perform dynamic decoding with `decoder`. @@ -287,6 +298,9 @@ def dynamic_decode(decoder, steps. Default is `None` (decode until the decoder is fully done). parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. + training: Python boolean. Indicates whether the layer should behave + in training mode or in inference mode. Only relevant + when `dropout` or `recurrent_dropout` is used. scope: Optional variable scope to use. **kwargs: dict, other keyword arguments for dynamic_decode. It might contain arguments for `BaseDecoder` to initialize, which takes all @@ -389,7 +403,7 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths): ``` """ (next_outputs, decoder_state, next_inputs, - decoder_finished) = decoder.step(time, inputs, state) + decoder_finished) = decoder.step(time, inputs, state, training) if decoder.tracks_own_finished: next_finished = decoder_finished else: