1818from __future__ import print_function
1919
2020import tensorflow as tf
21+ from tensorflow_addons .image import utils as img_utils
2122from tensorflow_addons .utils import keras_utils
2223
2324
@@ -59,7 +60,8 @@ def mean_filter2d(image,
5960 """Perform mean filtering on image(s).
6061
6162 Args:
62- image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
63+ image: Either a 2-D `Tensor` of shape `[height, width]`,
64+ a 3-D `Tensor` of shape `[height, width, channels]`,
6365 or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
6466 filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
6567 the height and width of the 2-D mean filter. Can be a single integer
@@ -74,16 +76,14 @@ def mean_filter2d(image,
7476 Returns:
7577 3-D or 4-D `Tensor` of the same dtype as input.
7678 Raises:
77- ValueError: If `image` is not 3 or 4-dimensional,
79+ ValueError: If `image` is not 2, 3 or 4-dimensional,
7880 if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
7981 or if `filter_shape` is invalid.
8082 """
8183 with tf .name_scope (name or "mean_filter2d" ):
8284 image = tf .convert_to_tensor (image , name = "image" )
83-
84- rank = image .shape .rank
85- if rank != 3 and rank != 4 :
86- raise ValueError ("image should be either 3 or 4-dimensional." )
85+ original_ndims = img_utils .get_ndims (image )
86+ image = img_utils .to_4D_image (image )
8787
8888 if padding not in ["REFLECT" , "CONSTANT" , "SYMMETRIC" ]:
8989 raise ValueError (
@@ -93,10 +93,6 @@ def mean_filter2d(image,
9393 filter_shape = keras_utils .normalize_tuple (filter_shape , 2 ,
9494 "filter_shape" )
9595
96- # Expand to a 4-D tensor
97- if rank == 3 :
98- image = tf .expand_dims (image , axis = 0 )
99-
10096 # Keep the precision if it's float;
10197 # otherwise, convert to float32 for computing.
10298 orig_dtype = image .dtype
@@ -119,11 +115,7 @@ def mean_filter2d(image,
119115
120116 output /= area
121117
122- # Squeeze out the first axis to make sure
123- # output has the same dimension with image.
124- if rank == 3 :
125- output = tf .squeeze (output , axis = 0 )
126-
118+ output = img_utils .from_4D_image (output , original_ndims )
127119 return tf .dtypes .cast (output , orig_dtype )
128120
129121
@@ -136,7 +128,8 @@ def median_filter2d(image,
136128 """Perform median filtering on image(s).
137129
138130 Args:
139- image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
131+ image: Either a 2-D `Tensor` of shape `[height, width]`,
132+ a 3-D `Tensor` of shape `[height, width, channels]`,
140133 or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
141134 filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
142135 the height and width of the 2-D median filter. Can be a single integer
@@ -151,16 +144,14 @@ def median_filter2d(image,
151144 Returns:
152145 3-D or 4-D `Tensor` of the same dtype as input.
153146 Raises:
154- ValueError: If `image` is not 3 or 4-dimensional,
147+ ValueError: If `image` is not 2, 3 or 4-dimensional,
155148 if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
156149 or if `filter_shape` is invalid.
157150 """
158151 with tf .name_scope (name or "median_filter2d" ):
159152 image = tf .convert_to_tensor (image , name = "image" )
160-
161- rank = image .shape .rank
162- if rank != 3 and rank != 4 :
163- raise ValueError ("image should be either 3 or 4-dimensional." )
153+ original_ndims = img_utils .get_ndims (image )
154+ image = img_utils .to_4D_image (image )
164155
165156 if padding not in ["REFLECT" , "CONSTANT" , "SYMMETRIC" ]:
166157 raise ValueError (
@@ -170,10 +161,6 @@ def median_filter2d(image,
170161 filter_shape = keras_utils .normalize_tuple (filter_shape , 2 ,
171162 "filter_shape" )
172163
173- # Expand to a 4-D tensor
174- if rank == 3 :
175- image = tf .expand_dims (image , axis = 0 )
176-
177164 image_shape = tf .shape (image )
178165 batch_size = image_shape [0 ]
179166 height = image_shape [1 ]
@@ -212,10 +199,5 @@ def median_filter2d(image,
212199 top [:, :, :, :, floor - 1 ] + top [:, :, :, :, ceil - 1 ]) / 2
213200
214201 output = tf .cast (median , image .dtype )
215-
216- # Squeeze out the first axis to make sure
217- # output has the same dimension with image.
218- if rank == 3 :
219- output = tf .squeeze (output , axis = 0 )
220-
202+ output = img_utils .from_4D_image (output , original_ndims )
221203 return output
0 commit comments