Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow_addons/rnn/esn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ class ESNCell(keras.layers.AbstractRNNCell):
"The "echo state" approach to analysing and training recurrent neural networks".
GMD Report148, German National Research Center for Information Technology, 2001.
https://www.researchgate.net/publication/215385037

Example:

>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> ESNCell = tfa.rnn.ESNCell(4)
>>> rnn = tf.keras.layers.RNN(ESNCell, return_sequences=True, return_state=True)
>>> outputs, memory_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])

Arguments:
units: Positive integer, dimensionality in the reservoir.
connectivity: Float between 0 and 1.
Expand Down
13 changes: 13 additions & 0 deletions tensorflow_addons/rnn/layer_norm_lstm_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ class LayerNormLSTMCell(keras.layers.LSTMCell):

"Recurrent Dropout without Memory Loss"
Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.

Example:

>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> lnLSTMCell = tfa.rnn.LayerNormLSTMCell(4)
>>> rnn = tf.keras.layers.RNN(lnLSTMCell, return_sequences=True, return_state=True)
>>> outputs, memory_state, carry_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])
>>> carry_state.shape
TensorShape([30, 4])
"""

@typechecked
Expand Down
41 changes: 23 additions & 18 deletions tensorflow_addons/rnn/layer_norm_simple_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell):
"Layer Normalization." ArXiv:1607.06450 [Cs, Stat],
July 21, 2016. http://arxiv.org/abs/1607.06450

Example:

>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> lnsRNNCell = tfa.rnn.LayerNormSimpleRNNCell(4)
>>> rnn = tf.keras.layers.RNN(lnsRNNCell, return_sequences=True, return_state=True)
>>> outputs, memory_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])

Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Expand Down Expand Up @@ -89,25 +100,19 @@ class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell):

Examples:

```python
import numpy as np
import tensorflow.keras as keras
import tensorflow_addons as tfa

inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4))

output = rnn(inputs) # The output has shape `[32, 4]`.

rnn = keras.layers.RNN(
tfa.rnn.LayerNormSimpleRNNCell(4),
return_sequences=True,
return_state=True)
>>> inputs = np.random.random([32, 10, 8]).astype(np.float32)
>>> rnn = tf.keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4))
>>> output = rnn(inputs) # The output has shape `[32, 4]`.
>>> rnn = tf.keras.layers.RNN(
... tfa.rnn.LayerNormSimpleRNNCell(4),
... return_sequences=True,
... return_state=True)
>>> whole_sequence_output, final_state = rnn(inputs)
>>> whole_sequence_output
<tf.Tensor: shape=(32, 10, 4), dtype=float32, numpy=...>
>>> final_state
<tf.Tensor: shape=(32, 4), dtype=float32, numpy=...>

# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""

@typechecked
Expand Down
13 changes: 13 additions & 0 deletions tensorflow_addons/rnn/nas_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ class NASCell(keras.layers.AbstractRNNCell):
"Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.

The class uses an optional projection layer.

Example:

>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> NASCell = tfa.rnn.NASCell(4)
>>> rnn = tf.keras.layers.RNN(NASCell, return_sequences=True, return_state=True)
>>> outputs, memory_state, carry_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])
>>> carry_state.shape
TensorShape([30, 4])
"""

# NAS cell's architecture base.
Expand Down
18 changes: 10 additions & 8 deletions tensorflow_addons/rnn/peephole_lstm_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ class PeepholeLSTMCell(tf.keras.layers.LSTMCell):

Example:

```python
# Create 2 PeepholeLSTMCells
peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
# Create a layer composed sequentially of the peephole LSTM cells.
layer = RNN(peephole_lstm_cells)
input = keras.Input((timesteps, input_dim))
output = layer(input)
```
>>> inputs = np.random.random([30,23,9]).astype(np.float32)
>>> LSTMCell = tfa.rnn.PeepholeLSTMCell(4)
>>> rnn = tf.keras.layers.RNN(LSTMCell, return_sequences=True, return_state=True)
>>> outputs, memory_state, carry_state = rnn(inputs)
>>> outputs.shape
TensorShape([30, 23, 4])
>>> memory_state.shape
TensorShape([30, 4])
>>> carry_state.shape
TensorShape([30, 4])
"""

def build(self, input_shape):
Expand Down