Skip to content

Commit 2029487

Browse files
Added an option to use the pure python implementation. (tensorflow#1137)
* Added an option to use the pure python implementation.
1 parent d3b2694 commit 2029487

File tree

7 files changed

+109
-16
lines changed

7 files changed

+109
-16
lines changed

README.md

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
| Build Type | Status |
1616
| --- | --- |
17-
| **MacOS CPU** | [![Status](https://github.com/tensorflow/addons/workflows/macos-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amacos-nightly) |
18-
| **Windows CPU** | [![Status](https://github.com/tensorflow/addons/workflows/windows-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Awindows-nightly) |
19-
| **Ubuntu CPU** | [![Status](https://github.com/tensorflow/addons/workflows/manylinux-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amanylinux-nightly) |
20-
| **Ubuntu GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) |
17+
| **MacOS** | [![Status](https://github.com/tensorflow/addons/workflows/macos-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amacos-nightly) |
18+
| **Windows** | [![Status](https://github.com/tensorflow/addons/workflows/windows-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Awindows-nightly) |
19+
| **Ubuntu** | [![Status](https://github.com/tensorflow/addons/workflows/manylinux-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amanylinux-nightly) |
20+
| **Ubuntu custom GPU ops** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) |
2121

2222
**TensorFlow Addons** is a repository of contributions that conform to
2323
well-established API patterns, but implement new functionality
@@ -113,9 +113,39 @@ TF-Addons. In order to achieve these we require that our additions
113113
conform to established API patterns seen in core TensorFlow.
114114

115115
#### GPU/CPU Custom-Ops
116-
A major benefit of TensorFlow Addons is that there are precompiled ops. Should
117-
a CUDA 10.1 installation not be found then the op will automatically fall back to
118-
a CPU implementation.
116+
A major benefit of TensorFlow Addons is that there are precompiled ops for CPU/GPU.
117+
Currently however, GPU custom ops only work for Linux distributions. For this reason Windows and MacOS will fallback to pure TensorFlow Python implementations whenever possible.
118+
119+
The order of priority in MacOS/Windows:
120+
1) Pure TensorFlow + Python implementation (work on cpu+gpu)
121+
2) C++ implementation for CPU
122+
123+
The order of priority for Linux:
124+
1) CUDA implementation
125+
2) C++ implementation
126+
3) Pure TensorFlow + Python implementation (work on cpu+gpu)
127+
128+
If you want to change the default priority, "C++ and CUDA" VS "pure TF Python",
129+
you can either set the variable `TF_ADDONS_PY_OPS` from the command line or in
130+
your code.
131+
132+
For example, if you're on linux and you have compatibility problems with the compiled ops,
133+
and you want to give priority to the Python implementation
134+
you can do:
135+
136+
From the command line:
137+
```
138+
export TF_ADDONS_PY_OPS=1
139+
```
140+
141+
or in your code:
142+
143+
```
144+
import tensorflow_addons as tfa
145+
tfa.options.TF_ADDONS_PY_OPS=True
146+
```
147+
148+
This variable will default to `True` on Windows and Mac, and `False` for Linux.
119149

120150
#### Proxy Maintainership
121151
Addons has been designed to compartmentalize subpackages and submodules so

tensorflow_addons/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
name = "tensorflow_addons",
1212
data = [
1313
"__init__.py",
14+
"options.py",
1415
"version.py",
1516
],
1617
deps = [

tensorflow_addons/activations/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ py_library(
1616
"tanhshrink.py",
1717
],
1818
data = [
19+
"//tensorflow_addons:options.py",
1920
"//tensorflow_addons/custom_ops/activations:_activation_ops.so",
2021
"//tensorflow_addons/utils",
2122
],

tensorflow_addons/activations/hardshrink.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tensorflow_addons.utils import types
2020
from tensorflow_addons.utils.resource_loader import LazySO
21+
from tensorflow_addons import options
2122

2223
_activation_so = LazySO("custom_ops/activations/_activation_ops.so")
2324

@@ -40,6 +41,18 @@ def hardshrink(
4041
A `Tensor`. Has the same type as `x`.
4142
"""
4243
x = tf.convert_to_tensor(x)
44+
45+
if not options.TF_ADDONS_PY_OPS:
46+
try:
47+
return _hardshrink_custom_op(x, lower, upper)
48+
except tf.errors.NotFoundError:
49+
options.warn_fallback("hardshrink")
50+
51+
return _hardshrink_py(x, lower, upper)
52+
53+
54+
def _hardshrink_custom_op(x, lower=-0.5, upper=0.5):
55+
"""Alias with lazy loading of the .so file"""
4356
return _activation_so.ops.addons_hardshrink(x, lower, upper)
4457

4558

@@ -59,7 +72,6 @@ def _hardshrink_py(
5972
" not be higher than the value "
6073
"variable upper, which is {} .".format(lower, upper)
6174
)
62-
x = tf.convert_to_tensor(x)
6375
mask_lower = x < lower
6476
mask_upper = upper < x
6577
mask = tf.logical_or(mask_lower, mask_upper)

tensorflow_addons/activations/hardshrink_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
import numpy as np
1919
import tensorflow as tf
20-
from tensorflow_addons.activations import hardshrink
21-
from tensorflow_addons.utils import test_utils
20+
from tensorflow_addons.activations.hardshrink import _hardshrink_custom_op
2221
from tensorflow_addons.activations.hardshrink import _hardshrink_py
22+
from tensorflow_addons.utils import test_utils
2323

2424

2525
@test_utils.run_all_in_graph_and_eager_modes
2626
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
2727
def test_invalid(self):
2828
with self.assertRaisesOpError("lower must be less than or equal to upper."):
29-
y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
29+
y = _hardshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
3030
self.evaluate(y)
3131

3232
@parameterized.named_parameters(
@@ -35,11 +35,11 @@ def test_invalid(self):
3535
def test_hardshrink(self, dtype):
3636
x = tf.constant([-2.0, -0.5, 0.0, 0.5, 2.0], dtype=dtype)
3737
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
38-
self.assertAllCloseAccordingToType(hardshrink(x), expected_result)
38+
self.assertAllCloseAccordingToType(_hardshrink_custom_op(x), expected_result)
3939

4040
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
4141
self.assertAllCloseAccordingToType(
42-
hardshrink(x, lower=-1.0, upper=1.0), expected_result
42+
_hardshrink_custom_op(x, lower=-1.0, upper=1.0), expected_result
4343
)
4444

4545
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
@@ -51,7 +51,7 @@ def test_theoretical_gradients(self, dtype):
5151
# Avoid these two points to make gradients smooth.
5252
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)
5353

54-
theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
54+
theoretical, numerical = tf.test.compute_gradient(_hardshrink_custom_op, [x])
5555
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
5656

5757
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
@@ -68,7 +68,7 @@ def verify_funcs_are_equivalent(self, dtype):
6868

6969
with tf.GradientTape(persistent=True) as t:
7070
t.watch(x)
71-
y_native = hardshrink(x, lower, upper)
71+
y_native = _hardshrink_custom_op(x, lower, upper)
7272
y_py = _hardshrink_py(x, lower, upper)
7373

7474
self.assertAllCloseAccordingToType(y_native, y_py)

tensorflow_addons/options.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
import platform
3+
import warnings
4+
import traceback
5+
6+
try:
7+
TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"]))
8+
except KeyError:
9+
if platform.system() == "Linux":
10+
TF_ADDONS_PY_OPS = False
11+
else:
12+
TF_ADDONS_PY_OPS = True
13+
14+
15+
FALLBACK_WARNING_TEMPLATE = """{}
16+
17+
The {} C++/CUDA custom op could not be loaded.
18+
For this reason, Addons will fallback to an implementation written
19+
in Python with public TensorFlow ops. There worst you might experience with
20+
this is a moderate slowdown on GPU. There can be multiple
21+
reason for this loading error, one of them may be an ABI incompatibility between
22+
the TensorFlow installed on your system and the TensorFlow used to compile
23+
TensorFlow Addons' custom ops. The stacktrace generated when loading the
24+
shared object file was displayed above.
25+
26+
If you want this warning to disappear, either make sure the TensorFlow installed
27+
is compatible with this version of Addons, or tell TensorFlow Addons to
28+
prefer using Python implementations and not custom C++/CUDA ones. You can do that
29+
by changing the TF_ADDONS_PY_OPS flag
30+
either with the environment variable:
31+
```bash
32+
TF_ADDONS_PY_OPS=1 python my_script.py
33+
```
34+
or in your code, after your imports:
35+
```python
36+
import tensorflow_addons as tfa
37+
import ...
38+
import ...
39+
40+
tfa.options.TF_ADDONS_PY_OPS = True
41+
```
42+
"""
43+
44+
45+
def warn_fallback(op_name):
46+
warning_msg = FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name)
47+
warnings.warn(warning_msg, RuntimeWarning)
48+
global TF_ADDONS_PY_OPS
49+
TF_ADDONS_PY_OPS = True

tools/ci_build/verify/check_typing_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
modules_list = []
3434
for attr_name in dir(tensorflow_addons):
3535
attr = getattr(tensorflow_addons, attr_name)
36-
if isinstance(attr, ModuleType):
36+
if isinstance(attr, ModuleType) and attr is not tensorflow_addons.options:
3737
modules_list.append(attr)
3838

3939

0 commit comments

Comments
 (0)