diff --git a/tensorflow_addons/BUILD b/tensorflow_addons/BUILD index cfc5ddd659..24033291e0 100644 --- a/tensorflow_addons/BUILD +++ b/tensorflow_addons/BUILD @@ -12,6 +12,7 @@ py_library( data = [ "__init__.py", "options.py", + "register.py", "version.py", ], deps = [ @@ -25,5 +26,18 @@ py_library( "//tensorflow_addons/rnn", "//tensorflow_addons/seq2seq", "//tensorflow_addons/text", + "//tensorflow_addons/utils", + ], +) + +py_test( + name = "register_test", + size = "small", + srcs = [ + "register_test.py", + ], + main = "register_test.py", + deps = [ + ":tensorflow_addons", ], ) diff --git a/tensorflow_addons/__init__.py b/tensorflow_addons/__init__.py index c4706b0d87..d5527ecb6e 100644 --- a/tensorflow_addons/__init__.py +++ b/tensorflow_addons/__init__.py @@ -28,5 +28,6 @@ from tensorflow_addons import rnn from tensorflow_addons import seq2seq from tensorflow_addons import text +from tensorflow_addons.register import register_all from tensorflow_addons.version import __version__ diff --git a/tensorflow_addons/register.py b/tensorflow_addons/register.py new file mode 100644 index 0000000000..9c76d53994 --- /dev/null +++ b/tensorflow_addons/register.py @@ -0,0 +1,109 @@ +import glob +import os +from pathlib import Path + +import tensorflow as tf + +from tensorflow_addons.utils.resource_loader import get_project_root + + +def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None: + """Register TensorFlow Addons' objects in TensorFlow global dictionaries. + + When loading a Keras model that has a TF Addons' function, it is needed + for this function to be known by the Keras deserialization process. + + There are two ways to do this, either do + + ```python + tf.keras.models.load_model( + "my_model.tf", + custom_objects={"LAMB": tfa.image.optimizer.LAMB} + ) + ``` + + or you can do: + ```python + tfa.register_all() + tf.tf.keras.models.load_model("my_model.tf") + ``` + + If the model contains custom ops (compiled ops) of TensorFlow Addons, + and the graph is loaded with `tf.saved_model.load`, then custom ops need + to be registered before to avoid an error of the type: + + ``` + tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered + '...' in binary running on ... Make sure the Op and Kernel are + registered in the binary running in this process. + ``` + + In this case, the only way to make sure that the ops are registered is to call + this function: + + ```python + tfa.register_all() + tf.saved_model.load("my_model.tf") + ``` + + Note that you can call this function multiple times in the same process, + it only has an effect the first time. Afterward, it's just a no-op. + + Args: + keras_objects: boolean, `True` by default. If `True`, register all + Keras objects + with `tf.keras.utils.register_keras_serializable(package="Addons")` + If set to False, doesn't register any Keras objects + of Addons in TensorFlow. + custom_kernels: boolean, `True` by default. If `True`, loads all + custom kernels of TensorFlow Addons with + `tf.load_op_library("path/to/so/file.so")`. Loading the SO files + register them automatically. If `False` doesn't load and register + the shared objects files. Not that it might be useful to turn it off + if your installation of Addons doesn't work well with custom ops. + Returns: + None + """ + if keras_objects: + register_keras_objects() + if custom_kernels: + register_custom_kernels() + + +def register_keras_objects() -> None: + # TODO: once layer_test is replaced by a public API + # and we can used unregistered objects with it + # we can remove all decorators. + # And register Keras objects here. + pass + + +def register_custom_kernels() -> None: + all_shared_objects = _get_all_shared_objects() + if not all_shared_objects: + raise FileNotFoundError( + "No shared objects files were found in the custom ops " + "directory in Tensorflow Addons, check your installation again," + "or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`" + " instead." + ) + try: + for shared_object in all_shared_objects: + tf.load_op_library(shared_object) + except tf.errors.NotFoundError as e: + raise RuntimeError( + "One of the shared objects ({}) could not be loaded. This may be " + "due to a number of reasons (incompatible TensorFlow version, buiding from " + "source with different flags, broken install of TensorFlow Addons...). If you" + "wanted to register the shared objects because you needed them when loading your " + "model, you should fix your install of TensorFlow Addons. If you don't " + "use custom ops in your model, you can skip registering custom ops with " + "`tfa.register_all(custom_kernels=False)`".format(shared_object) + ) from e + + +def _get_all_shared_objects(): + custom_ops_dir = os.path.join(get_project_root(), "custom_ops") + all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True) + all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()] + return all_shared_objects diff --git a/tensorflow_addons/register_test.py b/tensorflow_addons/register_test.py new file mode 100644 index 0000000000..248e6a24c2 --- /dev/null +++ b/tensorflow_addons/register_test.py @@ -0,0 +1,20 @@ +import unittest +import tensorflow as tf +from tensorflow_addons.register import register_all, _get_all_shared_objects + + +class AssertRNNCellTest(unittest.TestCase): + def test_multiple_register(self): + register_all() + register_all() + + def test_get_all_shared_objects(self): + all_shared_objects = _get_all_shared_objects() + self.assertTrue(len(all_shared_objects) >= 4) + + for file in all_shared_objects: + tf.load_op_library(file) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/verify/check_typing_info.py b/tools/ci_build/verify/check_typing_info.py index 8e998be09a..7aea9b2c13 100644 --- a/tools/ci_build/verify/check_typing_info.py +++ b/tools/ci_build/verify/check_typing_info.py @@ -13,12 +13,9 @@ # limitations under the License. # ============================================================================== # - -from types import ModuleType - from typedapi import ensure_api_is_typed -import tensorflow_addons +import tensorflow_addons as tfa TUTORIAL_URL = "https://docs.python.org/3/library/typing.html" HELP_MESSAGE = ( @@ -30,11 +27,18 @@ EXCEPTION_LIST = [] -modules_list = [] -for attr_name in dir(tensorflow_addons): - attr = getattr(tensorflow_addons, attr_name) - if isinstance(attr, ModuleType) and attr is not tensorflow_addons.options: - modules_list.append(attr) +modules_list = [ + tfa, + tfa.activations, + tfa.callbacks, + tfa.image, + tfa.losses, + tfa.metrics, + tfa.optimizers, + tfa.rnn, + tfa.seq2seq, + tfa.text, +] if __name__ == "__main__":