Skip to content

Commit c5c7dae

Browse files
seanpmorganfacaiy
authored andcommitted
CLN: Use public APIs (#198)
* CLN: Use public APIs * Use compat for temp dir * Fix linkt
1 parent 6bb0d4b commit c5c7dae

File tree

6 files changed

+27
-33
lines changed

6 files changed

+27
-33
lines changed

tensorflow_addons/image/distance_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_image_ops_so = tf.load_op_library(
2626
get_path_to_datafile("custom_ops/image/_image_ops.so"))
2727

28-
ops.NotDifferentiable("EuclideanDistanceTransform")
28+
tf.no_gradient("EuclideanDistanceTransform")
2929
ops.RegisterShape("EuclideanDistanceTransform")(
3030
common_shapes.call_cpp_shape_fn)
3131

tensorflow_addons/losses/lifted.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from __future__ import print_function
2020

2121
import tensorflow as tf
22-
23-
from tensorflow.python.keras import losses
2422
from tensorflow_addons.losses import metric_learning
2523
from tensorflow_addons.utils import keras_utils
2624

@@ -106,7 +104,7 @@ def lifted_struct_loss(labels, embeddings, margin=1.0):
106104

107105

108106
@keras_utils.register_keras_custom_object
109-
class LiftedStructLoss(losses.LossFunctionWrapper):
107+
class LiftedStructLoss(keras_utils.LossFunctionWrapper):
110108
"""Computes the lifted structured loss.
111109
112110
The loss encourages the positive distances (between a pair of embeddings

tensorflow_addons/losses/triplet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow.python.keras import losses
2221
from tensorflow_addons.losses import metric_learning
2322
from tensorflow_addons.utils import keras_utils
2423

@@ -134,7 +133,7 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0):
134133

135134

136135
@keras_utils.register_keras_custom_object
137-
class TripletSemiHardLoss(losses.LossFunctionWrapper):
136+
class TripletSemiHardLoss(keras_utils.LossFunctionWrapper):
138137
"""Computes the triplet loss with semi-hard negative mining.
139138
140139
The loss encourages the positive distances (between a pair of embeddings

tensorflow_addons/optimizers/lazy_adam_test.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import numpy as np
2222
import tensorflow as tf
2323

24-
from tensorflow.python.eager import context
25-
from tensorflow.python.ops import variables
2624
from tensorflow_addons.optimizers import lazy_adam
2725
from tensorflow_addons.utils import test_utils
2826

@@ -83,7 +81,7 @@ def testSparse(self):
8381
opt = lazy_adam.LazyAdam()
8482
update = opt.apply_gradients(
8583
zip([grads0, grads1], [var0, var1]))
86-
self.evaluate(variables.global_variables_initializer())
84+
self.evaluate(tf.compat.v1.global_variables_initializer())
8785

8886
# Fetch params to validate initial values
8987
self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0))
@@ -120,7 +118,7 @@ def testSparseDevicePlacement(self):
120118
g_sum = lambda: tf.math.reduce_sum(tf.gather(var, indices)) # pylint: disable=cell-var-from-loop
121119
optimizer = lazy_adam.LazyAdam(3.0)
122120
minimize_op = optimizer.minimize(g_sum, var_list=[var])
123-
self.evaluate(variables.global_variables_initializer())
121+
self.evaluate(tf.compat.v1.global_variables_initializer())
124122
self.evaluate(minimize_op)
125123

126124
@test_utils.run_deprecated_v1
@@ -143,7 +141,7 @@ def testSparseRepeatedIndices(self):
143141
aggregated_update_opt = lazy_adam.LazyAdam()
144142
aggregated_update = aggregated_update_opt.apply_gradients(
145143
[(grad_aggregated, aggregated_update_var)])
146-
self.evaluate(variables.global_variables_initializer())
144+
self.evaluate(tf.compat.v1.global_variables_initializer())
147145
self.assertAllClose(aggregated_update_var.eval(),
148146
repeated_index_update_var.eval())
149147
for _ in range(3):
@@ -182,10 +180,10 @@ def doTestBasic(self, use_callable_params=False):
182180
epsilon = epsilon()
183181

