Skip to content

Commit 1458f7f

Browse files
Public function to register custom ops (#1193)
* Added public functions to register everything. * Removed decorator * Revert "Removed decorator" This reverts commit ebea5bd. * Added some tests. * Added the two register. * Removed unused variables. * Private func. * Explicit modules. * FLake8 * Added documentation. * Remove useless setup method. * Black/ * Format BUILD.
1 parent 21d0574 commit 1458f7f

File tree

5 files changed

+157
-9
lines changed

5 files changed

+157
-9
lines changed

tensorflow_addons/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
data = [
1313
"__init__.py",
1414
"options.py",
15+
"register.py",
1516
"version.py",
1617
],
1718
deps = [
@@ -25,5 +26,18 @@ py_library(
2526
"//tensorflow_addons/rnn",
2627
"//tensorflow_addons/seq2seq",
2728
"//tensorflow_addons/text",
29+
"//tensorflow_addons/utils",
30+
],
31+
)
32+
33+
py_test(
34+
name = "register_test",
35+
size = "small",
36+
srcs = [
37+
"register_test.py",
38+
],
39+
main = "register_test.py",
40+
deps = [
41+
":tensorflow_addons",
2842
],
2943
)

tensorflow_addons/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@
2828
from tensorflow_addons import rnn
2929
from tensorflow_addons import seq2seq
3030
from tensorflow_addons import text
31+
from tensorflow_addons.register import register_all
3132

3233
from tensorflow_addons.version import __version__

tensorflow_addons/register.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import glob
2+
import os
3+
from pathlib import Path
4+
5+
import tensorflow as tf
6+
7+
from tensorflow_addons.utils.resource_loader import get_project_root
8+
9+
10+
def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None:
11+
"""Register TensorFlow Addons' objects in TensorFlow global dictionaries.
12+
13+
When loading a Keras model that has a TF Addons' function, it is needed
14+
for this function to be known by the Keras deserialization process.
15+
16+
There are two ways to do this, either do
17+
18+
```python
19+
tf.keras.models.load_model(
20+
"my_model.tf",
21+
custom_objects={"LAMB": tfa.image.optimizer.LAMB}
22+
)
23+
```
24+
25+
or you can do:
26+
```python
27+
tfa.register_all()
28+
tf.tf.keras.models.load_model("my_model.tf")
29+
```
30+
31+
If the model contains custom ops (compiled ops) of TensorFlow Addons,
32+
and the graph is loaded with `tf.saved_model.load`, then custom ops need
33+
to be registered before to avoid an error of the type:
34+
35+
```
36+
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered
37+
'...' in binary running on ... Make sure the Op and Kernel are
38+
registered in the binary running in this process.
39+
```
40+
41+
In this case, the only way to make sure that the ops are registered is to call
42+
this function:
43+
44+
```python
45+
tfa.register_all()
46+
tf.saved_model.load("my_model.tf")
47+
```
48+
49+
Note that you can call this function multiple times in the same process,
50+
it only has an effect the first time. Afterward, it's just a no-op.
51+
52+
Args:
53+
keras_objects: boolean, `True` by default. If `True`, register all
54+
Keras objects
55+
with `tf.keras.utils.register_keras_serializable(package="Addons")`
56+
If set to False, doesn't register any Keras objects
57+
of Addons in TensorFlow.
58+
custom_kernels: boolean, `True` by default. If `True`, loads all
59+
custom kernels of TensorFlow Addons with
60+
`tf.load_op_library("path/to/so/file.so")`. Loading the SO files
61+
register them automatically. If `False` doesn't load and register
62+
the shared objects files. Not that it might be useful to turn it off
63+
if your installation of Addons doesn't work well with custom ops.
64+
Returns:
65+
None
66+
"""
67+
if keras_objects:
68+
register_keras_objects()
69+
if custom_kernels:
70+
register_custom_kernels()
71+
72+
73+
def register_keras_objects() -> None:
74+
# TODO: once layer_test is replaced by a public API
75+
# and we can used unregistered objects with it
76+
# we can remove all decorators.
77+
# And register Keras objects here.
78+
pass
79+
80+
81+
def register_custom_kernels() -> None:
82+
all_shared_objects = _get_all_shared_objects()
83+
if not all_shared_objects:
84+
raise FileNotFoundError(
85+
"No shared objects files were found in the custom ops "
86+
"directory in Tensorflow Addons, check your installation again,"
87+
"or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`"
88+
" instead."
89+
)
90+
try:
91+
for shared_object in all_shared_objects:
92+
tf.load_op_library(shared_object)
93+
except tf.errors.NotFoundError as e:
94+
raise RuntimeError(
95+
"One of the shared objects ({}) could not be loaded. This may be "
96+
"due to a number of reasons (incompatible TensorFlow version, buiding from "
97+
"source with different flags, broken install of TensorFlow Addons...). If you"
98+
"wanted to register the shared objects because you needed them when loading your "
99+
"model, you should fix your install of TensorFlow Addons. If you don't "
100+
"use custom ops in your model, you can skip registering custom ops with "
101+
"`tfa.register_all(custom_kernels=False)`".format(shared_object)
102+
) from e
103+
104+
105+
def _get_all_shared_objects():
106+
custom_ops_dir = os.path.join(get_project_root(), "custom_ops")
107+
all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True)
108+
all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()]
109+
return all_shared_objects

tensorflow_addons/register_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import tensorflow as tf
3+
from tensorflow_addons.register import register_all, _get_all_shared_objects
4+
5+
6+
class AssertRNNCellTest(unittest.TestCase):
7+
def test_multiple_register(self):
8+
register_all()
9+
register_all()
10+
11+
def test_get_all_shared_objects(self):
12+
all_shared_objects = _get_all_shared_objects()
13+
self.assertTrue(len(all_shared_objects) >= 4)
14+
15+
for file in all_shared_objects:
16+
tf.load_op_library(file)
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main()

tools/ci_build/verify/check_typing_info.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
#
16-
17-
from types import ModuleType
18-
1916
from typedapi import ensure_api_is_typed
2017

21-
import tensorflow_addons
18+
import tensorflow_addons as tfa
2219

2320
TUTORIAL_URL = "https://docs.python.org/3/library/typing.html"
2421
HELP_MESSAGE = (
@@ -30,11 +27,18 @@
3027
EXCEPTION_LIST = []
3128

3229

33-
modules_list = []
34-
for attr_name in dir(tensorflow_addons):
35-
attr = getattr(tensorflow_addons, attr_name)
36-
if isinstance(attr, ModuleType) and attr is not tensorflow_addons.options:
37-
modules_list.append(attr)
30+
modules_list = [
31+
tfa,
32+
tfa.activations,
33+
tfa.callbacks,
34+
tfa.image,
35+
tfa.losses,
36+
tfa.metrics,
37+
tfa.optimizers,
38+
tfa.rnn,
39+
tfa.seq2seq,
40+
tfa.text,
41+
]
3842

3943

4044
if __name__ == "__main__":

0 commit comments

Comments
 (0)