Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 21 additions & 37 deletions tensorflow_addons/image/filters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,72 +22,56 @@
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class MedianFilter2dTest(tf.test.TestCase):
def _validate_median_filter2d(self,
inputs,
expected_values,
filter_shape=(3, 3)):
output = median_filter2d(inputs, filter_shape)
self.assertAllClose(output, expected_values)

values_op = median_filter2d(inputs)
with self.test_session(use_gpu=False) as sess:
if tf.executing_eagerly():
expected_values = expected_values.numpy()
values = values_op.numpy()
else:
expected_values = expected_values.eval()
values = values_op.eval()
self.assertShapeEqual(values, inputs)
self.assertShapeEqual(expected_values, values_op)
self.assertAllClose(expected_values, values)

@test_utils.run_in_graph_and_eager_modes
def test_filter_tuple(self):
tf_img = tf.zeros([3, 4, 3], tf.int32)

with self.assertRaisesRegexp(TypeError,
'Filter shape must be a tuple'):
median_filter2d(tf_img, 3)
median_filter2d(tf_img, 3.5)
median_filter2d(tf_img, 'dt')
median_filter2d(tf_img, None)
for filter_shape in [3, 3.5, 'dt', None]:
with self.assertRaisesRegexp(TypeError,
'Filter shape must be a tuple'):
median_filter2d(tf_img, filter_shape)

filter_shape = (3, 3, 3)
msg = 'Filter shape must be a tuple of 2 integers. ' \
'Got %s values in tuple' % len(filter_shape)
msg = ('Filter shape must be a tuple of 2 integers. '
'Got %s values in tuple' % len(filter_shape))
with self.assertRaisesRegexp(ValueError, msg):
median_filter2d(tf_img, filter_shape)

with self.assertRaisesRegexp(TypeError,
'Size of the filter must be Integers'):
median_filter2d(tf_img, (3.5, 3))
median_filter2d(tf_img, (None, 3))
msg = 'Size of the filter must be Integers'
for filter_shape in [(3.5, 3), (None, 3)]:
with self.assertRaisesRegexp(TypeError, msg):
median_filter2d(tf_img, filter_shape)

@test_utils.run_in_graph_and_eager_modes
def test_filter_value(self):
tf_img = tf.zeros([3, 4, 3], tf.int32)

with self.assertRaises(ValueError):
median_filter2d(tf_img, (4, 3))

@test_utils.run_deprecated_v1
def test_dimension(self):
tf.compat.v1.disable_eager_execution()
tf_img = tf.compat.v1.placeholder(tf.int32, shape=[3, 4, None])
tf_img1 = tf.compat.v1.placeholder(tf.int32, shape=[3, None, 4])
tf_img2 = tf.compat.v1.placeholder(tf.int32, shape=[None, 3, 4])

with self.assertRaises(TypeError):
median_filter2d(tf_img)
median_filter2d(tf_img1)
median_filter2d(tf_img2)
for image_shape in [(3, 4, None), (3, None, 4), (None, 3, 4)]:
with self.assertRaises(TypeError):
tf_img = tf.compat.v1.placeholder(tf.int32, shape=image_shape)
median_filter2d(tf_img)

@test_utils.run_in_graph_and_eager_modes
def test_image_vs_filter(self):
tf_img = tf.zeros([3, 4, 3], tf.int32)
m = tf_img.shape[0]
no = tf_img.shape[1]
ch = tf_img.shape[2]
filter_shape = (3, 5)
with self.assertRaises(ValueError):
median_filter2d(tf_img, filter_shape)

@test_utils.run_in_graph_and_eager_modes
def test_three_channels(self):
tf_img = [[[0.32801723, 0.08863795, 0.79119259],
[0.35526001, 0.79388736, 0.55435993],
Expand Down