Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tensorflow_addons/register.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import glob
import os
from pathlib import Path

import tensorflow as tf

from tensorflow_addons.utils.resource_loader import get_project_root
from tensorflow_addons.utils.resource_loader import get_path_to_datafile


def register_all(keras_objects: bool = True, custom_kernels: bool = True) -> None:
Expand Down Expand Up @@ -103,7 +102,7 @@ def register_custom_kernels() -> None:


def _get_all_shared_objects():
custom_ops_dir = os.path.join(get_project_root(), "custom_ops")
custom_ops_dir = get_path_to_datafile("custom_ops", is_so=True)
all_shared_objects = glob.glob(custom_ops_dir + "/**/*.so", recursive=True)
all_shared_objects = [x for x in all_shared_objects if Path(x).is_file()]
return all_shared_objects
10 changes: 8 additions & 2 deletions tensorflow_addons/utils/resource_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_project_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def get_path_to_datafile(path):
def get_path_to_datafile(path, is_so=False):
"""Get the path to the specified file in the data dependencies.

The path is relative to tensorflow_addons/
Expand All @@ -42,6 +42,10 @@ def get_path_to_datafile(path):
The path to the specified data file
"""
root_dir = get_project_root()
if is_so:
bazel_bin_dir = os.path.join(os.path.dirname(root_dir), "bazel-bin")
if os.path.isdir(bazel_bin_dir):
root_dir = os.path.join(bazel_bin_dir, "tensorflow_addons")
return os.path.join(root_dir, path.replace("/", os.sep))


Expand All @@ -61,7 +65,9 @@ def ops(self):
)
if self._ops is None:
self.display_warning_if_incompatible()
self._ops = tf.load_op_library(get_path_to_datafile(self.relative_path))
self._ops = tf.load_op_library(
get_path_to_datafile(self.relative_path, is_so=True)
)
return self._ops

def display_warning_if_incompatible(self):
Expand Down
2 changes: 1 addition & 1 deletion tools/run_gpu_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ docker build \
--build-arg TF_VERSION=2.2.0 \
--build-arg PY_VERSION=3.5 \
-t tfa_gpu_tests ./
docker run --rm -t -v cache_bazel:/root/.cache/bazel --gpus=all tfa_gpu_tests
docker run --rm -t --gpus=all tfa_gpu_tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is working so I don't know if anybody remember why the bazel cache volume was required here but we could revert this change.

2 changes: 1 addition & 1 deletion tools/testing/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ if ! [ -x "$(command -v nvidia-smi)" ]; then
EXTRA_ARGS="-n auto"
fi


bazel clean
Copy link
Contributor Author

@bhack bhack Jun 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanpmorgan I want to remove this but if I remove this it find some so in the cache that I don't understands where it comes from:
`

tensorflow.python.framework.errors_impl.NotFoundError: /addons/bazel-bin/tensorflow_addons/custom_ops/layers/lib_correlation_cost_ops_gpu.so: undefined symbol: _ZNK10tensorflow6Tensor21CheckTypeAndIsAlignedENS_8DataTypeE

In the local build I don't find any lib_correlation_cost_ops_gpu.so

python -m pytest -v --functions-durations=20 --modules-durations=5 $EXTRA_ARGS ./tensorflow_addons