diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 0de36d33e0..b5814c8f02 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -4,12 +4,13 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "layers_py", - srcs = ([ + srcs = [ "__init__.py", "python/__init__.py", + "python/maxout.py", "python/poincare.py", "python/wrappers.py", - ]), + ], srcs_version = "PY2AND3", ) @@ -20,12 +21,25 @@ py_test( ], main = "python/wrappers_test.py", deps = [ - ":layers_py", - ], + ":layers_py", + ], srcs_version = "PY2AND3", ) py_test( + name = "maxout_py_test", + size = "small", + srcs = [ + "python/maxout_test.py", + ], + main = "python/maxout_test.py", + deps = [ + ":layers_py", + ], + srcs_version = "PY2AND3", + ) + + py_test( name = "poincare_py_test", size = "small", srcs = [ diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 0141f32668..09a236c8c9 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -19,5 +19,5 @@ from __future__ import division from __future__ import print_function -# Weight Normalization Wrapper +from tensorflow_addons.layers.python.maxout import Maxout from tensorflow_addons.layers.python.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/python/maxout.py b/tensorflow_addons/layers/python/maxout.py new file mode 100644 index 0000000000..0beb2ae24d --- /dev/null +++ b/tensorflow_addons/layers/python/maxout.py @@ -0,0 +1,98 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementing Maxout layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +class Maxout(Layer): + """Applies Maxout to the input. + + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, Yoshua Bengio. https://arxiv.org/abs/1302.4389 + + Usually the operation is performed in the filter/channel dimension. This can + also be used after Dense layers to reduce number of features. + + Arguments: + num_units: Specifies how many features will remain after maxout + in the `axis` dimension (usually channel). + This must be a factor of number of features. + axis: The dimension where max pooling will be performed. Default is the + last dimension. + + Input shape: + nD tensor with shape: `(batch_size, ..., axis_dim, ...)`. + + Output shape: + nD tensor with shape: `(batch_size, ..., num_units, ...)`. + """ + + def __init__(self, num_units, axis=-1, **kwargs): + super(Maxout, self).__init__(**kwargs) + self.num_units = num_units + self.axis = axis + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs) + shape = inputs.get_shape().as_list() + # Dealing with batches with arbitrary sizes + for i in range(len(shape)): + if shape[i] is None: + shape[i] = array_ops.shape(inputs)[i] + + num_channels = shape[self.axis] + if (not isinstance(num_channels, ops.Tensor) + and num_channels % self.num_units): + raise ValueError('number of features({}) is not ' + 'a multiple of num_units({})'.format( + num_channels, self.num_units)) + + if self.axis < 0: + axis = self.axis + len(shape) + else: + axis = self.axis + assert axis >= 0, 'Find invalid axis: {}'.format(self.axis) + + expand_shape = shape[:] + expand_shape[axis] = self.num_units + k = num_channels // self.num_units + expand_shape.insert(axis, k) + + outputs = math_ops.reduce_max( + array_ops.reshape(inputs, expand_shape), axis, keepdims=False) + return outputs + + def compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + input_shape[self.axis] = self.num_units + return tensor_shape.TensorShape(input_shape) + + def get_config(self): + config = {'num_units': self.num_units, 'axis': self.axis} + base_config = super(Maxout, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +generic_utils._GLOBAL_CUSTOM_OBJECTS['Maxout'] = Maxout diff --git a/tensorflow_addons/layers/python/maxout_test.py b/tensorflow_addons/layers/python/maxout_test.py new file mode 100644 index 0000000000..22e381f8c2 --- /dev/null +++ b/tensorflow_addons/layers/python/maxout_test.py @@ -0,0 +1,71 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Maxout layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import test +from tensorflow_addons.layers.python.maxout import Maxout + + +class MaxOutTest(test.TestCase): + def test_simple(self): + testing_utils.layer_test( + Maxout, kwargs={'num_units': 3}, input_shape=(5, 4, 2, 18)) + + def test_nchw(self): + testing_utils.layer_test( + Maxout, + kwargs={ + 'num_units': 4, + 'axis': 1 + }, + input_shape=(2, 20, 3, 6)) + + testing_utils.layer_test( + Maxout, + kwargs={ + 'num_units': 4, + 'axis': -3 + }, + input_shape=(2, 20, 3, 6)) + + def test_unknown(self): + inputs = np.random.random((5, 4, 2, 18)).astype('float32') + testing_utils.layer_test( + Maxout, + kwargs={'num_units': 3}, + input_shape=(5, 4, 2, None), + input_data=inputs) + + testing_utils.layer_test( + Maxout, + kwargs={'num_units': 3}, + input_shape=(None, None, None, None), + input_data=inputs) + + def test_invalid_shape(self): + with self.assertRaisesRegexp(ValueError, r'number of features'): + testing_utils.layer_test( + Maxout, kwargs={'num_units': 3}, input_shape=(5, 4, 2, 7)) + + +if __name__ == '__main__': + test.main()