Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 3 additions & 44 deletions tensorflow_addons/image/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@

import tensorflow as tf
from tensorflow_addons.image import utils as img_utils
from tensorflow_addons.utils.resource_loader import LazySO
from tensorflow_addons.utils.types import TensorLike
from tensorflow_addons.image.utils import wrap, unwrap

from typing import Optional

_image_so = LazySO("custom_ops/image/_image_ops.so")

_IMAGE_DTYPES = {
tf.dtypes.uint8,
Expand All @@ -34,7 +32,6 @@
}


@tf.function
def transform(
images: TensorLike,
transforms: TensorLike,
Expand Down Expand Up @@ -108,10 +105,10 @@ def transform(
% len(transforms.get_shape())
)

output = _image_so.ops.addons_image_projective_transform_v2(
images,
output_shape=output_shape,
output = tf.raw_ops.ImageProjectiveTransformV2(
images=images,
transforms=transforms,
output_shape=output_shape,
interpolation=interpolation.upper(),
)
return img_utils.from_4D_image(output, original_ndims)
Expand Down Expand Up @@ -271,44 +268,6 @@ def angles_to_projective_transforms(
)


@tf.RegisterGradient("Addons>ImageProjectiveTransformV2")
def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
transforms = op.inputs[1]
interpolation = op.get_attr("interpolation")

image_or_images = tf.convert_to_tensor(images, name="images")
transform_or_transforms = tf.convert_to_tensor(
transforms, name="transforms", dtype=tf.dtypes.float32
)

if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise ValueError("Invalid dtype %s." % image_or_images.dtype)
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
transforms = transform_or_transforms
else:
transforms = transform_or_transforms
raise ValueError(
"transforms should have rank 1 or 2, but got rank %d"
% len(transforms.get_shape())
)

# Invert transformations
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = tf.linalg.inv(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = _image_so.ops.addons_image_projective_transform_v2(
images=grad,
transforms=transforms,
output_shape=tf.shape(image_or_images)[1:3],
interpolation=interpolation,
)
return [output, None, None]


def rotate(
images: TensorLike,
angles: TensorLike,
Expand Down