diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index 503d4657c9..861badd256 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -174,34 +174,44 @@ def median_filter2d(image, if rank == 3: image = tf.expand_dims(image, axis=0) + image_shape = tf.shape(image) + batch_size = image_shape[0] + height = image_shape[1] + width = image_shape[2] + channels = image_shape[3] + # Explicitly pad the image image = _pad( image, filter_shape, mode=padding, constant_values=constant_values) - floor = (filter_shape[0] * filter_shape[1] + 1) // 2 - ceil = (filter_shape[0] * filter_shape[1]) // 2 + 1 - - def _median_filter2d_single_channel(x): - x = tf.expand_dims(x, axis=-1) - patches = tf.image.extract_patches( - x, - sizes=[1, filter_shape[0], filter_shape[1], 1], - strides=[1, 1, 1, 1], - rates=[1, 1, 1, 1], - padding="VALID") - - # Note the returned median is casted back to the original type - # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 - # It turns out to be int(6.5) = 6 if the original type is int - top = tf.nn.top_k(patches, k=ceil).values - median = (top[:, :, :, floor - 1] + top[:, :, :, ceil - 1]) / 2 - return tf.dtypes.cast(median, x.dtype) - - output = tf.map_fn( - _median_filter2d_single_channel, - elems=tf.transpose(image, [3, 0, 1, 2]), - dtype=image.dtype) - output = tf.transpose(output, [1, 2, 3, 0]) + area = filter_shape[0] * filter_shape[1] + + floor = (area + 1) // 2 + ceil = area // 2 + 1 + + patches = tf.image.extract_patches( + image, + sizes=[1, filter_shape[0], filter_shape[1], 1], + strides=[1, 1, 1, 1], + rates=[1, 1, 1, 1], + padding="VALID") + + patches = tf.reshape( + patches, shape=[batch_size, height, width, area, channels]) + + patches = tf.transpose(patches, [0, 1, 2, 4, 3]) + + # Note the returned median is casted back to the original type + # Take [5, 6, 7, 8] for example, the median is (6 + 7) / 2 = 3.5 + # It turns out to be int(6.5) = 6 if the original type is int + top = tf.nn.top_k(patches, k=ceil).values + if area % 2 == 1: + median = top[:, :, :, :, floor - 1] + else: + median = ( + top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2 + + output = tf.cast(median, image.dtype) # Squeeze out the first axis to make sure # output has the same dimension with image.