|
| 1 | +import glob |
| 2 | +import os |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import tensorflow as tf |
| 6 | + |
| 7 | +from tensorflow_addons.utils.resource_loader import get_project_root |
| 8 | + |
| 9 | + |
| 10 | +def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None: |
| 11 | + """Register TensorFlow Addons' objects in TensorFlow global dictionaries. |
| 12 | +
|
| 13 | + When loading a Keras model that has a TF Addons' function, it is needed |
| 14 | + for this function to be known by the Keras deserialization process. |
| 15 | +
|
| 16 | + There are two ways to do this, either do |
| 17 | +
|
| 18 | + ```python |
| 19 | + tf.keras.models.load_model( |
| 20 | + "my_model.tf", |
| 21 | + custom_objects={"LAMB": tfa.image.optimizer.LAMB} |
| 22 | + ) |
| 23 | + ``` |
| 24 | +
|
| 25 | + or you can do: |
| 26 | + ```python |
| 27 | + tfa.register_all() |
| 28 | + tf.tf.keras.models.load_model("my_model.tf") |
| 29 | + ``` |
| 30 | +
|
| 31 | + If the model contains custom ops (compiled ops) of TensorFlow Addons, |
| 32 | + and the graph is loaded with `tf.saved_model.load`, then custom ops need |
| 33 | + to be registered before to avoid an error of the type: |
| 34 | +
|
| 35 | + ``` |
| 36 | + tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered |
| 37 | + '...' in binary running on ... Make sure the Op and Kernel are |
| 38 | + registered in the binary running in this process. |
| 39 | + ``` |
| 40 | +
|
| 41 | + In this case, the only way to make sure that the ops are registered is to call |
| 42 | + this function: |
| 43 | +
|
| 44 | + ```python |
| 45 | + tfa.register_all() |
| 46 | + tf.saved_model.load("my_model.tf") |
| 47 | + ``` |
| 48 | +
|
| 49 | + Note that you can call this function multiple times in the same process, |
| 50 | + it only has an effect the first time. Afterward, it's just a no-op. |
| 51 | +
|
| 52 | + Args: |
| 53 | + keras_objects: boolean, `True` by default. If `True`, register all |
| 54 | + Keras objects |
| 55 | + with `tf.keras.utils.register_keras_serializable(package="Addons")` |
| 56 | + If set to False, doesn't register any Keras objects |
| 57 | + of Addons in TensorFlow. |
| 58 | + custom_kernels: boolean, `True` by default. If `True`, loads all |
| 59 | + custom kernels of TensorFlow Addons with |
| 60 | + `tf.load_op_library("path/to/so/file.so")`. Loading the SO files |
| 61 | + register them automatically. If `False` doesn't load and register |
| 62 | + the shared objects files. Not that it might be useful to turn it off |
| 63 | + if your installation of Addons doesn't work well with custom ops. |
| 64 | + Returns: |
| 65 | + None |
| 66 | + """ |
| 67 | + if keras_objects: |
| 68 | + register_keras_objects() |
| 69 | + if custom_kernels: |
| 70 | + register_custom_kernels() |
| 71 | + |
| 72 | + |
| 73 | +def register_keras_objects() -> None: |
| 74 | + # TODO: once layer_test is replaced by a public API |
| 75 | + # and we can used unregistered objects with it |
| 76 | + # we can remove all decorators. |
| 77 | + # And register Keras objects here. |
| 78 | + pass |
| 79 | + |
| 80 | + |
| 81 | +def register_custom_kernels() -> None: |
| 82 | + all_shared_objects = _get_all_shared_objects() |
| 83 | + if not all_shared_objects: |
| 84 | + raise FileNotFoundError( |
| 85 | + "No shared objects files were found in the custom ops " |
| 86 | + "directory in Tensorflow Addons, check your installation again," |
| 87 | + "or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`" |
| 88 | + " instead." |
| 89 | + ) |
| 90 | + try: |
| 91 | + for shared_object in all_shared_objects: |
| 92 | + tf.load_op_library(shared_object) |
| 93 | + except tf.errors.NotFoundError as e: |
| 94 | + raise RuntimeError( |
| 95 | + "One of the shared objects ({}) could not be loaded. This may be " |
| 96 | + "due to a number of reasons (incompatible TensorFlow version, buiding from " |
| 97 | + "source with different flags, broken install of TensorFlow Addons...). If you" |
| 98 | + "wanted to register the shared objects because you needed them when loading your " |
| 99 | + "model, you should fix your install of TensorFlow Addons. If you don't " |
| 100 | + "use custom ops in your model, you can skip registering custom ops with " |
| 101 | + "`tfa.register_all(custom_kernels=False)`".format(shared_object) |
| 102 | + ) from e |
| 103 | + |
| 104 | + |
| 105 | +def _get_all_shared_objects(): |
| 106 | + custom_ops_dir = os.path.join(get_project_root(), "custom_ops") |
| 107 | + all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True) |
| 108 | + all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()] |
| 109 | + return all_shared_objects |
0 commit comments