diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index 2cf85c3071..92ccf97fdc 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -9,6 +9,7 @@ py_library( "dense_image_warp.py", "distance_transform.py", "distort_image_ops.py", + "mean_filter_2d.py", "median_filter_2d.py", "transform_ops.py", ]), @@ -59,6 +60,19 @@ py_test( ], ) +py_test( + name = "mean_filter_2d_test", + size = "medium", + srcs = [ + "mean_filter_2d_test.py", + ], + main = "mean_filter_2d_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) + py_test( name = "median_filter_2d_test", size = "medium", diff --git a/tensorflow_addons/image/README.md b/tensorflow_addons/image/README.md index 49902ea052..6968975230 100644 --- a/tensorflow_addons/image/README.md +++ b/tensorflow_addons/image/README.md @@ -6,6 +6,7 @@ | dense_image_warp | @WindQAQ | windqaq@gmail.com | | distance_transform_ops | | | | distort_image_ops | @WindQAQ | windqaq@gmail.com | +| mean_filter_2d | @Mainak431 | mainakdutta76@gmail.com | | median_filter_2d | @Mainak431 | mainakdutta76@gmail.com | | transform_ops | | | @@ -17,6 +18,7 @@ | distance_transform_ops | euclidean_distance_transform | | | distort_image_ops | adjust_hsv_in_yiq | | | distort_image_ops | random_hsv_in_yiq | | +| mean_filter_2d | mean_filter_2D | | | median_filter_2d | median_filter_2D | | | transform_ops | angles_to_projective_transforms | | | transform_ops | matrices_to_flat_transforms | | diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index c6caf78232..3d9d5d4f64 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -22,6 +22,7 @@ from tensorflow_addons.image.distance_transform import euclidean_dist_transform from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq +from tensorflow_addons.image.mean_filter_2d import mean_filter_2D from tensorflow_addons.image.median_filter_2d import median_filter_2D from tensorflow_addons.image.transform_ops import rotate from tensorflow_addons.image.transform_ops import transform diff --git a/tensorflow_addons/image/mean_filter_2d.py b/tensorflow_addons/image/mean_filter_2d.py new file mode 100644 index 0000000000..350bb6064c --- /dev/null +++ b/tensorflow_addons/image/mean_filter_2d.py @@ -0,0 +1,107 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +@tf.function +def mean_filter_2D(image, filter_shape=(3, 3)): + """This method performs Mean Filtering on image. Filter shape can be user + given. + + This method takes both kind of images where pixel values lie between 0 to + 255 and where it lies between 0.0 and 1.0 + Args: + image: A 3D `Tensor` of type `float32` or 'int32' or 'float64' or + 'int64 and of shape`[rows, columns, channels]` + + filter_shape: Optional Argument. A tuple of 2 integers (R,C). + R is the first value is the number of rows in the filter and + C is the second value in the filter is the number of columns + in the filter. This creates a filter of shape (R,C) or RxC + filter. Default value = (3,3) + + Returns: + A 3D mean filtered image tensor of shape [rows,columns,channels] and + type 'int32'. Pixel value of returned tensor ranges between 0 to 255 + """ + + def _normalize(li): + one = tf.convert_to_tensor(1.0) + two = tf.convert_to_tensor(255.0) + + def func1(): + return li + + def func2(): + return tf.math.truediv(li, two) + + return tf.cond(tf.math.greater(ma, one), func2, func1) + + if not isinstance(filter_shape, tuple): + raise TypeError('Filter shape must be a tuple') + if len(filter_shape) != 2: + raise ValueError('Filter shape must be a tuple of 2 integers. ' + 'Got %s values in tuple' % len(filter_shape)) + filter_shapex = filter_shape[0] + filter_shapey = filter_shape[1] + if not isinstance(filter_shapex, int) or not isinstance( + filter_shapey, int): + raise TypeError('Size of the filter must be Integers') + (row, col, ch) = (image.shape[0], image.shape[1], image.shape[2]) + if row != None and col != None and ch != None: + (row, col, ch) = (int(row), int(col), int(ch)) + else: + raise TypeError( + 'All the Dimensions of the input image tensor must be Integers.') + if row < filter_shapex or col < filter_shapey: + raise ValueError( + 'No of Pixels in each dimension of the image should be more \ + than the filter size. Got filter_shape (%sx' % filter_shape[0] + + '%s).' % filter_shape[1] + ' Image Shape (%s)' % image.shape) + if filter_shapex % 2 == 0 or filter_shapey % 2 == 0: + raise ValueError('Filter size should be odd. Got filter_shape (%sx' % + filter_shape[0] + '%s)' % filter_shape[1]) + image = tf.cast(image, tf.float32) + tf_i = tf.reshape(image, [row * col * ch]) + ma = tf.math.reduce_max(tf_i) + image = _normalize(image) + + # k and l is the Zero-padding size + + listi = [] + for a in range(ch): + img = image[:, :, a:a + 1] + img = tf.reshape(img, [1, row, col, 1]) + slic = tf.image.extract_image_patches( + img, [1, filter_shapex, filter_shapey, 1], [1, 1, 1, 1], + [1, 1, 1, 1], + padding='SAME') + li = tf.reduce_mean(slic, axis=-1) + li = tf.reshape(li, [row, col, 1]) + listi.append(li) + y = tf.concat(listi[0], 2) + + for i in range(len(listi) - 1): + y = tf.concat([y, listi[i + 1]], 2) + + y *= 255 + y = tf.cast(y, tf.int32) + + return y diff --git a/tensorflow_addons/image/mean_filter_2d_test.py b/tensorflow_addons/image/mean_filter_2d_test.py new file mode 100644 index 0000000000..985a7db745 --- /dev/null +++ b/tensorflow_addons/image/mean_filter_2d_test.py @@ -0,0 +1,108 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may noa use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import mean_filter_2d as md +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class Mean2DTest(tf.test.TestCase): + def _validateMean_2d(self, inputs, expected_values, filter_shape=(3, 3)): + + values_op = md.mean_filter_2D(inputs) + with self.test_session(use_gpu=False) as sess: + if tf.executing_eagerly(): + expected_values = expected_values.numpy() + values = values_op.numpy() + else: + expected_values = expected_values.eval() + values = values_op.eval() + self.assertShapeEqual(values, inputs) + self.assertShapeEqual(expected_values, values_op) + self.assertAllClose(expected_values, values) + + def testfiltertuple(self): + tf_img = tf.zeros([3, 4, 3], tf.int32) + + with self.assertRaisesRegexp(TypeError, + 'Filter shape must be a tuple'): + md.mean_filter_2D(tf_img, 3) + md.mean_filter_2D(tf_img, 3.5) + md.mean_filter_2D(tf_img, 'dt') + md.mean_filter_2D(tf_img, None) + + filter_shape = (3, 3, 3) + msg = 'Filter shape must be a tuple of 2 integers. ' \ + 'Got %s values in tuple' % len(filter_shape) + with self.assertRaisesRegexp(ValueError, msg): + md.mean_filter_2D(tf_img, filter_shape) + + with self.assertRaisesRegexp(TypeError, + 'Size of the filter must be Integers'): + md.mean_filter_2D(tf_img, (3.5, 3)) + md.mean_filter_2D(tf_img, (None, 3)) + + def testfiltervalue(self): + tf_img = tf.zeros([3, 4, 3], tf.int32) + + with self.assertRaises(ValueError): + md.mean_filter_2D(tf_img, (4, 3)) + + def testDimension(self): + tf.compat.v1.disable_eager_execution() + tf_img = tf.compat.v1.placeholder(tf.int32, shape=[3, 4, None]) + tf_img1 = tf.compat.v1.placeholder(tf.int32, shape=[3, None, 4]) + tf_img2 = tf.compat.v1.placeholder(tf.int32, shape=[None, 3, 4]) + + with self.assertRaises(TypeError): + md.mean_filter_2D(tf_img) + md.mean_filter_2D(tf_img1) + md.mean_filter_2D(tf_img2) + + def test_imagevsfilter(self): + tf_img = tf.zeros([3, 4, 3], tf.int32) + m = tf_img.shape[0] + no = tf_img.shape[1] + ch = tf_img.shape[2] + filter_shape = (3, 5) + with self.assertRaises(ValueError): + md.mean_filter_2D(tf_img, filter_shape) + + def testcase(self): + tf_img = [[[0.32801723, 0.08863795, 0.79119259], + [0.35526001, 0.79388736, 0.55435993], + [0.11607035, 0.55673079, 0.99473371]], + [[0.53240645, 0.74684819, 0.33700031], + [0.01760473, 0.28181609, 0.9751476], + [0.01605137, 0.8292904, 0.56405609]], + [[0.57215374, 0.10155051, 0.64836128], + [0.36533048, 0.91401874, 0.02524159], + [0.56379134, 0.9028874, 0.19505117]]] + + tf_img = tf.convert_to_tensor(value=tf_img) + expt = [[[34, 54, 75], [38, 93, 119], [14, 69, 87]], + [[61, 82, 94], [81, 147, 144], [40, 121, 93]], + [[42, 57, 56], [58, 106, 77], [27, 82, 49]]] + expt = tf.convert_to_tensor(value=expt) + self._validateMean_2d(tf_img, expt) + + +if __name__ == "__main__": + tf.test.main()