Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tensorflow_addons/image/tests/transform_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,62 @@ def test_extreme_projective_transform(dtype):
)


@pytest.mark.with_device(["cpu", "gpu"])
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", _DTYPES)
def test_transform_constant_fill_mode(dtype):
image = tf.constant(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=dtype
)
expected = np.asarray(
[[0, 0, 1, 2], [0, 4, 5, 6], [0, 8, 9, 10], [0, 12, 13, 14]],
dtype=dtype.as_numpy_dtype,
)
# Translate right by 1 (the transformation matrix is always inverted,
# hence the -1).
translation = tf.constant([1, 0, -1, 0, 1, 0, 0, 0], dtype=tf.float32)
image_transformed = transform_ops.transform(
image, translation, fill_mode="constant"
)
np.testing.assert_equal(image_transformed.numpy(), expected)


@pytest.mark.with_device(["cpu", "gpu"])
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", _DTYPES)
def test_transform_reflect_fill_mode(dtype):
image = tf.constant(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=dtype
)
expected = np.asarray(
[[0, 0, 1, 2], [4, 4, 5, 6], [8, 8, 9, 10], [12, 12, 13, 14]],
dtype=dtype.as_numpy_dtype,
)
# Translate right by 1 (the transformation matrix is always inverted,
# hence the -1).
translation = tf.constant([1, 0, -1, 0, 1, 0, 0, 0], dtype=tf.float32)
image_transformed = transform_ops.transform(image, translation, fill_mode="reflect")
np.testing.assert_equal(image_transformed.numpy(), expected)


@pytest.mark.with_device(["cpu", "gpu"])
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", _DTYPES)
def test_transform_wrap_fill_mode(dtype):
image = tf.constant(
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=dtype
)
expected = np.asarray(
[[3, 0, 1, 2], [7, 4, 5, 6], [11, 8, 9, 10], [15, 12, 13, 14]],
dtype=dtype.as_numpy_dtype,
)
# Translate right by 1 (the transformation matrix is always inverted,
# hence the -1).
translation = tf.constant([1, 0, -1, 0, 1, 0, 0, 0], dtype=tf.float32)
image_transformed = transform_ops.transform(image, translation, fill_mode="wrap")
np.testing.assert_equal(image_transformed.numpy(), expected)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_transform_static_output_shape():
image = tf.constant([[1.0, 2.0], [3.0, 4.0]])
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_addons/image/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def transform(
images: TensorLike,
transforms: TensorLike,
interpolation: str = "NEAREST",
fill_mode: str = "CONSTANT",
output_shape: Optional[list] = None,
name: Optional[str] = None,
) -> tf.Tensor:
Expand All @@ -55,6 +56,15 @@ def transform(
gradients are not backpropagated into transformation parameters.
interpolation: Interpolation mode.
Supported values: "NEAREST", "BILINEAR".
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
output_shape: Output dimesion after the transform, [height, width].
If None, output is the same size as input image.

Expand Down Expand Up @@ -105,11 +115,13 @@ def transform(
% len(transforms.get_shape())
)

# TODO(WindQAQ): Support "nearest" `fill_mode` and `fill_value` in TF2.4.
output = tf.raw_ops.ImageProjectiveTransformV2(
images=images,
transforms=transforms,
output_shape=output_shape,
interpolation=interpolation.upper(),
fill_mode=fill_mode.upper(),
)
return img_utils.from_4D_image(output, original_ndims)

Expand Down Expand Up @@ -268,6 +280,7 @@ def rotate(
images: TensorLike,
angles: TensorLike,
interpolation: str = "NEAREST",
fill_mode: str = "CONSTANT",
name: Optional[str] = None,
) -> tf.Tensor:
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.
Expand All @@ -282,6 +295,15 @@ def rotate(
batch.
interpolation: Interpolation mode. Supported values: "NEAREST",
"BILINEAR".
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
name: The name of the op.

Returns:
Expand All @@ -304,6 +326,7 @@ def rotate(
images,
angles_to_projective_transforms(angles, image_height, image_width),
interpolation=interpolation,
fill_mode=fill_mode,
)
return img_utils.from_4D_image(output, original_ndims)

Expand Down
29 changes: 20 additions & 9 deletions tensorflow_addons/image/translate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,36 @@ def translate(
images: TensorLike,
translations: TensorLike,
interpolation: str = "NEAREST",
fill_mode: str = "CONSTANT",
name: Optional[str] = None,
) -> tf.Tensor:
"""Translate image(s) by the passed vectors(s).

Args:
images: A tensor of shape
`(num_images, num_rows, num_columns, num_channels)` (NHWC),
`(num_rows, num_columns, num_channels)` (HWC), or
`(num_rows, num_columns)` (HW). The rank must be statically known (the
shape is not `TensorShape(None)`).
`(num_images, num_rows, num_columns, num_channels)` (NHWC),
`(num_rows, num_columns, num_channels)` (HWC), or
`(num_rows, num_columns)` (HW). The rank must be statically known (the
shape is not `TensorShape(None)`).
translations: A vector representing `[dx, dy]` or (if `images` has rank 4)
a matrix of length num_images, with a `[dx, dy]` vector for each image
in the batch.
a matrix of length num_images, with a `[dx, dy]` vector for each image
in the batch.
interpolation: Interpolation mode. Supported values: "NEAREST",
"BILINEAR".
"BILINEAR".
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
- *reflect*: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last pixel.
- *constant*: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond the edge with the
same constant value k = 0.
- *wrap*: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, translated by the
given vector(s). Empty space due to the translation will be filled with
zeros.
given vector(s). Empty space due to the translation will be filled with
zeros.
Raises:
TypeError: If `images` is an invalid type.
"""
Expand All @@ -103,6 +113,7 @@ def translate(
images,
translations_to_projective_transforms(translations),
interpolation=interpolation,
fill_mode=fill_mode,
)


Expand Down