Skip to content

Commit 4e0aefe

Browse files
qlzh727seanpmorgan
authored andcommitted
Add subpackage for RNN related code. (#189)
* Update .gitignore to include intellij project file (*.iml). * Add package for customized RNN cell. * Add package for customized RNN cell. * Update the unit test for rnn cell. * Fix more code and test. 1. The test case was missing test.main, which means the test wasn't executed at all. 2. Add initializer params for user to control how weights are initialized. 3. Update the tests with the fix. * Fix code format. * Fix more lint error. * Address the PR review comments. 1. Update the addon init/build/README 2. Add README for rnn. * Fix annotation in doc. * Update gitowner for rnn package.
1 parent 4a15afd commit 4e0aefe

File tree

9 files changed

+476
-1
lines changed

9 files changed

+476
-1
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
# Subpackage Owners
8+
/tensorflow_addons/rnn/ @qlzh727
89
/tensorflow_addons/seq2seq/ @qlzh727
910
/tensorflow_addons/custom_ops/seq2seq/ @qlzh727
1011

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ sh_binary(
1111
"//tensorflow_addons/layers",
1212
"//tensorflow_addons/losses",
1313
"//tensorflow_addons/optimizers",
14+
"//tensorflow_addons/rnn",
1415
"//tensorflow_addons/seq2seq",
1516
"//tensorflow_addons/text",
1617
],

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ developments that cannot be integrated into core TensorFlow
3434
| [tfa.layers](tensorflow_addons/layers/README.md) | SIG-Addons | [email protected] |
3535
| [tfa.losses](tensorflow_addons/losses/README.md) | SIG-Addons | [email protected] |
3636
| [tfa.optimizers](tensorflow_addons/optimizers/README.md) | SIG-Addons | [email protected] |
37-
| [tfa.seq2seq](tensorflow_addons/seq2seq/README.md) | Google | @qlzh727 |
37+
| [tfa.rnn](tensorflow_addons/rnn/README.md) | Google | @qlzh727 |
38+
| [tfa.seq2seq](tensorflow_addons/seq2seq/README.md) | Google | @qlzh727 |
3839
| [tfa.text](tensorflow_addons/text/README.md) | | |
3940

4041
## Core Concepts

tensorflow_addons/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _ensure_tf_install():
7676
from tensorflow_addons import layers
7777
from tensorflow_addons import losses
7878
from tensorflow_addons import optimizers
79+
from tensorflow_addons import rnn
7980
from tensorflow_addons import seq2seq
8081
from tensorflow_addons import text
8182

tensorflow_addons/rnn/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
py_library(
6+
name = "rnn",
7+
srcs = [
8+
"__init__.py",
9+
"cell.py",
10+
],
11+
srcs_version = "PY2AND3",
12+
deps = [
13+
"//tensorflow_addons/utils",
14+
],
15+
)
16+
17+
py_test(
18+
name = "cell_test",
19+
size = "small",
20+
srcs = ["cell_test.py"],
21+
deps = [
22+
":rnn",
23+
],
24+
)

tensorflow_addons/rnn/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Addons - RNN
2+
3+
## Maintainers
4+
| Submodule | Maintainers | Contact Info |
5+
|:---------- |:------------ |:------------- |
6+
| cell | Google | @qlzh727 |
7+
8+
## Components
9+
| Submodule | Class | Reference |
10+
|:----------|:------- |:--------- |
11+
| cell | NASCell | https://arxiv.org/abs/1611.01578 |
12+
13+
14+
## Contribution Guidelines
15+
#### Prerequisites
16+
* For any cell based on research paper, the original paper has to be well recognized.
17+
The criteria here is >= 100 citation based on Google scholar. If the contributor feels
18+
this requirement need to be overruled, please specify the detailed justification in the
19+
PR.
20+
21+
#### Standard API
22+
In order to conform with the current API standard, all cells must:
23+
* Inherit from either `keras.layers.AbstractRNNCell` or `keras.layers.Layer` with
24+
required properties.
25+
* [Register as a Keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/python/keras_utils.py)
26+
so it can be serialized properly.
27+
* Add the addon to the `py_library` in this sub-package's BUILD file.
28+
29+
#### Testing Requirements
30+
* When applicable, run all tests with TensorFlow's
31+
`@run_in_graph_and_eager_modes` (for test method)
32+
or `@run_all_in_graph_and_eager_modes` (for TestCase subclass)
33+
decorator.
34+
* Add a `py_test` to this sub-package's BUILD file.
35+
36+
#### Documentation Requirements
37+
* Update the table of contents in this sub-packages's README.

tensorflow_addons/rnn/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Customized RNN cells."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow_addons.rnn.cell import NASCell

tensorflow_addons/rnn/cell.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module for RNN Cells."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
import tensorflow.keras as keras
22+
from tensorflow_addons.utils import keras_utils
23+
24+
25+
@keras_utils.register_keras_custom_object
26+
class NASCell(keras.layers.AbstractRNNCell):
27+
"""Neural Architecture Search (NAS) recurrent network cell.
28+
29+
This implements the recurrent cell from the paper:
30+
31+
https://arxiv.org/abs/1611.01578
32+
33+
Barret Zoph and Quoc V. Le.
34+
"Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
35+
36+
The class uses an optional projection layer.
37+
"""
38+
39+
# NAS cell's architecture base.
40+
_NAS_BASE = 8
41+
42+
def __init__(self,
43+
units,
44+
projection=None,
45+
use_bias=False,
46+
kernel_initializer="glorot_uniform",
47+
recurrent_initializer="glorot_uniform",
48+
projection_initializer="glorot_uniform",
49+
bias_initializer="zeros",
50+
**kwargs):
51+
"""Initialize the parameters for a NAS cell.
52+
53+
Args:
54+
units: int, The number of units in the NAS cell.
55+
projection: (optional) int, The output dimensionality for the
56+
projection matrices. If None, no projection is performed.
57+
use_bias: (optional) bool, If True then use biases within the cell.
58+
This is False by default.
59+
kernel_initializer: Initializer for kernel weight.
60+
recurrent_initializer: Initializer for recurrent kernel weight.
61+
projection_initializer: Initializer for projection weight, used when
62+
projection is not None.
63+
bias_initializer: Initializer for bias, used when use_bias is True.
64+
**kwargs: Additional keyword arguments.
65+
"""
66+
super(NASCell, self).__init__(**kwargs)
67+
self.units = units
68+
self.projection = projection
69+
self.use_bias = use_bias
70+
self.kernel_initializer = kernel_initializer
71+
self.recurrent_initializer = recurrent_initializer
72+
self.projection_initializer = projection_initializer
73+
self.bias_initializer = bias_initializer
74+
75+
if projection is not None:
76+
self._state_size = [units, projection]
77+
self._output_size = projection
78+
else:
79+
self._state_size = [units, units]
80+
self._output_size = units
81+
82+
@property
83+
def state_size(self):
84+
return self._state_size
85+
86+
@property
87+
def output_size(self):
88+
return self._output_size
89+
90+
def build(self, inputs_shape):
91+
input_size = tf.compat.dimension_value(
92+
tf.TensorShape(inputs_shape).with_rank(2)[1])
93+
if input_size is None:
94+
raise ValueError(
95+
"Could not infer input size from inputs.get_shape()[-1]")
96+
97+
# Variables for the NAS cell. `recurrent_kernel` is all matrices
98+
# multiplying the hidden state and `kernel` is all matrices multiplying
99+
# the inputs.
100+
self.recurrent_kernel = self.add_variable(
101+
name="recurrent_kernel",
102+
shape=[self.output_size, self._NAS_BASE * self.units],
103+
initializer=self.recurrent_initializer)
104+
self.kernel = self.add_variable(
105+
name="kernel",
106+
shape=[input_size, self._NAS_BASE * self.units],
107+
initializer=self.kernel_initializer)
108+
109+
if self.use_bias:
110+
self.bias = self.add_variable(
111+
name="bias",
112+
shape=[self._NAS_BASE * self.units],
113+
initializer=self.bias_initializer)
114+
# Projection layer if specified
115+
if self.projection is not None:
116+
self.projection_weights = self.add_variable(
117+
name="projection_weights",
118+
shape=[self.units, self.projection],
119+
initializer=self.projection_initializer)
120+
121+
self.built = True
122+
123+
def call(self, inputs, state):
124+
"""Run one step of NAS Cell.
125+
126+
Args:
127+
inputs: input Tensor, 2D, batch x num_units.
128+
state: This must be a list of state Tensors, both `2-D`, with column
129+
sizes `c_state` and `m_state`.
130+
131+
Returns:
132+
A tuple containing:
133+
- A `2-D, [batch x output_dim]`, Tensor representing the output of
134+
the NAS Cell after reading `inputs` when previous state was
135+
`state`.
136+
Here output_dim is:
137+
projection if projection was set, units otherwise.
138+
- Tensor(s) representing the new state of NAS Cell after reading
139+
`inputs` when the previous state was `state`. Same type and
140+
shape(s) as `state`.
141+
142+
Raises:
143+
ValueError: If input size cannot be inferred from inputs via
144+
static shape inference.
145+
"""
146+
sigmoid = tf.math.sigmoid
147+
tanh = tf.math.tanh
148+
relu = tf.nn.relu
149+
150+
c_prev, m_prev = state
151+
152+
m_matrix = tf.matmul(m_prev, self.recurrent_kernel)
153+
inputs_matrix = tf.matmul(inputs, self.kernel)
154+
155+
if self.use_bias:
156+
m_matrix = tf.nn.bias_add(m_matrix, self.bias)
157+
158+
# The NAS cell branches into 8 different splits for both the hidden
159+
# state and the input
160+
m_matrix_splits = tf.split(
161+
axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix)
162+
inputs_matrix_splits = tf.split(
163+
axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix)
164+
165+
# First layer
166+
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
167+
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
168+
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
169+
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
170+
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
171+
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
172+
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
173+
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
174+
175+
# Second layer
176+
l2_0 = tanh(layer1_0 * layer1_1)
177+
l2_1 = tanh(layer1_2 + layer1_3)
178+
l2_2 = tanh(layer1_4 * layer1_5)
179+
l2_3 = sigmoid(layer1_6 + layer1_7)
180+
181+
# Inject the cell
182+
l2_0 = tanh(l2_0 + c_prev)
183+
184+
# Third layer
185+
l3_0_pre = l2_0 * l2_1
186+
new_c = l3_0_pre # create new cell
187+
l3_0 = l3_0_pre
188+
l3_1 = tanh(l2_2 + l2_3)
189+
190+
# Final layer
191+
new_m = tanh(l3_0 * l3_1)
192+
193+
# Projection layer if specified
194+
if self.projection is not None:
195+
new_m = tf.matmul(new_m, self.projection_weights)
196+
197+
return new_m, [new_c, new_m]
198+
199+
def get_config(self):
200+
config = {
201+
"units": self.units,
202+
"projection": self.projection,
203+
"use_bias": self.use_bias,
204+
"kernel_initializer": self.kernel_initializer,
205+
"recurrent_initializer": self.recurrent_initializer,
206+
"bias_initializer": self.bias_initializer,
207+
"projection_initializer": self.projection_initializer,
208+
}
209+
base_config = super(NASCell, self).get_config()
210+
return dict(list(base_config.items()) + list(config.items()))

0 commit comments

Comments
 (0)