184182
opt = lazy_adam.LazyAdam(learning_rate=learning_rate)
185-
if not context.executing_eagerly():
183+
if not tf.executing_eagerly():
186184
update = opt.apply_gradients(
187185
zip([grads0, grads1], [var0, var1]))
188-
self.evaluate(variables.global_variables_initializer())
186+
self.evaluate(tf.compat.v1.global_variables_initializer())
189187
# Fetch params to validate initial values
190188
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
191189
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
@@ -198,7 +196,7 @@ def doTestBasic(self, use_callable_params=False):
198196
0.9**(t + 1), self.evaluate(beta_1_power))
199197
self.assertAllCloseAccordingToType(
200198
0.999**(t + 1), self.evaluate(beta_2_power))
201-
if not context.executing_eagerly():
199+
if not tf.executing_eagerly():
202200
self.evaluate(update)
203201
else:
204202
opt.apply_gradients(
@@ -222,8 +220,7 @@ def testResourceBasic(self):
222220
self.doTestBasic()
223221

224222
def testBasicCallableParams(self):
225-
with context.eager_mode():
226-
self.doTestBasic(use_callable_params=True)
223+
self.doTestBasic(use_callable_params=True)
227224

228225
@test_utils.run_deprecated_v1
229226
def testTensorLearningRate(self):
@@ -243,7 +240,7 @@ def testTensorLearningRate(self):
243240
opt = lazy_adam.LazyAdam(tf.constant(0.001))
244241
update = opt.apply_gradients(
245242
zip([grads0, grads1], [var0, var1]))
246-
self.evaluate(variables.global_variables_initializer())
243+
self.evaluate(tf.compat.v1.global_variables_initializer())
247244

248245
# Fetch params to validate initial values
249246
self.assertAllClose([1.0, 2.0], var0.eval())
@@ -289,7 +286,7 @@ def testSharing(self):
289286
zip([grads0, grads1], [var0, var1]))
290287
update2 = opt.apply_gradients(
291288
zip([grads0, grads1], [var0, var1]))
292-
self.evaluate(variables.global_variables_initializer())
289+
self.evaluate(tf.compat.v1.global_variables_initializer())
293290

294291
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
295292

@@ -320,16 +317,14 @@ def testSharing(self):
320317
self.evaluate(var1))
321318

322319
def testSlotsUniqueEager(self):
323-
with context.eager_mode():
324-
v1 = tf.Variable(1.)
325-
v2 = tf.Variable(1.)
326-
opt = lazy_adam.LazyAdam(1.)
327-
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
328-
# There should be iteration, and two unique slot variables for v1 and v2.
329-
self.assertEqual(5, len(set(opt.variables())))
330-
self.assertEqual(
331-
self.evaluate(opt.variables()[0]),
332-
self.evaluate(opt.iterations))
320+
v1 = tf.Variable(1.)
321+
v2 = tf.Variable(1.)
322+
opt = lazy_adam.LazyAdam(1.)
323+
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
324+
# There should be iteration, and two unique slot variables for v1 and v2.
325+
self.assertEqual(5, len(set(opt.variables())))
326+
self.assertEqual(
327+
self.evaluate(opt.variables()[0]), self.evaluate(opt.iterations))
333328

334329

335330
if __name__ == "__main__":

tensorflow_addons/text/skip_gram_ops_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import os
2323
import tensorflow as tf
2424

25-
from tensorflow.compat.v1 import test
2625
from tensorflow_addons import text
2726
from tensorflow_addons.text import skip_gram_ops
2827
from tensorflow_addons.utils import test_utils
@@ -368,7 +367,8 @@ def test_filter_input_subsample_vocab(self):
368367

369368
@staticmethod
370369
def _make_text_vocab_freq_file():
371-
filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt")
370+
filepath = os.path.join(tf.compat.v1.test.get_temp_dir(),
371+
"vocab_freq.txt")
372372
with open(filepath, "w") as f:
373373
writer = csv.writer(f)
374374
writer.writerows([
@@ -382,7 +382,8 @@ def _make_text_vocab_freq_file():
382382

383383
@staticmethod
384384
def _make_text_vocab_float_file():
385-
filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt")
385+
filepath = os.path.join(tf.compat.v1.test.get_temp_dir(),
386+
"vocab_freq_float.txt")
386387
with open(filepath, "w") as f:
387388
writer = csv.writer(f)
388389
writer.writerows([

tensorflow_addons/utils/keras_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import tensorflow as tf
21+
2022
# TODO: find public API alternative to these
2123
from tensorflow.python.keras.losses import LossFunctionWrapper # pylint: disable=unused-import
22-
from tensorflow.keras.utils import get_custom_objects
2324

2425

2526
def register_keras_custom_object(cls):
26-
get_custom_objects()[cls.__name__] = cls
27+
tf.keras.utils.get_custom_objects()[cls.__name__] = cls
2728
return cls

0 commit comments

Comments
 (0)