Skip to content

Commit 58f4917

Browse files
fsx950223ashutosh1919
authored andcommitted
use raw ops (tensorflow#1914)
1 parent fd032e6 commit 58f4917

File tree

1 file changed

+3
-44
lines changed

1 file changed

+3
-44
lines changed

tensorflow_addons/image/transform_ops.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616

1717
import tensorflow as tf
1818
from tensorflow_addons.image import utils as img_utils
19-
from tensorflow_addons.utils.resource_loader import LazySO
2019
from tensorflow_addons.utils.types import TensorLike
2120
from tensorflow_addons.image.utils import wrap, unwrap
2221

2322
from typing import Optional
2423

25-
_image_so = LazySO("custom_ops/image/_image_ops.so")
2624

2725
_IMAGE_DTYPES = {
2826
tf.dtypes.uint8,
@@ -34,7 +32,6 @@
3432
}
3533

3634

37-
@tf.function
3835
def transform(
3936
images: TensorLike,
4037
transforms: TensorLike,
@@ -108,10 +105,10 @@ def transform(
108105
% len(transforms.get_shape())
109106
)
110107

111-
output = _image_so.ops.addons_image_projective_transform_v2(
112-
images,
113-
output_shape=output_shape,
108+
output = tf.raw_ops.ImageProjectiveTransformV2(
109+
images=images,
114110
transforms=transforms,
111+
output_shape=output_shape,
115112
interpolation=interpolation.upper(),
116113
)
117114
return img_utils.from_4D_image(output, original_ndims)
@@ -271,44 +268,6 @@ def angles_to_projective_transforms(
271268
)
272269

273270

274-
@tf.RegisterGradient("Addons>ImageProjectiveTransformV2")
275-
def _image_projective_transform_grad(op, grad):
276-
"""Computes the gradient for ImageProjectiveTransform."""
277-
images = op.inputs[0]
278-
transforms = op.inputs[1]
279-
interpolation = op.get_attr("interpolation")
280-
281-
image_or_images = tf.convert_to_tensor(images, name="images")
282-
transform_or_transforms = tf.convert_to_tensor(
283-
transforms, name="transforms", dtype=tf.dtypes.float32
284-
)
285-
286-
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
287-
raise ValueError("Invalid dtype %s." % image_or_images.dtype)
288-
if len(transform_or_transforms.get_shape()) == 1:
289-
transforms = transform_or_transforms[None]
290-
elif len(transform_or_transforms.get_shape()) == 2:
291-
transforms = transform_or_transforms
292-
else:
293-
transforms = transform_or_transforms
294-
raise ValueError(
295-
"transforms should have rank 1 or 2, but got rank %d"
296-
% len(transforms.get_shape())
297-
)
298-
299-
# Invert transformations
300-
transforms = flat_transforms_to_matrices(transforms=transforms)
301-
inverse = tf.linalg.inv(transforms)
302-
transforms = matrices_to_flat_transforms(inverse)
303-
output = _image_so.ops.addons_image_projective_transform_v2(
304-
images=grad,
305-
transforms=transforms,
306-
output_shape=tf.shape(image_or_images)[1:3],
307-
interpolation=interpolation,
308-
)
309-
return [output, None, None]
310-
311-
312271
def rotate(
313272
images: TensorLike,
314273
angles: TensorLike,

0 commit comments

Comments
 (0)