diff --git a/tensorflow_addons/rnn/esn_cell.py b/tensorflow_addons/rnn/esn_cell.py index d36be9f369..e85f638f43 100644 --- a/tensorflow_addons/rnn/esn_cell.py +++ b/tensorflow_addons/rnn/esn_cell.py @@ -32,6 +32,18 @@ class ESNCell(keras.layers.AbstractRNNCell): "The "echo state" approach to analysing and training recurrent neural networks". GMD Report148, German National Research Center for Information Technology, 2001. https://www.researchgate.net/publication/215385037 + + Example: + + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> ESNCell = tfa.rnn.ESNCell(4) + >>> rnn = tf.keras.layers.RNN(ESNCell, return_sequences=True, return_state=True) + >>> outputs, memory_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + Arguments: units: Positive integer, dimensionality in the reservoir. connectivity: Float between 0 and 1. diff --git a/tensorflow_addons/rnn/layer_norm_lstm_cell.py b/tensorflow_addons/rnn/layer_norm_lstm_cell.py index 3a3a359663..adff626673 100644 --- a/tensorflow_addons/rnn/layer_norm_lstm_cell.py +++ b/tensorflow_addons/rnn/layer_norm_lstm_cell.py @@ -46,6 +46,19 @@ class LayerNormLSTMCell(keras.layers.LSTMCell): "Recurrent Dropout without Memory Loss" Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. + + Example: + + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> lnLSTMCell = tfa.rnn.LayerNormLSTMCell(4) + >>> rnn = tf.keras.layers.RNN(lnLSTMCell, return_sequences=True, return_state=True) + >>> outputs, memory_state, carry_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + >>> carry_state.shape + TensorShape([30, 4]) """ @typechecked diff --git a/tensorflow_addons/rnn/layer_norm_simple_rnn_cell.py b/tensorflow_addons/rnn/layer_norm_simple_rnn_cell.py index f9562c6329..cc5c2b6d70 100644 --- a/tensorflow_addons/rnn/layer_norm_simple_rnn_cell.py +++ b/tensorflow_addons/rnn/layer_norm_simple_rnn_cell.py @@ -37,6 +37,17 @@ class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell): "Layer Normalization." ArXiv:1607.06450 [Cs, Stat], July 21, 2016. http://arxiv.org/abs/1607.06450 + Example: + + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> lnsRNNCell = tfa.rnn.LayerNormSimpleRNNCell(4) + >>> rnn = tf.keras.layers.RNN(lnsRNNCell, return_sequences=True, return_state=True) + >>> outputs, memory_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. @@ -89,25 +100,19 @@ class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell): Examples: - ```python - import numpy as np - import tensorflow.keras as keras - import tensorflow_addons as tfa - - inputs = np.random.random([32, 10, 8]).astype(np.float32) - rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4)) - - output = rnn(inputs) # The output has shape `[32, 4]`. - - rnn = keras.layers.RNN( - tfa.rnn.LayerNormSimpleRNNCell(4), - return_sequences=True, - return_state=True) + >>> inputs = np.random.random([32, 10, 8]).astype(np.float32) + >>> rnn = tf.keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4)) + >>> output = rnn(inputs) # The output has shape `[32, 4]`. + >>> rnn = tf.keras.layers.RNN( + ... tfa.rnn.LayerNormSimpleRNNCell(4), + ... return_sequences=True, + ... return_state=True) + >>> whole_sequence_output, final_state = rnn(inputs) + >>> whole_sequence_output + + >>> final_state + - # whole_sequence_output has shape `[32, 10, 4]`. - # final_state has shape `[32, 4]`. - whole_sequence_output, final_state = rnn(inputs) - ``` """ @typechecked diff --git a/tensorflow_addons/rnn/nas_cell.py b/tensorflow_addons/rnn/nas_cell.py index 05054a0233..6b6686ba4b 100644 --- a/tensorflow_addons/rnn/nas_cell.py +++ b/tensorflow_addons/rnn/nas_cell.py @@ -38,6 +38,19 @@ class NASCell(keras.layers.AbstractRNNCell): "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. The class uses an optional projection layer. + + Example: + + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> NASCell = tfa.rnn.NASCell(4) + >>> rnn = tf.keras.layers.RNN(NASCell, return_sequences=True, return_state=True) + >>> outputs, memory_state, carry_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + >>> carry_state.shape + TensorShape([30, 4]) """ # NAS cell's architecture base. diff --git a/tensorflow_addons/rnn/peephole_lstm_cell.py b/tensorflow_addons/rnn/peephole_lstm_cell.py index 1f791d53d8..658db084f4 100644 --- a/tensorflow_addons/rnn/peephole_lstm_cell.py +++ b/tensorflow_addons/rnn/peephole_lstm_cell.py @@ -39,14 +39,16 @@ class PeepholeLSTMCell(tf.keras.layers.LSTMCell): Example: - ```python - # Create 2 PeepholeLSTMCells - peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] - # Create a layer composed sequentially of the peephole LSTM cells. - layer = RNN(peephole_lstm_cells) - input = keras.Input((timesteps, input_dim)) - output = layer(input) - ``` + >>> inputs = np.random.random([30,23,9]).astype(np.float32) + >>> LSTMCell = tfa.rnn.PeepholeLSTMCell(4) + >>> rnn = tf.keras.layers.RNN(LSTMCell, return_sequences=True, return_state=True) + >>> outputs, memory_state, carry_state = rnn(inputs) + >>> outputs.shape + TensorShape([30, 23, 4]) + >>> memory_state.shape + TensorShape([30, 4]) + >>> carry_state.shape + TensorShape([30, 4]) """ def build(self, input_shape):