From d1336aa74be98d4bde9ccc01ce92ec65291201e6 Mon Sep 17 00:00:00 2001 From: Reed Date: Mon, 10 Apr 2023 20:14:47 -0700 Subject: [PATCH 1/2] Fix Keras imports. --- .../optimizers/discriminative_layer_training.py | 10 ++++++++-- tensorflow_addons/utils/test_utils.py | 7 ++++++- tensorflow_addons/utils/types.py | 7 ++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index a0c3e49409..f0f6072868 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -21,8 +21,14 @@ from tensorflow_addons.optimizers import KerasLegacyOptimizer from typeguard import typechecked -from keras import backend -from keras.utils import tf_utils +try: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src import backend + from keras.src.utils import tf_utils +except ImportError: + from keras import backend + from keras.utils import tf_utils @tf.keras.utils.register_keras_serializable(package="Addons") diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index 8c1917a17f..01e936ab0a 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -27,7 +27,12 @@ from tensorflow_addons.utils import resource_loader if Version(tf.__version__) >= Version("2.9"): - from keras.testing_infra.test_utils import layer_test # noqa: F401 + try: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.testing_infra.test_utils import layer_test # noqa: F401 + except ImportError: + from keras.testing_infra.test_utils import layer_test # noqa: F401 else: from keras.testing_utils import layer_test # noqa: F401 diff --git a/tensorflow_addons/utils/types.py b/tensorflow_addons/utils/types.py index 4bfa0dacf6..c4ed6646a4 100644 --- a/tensorflow_addons/utils/types.py +++ b/tensorflow_addons/utils/types.py @@ -22,7 +22,12 @@ # TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved if tf.__version__[:3] > "2.5": - from keras.engine import keras_tensor + try: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.engine import keras_tensor + except ImportError: + from keras.engine import keras_tensor else: from tensorflow.python.keras.engine import keras_tensor From a3fa15d1fc8eed111e0b7c1556f9145d8af91807 Mon Sep 17 00:00:00 2001 From: Reed Date: Mon, 17 Apr 2023 18:58:43 -0700 Subject: [PATCH 2/2] Check TF version instead of importing in try-catch. --- .../optimizers/discriminative_layer_training.py | 5 +++-- tensorflow_addons/utils/test_utils.py | 13 ++++++------- tensorflow_addons/utils/types.py | 15 ++++++++------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index f0f6072868..a82f1b2d3e 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -18,15 +18,16 @@ import tensorflow as tf +from packaging.version import Version from tensorflow_addons.optimizers import KerasLegacyOptimizer from typeguard import typechecked -try: +if Version(tf.__version__).release >= Version("2.13").release: # New versions of Keras require importing from `keras.src` when # importing internal symbols. from keras.src import backend from keras.src.utils import tf_utils -except ImportError: +else: from keras import backend from keras.utils import tf_utils diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index 01e936ab0a..f998fb4a45 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -26,13 +26,12 @@ from tensorflow_addons import options from tensorflow_addons.utils import resource_loader -if Version(tf.__version__) >= Version("2.9"): - try: - # New versions of Keras require importing from `keras.src` when - # importing internal symbols. - from keras.src.testing_infra.test_utils import layer_test # noqa: F401 - except ImportError: - from keras.testing_infra.test_utils import layer_test # noqa: F401 +if Version(tf.__version__).release >= Version("2.13").release: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.testing_infra.test_utils import layer_test # noqa: F401 +elif Version(tf.__version__) >= Version("2.9"): + from keras.testing_infra.test_utils import layer_test # noqa: F401 else: from keras.testing_utils import layer_test # noqa: F401 diff --git a/tensorflow_addons/utils/types.py b/tensorflow_addons/utils/types.py index c4ed6646a4..de8da2a5dd 100644 --- a/tensorflow_addons/utils/types.py +++ b/tensorflow_addons/utils/types.py @@ -20,14 +20,15 @@ import numpy as np import tensorflow as tf +from packaging.version import Version + # TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved -if tf.__version__[:3] > "2.5": - try: - # New versions of Keras require importing from `keras.src` when - # importing internal symbols. - from keras.src.engine import keras_tensor - except ImportError: - from keras.engine import keras_tensor +if Version(tf.__version__).release >= Version("2.13").release: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.engine import keras_tensor +elif Version(tf.__version__).release >= Version("2.5").release: + from keras.engine import keras_tensor else: from tensorflow.python.keras.engine import keras_tensor