diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index a0c3e49409..a82f1b2d3e 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -18,11 +18,18 @@ import tensorflow as tf +from packaging.version import Version from tensorflow_addons.optimizers import KerasLegacyOptimizer from typeguard import typechecked -from keras import backend -from keras.utils import tf_utils +if Version(tf.__version__).release >= Version("2.13").release: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src import backend + from keras.src.utils import tf_utils +else: + from keras import backend + from keras.utils import tf_utils @tf.keras.utils.register_keras_serializable(package="Addons") diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index 8c1917a17f..f998fb4a45 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -26,7 +26,11 @@ from tensorflow_addons import options from tensorflow_addons.utils import resource_loader -if Version(tf.__version__) >= Version("2.9"): +if Version(tf.__version__).release >= Version("2.13").release: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.testing_infra.test_utils import layer_test # noqa: F401 +elif Version(tf.__version__) >= Version("2.9"): from keras.testing_infra.test_utils import layer_test # noqa: F401 else: from keras.testing_utils import layer_test # noqa: F401 diff --git a/tensorflow_addons/utils/types.py b/tensorflow_addons/utils/types.py index 4bfa0dacf6..de8da2a5dd 100644 --- a/tensorflow_addons/utils/types.py +++ b/tensorflow_addons/utils/types.py @@ -20,8 +20,14 @@ import numpy as np import tensorflow as tf +from packaging.version import Version + # TODO: Remove once https://github.com/tensorflow/tensorflow/issues/44613 is resolved -if tf.__version__[:3] > "2.5": +if Version(tf.__version__).release >= Version("2.13").release: + # New versions of Keras require importing from `keras.src` when + # importing internal symbols. + from keras.src.engine import keras_tensor +elif Version(tf.__version__).release >= Version("2.5").release: from keras.engine import keras_tensor else: from tensorflow.python.keras.engine import keras_tensor