From 2b004243f5a3f57e24f075a8139594916a994fb4 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 4 Jan 2019 14:06:31 -0500 Subject: [PATCH 1/8] Directory structure for discussion --- .gitignore | 37 ++ BUILD | 13 + CONTRIBUTING.md | 30 + LICENSE | 2 +- MANIFEST.in | 1 + README.md | 16 +- WORKSPACE | 5 + build_pip_pkg.sh | 68 ++ configure.sh | 35 + setup.py | 65 ++ tensorflow_addons/__init__.py | 0 tensorflow_addons/examples/demo.py | 23 + tensorflow_addons/layers/BUILD | 33 + tensorflow_addons/layers/README.md | 4 + tensorflow_addons/layers/__init__.py | 26 + tensorflow_addons/layers/python/__init__.py | 0 .../layers/python/layers/__init__.py | 0 .../python/layers/poincare_normalize.py | 58 ++ .../python/layers/poincare_normalize_test.py | 97 +++ .../layers/python/layers/wrappers.py | 155 +++++ .../layers/python/layers/wrappers_test.py | 86 +++ tensorflow_addons/opt/BUILD | 23 + tensorflow_addons/opt/README.md | 3 + tensorflow_addons/opt/__init__.py | 24 + tensorflow_addons/opt/python/__init__.py | 0 tensorflow_addons/opt/python/opt/__init__.py | 0 .../opt/python/opt/lazy_adam_optimizer.py | 118 ++++ .../python/opt/lazy_adam_optimizer_test.py | 366 +++++++++++ tensorflow_addons/text/BUILD | 55 ++ tensorflow_addons/text/README.md | 3 + tensorflow_addons/text/__init__.py | 24 + .../text/cc/kernels/skip_gram_kernels.cc | 138 ++++ .../text/cc/ops/skip_gram_ops.cc | 54 ++ tensorflow_addons/text/python/__init__.py | 0 tensorflow_addons/text/python/ops/__init__.py | 0 .../text/python/ops/skip_gram_ops.py | 448 +++++++++++++ .../text/python/ops/skip_gram_ops_test.py | 606 ++++++++++++++++++ tf/BUILD | 0 tf/BUILD.tpl | 18 + tf/tf_configure.bzl | 206 ++++++ 40 files changed, 2835 insertions(+), 5 deletions(-) create mode 100644 BUILD create mode 100644 CONTRIBUTING.md create mode 100644 MANIFEST.in create mode 100644 WORKSPACE create mode 100755 build_pip_pkg.sh create mode 100755 configure.sh create mode 100644 setup.py create mode 100644 tensorflow_addons/__init__.py create mode 100644 tensorflow_addons/examples/demo.py create mode 100644 tensorflow_addons/layers/BUILD create mode 100644 tensorflow_addons/layers/README.md create mode 100644 tensorflow_addons/layers/__init__.py create mode 100644 tensorflow_addons/layers/python/__init__.py create mode 100644 tensorflow_addons/layers/python/layers/__init__.py create mode 100644 tensorflow_addons/layers/python/layers/poincare_normalize.py create mode 100644 tensorflow_addons/layers/python/layers/poincare_normalize_test.py create mode 100644 tensorflow_addons/layers/python/layers/wrappers.py create mode 100644 tensorflow_addons/layers/python/layers/wrappers_test.py create mode 100644 tensorflow_addons/opt/BUILD create mode 100644 tensorflow_addons/opt/README.md create mode 100644 tensorflow_addons/opt/__init__.py create mode 100644 tensorflow_addons/opt/python/__init__.py create mode 100644 tensorflow_addons/opt/python/opt/__init__.py create mode 100644 tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py create mode 100644 tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py create mode 100644 tensorflow_addons/text/BUILD create mode 100644 tensorflow_addons/text/README.md create mode 100644 tensorflow_addons/text/__init__.py create mode 100644 tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc create mode 100644 tensorflow_addons/text/cc/ops/skip_gram_ops.cc create mode 100644 tensorflow_addons/text/python/__init__.py create mode 100644 tensorflow_addons/text/python/ops/__init__.py create mode 100644 tensorflow_addons/text/python/ops/skip_gram_ops.py create mode 100644 tensorflow_addons/text/python/ops/skip_gram_ops_test.py create mode 100644 tf/BUILD create mode 100644 tf/BUILD.tpl create mode 100644 tf/tf_configure.bzl diff --git a/.gitignore b/.gitignore index e69de29bb2..9b3bc500eb 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1,37 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Jupyter Notebook +.ipynb_checkpoints + +# IDE +.vscode/ +.idea/ + +# Build +/.bazelrc +/bazel-* +/artifacts \ No newline at end of file diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000..702c5dc395 --- /dev/null +++ b/BUILD @@ -0,0 +1,13 @@ +sh_binary( + name = "build_pip_pkg", + srcs = ["build_pip_pkg.sh"], + data = [ + "LICENSE", + "MANIFEST.in", + "setup.py", + "tensorflow_addons/__init__.py", + "//tensorflow_addons/layers:layers_py", + "//tensorflow_addons/opt:opt_py", + "//tensorflow_addons/text:text_py", + ], +) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..38d439115b --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,30 @@ +Want to contribute? Great! First, read this page (including the small print at the end). + +### Before you contribute + +Before we can use your code, you must sign the +[Google Individual Contributor License Agreement] +(https://cla.developers.google.com/about/google-individual) +(CLA), which you can do online. The CLA is necessary mainly because you own the +copyright to your changes, even after your contribution becomes part of our +codebase, so we need your permission to use and distribute your code. We also +need to be sure of various other things—for instance that you'll tell us if you +know that your code infringes on other people's patents. You don't have to sign +the CLA until after you've submitted your code for review and a member has +approved it, but you must do it before we can put your code into our codebase. +Before you start working on a larger contribution, you should get in touch with +us first through the issue tracker with your idea so that we can help out and +possibly guide you. Coordinating up front makes it much easier to avoid +frustration later on. + +### Code reviews + +All submissions, including submissions by project members, require review. We +use Github pull requests for this purpose. + +### The small print + +Contributions made by corporations are covered by a different agreement than +the one above, the +[Software Grant and Corporate Contributor License Agreement] +(https://cla.developers.google.com/about/google-corporate). \ No newline at end of file diff --git a/LICENSE b/LICENSE index 67cc16395d..8fc9e29688 100644 --- a/LICENSE +++ b/LICENSE @@ -188,7 +188,7 @@ Copyright 2018 The TensorFlow Authors. All rights reserved. same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2017, The TensorFlow Authors. + Copyright 2019, The TensorFlow Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..66661736de --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include tensorflow_addons/ *.so diff --git a/README.md b/README.md index efbc1fb51a..9aba5d7b42 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,15 @@ The tensorflow/addons repository, will contain additional functionality fitting --- -

-
- -
+# Developing + +## Docker +``` +``` + +## Packaging +``` +./configure.sh +bazel build build_pip_pkg +bazel-bin/build_pip_pkg artifacts +``` diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000000..ee0f0eb0a8 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,5 @@ +load("//tf:tf_configure.bzl", "tf_configure") + +tf_configure( + name = "local_config_tf", +) diff --git a/build_pip_pkg.sh b/build_pip_pkg.sh new file mode 100755 index 0000000000..5f799447b4 --- /dev/null +++ b/build_pip_pkg.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" + +PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.runfiles/__main__/" + +function main() { + while [[ ! -z "${1}" ]]; do + if [[ ${1} == "make" ]]; then + echo "Using Makefile to build pip package." + PIP_FILE_PREFIX="" + else + DEST=${1} + fi + shift + done + + if [[ -z ${DEST} ]]; then + echo "No destination dir provided" + exit 1 + fi + + # Create the directory, then do dirname on a non-existent file inside it to + # give us an absolute paths with tilde characters resolved to the destination + # directory. + mkdir -p ${DEST} + DEST=$(readlink -f "${DEST}") + echo "=== destination directory: ${DEST}" + + TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) + + echo $(date) : "=== Using tmpdir: ${TMPDIR}" + + echo "=== Copy TensorFlow Addons files" + + cp ${PIP_FILE_PREFIX}setup.py "${TMPDIR}" + cp ${PIP_FILE_PREFIX}MANIFEST.in "${TMPDIR}" + cp ${PIP_FILE_PREFIX}LICENSE "${TMPDIR}" + rsync -avm -L --exclude='*_test.py' ${PIP_FILE_PREFIX}tensorflow_addons "${TMPDIR}" + + pushd ${TMPDIR} + echo $(date) : "=== Building wheel" + + python setup.py bdist_wheel > /dev/null + + cp dist/*.whl "${DEST}" + popd + rm -rf ${TMPDIR} + echo $(date) : "=== Output wheel file is in: ${DEST}" +} + +main "$@" \ No newline at end of file diff --git a/configure.sh b/configure.sh new file mode 100755 index 0000000000..a4f4283d6d --- /dev/null +++ b/configure.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +function write_to_bazelrc() { + echo "$1" >> .bazelrc +} + +function write_action_env_to_bazelrc() { + write_to_bazelrc "build --action_env $1=\"$2\"" +} + +rm .bazelrc +if python -c "import tensorflow" &> /dev/null; then + echo 'using installed tensorflow' +else + pip install tensorflow +fi + +TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) +TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) + +write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} +write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${TF_LFLAGS:2} \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..0a379b1b52 --- /dev/null +++ b/setup.py @@ -0,0 +1,65 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Setup for pip package.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import find_packages +from setuptools import setup +from setuptools.dist import Distribution + +__version__ = '0.0.1' +REQUIRED_PACKAGES = [ + 'tf-nightly-2.0-preview', +] +project_name = 'tensorflow-addons' + + +class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + + def has_ext_modules(self): + return True + + +setup( + name=project_name, + version=__version__, + description=('TensorFlow Addons'), + author='Google Inc.', + author_email='opensource@google.com', + packages=find_packages(), + install_requires=REQUIRED_PACKAGES, + include_package_data=True, + zip_safe=False, + distclass=BinaryDistribution, + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: Software Development :: Libraries', + ], + license='Apache 2.0', + keywords='tensorflow addons machine learning', +) diff --git a/tensorflow_addons/__init__.py b/tensorflow_addons/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/examples/demo.py b/tensorflow_addons/examples/demo.py new file mode 100644 index 0000000000..b3e1709ed8 --- /dev/null +++ b/tensorflow_addons/examples/demo.py @@ -0,0 +1,23 @@ +# Copyright 2019 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# import tensorflow as tf +# import tensorflow_addons as tfa +# from tensorflow_addons.opt import LazyAdamOptimizer + + +# TODO: Build this out +if __name__ == "__main__": + pass diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD new file mode 100644 index 0000000000..797fbc60a4 --- /dev/null +++ b/tensorflow_addons/layers/BUILD @@ -0,0 +1,33 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "layers_poincare_py_test", + srcs = [ + "python/layers/poincare_normalize_test.py", + ], + main = "python/layers/poincare_normalize_test.py", + srcs_version = "PY2AND3", +) + +py_test( + name = "layers_wrappers_py_test", + srcs = [ + "python/layers/wrappers_test.py", + ], + main = "python/layers/wrappers_test.py", + srcs_version = "PY2AND3", +) + +py_library( + name = "layers_py", + srcs = ([ + "__init__.py", + "python/__init__.py", + "python/layers/__init__.py", + "python/layers/poincare_normalize.py", + "python/layers/wrappers.py", + ]), + srcs_version = "PY2AND3", +) \ No newline at end of file diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md new file mode 100644 index 0000000000..e3bfb00a0c --- /dev/null +++ b/tensorflow_addons/layers/README.md @@ -0,0 +1,4 @@ +# Addons - Layers + + +## Standard API \ No newline at end of file diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py new file mode 100644 index 0000000000..007cf4d683 --- /dev/null +++ b/tensorflow_addons/layers/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Ops for building neural network layers, regularizers, summaries, etc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Poincare Normalize +from tensorflow_addons.layers.python.layers.poincare_normalize import poincare_normalize + +# Weight Normalization +from tensorflow_addons.layers.python.layers.wrappers import WeightNorm diff --git a/tensorflow_addons/layers/python/__init__.py b/tensorflow_addons/layers/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/layers/python/layers/__init__.py b/tensorflow_addons/layers/python/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/layers/python/layers/poincare_normalize.py b/tensorflow_addons/layers/python/layers/poincare_normalize.py new file mode 100644 index 0000000000..851c38c4d4 --- /dev/null +++ b/tensorflow_addons/layers/python/layers/poincare_normalize.py @@ -0,0 +1,58 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops + + +def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): + """Project into the Poincare ball with norm <= 1.0 - epsilon. + + https://en.wikipedia.org/wiki/Poincare_ball_model + + Used in + Poincare Embeddings for Learning Hierarchical Representations + Maximilian Nickel, Douwe Kiela + https://arxiv.org/pdf/1705.08039.pdf + + For a 1-D tensor with `axis = 0`, computes + + (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon + output = + x otherwise + + For `x` with more dimensions, independently normalizes each 1-D slice along + dimension `axis`. + + Args: + x: A `Tensor`. + axis: Axis along which to normalize. A scalar or a vector of + integers. + epsilon: A small deviation from the edge of the unit sphere for numerical + stability. + name: A name for this operation (optional). + + Returns: + A `Tensor` with the same shape as `x`. + """ + with ops.name_scope(name, 'poincare_normalize', [x]) as name: + x = ops.convert_to_tensor(x, name='x') + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) + x_inv_norm = math_ops.rsqrt(square_sum) + x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.) + return math_ops.multiply(x, x_inv_norm, name=name) diff --git a/tensorflow_addons/layers/python/layers/poincare_normalize_test.py b/tensorflow_addons/layers/python/layers/poincare_normalize_test.py new file mode 100644 index 0000000000..eae709c02b --- /dev/null +++ b/tensorflow_addons/layers/python/layers/poincare_normalize_test.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.platform import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradient_checker +from tensorflow.python.framework import constant_op + +from tensorflow_addons.layers.python.layers.poincare_normalize import poincare_normalize + + +# TODO: Is this the prefered way to run tests in TF2? +@test_util.run_all_in_graph_and_eager_modes +class PoincareNormalizeTest(test.TestCase): + + def _PoincareNormalize(self, x, dim, epsilon=1e-5): + if isinstance(dim, list): + norm = np.linalg.norm(x, axis=tuple(dim)) + for d in dim: + norm = np.expand_dims(norm, d) + norm_x = ((1. - epsilon) * x) / norm + else: + norm = np.expand_dims(np.apply_along_axis(np.linalg.norm, dim, x), dim) + norm_x = ((1. - epsilon) * x) / norm + return np.where(norm > 1.0 - epsilon, norm_x, x) + + def testPoincareNormalize(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in range(len(x_shape)): + y_np = self._PoincareNormalize(x_np, dim, epsilon) + with self.cached_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = poincare_normalize(x_tf, dim, epsilon) + y_tf_eval = y_tf.numpy() + norm = np.linalg.norm(y_np, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + norm = np.linalg.norm(y_tf_eval, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + self.assertAllClose(y_np, y_tf_eval) + + def testPoincareNormalizeDimArray(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + y_np = self._PoincareNormalize(x_np, dim, epsilon) + with self.cached_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = poincare_normalize(x_tf, dim, epsilon) + y_tf_eval = y_tf.numpy() + norm = np.linalg.norm(y_np, axis=tuple(dim)) + self.assertLess(norm.max(), 1. - epsilon + tol) + norm = np.linalg.norm(y_tf_eval, axis=tuple(dim)) + self.assertLess(norm.max(), 1. - epsilon + tol) + self.assertAllClose(y_np, y_tf_eval, rtol=1e-6, atol=1e-6) + + def testPoincareNormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float64) + for dim in range(len(x_shape)): + with self.cached_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = poincare_normalize(x_tf, dim) + err = gradient_checker.compute_gradient_error(x_tf, + x_shape, + y_tf, + x_shape) + print('PoinCareNormalize gradient err = %g ' % err) + self.assertLess(err, 1e-4) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow_addons/layers/python/layers/wrappers.py b/tensorflow_addons/layers/python/layers/wrappers.py new file mode 100644 index 0000000000..90f2f3c105 --- /dev/null +++ b/tensorflow_addons/layers/python/layers/wrappers.py @@ -0,0 +1,155 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from tensorflow import name_scope +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.keras import initializers +from tensorflow.python.eager import context +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.layers import Wrapper +from tensorflow.python.ops import variables as tf_variables + + +class WeightNorm(Wrapper): + """ This wrapper reparameterizes a layer by decoupling the weight's + magnitude and direction. This speeds up convergence by improving the + conditioning of the optimization problem. + Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks: https://arxiv.org/abs/1602.07868 + Tim Salimans, Diederik P. Kingma (2016) + WeightNorm wrapper works for keras and tf layers. + ```python + net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), + input_shape=(32, 32, 3), data_init=True)(x) + net = WeightNorm(tf.keras.layers.Conv2D(16, 5, activation='relu'), + data_init=True) + net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), + data_init=True)(net) + net = WeightNorm(tf.keras.layers.Dense(n_classes), + data_init=True)(net) + ``` + Arguments: + layer: a layer instance. + data_init: If `True` use data dependent variable initialization + Raises: + ValueError: If not initialized with a `Layer` instance. + ValueError: If `Layer` does not contain a `kernel` of weights + NotImplementedError: If `data_init` is True and running graph execution + """ + def __init__(self, layer, data_init=False, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `WeightNorm` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) + + if not context.executing_eagerly() and data_init: + raise NotImplementedError( + 'Data dependent variable initialization is not available for ' + 'graph execution') + + self.initialized = True + if data_init: + self.initialized = False + + super(WeightNorm, self).__init__(layer, **kwargs) + self._track_checkpointable(layer, name='layer') + + def _compute_weights(self): + """Generate weights by combining the direction of weight vector + with it's norm """ + with name_scope('compute_weights'): + self.layer.kernel = nn_impl.l2_normalize( + self.layer.v, axis=self.norm_axes) * self.layer.g + + def _init_norm(self, weights): + """Set the norm of the weight vector""" + from tensorflow.python.ops.linalg_ops import norm + with name_scope('init_norm'): + flat = array_ops.reshape(weights, [-1, self.layer_depth]) + return array_ops.reshape(norm(flat, axis=0), (self.layer_depth,)) + + def _data_dep_init(self, inputs): + """Data dependent initialization for eager execution""" + from tensorflow.python.ops.nn import moments + from tensorflow.python.ops.math_ops import sqrt + + with name_scope('data_dep_init'): + # Generate data dependent init values + activation = self.layer.activation + self.layer.activation = None + x_init = self.layer.call(inputs) + m_init, v_init = moments(x_init, self.norm_axes) + scale_init = 1. / sqrt(v_init + 1e-10) + + # Assign data dependent init values + self.layer.g = self.layer.g * scale_init + self.layer.bias = (-m_init * scale_init) + self.layer.activation = activation + self.initialized = True + + def build(self, input_shape): + """Build `Layer`""" + input_shape = tensor_shape.TensorShape(input_shape).as_list() + self.input_spec = InputSpec(shape=input_shape) + + if not self.layer.built: + self.layer.build(input_shape) + self.layer.built = False + + if not hasattr(self.layer, 'kernel'): + raise ValueError( + '`WeightNorm` must wrap a layer that' + ' contains a `kernel` for weights' + ) + + # The kernel's filter or unit dimension is -1 + self.layer_depth = int(self.layer.kernel.shape[-1]) + self.norm_axes = list(range(self.layer.kernel.shape.ndims - 1)) + + self.layer.v = self.layer.kernel + self.layer.g = self.layer.add_variable( + name="g", + shape=(self.layer_depth,), + initializer=initializers.get('ones'), + dtype=self.layer.kernel.dtype, + trainable=True, + aggregation=tf_variables.VariableAggregation.MEAN) + + with ops.control_dependencies([self.layer.g.assign( + self._init_norm(self.layer.v))]): + self._compute_weights() + + self.layer.built = True + + super(WeightNorm, self).build() + self.built = True + + def call(self, inputs): + """Call `Layer`""" + if context.executing_eagerly(): + if not self.initialized: + self._data_dep_init(inputs) + self._compute_weights() # Recompute weights for each forward pass + + output = self.layer.call(inputs) + return output + + def compute_output_shape(self, input_shape): + return tensor_shape.TensorShape( + self.layer.compute_output_shape(input_shape).as_list()) diff --git a/tensorflow_addons/layers/python/layers/wrappers_test.py b/tensorflow_addons/layers/python/layers/wrappers_test.py new file mode 100644 index 0000000000..8c8dbae5b0 --- /dev/null +++ b/tensorflow_addons/layers/python/layers/wrappers_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow_addons.layers.python.layers import wrappers + +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.layers import layers +from tensorflow.python.training.rmsprop import RMSPropOptimizer + +from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python import keras + + +class WeightNormTest(test.TestCase): + + @tf_test_util.run_all_in_graph_and_eager_modes + def test_weightnorm_dense_train(self): + model = keras.models.Sequential() + model.add(wrappers.WeightNorm( + keras.layers.Dense(2), input_shape=(3, 4))) + + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.fit( + np.random.random((10, 3, 4)), + np.random.random((10, 3, 2)), + epochs=1, + batch_size=10) + self.assertTrue(hasattr(model.layers[0].layer, 'g')) + + # @tf_test_util.run_all_in_graph_and_eager_modes + # def test_weightnorm_conv2d(self): + # with self.test_session(): + # model = keras.models.Sequential() + # model.add(wrappers.WeightNorm( + # keras.layers.Conv2D(5, (2, 2), padding='same'), + # input_shape=(4, 4, 3))) + # + # model.add(keras.layers.Activation('relu')) + # model.compile(optimizer='rmsprop', loss='mse') + # model.train_on_batch( + # np.random.random((2, 4, 4, 3)), + # np.random.random((2, 4, 4, 8))) + # + # self.assertTrue(hasattr(model.layers[0].layer, 'g')) + + @tf_test_util.run_all_in_graph_and_eager_modes + def test_weight_norm_tflayers(self): + images = random_ops.random_uniform((2, 4, 4, 3)) + wn_wrapper = wrappers.WeightNorm(layers.Conv2D(32, [2, 2]), + input_shape=(4, 4, 3)) + wn_wrapper.apply(images) + self.assertTrue(hasattr(wn_wrapper.layer, 'g')) + + @tf_test_util.run_all_in_graph_and_eager_modes + def test_weight_norm_nonlayer(self): + images = random_ops.random_uniform((2, 4, 43)) + with self.assertRaises(ValueError): + wrappers.WeightNorm(images) + + @tf_test_util.run_all_in_graph_and_eager_modes + def test_weight_norm_nokernel(self): + with self.assertRaises(ValueError): + wrappers.WeightNorm(layers.MaxPooling2D(2, 2)).build((2, 2)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_addons/opt/BUILD b/tensorflow_addons/opt/BUILD new file mode 100644 index 0000000000..b5b8c5f3fa --- /dev/null +++ b/tensorflow_addons/opt/BUILD @@ -0,0 +1,23 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "opt_py_test", + srcs = [ + "python/opt/lazy_adam_optimizer_test.py", + ], + main = "python/opt/lazy_adam_optimizer_test.py", + srcs_version = "PY2AND3", +) + +py_library( + name = "opt_py", + srcs = ([ + "__init__.py", + "python/__init__.py", + "python/opt/__init__.py", + "python/opt/lazy_adam_optimizer.py", + ]), + srcs_version = "PY2AND3", +) \ No newline at end of file diff --git a/tensorflow_addons/opt/README.md b/tensorflow_addons/opt/README.md new file mode 100644 index 0000000000..966e5f1c65 --- /dev/null +++ b/tensorflow_addons/opt/README.md @@ -0,0 +1,3 @@ +# Addons - Optimizers + +## Standard API diff --git a/tensorflow_addons/opt/__init__.py b/tensorflow_addons/opt/__init__.py new file mode 100644 index 0000000000..c14aaab5b6 --- /dev/null +++ b/tensorflow_addons/opt/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +A module containing optimization routines. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Lazy Adam Optimizer +from tensorflow_addons.opt.python.opt.lazy_adam_optimizer import LazyAdamOptimizer diff --git a/tensorflow_addons/opt/python/__init__.py b/tensorflow_addons/opt/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/opt/python/opt/__init__.py b/tensorflow_addons/opt/python/opt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py new file mode 100644 index 0000000000..caa46e961f --- /dev/null +++ b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py @@ -0,0 +1,118 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Variant of the Adam optimizer that handles sparse updates more efficiently. + +Compared with the original Adam optimizer, the one in this file can provide a +large improvement in model training throughput for some applications. However, +it provides slightly different semantics than the original Adam algorithm, and +may lead to different empirical results. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# FIXME: Which way to import? +from tensorflow.python.keras.optimizer_v2.adam import Adam +# from tensorflow.keras.optimizers import Adam (package_hook) + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamOptimizer(Adam): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py new file mode 100644 index 0000000000..2ca96a6e5e --- /dev/null +++ b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py @@ -0,0 +1,366 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from absl.testing import parameterized +from tensorflow_addons.opt.python.opt import lazy_adam_optimizer + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_optimizer.LazyAdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD new file mode 100644 index 0000000000..fa71b7ce73 --- /dev/null +++ b/tensorflow_addons/text/BUILD @@ -0,0 +1,55 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + + +cc_binary( + name = 'python/ops/_skip_gram_ops.so', + srcs = [ + "cc/kernels/skip_gram_kernels.cc", + "cc/ops/skip_gram_ops.cc", + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0"] +) + + +py_library( + name = "text_ops_py", + srcs = ([ + "python/ops/skip_gram_ops.py", + ]), + data = [ + ":python/ops/_skip_gram_ops.so" + ], + srcs_version = "PY2AND3", +) + +py_test( + name = "text_ops_py_test", + srcs = [ + "python/ops/skip_gram_ops_test.py" + ], + main = "python/ops/skip_gram_ops_test.py", + deps = [ + ":text_ops_py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "text_py", + srcs = ([ + "__init__.py", + "python/__init__.py", + "python/ops/__init__.py", + ]), + deps = [ + ":text_ops_py" + ], + srcs_version = "PY2AND3", +) \ No newline at end of file diff --git a/tensorflow_addons/text/README.md b/tensorflow_addons/text/README.md new file mode 100644 index 0000000000..82c60c97b7 --- /dev/null +++ b/tensorflow_addons/text/README.md @@ -0,0 +1,3 @@ +# Addons Text + +## Standard API \ No newline at end of file diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py new file mode 100644 index 0000000000..34bd006be3 --- /dev/null +++ b/tensorflow_addons/text/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Text-processing ops. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Skip Gram Sample +from tensorflow_addons.text.python.ops.skip_gram_ops import skip_gram_sample +from tensorflow_addons.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab diff --git a/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc new file mode 100644 index 0000000000..c75b98a924 --- /dev/null +++ b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc @@ -0,0 +1,138 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +template +class SkipGramGenerateCandidatesOp : public OpKernel { + public: + explicit SkipGramGenerateCandidatesOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, generator_.Init(context)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input_tensor", &input_tensor)); + const auto input = input_tensor->flat(); + + const Tensor* min_skips_tensor; + OP_REQUIRES_OK(context, context->input("min_skips", &min_skips_tensor)); + const int min_skips = *(min_skips_tensor->scalar().data()); + const Tensor* max_skips_tensor; + OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor)); + const int max_skips = *(max_skips_tensor->scalar().data()); + + OP_REQUIRES( + context, min_skips >= 0 && max_skips >= 0, + errors::InvalidArgument("Both min_skips and max_skips must be >= 0.")); + OP_REQUIRES(context, min_skips <= max_skips, + errors::InvalidArgument("min_skips must be <= max_skips.")); + + const Tensor* start_tensor; + OP_REQUIRES_OK(context, context->input("start", &start_tensor)); + const int start = *(start_tensor->scalar().data()); + const Tensor* limit_tensor; + OP_REQUIRES_OK(context, context->input("limit", &limit_tensor)); + const int limit = *(limit_tensor->scalar().data()); + const int end = + limit < 0 ? input.size() + : std::min(start + limit, static_cast(input.size())); + + const Tensor* emit_self_tensor; + OP_REQUIRES_OK(context, + context->input("emit_self_as_target", &emit_self_tensor)); + const bool emit_self_as_target = *(emit_self_tensor->scalar().data()); + + std::vector tokens; + std::vector labels; + + // Reserve the number of random numbers we will use - we use one for each + // token between start and end. + random::PhiloxRandom local_gen = + generator_.ReserveSamples32(end - start + 1); + random::SimplePhilox rng(&local_gen); + + // For each token in the sentence, pick a random skip, then generates + // (token, label) pairs for all labels whose distances from the token are + // within the range [-skip, skip]. + for (int i = start; i < end; ++i) { + const int skips = min_skips + rng.Uniform(max_skips - min_skips + 1); + for (int j = -skips; j <= skips; ++j) { + if ((i + j < start) || (i + j >= end) || + (j == 0 && !emit_self_as_target)) { + continue; + } + tokens.push_back(input(i)); + labels.push_back(input(i + j)); + } + } + + Tensor* tokens_output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "tokens", TensorShape({static_cast(tokens.size())}), + &tokens_output)); + Tensor* labels_output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "labels", TensorShape({static_cast(labels.size())}), + &labels_output)); + OP_REQUIRES( + context, tokens_output->IsSameSize(*labels_output), + errors::Internal(strings::StrCat( + "Mismatch between tokens_output shape of ", + tokens_output->shape().DebugString(), + " and labels_output shape of ", + labels_output->shape().DebugString(), + ". This should never happen - contact ami-team@ if it does."))); + + // Copies results to output tensors. + for (int i = 0; i < tokens.size(); ++i) { + tokens_output->vec()(i) = tokens[i]; + labels_output->vec()(i) = labels[i]; + } + } + + private: + GuardedPhiloxRandom generator_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("SkipGramGenerateCandidates") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + SkipGramGenerateCandidatesOp) + +REGISTER_KERNEL(string); +REGISTER_KERNEL(int64); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int16); + +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow_addons/text/cc/ops/skip_gram_ops.cc b/tensorflow_addons/text/cc/ops/skip_gram_ops.cc new file mode 100644 index 0000000000..a7b2275024 --- /dev/null +++ b/tensorflow_addons/text/cc/ops/skip_gram_ops.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +REGISTER_OP("SkipGramGenerateCandidates") + .Input("input_tensor: T") + .Input("min_skips: int32") + .Input("max_skips: int32") + .Input("start: int32") + .Input("limit: int32") + .Input("emit_self_as_target: bool") + .Output("tokens: T") + .Output("labels: T") + .Attr("T: type") + // The seed attributes are needed by GuardedPhiloxRandom + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // input_tensor must be of rank-1. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + // All other args must be scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + + // Due to possible randomness in selecting skips, we only know that the + // outputs will be of rank-1, but not their sizes. + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Generates skip-gram token and label paired Tensors from the input tensor. +See docs for the public-facing skip_gram_sample() Python op for more details. +)doc"); +} // namespace tensorflow diff --git a/tensorflow_addons/text/python/__init__.py b/tensorflow_addons/text/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/text/python/ops/__init__.py b/tensorflow_addons/text/python/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/text/python/ops/skip_gram_ops.py b/tensorflow_addons/text/python/ops/skip_gram_ops.py new file mode 100644 index 0000000000..8eeb2dfe8e --- /dev/null +++ b/tensorflow_addons/text/python/ops/skip_gram_ops.py @@ -0,0 +1,448 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv + +from tensorflow.python.ops import lookup_ops +from tensorflow.python.framework import load_library +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import resource_loader +from tensorflow.python.training import input as input_ops + +skip_gram_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile("_skip_gram_ops.so")) + +ops.NotDifferentiable("SkipGramGenerateCandidates") + + +def skip_gram_sample(input_tensor, + min_skips=1, + max_skips=5, + start=0, + limit=-1, + emit_self_as_target=False, + vocab_freq_table=None, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + batch_size=None, + batch_capacity=None, + seed=None, + name=None): + """Generates skip-gram token and label paired Tensors from the input tensor. + + Generates skip-gram `("token", "label")` pairs using each element in the + rank-1 `input_tensor` as a token. The window size used for each token will be + randomly selected from the range specified by `[min_skips, max_skips]`, + inclusive. See https://arxiv.org/abs/1301.3781 for more details about + skip-gram. + + For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`, + `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output + `(tokens, labels)` pairs for the token "quick" will be randomly selected from + either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or + `(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2 + skips. + + If `emit_self_as_target = True`, each token will also be emitted as a label + for itself. From the previous example, the output will be either + `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1 + skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the", + "quick", "brown", "fox"])` for 2 skips. + + The same process is repeated for each element of `input_tensor` and + concatenated together into the two output rank-1 `Tensors` (one for all the + tokens, another for all the labels). + + If `vocab_freq_table` is specified, tokens in `input_tensor` that are not + present in the vocabulary are discarded. Tokens whose frequency counts are + below `vocab_min_count` are also discarded. Tokens whose frequency proportions + in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See + Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling. + + Due to the random window sizes used for each token, the lengths of the outputs + are non-deterministic, unless `batch_size` is specified to batch the outputs + to always return `Tensors` of length `batch_size`. + + Args: + input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. + min_skips: `int` or scalar `Tensor` specifying the minimum window size to + randomly use for each token. Must be >= 0 and <= `max_skips`. If + `min_skips` and `max_skips` are both 0, the only label outputted will be + the token itself when `emit_self_as_target = True` - or no output + otherwise. + max_skips: `int` or scalar `Tensor` specifying the maximum window size to + randomly use for each token. Must be >= 0. + start: `int` or scalar `Tensor` specifying the position in + `input_tensor` from which to start generating skip-gram candidates. + limit: `int` or scalar `Tensor` specifying the maximum number of + elements in `input_tensor` to use in generating skip-gram candidates. -1 + means to use the rest of the `Tensor` after `start`. + emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit + each token as a label for itself. + vocab_freq_table: (Optional) A lookup table (subclass of + `lookup.InitializableLookupTableBase`) that maps tokens to their raw + frequency counts. If specified, any token in `input_tensor` that is not + found in `vocab_freq_table` will be filtered out before generating + skip-gram candidates. While this will typically map to integer raw + frequency counts, it could also map to float frequency proportions. + `vocab_min_count` and `corpus_size` should be in the same units as this. + vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying + minimum frequency threshold (from `vocab_freq_table`) for a token to be + kept in `input_tensor`. If this is specified, `vocab_freq_table` must also + be specified - and they should both be in the same units. + vocab_subsampling: (Optional) `float` specifying frequency proportion + threshold for tokens from `input_tensor`. Tokens that occur more + frequently (based on the ratio of the token's `vocab_freq_table` value to + the `corpus_size`) will be randomly down-sampled. Reasonable starting + values may be around 1e-3 or 1e-5. If this is specified, both + `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5 + in http://arxiv.org/abs/1310.4546 for more details. + corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the + total number of tokens in the corpus (e.g., sum of all the frequency + counts of `vocab_freq_table`). Used with `vocab_subsampling` for + down-sampling frequently occurring tokens. If this is specified, + `vocab_freq_table` and `vocab_subsampling` must also be specified. + batch_size: (Optional) `int` specifying batch size of returned `Tensors`. + batch_capacity: (Optional) `int` specifying batch capacity for the queue + used for batching returned `Tensors`. Only has an effect if + `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. + seed: (Optional) `int` used to create a random seed for window size and + subsampling. See `set_random_seed` docs for behavior. + name: (Optional) A `string` name or a name scope for the operations. + + Returns: + A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of + rank-1 and has the same type as `input_tensor`. The `Tensors` will be of + length `batch_size`; if `batch_size` is not specified, they will be of + random length, though they will be in sync with each other as long as they + are evaluated together. + + Raises: + ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`, + `vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling` + and `corpus_size` are not both present or both absent. + """ + + if vocab_freq_table is None and (vocab_min_count is not None or + vocab_subsampling is not None or + corpus_size is not None): + raise ValueError( + "vocab_freq_table is not provided, but vocab_min_count={}, " + "vocab_subsampling={}, or corpus_size={} is not None. These settings " + "are useless without a vocab_freq_table.".format( + vocab_min_count, vocab_subsampling, corpus_size)) + + if (vocab_subsampling is None) != (corpus_size is None): + raise ValueError( + "vocab_subsampling is {} while corpus_size is {} - both must be " + "provided in order for subsampling to work.".format( + vocab_subsampling, corpus_size)) + + with ops.name_scope( + name, + "skip_gram_sample", + values=[input_tensor, min_skips, max_skips, start, limit]): + + input_tensor = _filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=vocab_min_count, + vocab_subsampling=vocab_subsampling, + corpus_size=corpus_size, + seed=seed) + + seed1, seed2 = random_seed.get_seed(seed) + tokens, labels = skip_gram_ops.skip_gram_generate_candidates( + input_tensor=input_tensor, + min_skips=min_skips, + max_skips=max_skips, + start=start, + limit=limit, + emit_self_as_target=emit_self_as_target, + # Note that seed here should be seed1! This is due to + # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2". + seed=seed1, + seed2=seed2) + + # TODO(weiho): If the need arises, add support for sparse input_tensor that + # figures out sentence boundaries, then calls + # skip_gram_generate_candidates() on each sentence. + + # Batches the (tokens, labels) outputs so that they will be of deterministic + # batch_size, to facilitate feeding them into the rest of the network. + if batch_size is not None and batch_size > 0: + batch_capacity = (batch_capacity + if (batch_capacity is not None and batch_capacity > 0) + else 100 * batch_size) + return input_ops.batch( + [tokens, labels], + batch_size, + capacity=batch_capacity, + enqueue_many=True) + + return tokens, labels + + +def skip_gram_sample_with_text_vocab(input_tensor, + vocab_freq_file, + vocab_token_index=0, + vocab_token_dtype=dtypes.string, + vocab_freq_index=1, + vocab_freq_dtype=dtypes.float64, + vocab_delimiter=",", + vocab_min_count=0, + vocab_subsampling=None, + corpus_size=None, + min_skips=1, + max_skips=5, + start=0, + limit=-1, + emit_self_as_target=False, + batch_size=None, + batch_capacity=None, + seed=None, + name=None): + """Skip-gram sampling with a text vocabulary file. + + Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The + vocabulary file is expected to be a plain-text file, with lines of + `vocab_delimiter`-separated columns. The `vocab_token_index` column should + contain the vocabulary term, while the `vocab_freq_index` column should + contain the number of times that term occurs in the corpus. For example, with + a text vocabulary file of: + + ``` + bonjour,fr,42 + hello,en,777 + hola,es,99 + ``` + + You should set `vocab_delimiter=","`, `vocab_token_index=0`, and + `vocab_freq_index=2`. + + See `skip_gram_sample()` documentation for more details about the skip-gram + sampling process. + + Args: + input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. + vocab_freq_file: `string` specifying full file path to the text vocab file. + vocab_token_index: `int` specifying which column in the text vocab file + contains the tokens. + vocab_token_dtype: `DType` specifying the format of the tokens in the text + vocab file. + vocab_freq_index: `int` specifying which column in the text vocab file + contains the frequency counts of the tokens. + vocab_freq_dtype: `DType` specifying the format of the frequency counts in + the text vocab file. + vocab_delimiter: `string` specifying the delimiter used in the text vocab + file. + vocab_min_count: `int`, `float`, or scalar `Tensor` specifying + minimum frequency threshold (from `vocab_freq_file`) for a token to be + kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. + vocab_subsampling: (Optional) `float` specifying frequency proportion + threshold for tokens from `input_tensor`. Tokens that occur more + frequently will be randomly down-sampled. Reasonable starting values may + be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for + more details. + corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the + total number of tokens in the corpus (e.g., sum of all the frequency + counts of `vocab_freq_file`). Used with `vocab_subsampling` for + down-sampling frequently occurring tokens. If this is specified, + `vocab_freq_file` and `vocab_subsampling` must also be specified. + If `corpus_size` is needed but not supplied, then it will be calculated + from `vocab_freq_file`. You might want to supply your own value if you + have already eliminated infrequent tokens from your vocabulary files + (where frequency < vocab_min_count) to save memory in the internal token + lookup table. Otherwise, the unused tokens' variables will waste memory. + The user-supplied `corpus_size` value must be greater than or equal to the + sum of all the frequency counts of `vocab_freq_file`. + min_skips: `int` or scalar `Tensor` specifying the minimum window size to + randomly use for each token. Must be >= 0 and <= `max_skips`. If + `min_skips` and `max_skips` are both 0, the only label outputted will be + the token itself. + max_skips: `int` or scalar `Tensor` specifying the maximum window size to + randomly use for each token. Must be >= 0. + start: `int` or scalar `Tensor` specifying the position in `input_tensor` + from which to start generating skip-gram candidates. + limit: `int` or scalar `Tensor` specifying the maximum number of elements in + `input_tensor` to use in generating skip-gram candidates. -1 means to use + the rest of the `Tensor` after `start`. + emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit + each token as a label for itself. + batch_size: (Optional) `int` specifying batch size of returned `Tensors`. + batch_capacity: (Optional) `int` specifying batch capacity for the queue + used for batching returned `Tensors`. Only has an effect if + `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. + seed: (Optional) `int` used to create a random seed for window size and + subsampling. See + [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) + for behavior. + name: (Optional) A `string` name or a name scope for the operations. + + Returns: + A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of + rank-1 and has the same type as `input_tensor`. The `Tensors` will be of + length `batch_size`; if `batch_size` is not specified, they will be of + random length, though they will be in sync with each other as long as they + are evaluated together. + + Raises: + ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or + exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index` + and `vocab_freq_index` are both set to the same column. If any token in + `vocab_freq_file` has a negative frequency. + """ + + if vocab_token_index < 0 or vocab_freq_index < 0: + raise ValueError( + "vocab_token_index={} and vocab_freq_index={} must both be >= 0.". + format(vocab_token_index, vocab_freq_index)) + if vocab_token_index == vocab_freq_index: + raise ValueError( + "vocab_token_index and vocab_freq_index should be different, but are " + "both {}.".format(vocab_token_index)) + + # Iterates through the vocab file and calculates the number of vocab terms as + # well as the total corpus size (by summing the frequency counts of all the + # vocab terms). + calculated_corpus_size = 0.0 + vocab_size = 0 + with gfile.GFile(vocab_freq_file, mode="r") as f: + reader = csv.reader(f, delimiter=vocab_delimiter) + for row in reader: + if vocab_token_index >= len(row) or vocab_freq_index >= len(row): + raise ValueError( + "Row in vocab file only has {} columns, so vocab_token_index={} or " + "vocab_freq_index={} is out of bounds. Row content: {}".format( + len(row), vocab_token_index, vocab_freq_index, row)) + vocab_size += 1 + freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) + if freq < 0: + raise ValueError( + "Row in vocab file has negative frequency of {}. Row content: {}". + format(freq, row)) + # Note: tokens whose frequencies are below vocab_min_count will still + # contribute to the total corpus size used for vocab subsampling. + calculated_corpus_size += freq + + if not corpus_size: + corpus_size = calculated_corpus_size + elif calculated_corpus_size - corpus_size > 1e-6: + raise ValueError( + "`corpus_size`={} must be greater than or equal to the sum of all the " + "frequency counts ({}) of `vocab_freq_file` ({}).".format( + corpus_size, calculated_corpus_size, vocab_freq_file)) + + vocab_freq_table = lookup_ops.HashTable( + lookup_ops.TextFileInitializer( + filename=vocab_freq_file, + key_dtype=vocab_token_dtype, + key_index=vocab_token_index, + value_dtype=vocab_freq_dtype, + value_index=vocab_freq_index, + vocab_size=vocab_size, + delimiter=vocab_delimiter), + # For vocab terms not in vocab file, use a default value of -1. + default_value=-1) + + return skip_gram_sample( + input_tensor, + min_skips=min_skips, + max_skips=max_skips, + start=start, + limit=limit, + emit_self_as_target=emit_self_as_target, + vocab_freq_table=vocab_freq_table, + vocab_min_count=vocab_min_count, + vocab_subsampling=vocab_subsampling, + # corpus_size is not used unless vocab_subsampling is specified. + corpus_size=None if vocab_subsampling is None else corpus_size, + batch_size=batch_size, + batch_capacity=batch_capacity, + seed=seed, + name=name) + + +def _filter_input(input_tensor, vocab_freq_table, vocab_min_count, + vocab_subsampling, corpus_size, seed): + """Filters input tensor based on vocab freq, threshold, and subsampling.""" + if vocab_freq_table is None: + return input_tensor + + if not isinstance(vocab_freq_table, lookup_ops.InitializableLookupTableBase): + raise ValueError( + "vocab_freq_table must be a subclass of " + "InitializableLookupTableBase (such as HashTable) instead of type " + "{}.".format(type(vocab_freq_table))) + + with ops.name_scope( + "filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]): + freq = vocab_freq_table.lookup(input_tensor) + # Filters out elements in input_tensor that are not found in + # vocab_freq_table (table returns a default value of -1 specified above when + # an element is not found). + mask = math_ops.not_equal(freq, vocab_freq_table.default_value) + + # Filters out elements whose vocab frequencies are less than the threshold. + if vocab_min_count is not None: + cast_threshold = math_ops.cast(vocab_min_count, freq.dtype) + mask = math_ops.logical_and(mask, + math_ops.greater_equal(freq, cast_threshold)) + + input_tensor = array_ops.boolean_mask(input_tensor, mask) + freq = array_ops.boolean_mask(freq, mask) + + if not vocab_subsampling: + return input_tensor + + if vocab_subsampling < 0 or vocab_subsampling > 1: + raise ValueError( + "Invalid vocab_subsampling={} - it should be within range [0, 1].". + format(vocab_subsampling)) + + # Subsamples the input tokens based on vocabulary frequency and + # vocab_subsampling threshold (ie randomly discard commonly appearing + # tokens). + with ops.name_scope( + "subsample_vocab", values=[input_tensor, freq, vocab_subsampling]): + corpus_size = math_ops.cast(corpus_size, dtypes.float64) + freq = math_ops.cast(freq, dtypes.float64) + vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64) + + # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is + # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546. + keep_prob = ((math_ops.sqrt(freq / + (vocab_subsampling * corpus_size)) + 1.0) * + (vocab_subsampling * corpus_size / freq)) + random_prob = random_ops.random_uniform( + array_ops.shape(freq), + minval=0, + maxval=1, + dtype=dtypes.float64, + seed=seed) + + mask = math_ops.less_equal(random_prob, keep_prob) + return array_ops.boolean_mask(input_tensor, mask) diff --git a/tensorflow_addons/text/python/ops/skip_gram_ops_test.py b/tensorflow_addons/text/python/ops/skip_gram_ops_test.py new file mode 100644 index 0000000000..c2256dac16 --- /dev/null +++ b/tensorflow_addons/text/python/ops/skip_gram_ops_test.py @@ -0,0 +1,606 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Skip-gram sampling ops tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import csv +import os + +from tensorflow_addons.text.python.ops import skip_gram_ops +from tensorflow_addons import text + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + + +class SkipGramOpsTest(test.TestCase): + + def _split_tokens_labels(self, output): + tokens = [x[0] for x in output] + labels = [x[1] for x in output] + return tokens, labels + + def test_skip_gram_sample_skips_2(self): + """Tests skip-gram with min_skips = max_skips = 2.""" + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=2, max_skips=2) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"quick"), + (b"fox", b"brown"), + (b"fox", b"jumps"), + (b"jumps", b"brown"), + (b"jumps", b"fox"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_emit_self(self): + """Tests skip-gram with emit_self_as_target = True.""" + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"the"), + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"quick"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"brown"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"quick"), + (b"fox", b"brown"), + (b"fox", b"fox"), + (b"fox", b"jumps"), + (b"jumps", b"brown"), + (b"jumps", b"fox"), + (b"jumps", b"jumps"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_skips_0(self): + """Tests skip-gram with min_skips = max_skips = 0.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + + # If emit_self_as_target is False (default), output will be empty. + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) + with self.cached_session(): + self.assertEqual(0, tokens.eval().size) + self.assertEqual(0, labels.eval().size) + + # If emit_self_as_target is True, each token will be its own label. + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"the"), + (b"quick", b"quick"), + (b"brown", b"brown"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_skips_exceed_length(self): + """Tests skip-gram when min/max_skips exceed length of input.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=100, max_skips=100) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"the"), + (b"brown", b"quick"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_start_limit(self): + """Tests skip-gram over a limited portion of the input.""" + input_tensor = constant_op.constant( + [b"foo", b"the", b"quick", b"brown", b"bar"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, start=1, limit=3) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_limit_exceeds(self): + """Tests skip-gram when limit exceeds the length of the input.""" + input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, start=1, limit=100) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_random_skips(self): + """Tests skip-gram with min_skips != max_skips, with random output.""" + # The number of outputs is non-deterministic in this case, so set random + # seed to help ensure the outputs remain constant for this test case. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant( + [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=2, seed=9) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"the", b"brown"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"quick", b"fox"), + (b"brown", b"the"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"brown", b"jumps"), + (b"fox", b"brown"), + (b"fox", b"jumps"), + (b"jumps", b"fox"), + (b"jumps", b"over"), + (b"over", b"fox"), + (b"over", b"jumps"), + ]) + with self.cached_session() as sess: + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens, tokens_eval) + self.assertAllEqual(expected_labels, labels_eval) + + def test_skip_gram_sample_random_skips_default_seed(self): + """Tests outputs are still random when no op-level seed is specified.""" + # This is needed since tests set a graph-level seed by default. We want to + # explicitly avoid setting both graph-level seed and op-level seed, to + # simulate behavior under non-test settings when the user doesn't provide a + # seed to us. This results in random_seed.get_seed() returning None for both + # seeds, forcing the C++ kernel to execute its default seed logic. + random_seed.set_random_seed(None) + + # Uses an input tensor with 10 words, with possible skip ranges in [1, + # 5]. Thus, the probability that two random samplings would result in the + # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being + # flaky). + input_tensor = constant_op.constant([str(x) for x in range(10)]) + + # Do not provide an op-level seed here! + tokens_1, labels_1 = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=5) + tokens_2, labels_2 = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=5) + + with self.cached_session() as sess: + tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( + [tokens_1, labels_1, tokens_2, labels_2]) + + if len(tokens_1_eval) == len(tokens_2_eval): + self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist()) + if len(labels_1_eval) == len(labels_2_eval): + self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist()) + + def test_skip_gram_sample_batch(self): + """Tests skip-gram with batching.""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"]) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1, batch_size=3) + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"quick"), + (b"quick", b"the"), + (b"quick", b"brown"), + (b"brown", b"quick"), + (b"brown", b"fox"), + (b"fox", b"brown"), + ]) + with self.cached_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) + + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens[:3], tokens_eval) + self.assertAllEqual(expected_labels[:3], labels_eval) + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens[3:6], tokens_eval) + self.assertAllEqual(expected_labels[3:6], labels_eval) + + coord.request_stop() + coord.join(threads) + + def test_skip_gram_sample_non_string_input(self): + """Tests skip-gram with non-string input.""" + input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16) + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=1, max_skips=1) + expected_tokens, expected_labels = self._split_tokens_labels([ + (1, 2), + (2, 1), + (2, 3), + (3, 2), + ]) + with self.cached_session(): + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def test_skip_gram_sample_errors(self): + """Tests various errors raised by skip_gram_sample().""" + input_tensor = constant_op.constant([b"the", b"quick", b"brown"]) + + invalid_skips = ( + # min_skips and max_skips must be >= 0. + (-1, 2), + (1, -2), + # min_skips must be <= max_skips. + (2, 1)) + for min_skips, max_skips in invalid_skips: + tokens, labels = text.skip_gram_sample( + input_tensor, min_skips=min_skips, max_skips=max_skips) + with self.cached_session() as sess, self.assertRaises( + errors.InvalidArgumentError): + sess.run([tokens, labels]) + + # input_tensor must be of rank 1. + with self.assertRaises(ValueError): + invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) + text.skip_gram_sample(invalid_tensor) + + # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, + # or corpus_size is specified. + dummy_input = constant_op.constant([""]) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, vocab_freq_table=None, vocab_min_count=1) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) + with self.assertRaises(ValueError): + text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=None, + vocab_subsampling=1e-5, + corpus_size=100) + + # vocab_subsampling and corpus_size must both be present or absent. + dummy_table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer([b"foo"], [10]), -1) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=dummy_table, + vocab_subsampling=None, + corpus_size=100) + with self.assertRaises(ValueError): + text.skip_gram_sample( + dummy_input, + vocab_freq_table=dummy_table, + vocab_subsampling=1e-5, + corpus_size=None) + + def test_filter_input_filter_vocab(self): + """Tests input filtering based on vocab frequency table and thresholds.""" + input_tensor = constant_op.constant( + [b"the", b"answer", b"to", b"life", b"and", b"universe"]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) + vocab_freq_table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), -1) + + with self.cached_session(): + vocab_freq_table.initializer.run() + + # No vocab_freq_table specified - output should be the same as input. + no_table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=None, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual(input_tensor.eval(), no_table_output.eval()) + + # vocab_freq_table specified, but no vocab_min_count - output should have + # filtered out tokens not in the table (b"answer"). + table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], + table_output.eval()) + + # vocab_freq_table and vocab_min_count specified - output should have + # filtered out tokens whose frequencies are below the threshold + # (b"and": 0, b"life": 1). + threshold_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=2, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval()) + + def test_filter_input_subsample_vocab(self): + """Tests input filtering based on vocab subsampling.""" + # The outputs are non-deterministic, so set random seed to help ensure that + # the outputs remain constant for testing. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant([ + # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. + b"the", + b"answer", # Not in vocab. (Always discarded) + b"to", # keep_prob = 0.75. + b"life", # keep_prob > 1. (Always kept) + b"and", # keep_prob = 0.48. + b"universe" # Below vocab threshold of 3. (Always discarded) + ]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64) + vocab_freq_table = lookup_ops.HashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), -1) + + with self.cached_session(): + vocab_freq_table.initializer.run() + output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=3, + vocab_subsampling=0.05, + corpus_size=math_ops.reduce_sum(values), + seed=9) + self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval()) + + def _make_text_vocab_freq_file(self): + filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt") + with open(filepath, "w") as f: + writer = csv.writer(f) + writer.writerows([ + ["and", 40], + ["life", 8], + ["the", 30], + ["to", 20], + ["universe", 2], + ]) + return filepath + + def _make_text_vocab_float_file(self): + filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt") + with open(filepath, "w") as f: + writer = csv.writer(f) + writer.writerows([ + ["and", 0.4], + ["life", 0.08], + ["the", 0.3], + ["to", 0.2], + ["universe", 0.02], + ]) + return filepath + + def test_skip_gram_sample_with_text_vocab_filter_vocab(self): + """Tests skip-gram sampling with text vocab and freq threshold filtering.""" + input_tensor = constant_op.constant([ + b"the", + b"answer", # Will be filtered before candidate generation. + b"to", + b"life", + b"and", + b"universe" # Will be filtered before candidate generation. + ]) + + # b"answer" is not in vocab file, and b"universe"'s frequency is below + # threshold of 3. + vocab_freq_file = self._make_text_vocab_freq_file() + + tokens, labels = text.skip_gram_sample_with_text_vocab( + input_tensor=input_tensor, + vocab_freq_file=vocab_freq_file, + vocab_token_index=0, + vocab_freq_index=1, + vocab_min_count=3, + min_skips=1, + max_skips=1) + + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"to"), + (b"to", b"the"), + (b"to", b"life"), + (b"life", b"to"), + (b"life", b"and"), + (b"and", b"life"), + ]) + with self.cached_session(): + lookup_ops.tables_initializer().run() + self.assertAllEqual(expected_tokens, tokens.eval()) + self.assertAllEqual(expected_labels, labels.eval()) + + def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, + vocab_freq_dtype, corpus_size=None): + # The outputs are non-deterministic, so set random seed to help ensure that + # the outputs remain constant for testing. + random_seed.set_random_seed(42) + + input_tensor = constant_op.constant([ + # keep_prob = (sqrt(30/(0.05*100)) + 1) * (0.05*100/30) = 0.57. + b"the", + b"answer", # Not in vocab. (Always discarded) + b"to", # keep_prob = 0.75. + b"life", # keep_prob > 1. (Always kept) + b"and", # keep_prob = 0.48. + b"universe" # Below vocab threshold of 3. (Always discarded) + ]) + # keep_prob calculated from vocab file with relative frequencies of: + # and: 40 + # life: 8 + # the: 30 + # to: 20 + # universe: 2 + + tokens, labels = text.skip_gram_sample_with_text_vocab( + input_tensor=input_tensor, + vocab_freq_file=vocab_freq_file, + vocab_token_index=0, + vocab_freq_index=1, + vocab_freq_dtype=vocab_freq_dtype, + vocab_min_count=vocab_min_count, + vocab_subsampling=0.05, + corpus_size=corpus_size, + min_skips=1, + max_skips=1, + seed=123) + + expected_tokens, expected_labels = self._split_tokens_labels([ + (b"the", b"to"), + (b"to", b"the"), + (b"to", b"life"), + (b"life", b"to"), + ]) + with self.cached_session() as sess: + lookup_ops.tables_initializer().run() + tokens_eval, labels_eval = sess.run([tokens, labels]) + self.assertAllEqual(expected_tokens, tokens_eval) + self.assertAllEqual(expected_labels, labels_eval) + + def test_skip_gram_sample_with_text_vocab_subsample_vocab(self): + """Tests skip-gram sampling with text vocab and vocab subsampling.""" + # Vocab file frequencies + # and: 40 + # life: 8 + # the: 30 + # to: 20 + # universe: 2 + # + # corpus_size for the above vocab is 40+8+30+20+2 = 100. + text_vocab_freq_file = self._make_text_vocab_freq_file() + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=100) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 100. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=99) + + def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): + """Tests skip-gram sampling with text vocab and subsampling with floats.""" + # Vocab file frequencies + # and: 0.4 + # life: 0.08 + # the: 0.3 + # to: 0.2 + # universe: 0.02 + # + # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1. + text_vocab_float_file = self._make_text_vocab_float_file() + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=1.0) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 1. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=0.99) + + def test_skip_gram_sample_with_text_vocab_errors(self): + """Tests various errors raised by skip_gram_sample_with_text_vocab().""" + dummy_input = constant_op.constant([""]) + vocab_freq_file = self._make_text_vocab_freq_file() + + invalid_indices = ( + # vocab_token_index can't be negative. + (-1, 0), + # vocab_freq_index can't be negative. + (0, -1), + # vocab_token_index can't be equal to vocab_freq_index. + (0, 0), + (1, 1), + # vocab_freq_file only has two columns. + (0, 2), + (2, 0)) + + for vocab_token_index, vocab_freq_index in invalid_indices: + with self.assertRaises(ValueError): + text.skip_gram_sample_with_text_vocab( + input_tensor=dummy_input, + vocab_freq_file=vocab_freq_file, + vocab_token_index=vocab_token_index, + vocab_freq_index=vocab_freq_index) + + +if __name__ == "__main__": + test.main() diff --git a/tf/BUILD b/tf/BUILD new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tf/BUILD.tpl b/tf/BUILD.tpl new file mode 100644 index 0000000000..bee021f100 --- /dev/null +++ b/tf/BUILD.tpl @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "tf_header_lib", + hdrs = [":tf_header_include"], + includes = ["include"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "libtensorflow_framework", + srcs = [":libtensorflow_framework.so"], + #data = ["lib/libtensorflow_framework.so"], + visibility = ["//visibility:public"], +) + +%{TF_HEADER_GENRULE} +%{TF_SHARED_LIBRARY_GENRULE} \ No newline at end of file diff --git a/tf/tf_configure.bzl b/tf/tf_configure.bzl new file mode 100644 index 0000000000..7ddd739b81 --- /dev/null +++ b/tf/tf_configure.bzl @@ -0,0 +1,206 @@ +"""Setup TensorFlow as external dependency""" + +_TF_HEADER_DIR = "TF_HEADER_DIR" +_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl + repository_ctx.template( + out, + Label("//tf:%s.tpl" % tpl), + substitutions, + ) + +def _fail(msg): + """Output failure message when auto configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + +def _execute( + repository_ctx, + cmdline, + error_msg = None, + error_details = None, + empty_stdout_fine = False): + """Executes an arbitrary shell command. + + Helper for executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object. + cmdline: list of strings, the command to execute. + error_msg: string, a summary of the error if the command fails. + error_details: string, details about the error or steps to fix it. + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error. + + Returns: + The result of repository_ctx.execute(cmdline). + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + _fail("\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else "", + ])) + return result + +def _read_dir(repository_ctx, src_dir): + """Returns a string with all files in a directory. + + Finds all files inside a directory, traversing subfolders and following + symlinks. The returned string contains the full path of all files + separated by line breaks. + + Args: + repository_ctx: the repository_ctx object. + src_dir: directory to find files from. + + Returns: + A string of all files inside the given dir. + """ + if _is_windows(repository_ctx): + src_dir = src_dir.replace("/", "\\") + find_result = _execute( + repository_ctx, + ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], + empty_stdout_fine = True, + ) + + # src_files will be used in genrule.outs where the paths must + # use forward slashes. + result = find_result.stdout.replace("\\", "/") + else: + find_result = _execute( + repository_ctx, + ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine = True, + ) + result = find_result.stdout + return result + +def _genrule(genrule_name, command, outs): + """Returns a string with a genrule. + + Genrule executes the given command and produces the given outputs. + + Args: + genrule_name: A unique name for genrule target. + command: The command to run. + outs: A list of files generated by this rule. + + Returns: + A genrule target. + """ + return ( + "genrule(\n" + + ' name = "' + + genrule_name + '",\n' + + " outs = [\n" + + outs + + "\n ],\n" + + ' cmd = """\n' + + command + + '\n """,\n' + + ")\n" + ) + +def _norm_path(path): + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + +def _symlink_genrule_for_dir( + repository_ctx, + src_dir, + dest_dir, + genrule_name, + src_files = [], + dest_files = []): + """Returns a genrule to symlink(or copy if on Windows) a set of files. + + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files. + + Args: + repository_ctx: the repository_ctx object. + src_dir: source directory. + dest_dir: directory to create symlink in. + genrule_name: genrule name. + src_files: list of source files instead of src_dir. + dest_files: list of corresonding destination files. + + Returns: + genrule target that creates the symlinks. + """ + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) + + # Create a list with the src_dir stripped to use for outputs. + dest_files = files.replace(src_dir, "").splitlines() + src_files = files.splitlines() + command = [] + outs = [] + for i in range(len(dest_files)): + if dest_files[i] != "": + # If we have only one file to link we do not want to use the dest_dir, as + # $(@D) will include the full path to the file. + dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] + + # Copy the headers to create a sandboxable setup. + cmd = "cp -f" + command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) + outs.append(' "' + dest_dir + dest_files[i] + '",') + genrule = _genrule( + genrule_name, + " && ".join(command), + "\n".join(outs), + ) + return genrule + +def _tf_pip_impl(repository_ctx): + tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] + tf_header_rule = _symlink_genrule_for_dir( + repository_ctx, + tf_header_dir, + "include", + "tf_header_include", + ) + + tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] + tf_shared_library_path = "%s/libtensorflow_framework.so" % tf_shared_library_dir + tf_shared_library_rule = _symlink_genrule_for_dir( + repository_ctx, + None, + "", + "libtensorflow_framework.so", + [tf_shared_library_path], + ["libtensorflow_framework.so"], + ) + + _tpl(repository_ctx, "BUILD", { + "%{TF_HEADER_GENRULE}": tf_header_rule, + "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, + }) + +tf_configure = repository_rule( + implementation = _tf_pip_impl, + environ = [ + _TF_HEADER_DIR, + _TF_SHARED_LIBRARY_DIR, + ], +) \ No newline at end of file From 3ff830983b7b9cc22487e3714c9be7df310c1ee4 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Wed, 9 Jan 2019 12:21:35 -0500 Subject: [PATCH 2/8] Re-org directory structure --- BUILD | 1 - Dockerfile | 3 + README.md | 3 +- tensorflow_addons/crf/BUILD | 3 + tensorflow_addons/crf/README.md | 4 + .../{layers/python/layers => crf}/__init__.py | 0 .../{opt => crf}/python/__init__.py | 0 tensorflow_addons/examples/demo.py | 6 +- tensorflow_addons/image/BUILD | 3 + tensorflow_addons/image/README.md | 4 + .../{opt/python/opt => image}/__init__.py | 0 .../python/ops => image/python}/__init__.py | 0 tensorflow_addons/layers/BUILD | 32 +- tensorflow_addons/layers/README.md | 4 +- tensorflow_addons/layers/__init__.py | 7 +- .../python/layers/poincare_normalize.py | 58 --- .../python/layers/poincare_normalize_test.py | 97 ----- .../layers/python/{layers => }/wrappers.py | 0 .../python/{layers => }/wrappers_test.py | 32 +- tensorflow_addons/losses/BUILD | 3 + tensorflow_addons/losses/README.md | 4 + tensorflow_addons/losses/__init__.py | 0 tensorflow_addons/losses/python/__init__.py | 0 tensorflow_addons/opt/BUILD | 23 -- .../opt/python/opt/lazy_adam_optimizer.py | 118 ------ .../python/opt/lazy_adam_optimizer_test.py | 366 ------------------ tensorflow_addons/optimizers/BUILD | 3 + .../{opt => optimizers}/README.md | 0 .../{opt => optimizers}/__init__.py | 3 - .../optimizers/python/__init__.py | 0 tensorflow_addons/text/BUILD | 34 +- tensorflow_addons/text/__init__.py | 4 +- .../text/python/{ops => }/skip_gram_ops.py | 0 .../python/{ops => }/skip_gram_ops_test.py | 0 34 files changed, 83 insertions(+), 732 deletions(-) create mode 100644 Dockerfile create mode 100644 tensorflow_addons/crf/BUILD create mode 100644 tensorflow_addons/crf/README.md rename tensorflow_addons/{layers/python/layers => crf}/__init__.py (100%) rename tensorflow_addons/{opt => crf}/python/__init__.py (100%) create mode 100644 tensorflow_addons/image/BUILD create mode 100644 tensorflow_addons/image/README.md rename tensorflow_addons/{opt/python/opt => image}/__init__.py (100%) rename tensorflow_addons/{text/python/ops => image/python}/__init__.py (100%) delete mode 100644 tensorflow_addons/layers/python/layers/poincare_normalize.py delete mode 100644 tensorflow_addons/layers/python/layers/poincare_normalize_test.py rename tensorflow_addons/layers/python/{layers => }/wrappers.py (100%) rename tensorflow_addons/layers/python/{layers => }/wrappers_test.py (77%) create mode 100644 tensorflow_addons/losses/BUILD create mode 100644 tensorflow_addons/losses/README.md create mode 100644 tensorflow_addons/losses/__init__.py create mode 100644 tensorflow_addons/losses/python/__init__.py delete mode 100644 tensorflow_addons/opt/BUILD delete mode 100644 tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py delete mode 100644 tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py create mode 100644 tensorflow_addons/optimizers/BUILD rename tensorflow_addons/{opt => optimizers}/README.md (100%) rename tensorflow_addons/{opt => optimizers}/__init__.py (88%) create mode 100644 tensorflow_addons/optimizers/python/__init__.py rename tensorflow_addons/text/python/{ops => }/skip_gram_ops.py (100%) rename tensorflow_addons/text/python/{ops => }/skip_gram_ops_test.py (100%) diff --git a/BUILD b/BUILD index 702c5dc395..09d2a8a26a 100644 --- a/BUILD +++ b/BUILD @@ -7,7 +7,6 @@ sh_binary( "setup.py", "tensorflow_addons/__init__.py", "//tensorflow_addons/layers:layers_py", - "//tensorflow_addons/opt:opt_py", "//tensorflow_addons/text:text_py", ], ) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..475eeb63be --- /dev/null +++ b/Dockerfile @@ -0,0 +1,3 @@ +FROM tensorflow/tensorflow:custom-op + +RUN pip install tf-nightly-2.0-preview diff --git a/README.md b/README.md index 9aba5d7b42..b5a1e6d3ad 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,11 @@ The tensorflow/addons repository, will contain additional functionality fitting * The addon is useful for a large number of users (e.g., an implementation used in widely cited paper, or a utility with broad applicability) ---- - # Developing ## Docker ``` +docker run --rm -it -v ${PWD}:/working_dir -w /working_dir seanpmorgan/addons:tf2-preview ``` ## Packaging diff --git a/tensorflow_addons/crf/BUILD b/tensorflow_addons/crf/BUILD new file mode 100644 index 0000000000..3ad427fd87 --- /dev/null +++ b/tensorflow_addons/crf/BUILD @@ -0,0 +1,3 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow_addons/crf/README.md b/tensorflow_addons/crf/README.md new file mode 100644 index 0000000000..7d57256d25 --- /dev/null +++ b/tensorflow_addons/crf/README.md @@ -0,0 +1,4 @@ +# Addons - Conditional Random Fields (CRF) + + +## Standard API \ No newline at end of file diff --git a/tensorflow_addons/layers/python/layers/__init__.py b/tensorflow_addons/crf/__init__.py similarity index 100% rename from tensorflow_addons/layers/python/layers/__init__.py rename to tensorflow_addons/crf/__init__.py diff --git a/tensorflow_addons/opt/python/__init__.py b/tensorflow_addons/crf/python/__init__.py similarity index 100% rename from tensorflow_addons/opt/python/__init__.py rename to tensorflow_addons/crf/python/__init__.py diff --git a/tensorflow_addons/examples/demo.py b/tensorflow_addons/examples/demo.py index b3e1709ed8..41f3966b86 100644 --- a/tensorflow_addons/examples/demo.py +++ b/tensorflow_addons/examples/demo.py @@ -1,4 +1,4 @@ -# Copyright 2019 The TensorFlow Probability Authors. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================ +# ============================================================================== # import tensorflow as tf # import tensorflow_addons as tfa -# from tensorflow_addons.opt import LazyAdamOptimizer +# from tensorflow_addons.text import skip_gram_sample # TODO: Build this out diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD new file mode 100644 index 0000000000..3ad427fd87 --- /dev/null +++ b/tensorflow_addons/image/BUILD @@ -0,0 +1,3 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow_addons/image/README.md b/tensorflow_addons/image/README.md new file mode 100644 index 0000000000..8e5eedf02e --- /dev/null +++ b/tensorflow_addons/image/README.md @@ -0,0 +1,4 @@ +# Addons - Image + + +## Standard API \ No newline at end of file diff --git a/tensorflow_addons/opt/python/opt/__init__.py b/tensorflow_addons/image/__init__.py similarity index 100% rename from tensorflow_addons/opt/python/opt/__init__.py rename to tensorflow_addons/image/__init__.py diff --git a/tensorflow_addons/text/python/ops/__init__.py b/tensorflow_addons/image/python/__init__.py similarity index 100% rename from tensorflow_addons/text/python/ops/__init__.py rename to tensorflow_addons/image/python/__init__.py diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 797fbc60a4..1d5c07d687 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -2,32 +2,24 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -py_test( - name = "layers_poincare_py_test", - srcs = [ - "python/layers/poincare_normalize_test.py", - ], - main = "python/layers/poincare_normalize_test.py", +py_library( + name = "layers_py", + srcs = ([ + "__init__.py", + "python/__init__.py", + "python/wrappers.py", + ]), srcs_version = "PY2AND3", ) py_test( name = "layers_wrappers_py_test", srcs = [ - "python/layers/wrappers_test.py", + "python/wrappers_test.py", ], - main = "python/layers/wrappers_test.py", - srcs_version = "PY2AND3", -) - -py_library( - name = "layers_py", - srcs = ([ - "__init__.py", - "python/__init__.py", - "python/layers/__init__.py", - "python/layers/poincare_normalize.py", - "python/layers/wrappers.py", - ]), + main = "python/wrappers_test.py", + deps = [ + ":layers_py", + ], srcs_version = "PY2AND3", ) \ No newline at end of file diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index e3bfb00a0c..3d228c5481 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -1,4 +1,6 @@ # Addons - Layers -## Standard API \ No newline at end of file +## Standard API +In order to conform with the current API standard, all layers +must inherit from either `keras.layers.Layer` or it's subclasses. \ No newline at end of file diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 007cf4d683..de8f5c2d2c 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -19,8 +19,5 @@ from __future__ import division from __future__ import print_function -# Poincare Normalize -from tensorflow_addons.layers.python.layers.poincare_normalize import poincare_normalize - -# Weight Normalization -from tensorflow_addons.layers.python.layers.wrappers import WeightNorm +# Weight Normalization Wrapper +from tensorflow_addons.layers.python.wrappers import WeightNorm diff --git a/tensorflow_addons/layers/python/layers/poincare_normalize.py b/tensorflow_addons/layers/python/layers/poincare_normalize.py deleted file mode 100644 index 851c38c4d4..0000000000 --- a/tensorflow_addons/layers/python/layers/poincare_normalize.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops - - -def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): - """Project into the Poincare ball with norm <= 1.0 - epsilon. - - https://en.wikipedia.org/wiki/Poincare_ball_model - - Used in - Poincare Embeddings for Learning Hierarchical Representations - Maximilian Nickel, Douwe Kiela - https://arxiv.org/pdf/1705.08039.pdf - - For a 1-D tensor with `axis = 0`, computes - - (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon - output = - x otherwise - - For `x` with more dimensions, independently normalizes each 1-D slice along - dimension `axis`. - - Args: - x: A `Tensor`. - axis: Axis along which to normalize. A scalar or a vector of - integers. - epsilon: A small deviation from the edge of the unit sphere for numerical - stability. - name: A name for this operation (optional). - - Returns: - A `Tensor` with the same shape as `x`. - """ - with ops.name_scope(name, 'poincare_normalize', [x]) as name: - x = ops.convert_to_tensor(x, name='x') - square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) - x_inv_norm = math_ops.rsqrt(square_sum) - x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.) - return math_ops.multiply(x, x_inv_norm, name=name) diff --git a/tensorflow_addons/layers/python/layers/poincare_normalize_test.py b/tensorflow_addons/layers/python/layers/poincare_normalize_test.py deleted file mode 100644 index eae709c02b..0000000000 --- a/tensorflow_addons/layers/python/layers/poincare_normalize_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.platform import test -from tensorflow.python.framework import test_util -from tensorflow.python.ops import gradient_checker -from tensorflow.python.framework import constant_op - -from tensorflow_addons.layers.python.layers.poincare_normalize import poincare_normalize - - -# TODO: Is this the prefered way to run tests in TF2? -@test_util.run_all_in_graph_and_eager_modes -class PoincareNormalizeTest(test.TestCase): - - def _PoincareNormalize(self, x, dim, epsilon=1e-5): - if isinstance(dim, list): - norm = np.linalg.norm(x, axis=tuple(dim)) - for d in dim: - norm = np.expand_dims(norm, d) - norm_x = ((1. - epsilon) * x) / norm - else: - norm = np.expand_dims(np.apply_along_axis(np.linalg.norm, dim, x), dim) - norm_x = ((1. - epsilon) * x) / norm - return np.where(norm > 1.0 - epsilon, norm_x, x) - - def testPoincareNormalize(self): - x_shape = [20, 7, 3] - epsilon = 1e-5 - tol = 1e-6 - np.random.seed(1) - x_np = np.random.random_sample(x_shape).astype(np.float32) - for dim in range(len(x_shape)): - y_np = self._PoincareNormalize(x_np, dim, epsilon) - with self.cached_session(): - x_tf = constant_op.constant(x_np, name='x') - y_tf = poincare_normalize(x_tf, dim, epsilon) - y_tf_eval = y_tf.numpy() - norm = np.linalg.norm(y_np, axis=dim) - self.assertLessEqual(norm.max(), 1. - epsilon + tol) - norm = np.linalg.norm(y_tf_eval, axis=dim) - self.assertLessEqual(norm.max(), 1. - epsilon + tol) - self.assertAllClose(y_np, y_tf_eval) - - def testPoincareNormalizeDimArray(self): - x_shape = [20, 7, 3] - epsilon = 1e-5 - tol = 1e-6 - np.random.seed(1) - x_np = np.random.random_sample(x_shape).astype(np.float32) - dim = [1, 2] - y_np = self._PoincareNormalize(x_np, dim, epsilon) - with self.cached_session(): - x_tf = constant_op.constant(x_np, name='x') - y_tf = poincare_normalize(x_tf, dim, epsilon) - y_tf_eval = y_tf.numpy() - norm = np.linalg.norm(y_np, axis=tuple(dim)) - self.assertLess(norm.max(), 1. - epsilon + tol) - norm = np.linalg.norm(y_tf_eval, axis=tuple(dim)) - self.assertLess(norm.max(), 1. - epsilon + tol) - self.assertAllClose(y_np, y_tf_eval, rtol=1e-6, atol=1e-6) - - def testPoincareNormalizeGradient(self): - x_shape = [20, 7, 3] - np.random.seed(1) - x_np = np.random.random_sample(x_shape).astype(np.float64) - for dim in range(len(x_shape)): - with self.cached_session(): - x_tf = constant_op.constant(x_np, name='x') - y_tf = poincare_normalize(x_tf, dim) - err = gradient_checker.compute_gradient_error(x_tf, - x_shape, - y_tf, - x_shape) - print('PoinCareNormalize gradient err = %g ' % err) - self.assertLess(err, 1e-4) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow_addons/layers/python/layers/wrappers.py b/tensorflow_addons/layers/python/wrappers.py similarity index 100% rename from tensorflow_addons/layers/python/layers/wrappers.py rename to tensorflow_addons/layers/python/wrappers.py diff --git a/tensorflow_addons/layers/python/layers/wrappers_test.py b/tensorflow_addons/layers/python/wrappers_test.py similarity index 77% rename from tensorflow_addons/layers/python/layers/wrappers_test.py rename to tensorflow_addons/layers/python/wrappers_test.py index 8c8dbae5b0..da418fcb3e 100644 --- a/tensorflow_addons/layers/python/layers/wrappers_test.py +++ b/tensorflow_addons/layers/python/wrappers_test.py @@ -18,8 +18,7 @@ from __future__ import print_function import numpy as np - -from tensorflow_addons.layers.python.layers import wrappers +from tensorflow_addons.layers.python import wrappers from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,21 +45,20 @@ def test_weightnorm_dense_train(self): batch_size=10) self.assertTrue(hasattr(model.layers[0].layer, 'g')) - # @tf_test_util.run_all_in_graph_and_eager_modes - # def test_weightnorm_conv2d(self): - # with self.test_session(): - # model = keras.models.Sequential() - # model.add(wrappers.WeightNorm( - # keras.layers.Conv2D(5, (2, 2), padding='same'), - # input_shape=(4, 4, 3))) - # - # model.add(keras.layers.Activation('relu')) - # model.compile(optimizer='rmsprop', loss='mse') - # model.train_on_batch( - # np.random.random((2, 4, 4, 3)), - # np.random.random((2, 4, 4, 8))) - # - # self.assertTrue(hasattr(model.layers[0].layer, 'g')) + @tf_test_util.run_all_in_graph_and_eager_modes + def test_weightnorm_conv2d(self): + model = keras.models.Sequential() + model.add(wrappers.WeightNorm( + keras.layers.Conv2D(5, (2, 2), padding='same'), + input_shape=(4, 4, 3))) + + model.add(keras.layers.Activation('relu')) + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.train_on_batch( + np.random.random((2, 4, 4, 3)), + np.random.random((2, 4, 4, 5))) + + self.assertTrue(hasattr(model.layers[0].layer, 'g')) @tf_test_util.run_all_in_graph_and_eager_modes def test_weight_norm_tflayers(self): diff --git a/tensorflow_addons/losses/BUILD b/tensorflow_addons/losses/BUILD new file mode 100644 index 0000000000..3ad427fd87 --- /dev/null +++ b/tensorflow_addons/losses/BUILD @@ -0,0 +1,3 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow_addons/losses/README.md b/tensorflow_addons/losses/README.md new file mode 100644 index 0000000000..6ed1e4c0fb --- /dev/null +++ b/tensorflow_addons/losses/README.md @@ -0,0 +1,4 @@ +# Addons - Losses + + +## Standard API \ No newline at end of file diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/losses/python/__init__.py b/tensorflow_addons/losses/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/opt/BUILD b/tensorflow_addons/opt/BUILD deleted file mode 100644 index b5b8c5f3fa..0000000000 --- a/tensorflow_addons/opt/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - -py_test( - name = "opt_py_test", - srcs = [ - "python/opt/lazy_adam_optimizer_test.py", - ], - main = "python/opt/lazy_adam_optimizer_test.py", - srcs_version = "PY2AND3", -) - -py_library( - name = "opt_py", - srcs = ([ - "__init__.py", - "python/__init__.py", - "python/opt/__init__.py", - "python/opt/lazy_adam_optimizer.py", - ]), - srcs_version = "PY2AND3", -) \ No newline at end of file diff --git a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py deleted file mode 100644 index caa46e961f..0000000000 --- a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Variant of the Adam optimizer that handles sparse updates more efficiently. - -Compared with the original Adam optimizer, the one in this file can provide a -large improvement in model training throughput for some applications. However, -it provides slightly different semantics than the original Adam algorithm, and -may lead to different empirical results. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -# FIXME: Which way to import? -from tensorflow.python.keras.optimizer_v2.adam import Adam -# from tensorflow.keras.optimizers import Adam (package_hook) - -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops - - -class LazyAdamOptimizer(Adam): - """Variant of the Adam optimizer that handles sparse updates more efficiently. - - The original Adam algorithm maintains two moving-average accumulators for - each trainable variable; the accumulators are updated at every step. - This class provides lazier handling of gradient updates for sparse variables. - It only updates moving-average accumulators for sparse variable indices that - appear in the current batch, rather than updating the accumulators for all - indices. Compared with the original Adam optimizer, it can provide large - improvements in model training throughput for some applications. However, it - provides slightly different semantics than the original Adam algorithm, and - may lead to different empirical results. - """ - - def _apply_sparse(self, grad, var): - beta1_power, beta2_power = self._get_beta_accumulators() - beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) - lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) - beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) - beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) - epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) - lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) - - # \\(m := beta1 * m + (1 - beta1) * g_t\\) - m = self.get_slot(var, "m") - m_t = state_ops.scatter_update(m, grad.indices, - beta1_t * array_ops.gather(m, grad.indices) + - (1 - beta1_t) * grad.values, - use_locking=self._use_locking) - - # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) - v = self.get_slot(var, "v") - v_t = state_ops.scatter_update(v, grad.indices, - beta2_t * array_ops.gather(v, grad.indices) + - (1 - beta2_t) * math_ops.square(grad.values), - use_locking=self._use_locking) - - # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, grad.indices) - v_t_slice = array_ops.gather(v_t, grad.indices) - denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = state_ops.scatter_sub(var, grad.indices, - lr * m_t_slice / denominator_slice, - use_locking=self._use_locking) - return control_flow_ops.group(var_update, m_t, v_t) - - def _resource_apply_sparse(self, grad, var, indices): - beta1_power, beta2_power = self._get_beta_accumulators() - beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) - beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) - lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) - beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) - beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) - epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) - lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) - - # \\(m := beta1 * m + (1 - beta1) * g_t\\) - m = self.get_slot(var, "m") - m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad - m_update_op = resource_variable_ops.resource_scatter_update(m.handle, - indices, - m_t_slice) - - # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) - v = self.get_slot(var, "v") - v_t_slice = (beta2_t * array_ops.gather(v, indices) + - (1 - beta2_t) * math_ops.square(grad)) - v_update_op = resource_variable_ops.resource_scatter_update(v.handle, - indices, - v_t_slice) - - # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) - var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, - indices, - var_slice) - - return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py b/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py deleted file mode 100644 index 2ca96a6e5e..0000000000 --- a/tensorflow_addons/opt/python/opt/lazy_adam_optimizer_test.py +++ /dev/null @@ -1,366 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for LazyAdamOptimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from absl.testing import parameterized -from tensorflow_addons.opt.python.opt import lazy_adam_optimizer - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -def adam_update_numpy(param, - g_t, - t, - m, - v, - alpha=0.001, - beta1=0.9, - beta2=0.999, - epsilon=1e-8): - alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) - - m_t = beta1 * m + (1 - beta1) * g_t - v_t = beta2 * v + (1 - beta2) * g_t * g_t - - param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) - return param_t, m_t, v_t - - -class AdamOptimizerTest(test.TestCase, parameterized.TestCase): - - @parameterized.parameters([False, True]) - def testSparse(self, use_resource): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - - grads0_np_indices = np.array([0, 1], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([2])) - opt = lazy_adam_optimizer.LazyAdamOptimizer() - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - beta1_power, beta2_power = opt._get_beta_accumulators() - - # Run 3 steps of Adam - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) - update.run() - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - @parameterized.parameters([False, True]) - def testSparseDevicePlacement(self, use_resource): - for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session(force_gpu=test.is_gpu_available()): - # If a GPU is available, tests that all optimizer ops can be placed on - # it (i.e. they have GPU kernels). - if use_resource: - var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) - else: - var = variables.Variable([[1.0], [2.0]]) - - indices = constant_op.constant([0, 1], dtype=index_dtype) - gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) - optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) - minimize_op = optimizer.minimize(gathered_sum) - variables.global_variables_initializer().run() - minimize_op.run() - - @parameterized.parameters([False, True]) - def testSparseRepeatedIndices(self, use_resource): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - if use_resource: - repeated_index_update_var = resource_variable_ops.ResourceVariable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = resource_variable_ops.ResourceVariable( - [[1.0], [2.0]], dtype=dtype) - else: - repeated_index_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - - grad_repeated_index = ops.IndexedSlices( - constant_op.constant( - [0.1, 0.1], shape=[2, 1], dtype=dtype), - constant_op.constant([1, 1]), - constant_op.constant([2, 1])) - grad_aggregated = ops.IndexedSlices( - constant_op.constant( - [0.2], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), - constant_op.constant([2, 1])) - repeated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() - repeated_update = repeated_update_opt.apply_gradients( - [(grad_repeated_index, repeated_index_update_var)]) - aggregated_update_opt = lazy_adam_optimizer.LazyAdamOptimizer() - aggregated_update = aggregated_update_opt.apply_gradients( - [(grad_aggregated, aggregated_update_var)]) - variables.global_variables_initializer().run() - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) - for _ in range(3): - repeated_update.run() - aggregated_update.run() - self.assertAllClose(aggregated_update_var.eval(), - repeated_index_update_var.eval()) - - def doTestBasic(self, use_resource=False, use_callable_params=False): - for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.session(graph=ops.Graph()): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable( - var0_np, name="var0_%d" % i) - var1 = resource_variable_ops.ResourceVariable( - var1_np, name="var1_%d" % i) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - - learning_rate = lambda: 0.001 - beta1 = lambda: 0.9 - beta2 = lambda: 0.999 - epsilon = lambda: 1e-8 - if not use_callable_params: - learning_rate = learning_rate() - beta1 = beta1() - beta2 = beta2() - epsilon = epsilon() - - opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - opt_variables = opt.variables() - beta1_power, beta2_power = opt._get_beta_accumulators() - self.assertIsNotNone(beta1_power) - self.assertIsNotNone(beta2_power is not None) - self.assertIn(beta1_power, opt_variables) - self.assertIn(beta2_power, opt_variables) - - if not context.executing_eagerly(): - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - self.evaluate(variables.global_variables_initializer()) - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], self.evaluate(var0)) - self.assertAllClose([3.0, 4.0], self.evaluate(var1)) - - beta1_power, beta2_power = opt._get_beta_accumulators() - - # Run 3 steps of Adam - for t in range(1, 4): - if not context.executing_eagerly(): - self.evaluate(update) - elif t > 1: - opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - - self.assertAllCloseAccordingToType(0.9**(t + 1), - self.evaluate(beta1_power)) - self.assertAllCloseAccordingToType(0.999**(t + 1), - self.evaluate(beta2_power)) - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - if use_resource: - self.assertEqual("var0_%d/Adam:0" % (i,), - opt.get_slot(var=var0, name="m").name) - - def testBasic(self): - with self.cached_session(): - self.doTestBasic(use_resource=False) - - @test_util.run_in_graph_and_eager_modes(reset_test=True) - def testResourceBasic(self): - self.doTestBasic(use_resource=True) - - def testBasicCallableParams(self): - with context.eager_mode(): - self.doTestBasic(use_resource=True, use_callable_params=True) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - beta1_power, beta2_power = opt._get_beta_accumulators() - - # Run 3 steps of Adam - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) - update.run() - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testSharing(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - # Initialize variables for numpy implementation. - m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = lazy_adam_optimizer.LazyAdamOptimizer() - update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - beta1_power, beta2_power = opt._get_beta_accumulators() - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 3 steps of intertwined Adam1 and Adam2. - for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) - if t % 2 == 0: - update1.run() - else: - update2.run() - - var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) - - # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testTwoSessions(self): - optimizer = lazy_adam_optimizer.LazyAdamOptimizer() - - with context.eager_mode(): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - optimizer.apply_gradients([(grads0, var0)]) - - g = ops.Graph() - with g.as_default(): - with self.session(graph=g): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - optimizer.apply_gradients([(grads0, var0)]) - - gg = ops.Graph() - with gg.as_default(): - with self.session(graph=gg): - var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") - grads0 = constant_op.constant(np.array([0.1, 0.1])) - - # If the optimizer saves any state not keyed by graph the following line - # fails. - optimizer.apply_gradients([(grads0, var0)]) - - def testSlotsUniqueEager(self): - with context.eager_mode(): - v1 = resource_variable_ops.ResourceVariable(1.) - v2 = resource_variable_ops.ResourceVariable(1.) - opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) - opt.minimize(lambda: v1 + v2) - # There should be two non-slot variables, and two unique slot variables - # for v1 and v2 respectively. - self.assertEqual(6, len(set(opt.variables()))) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow_addons/optimizers/BUILD b/tensorflow_addons/optimizers/BUILD new file mode 100644 index 0000000000..3ad427fd87 --- /dev/null +++ b/tensorflow_addons/optimizers/BUILD @@ -0,0 +1,3 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow_addons/opt/README.md b/tensorflow_addons/optimizers/README.md similarity index 100% rename from tensorflow_addons/opt/README.md rename to tensorflow_addons/optimizers/README.md diff --git a/tensorflow_addons/opt/__init__.py b/tensorflow_addons/optimizers/__init__.py similarity index 88% rename from tensorflow_addons/opt/__init__.py rename to tensorflow_addons/optimizers/__init__.py index c14aaab5b6..a4e4fe2cc8 100644 --- a/tensorflow_addons/opt/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -19,6 +19,3 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - -# Lazy Adam Optimizer -from tensorflow_addons.opt.python.opt.lazy_adam_optimizer import LazyAdamOptimizer diff --git a/tensorflow_addons/optimizers/python/__init__.py b/tensorflow_addons/optimizers/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index fa71b7ce73..39699f65b7 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -4,7 +4,7 @@ package(default_visibility = ["//visibility:public"]) cc_binary( - name = 'python/ops/_skip_gram_ops.so', + name = 'python/_skip_gram_ops.so', srcs = [ "cc/kernels/skip_gram_kernels.cc", "cc/ops/skip_gram_ops.cc", @@ -21,35 +21,37 @@ cc_binary( py_library( name = "text_ops_py", srcs = ([ - "python/ops/skip_gram_ops.py", + "python/skip_gram_ops.py", ]), data = [ - ":python/ops/_skip_gram_ops.so" + ":python/_skip_gram_ops.so" ], srcs_version = "PY2AND3", ) -py_test( - name = "text_ops_py_test", - srcs = [ - "python/ops/skip_gram_ops_test.py" - ], - main = "python/ops/skip_gram_ops_test.py", - deps = [ - ":text_ops_py", - ], - srcs_version = "PY2AND3", -) py_library( name = "text_py", srcs = ([ "__init__.py", "python/__init__.py", - "python/ops/__init__.py", ]), deps = [ ":text_ops_py" ], srcs_version = "PY2AND3", -) \ No newline at end of file +) + + +py_test( + name = "text_ops_py_test", + srcs = [ + "python/skip_gram_ops_test.py" + ], + main = "python/skip_gram_ops_test.py", + deps = [ + ":text_py", + ], + srcs_version = "PY2AND3", +) + diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py index 34bd006be3..41cb89a60e 100644 --- a/tensorflow_addons/text/__init__.py +++ b/tensorflow_addons/text/__init__.py @@ -20,5 +20,5 @@ from __future__ import print_function # Skip Gram Sample -from tensorflow_addons.text.python.ops.skip_gram_ops import skip_gram_sample -from tensorflow_addons.text.python.ops.skip_gram_ops import skip_gram_sample_with_text_vocab +from tensorflow_addons.text.python.skip_gram_ops import skip_gram_sample +from tensorflow_addons.text.python.skip_gram_ops import skip_gram_sample_with_text_vocab diff --git a/tensorflow_addons/text/python/ops/skip_gram_ops.py b/tensorflow_addons/text/python/skip_gram_ops.py similarity index 100% rename from tensorflow_addons/text/python/ops/skip_gram_ops.py rename to tensorflow_addons/text/python/skip_gram_ops.py diff --git a/tensorflow_addons/text/python/ops/skip_gram_ops_test.py b/tensorflow_addons/text/python/skip_gram_ops_test.py similarity index 100% rename from tensorflow_addons/text/python/ops/skip_gram_ops_test.py rename to tensorflow_addons/text/python/skip_gram_ops_test.py From f7f2381cbc2586b99988fb3a7caaeb9cfdf21393 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Wed, 9 Jan 2019 12:23:29 -0500 Subject: [PATCH 3/8] Modify tests for TF2 --- .../text/python/skip_gram_ops_test.py | 300 ++++++++---------- 1 file changed, 141 insertions(+), 159 deletions(-) diff --git a/tensorflow_addons/text/python/skip_gram_ops_test.py b/tensorflow_addons/text/python/skip_gram_ops_test.py index c2256dac16..8f3a578c55 100644 --- a/tensorflow_addons/text/python/skip_gram_ops_test.py +++ b/tensorflow_addons/text/python/skip_gram_ops_test.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Skip-gram sampling ops tests.""" +""" +Skip-gram sampling ops tests +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -20,7 +23,8 @@ import csv import os -from tensorflow_addons.text.python.ops import skip_gram_ops +from tensorflow.python.framework import test_util +from tensorflow_addons.text.python import skip_gram_ops from tensorflow_addons import text from tensorflow.python.framework import constant_op @@ -30,13 +34,12 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import queue_runner_impl class SkipGramOpsTest(test.TestCase): - def _split_tokens_labels(self, output): + @staticmethod + def _split_tokens_labels(output): tokens = [x[0] for x in output] labels = [x[1] for x in output] return tokens, labels @@ -63,9 +66,8 @@ def test_skip_gram_sample_skips_2(self): (b"jumps", b"brown"), (b"jumps", b"fox"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_emit_self(self): """Tests skip-gram with emit_self_as_target = True.""" @@ -94,9 +96,8 @@ def test_skip_gram_sample_emit_self(self): (b"jumps", b"fox"), (b"jumps", b"jumps"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_skips_0(self): """Tests skip-gram with min_skips = max_skips = 0.""" @@ -106,8 +107,8 @@ def test_skip_gram_sample_skips_0(self): tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) with self.cached_session(): - self.assertEqual(0, tokens.eval().size) - self.assertEqual(0, labels.eval().size) + self.assertEqual(0, len(tokens)) + self.assertEqual(0, len(labels)) # If emit_self_as_target is True, each token will be its own label. tokens, labels = text.skip_gram_sample( @@ -117,9 +118,8 @@ def test_skip_gram_sample_skips_0(self): (b"quick", b"quick"), (b"brown", b"brown"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_skips_exceed_length(self): """Tests skip-gram when min/max_skips exceed length of input.""" @@ -134,9 +134,8 @@ def test_skip_gram_sample_skips_exceed_length(self): (b"brown", b"the"), (b"brown", b"quick"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_start_limit(self): """Tests skip-gram over a limited portion of the input.""" @@ -150,13 +149,13 @@ def test_skip_gram_sample_start_limit(self): (b"quick", b"brown"), (b"brown", b"quick"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_limit_exceeds(self): """Tests skip-gram when limit exceeds the length of the input.""" - input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"]) + input_tensor = constant_op.constant([b"foo", b"the", + b"quick", b"brown"]) tokens, labels = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=1, start=1, limit=100) expected_tokens, expected_labels = self._split_tokens_labels([ @@ -165,9 +164,8 @@ def test_skip_gram_sample_limit_exceeds(self): (b"quick", b"brown"), (b"brown", b"quick"), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_random_skips(self): """Tests skip-gram with min_skips != max_skips, with random output.""" @@ -196,24 +194,26 @@ def test_skip_gram_sample_random_skips(self): (b"over", b"fox"), (b"over", b"jumps"), ]) - with self.cached_session() as sess: - tokens_eval, labels_eval = sess.run([tokens, labels]) - self.assertAllEqual(expected_tokens, tokens_eval) - self.assertAllEqual(expected_labels, labels_eval) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_random_skips_default_seed(self): - """Tests outputs are still random when no op-level seed is specified.""" - # This is needed since tests set a graph-level seed by default. We want to - # explicitly avoid setting both graph-level seed and op-level seed, to - # simulate behavior under non-test settings when the user doesn't provide a - # seed to us. This results in random_seed.get_seed() returning None for both - # seeds, forcing the C++ kernel to execute its default seed logic. + """ + Tests outputs are still random when no op-level seed is specified. + """ + + # This is needed since tests set a graph-level seed by default. We want + # to explicitly avoid setting both graph-level seed and op-level seed, + # to simulate behavior under non-test settings when the user doesn't + # provide a seed to us. This results in random_seed.get_seed() returning + # None for both seeds, forcing the C++ kernel to execute its default + # seed logic. random_seed.set_random_seed(None) - # Uses an input tensor with 10 words, with possible skip ranges in [1, - # 5]. Thus, the probability that two random samplings would result in the - # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being - # flaky). + # Uses an input tensor with 10 words, with possible skip ranges in + # [1, 5]. Thus, the probability that two random samplings would result + # in the same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test + # being flaky). input_tensor = constant_op.constant([str(x) for x in range(10)]) # Do not provide an op-level seed here! @@ -222,41 +222,10 @@ def test_skip_gram_sample_random_skips_default_seed(self): tokens_2, labels_2 = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=5) - with self.cached_session() as sess: - tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( - [tokens_1, labels_1, tokens_2, labels_2]) - - if len(tokens_1_eval) == len(tokens_2_eval): - self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist()) - if len(labels_1_eval) == len(labels_2_eval): - self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist()) - - def test_skip_gram_sample_batch(self): - """Tests skip-gram with batching.""" - input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"]) - tokens, labels = text.skip_gram_sample( - input_tensor, min_skips=1, max_skips=1, batch_size=3) - expected_tokens, expected_labels = self._split_tokens_labels([ - (b"the", b"quick"), - (b"quick", b"the"), - (b"quick", b"brown"), - (b"brown", b"quick"), - (b"brown", b"fox"), - (b"fox", b"brown"), - ]) - with self.cached_session() as sess: - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) - - tokens_eval, labels_eval = sess.run([tokens, labels]) - self.assertAllEqual(expected_tokens[:3], tokens_eval) - self.assertAllEqual(expected_labels[:3], labels_eval) - tokens_eval, labels_eval = sess.run([tokens, labels]) - self.assertAllEqual(expected_tokens[3:6], tokens_eval) - self.assertAllEqual(expected_labels[3:6], labels_eval) - - coord.request_stop() - coord.join(threads) + if len(tokens_1) == len(tokens_2): + self.assertNotEqual(list(tokens_1), list(tokens_2)) + if len(labels_1) == len(labels_2): + self.assertNotEqual(list(labels_1), list(labels_2)) def test_skip_gram_sample_non_string_input(self): """Tests skip-gram with non-string input.""" @@ -269,9 +238,17 @@ def test_skip_gram_sample_non_string_input(self): (2, 3), (3, 2), ]) - with self.cached_session(): - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) + + @test_util.run_deprecated_v1 + def test_skip_gram_sample_errors_v1(self): + """Tests various errors raised by skip_gram_sample().""" + # input_tensor must be of rank 1. + with self.assertRaises(ValueError): + invalid_tensor = constant_op.constant([[b"the"], [b"quick"], + [b"brown"]]) + text.skip_gram_sample(invalid_tensor) def test_skip_gram_sample_errors(self): """Tests various errors raised by skip_gram_sample().""" @@ -284,19 +261,22 @@ def test_skip_gram_sample_errors(self): # min_skips must be <= max_skips. (2, 1)) for min_skips, max_skips in invalid_skips: - tokens, labels = text.skip_gram_sample( - input_tensor, min_skips=min_skips, max_skips=max_skips) - with self.cached_session() as sess, self.assertRaises( - errors.InvalidArgumentError): - sess.run([tokens, labels]) + with self.assertRaises(errors.InvalidArgumentError): + text.skip_gram_sample(input_tensor, min_skips=min_skips, + max_skips=max_skips) - # input_tensor must be of rank 1. - with self.assertRaises(ValueError): - invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]]) - text.skip_gram_sample(invalid_tensor) + ######################################### - # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling, - # or corpus_size is specified. + # FIXME: Why is this not failing? + # with self.assertRaises(ValueError): + # invalid_tensor = constant_op.constant([[b"the"], [b"quick"], + # [b"brown"]]) + # text.skip_gram_sample(invalid_tensor) + + ######################################### + + # vocab_freq_table must be provided if vocab_min_count, + # vocab_subsampling, or corpus_size is specified. dummy_input = constant_op.constant([""]) with self.assertRaises(ValueError): text.skip_gram_sample( @@ -305,7 +285,8 @@ def test_skip_gram_sample_errors(self): text.skip_gram_sample( dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5) with self.assertRaises(ValueError): - text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100) + text.skip_gram_sample(dummy_input, vocab_freq_table=None, + corpus_size=100) with self.assertRaises(ValueError): text.skip_gram_sample( dummy_input, @@ -330,55 +311,55 @@ def test_skip_gram_sample_errors(self): corpus_size=None) def test_filter_input_filter_vocab(self): - """Tests input filtering based on vocab frequency table and thresholds.""" + """ + Tests input filtering based on vocab frequency table and thresholds. + """ input_tensor = constant_op.constant( [b"the", b"answer", b"to", b"life", b"and", b"universe"]) - keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", + b"universe"]) values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) vocab_freq_table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), -1) - with self.cached_session(): - vocab_freq_table.initializer.run() - - # No vocab_freq_table specified - output should be the same as input. - no_table_output = skip_gram_ops._filter_input( - input_tensor=input_tensor, - vocab_freq_table=None, - vocab_min_count=None, - vocab_subsampling=None, - corpus_size=None, - seed=None) - self.assertAllEqual(input_tensor.eval(), no_table_output.eval()) - - # vocab_freq_table specified, but no vocab_min_count - output should have - # filtered out tokens not in the table (b"answer"). - table_output = skip_gram_ops._filter_input( - input_tensor=input_tensor, - vocab_freq_table=vocab_freq_table, - vocab_min_count=None, - vocab_subsampling=None, - corpus_size=None, - seed=None) - self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], - table_output.eval()) - - # vocab_freq_table and vocab_min_count specified - output should have - # filtered out tokens whose frequencies are below the threshold - # (b"and": 0, b"life": 1). - threshold_output = skip_gram_ops._filter_input( - input_tensor=input_tensor, - vocab_freq_table=vocab_freq_table, - vocab_min_count=2, - vocab_subsampling=None, - corpus_size=None, - seed=None) - self.assertAllEqual([b"the", b"to", b"universe"], threshold_output.eval()) + # No vocab_freq_table specified - output should be the same as input + no_table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=None, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual(input_tensor, no_table_output) + + # vocab_freq_table specified, but no vocab_min_count - output should + # have filtered out tokens not in the table (b"answer"). + table_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=None, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"life", b"and", b"universe"], + table_output) + + # vocab_freq_table and vocab_min_count specified - output should have + # filtered out tokens whose frequencies are below the threshold + # (b"and": 0, b"life": 1). + threshold_output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=2, + vocab_subsampling=None, + corpus_size=None, + seed=None) + self.assertAllEqual([b"the", b"to", b"universe"], threshold_output) def test_filter_input_subsample_vocab(self): """Tests input filtering based on vocab subsampling.""" - # The outputs are non-deterministic, so set random seed to help ensure that - # the outputs remain constant for testing. + # The outputs are non-deterministic, so set random seed to help ensure + # that the outputs remain constant for testing. random_seed.set_random_seed(42) input_tensor = constant_op.constant([ @@ -390,23 +371,23 @@ def test_filter_input_subsample_vocab(self): b"and", # keep_prob = 0.48. b"universe" # Below vocab threshold of 3. (Always discarded) ]) - keys = constant_op.constant([b"and", b"life", b"the", b"to", b"universe"]) + keys = constant_op.constant([b"and", b"life", b"the", b"to", + b"universe"]) values = constant_op.constant([40, 8, 30, 20, 2], dtypes.int64) vocab_freq_table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), -1) - with self.cached_session(): - vocab_freq_table.initializer.run() - output = skip_gram_ops._filter_input( - input_tensor=input_tensor, - vocab_freq_table=vocab_freq_table, - vocab_min_count=3, - vocab_subsampling=0.05, - corpus_size=math_ops.reduce_sum(values), - seed=9) - self.assertAllEqual([b"the", b"to", b"life", b"and"], output.eval()) + output = skip_gram_ops._filter_input( + input_tensor=input_tensor, + vocab_freq_table=vocab_freq_table, + vocab_min_count=3, + vocab_subsampling=0.05, + corpus_size=math_ops.reduce_sum(values), + seed=9) + self.assertAllEqual([b"the", b"to", b"life", b"and"], output) - def _make_text_vocab_freq_file(self): + @staticmethod + def _make_text_vocab_freq_file(): filepath = os.path.join(test.get_temp_dir(), "vocab_freq.txt") with open(filepath, "w") as f: writer = csv.writer(f) @@ -419,7 +400,8 @@ def _make_text_vocab_freq_file(self): ]) return filepath - def _make_text_vocab_float_file(self): + @staticmethod + def _make_text_vocab_float_file(): filepath = os.path.join(test.get_temp_dir(), "vocab_freq_float.txt") with open(filepath, "w") as f: writer = csv.writer(f) @@ -433,7 +415,9 @@ def _make_text_vocab_float_file(self): return filepath def test_skip_gram_sample_with_text_vocab_filter_vocab(self): - """Tests skip-gram sampling with text vocab and freq threshold filtering.""" + """ + Tests skip-gram sampling with text vocab and freq threshold filtering. + """ input_tensor = constant_op.constant([ b"the", b"answer", # Will be filtered before candidate generation. @@ -464,15 +448,14 @@ def test_skip_gram_sample_with_text_vocab_filter_vocab(self): (b"life", b"and"), (b"and", b"life"), ]) - with self.cached_session(): - lookup_ops.tables_initializer().run() - self.assertAllEqual(expected_tokens, tokens.eval()) - self.assertAllEqual(expected_labels, labels.eval()) - - def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, - vocab_freq_dtype, corpus_size=None): - # The outputs are non-deterministic, so set random seed to help ensure that - # the outputs remain constant for testing. + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) + + def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, + vocab_min_count, vocab_freq_dtype, + corpus_size=None): + # The outputs are non-deterministic, so set random seed to help ensure + # that the outputs remain constant for testing. random_seed.set_random_seed(42) input_tensor = constant_op.constant([ @@ -510,11 +493,8 @@ def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, (b"to", b"life"), (b"life", b"to"), ]) - with self.cached_session() as sess: - lookup_ops.tables_initializer().run() - tokens_eval, labels_eval = sess.run([tokens, labels]) - self.assertAllEqual(expected_tokens, tokens_eval) - self.assertAllEqual(expected_labels, labels_eval) + self.assertAllEqual(expected_tokens, tokens) + self.assertAllEqual(expected_labels, labels) def test_skip_gram_sample_with_text_vocab_subsample_vocab(self): """Tests skip-gram sampling with text vocab and vocab subsampling.""" @@ -547,7 +527,9 @@ def test_skip_gram_sample_with_text_vocab_subsample_vocab(self): corpus_size=99) def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): - """Tests skip-gram sampling with text vocab and subsampling with floats.""" + """ + Tests skip-gram sampling with text vocab and subsampling with floats. + """ # Vocab file frequencies # and: 0.4 # life: 0.08 From eac36b53fd875b223f7fd774495e18ef2d96d515 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Wed, 9 Jan 2019 12:32:13 -0500 Subject: [PATCH 4/8] README update --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b5a1e6d3ad..82d865436f 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ The tensorflow/addons repository, will contain additional functionality fitting # Developing ## Docker +**Note:** This docker container is just temporary until we can pull a +tensorflow/tensorflow:custom-op container that reflects nightly changes. ``` docker run --rm -it -v ${PWD}:/working_dir -w /working_dir seanpmorgan/addons:tf2-preview ``` From 0c2edc6a31085fb5541dd4c14f93cf1bbbdc7395 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Wed, 9 Jan 2019 14:39:02 -0500 Subject: [PATCH 5/8] Minor fixes --- setup.py | 2 +- tensorflow_addons/layers/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 0a379b1b52..4b15ab6e91 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index 3d228c5481..8b3ed6b52b 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -3,4 +3,4 @@ ## Standard API In order to conform with the current API standard, all layers -must inherit from either `keras.layers.Layer` or it's subclasses. \ No newline at end of file +must inherit from either `keras.layers.Layer` or its subclasses. \ No newline at end of file From 76c957c45ad50f9842768c39e82a1c208e7ab12b Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 10 Jan 2019 10:20:44 -0500 Subject: [PATCH 6/8] Set docker container as nightly custom-op --- Dockerfile | 3 --- README.md | 12 +++++++++--- configure.sh | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) delete mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 475eeb63be..0000000000 --- a/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM tensorflow/tensorflow:custom-op - -RUN pip install tf-nightly-2.0-preview diff --git a/README.md b/README.md index 82d865436f..005a8ff442 100644 --- a/README.md +++ b/README.md @@ -19,15 +19,21 @@ The tensorflow/addons repository, will contain additional functionality fitting # Developing ## Docker -**Note:** This docker container is just temporary until we can pull a -tensorflow/tensorflow:custom-op container that reflects nightly changes. ``` -docker run --rm -it -v ${PWD}:/working_dir -w /working_dir seanpmorgan/addons:tf2-preview +docker run --rm -it -v ${PWD}:/working_dir -w /working_dir tensorflow/tensorflow:nightly-custom-op /bin/bash ``` ## Packaging ``` +# In docker ./configure.sh bazel build build_pip_pkg bazel-bin/build_pip_pkg artifacts ``` + +## Testing +``` +# In docker +./configure.sh +bazel test //tensorflow_addons/... +``` \ No newline at end of file diff --git a/configure.sh b/configure.sh index a4f4283d6d..bf30d99f27 100755 --- a/configure.sh +++ b/configure.sh @@ -25,7 +25,7 @@ rm .bazelrc if python -c "import tensorflow" &> /dev/null; then echo 'using installed tensorflow' else - pip install tensorflow + pip install tf-nightly-2.0-preview fi TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) From 4779862cf4ef9e24fa1802dea445e9e256ba0e92 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 10 Jan 2019 10:23:01 -0500 Subject: [PATCH 7/8] Wrappers doc fixes --- tensorflow_addons/layers/python/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/layers/python/wrappers.py b/tensorflow_addons/layers/python/wrappers.py index 90f2f3c105..e9e5df37c9 100644 --- a/tensorflow_addons/layers/python/wrappers.py +++ b/tensorflow_addons/layers/python/wrappers.py @@ -38,7 +38,7 @@ class WeightNorm(Wrapper): net = WeightNorm(tf.keras.layers.Conv2D(2, 2, activation='relu'), input_shape=(32, 32, 3), data_init=True)(x) net = WeightNorm(tf.keras.layers.Conv2D(16, 5, activation='relu'), - data_init=True) + data_init=True)(net) net = WeightNorm(tf.keras.layers.Dense(120, activation='relu'), data_init=True)(net) net = WeightNorm(tf.keras.layers.Dense(n_classes), @@ -72,7 +72,7 @@ def __init__(self, layer, data_init=False, **kwargs): def _compute_weights(self): """Generate weights by combining the direction of weight vector - with it's norm """ + with its norm """ with name_scope('compute_weights'): self.layer.kernel = nn_impl.l2_normalize( self.layer.v, axis=self.norm_axes) * self.layer.g From 199432bf4074771176c8735310335b501cc1cc89 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Thu, 10 Jan 2019 10:30:20 -0500 Subject: [PATCH 8/8] Update tf dependency folder --- README.md | 3 +++ WORKSPACE | 2 +- {tf => tf_dependency}/BUILD | 0 {tf => tf_dependency}/BUILD.tpl | 0 {tf => tf_dependency}/tf_configure.bzl | 2 +- 5 files changed, 5 insertions(+), 2 deletions(-) rename {tf => tf_dependency}/BUILD (100%) rename {tf => tf_dependency}/BUILD.tpl (100%) rename {tf => tf_dependency}/tf_configure.bzl (99%) diff --git a/README.md b/README.md index 005a8ff442..dc47a2afe1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ bazel build build_pip_pkg bazel-bin/build_pip_pkg artifacts ``` +A package file artifacts/tensorflow_addons-*.whl will be generated after a build is successful. + + ## Testing ``` # In docker diff --git a/WORKSPACE b/WORKSPACE index ee0f0eb0a8..42ae8ce932 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,4 +1,4 @@ -load("//tf:tf_configure.bzl", "tf_configure") +load("//tf_dependency:tf_configure.bzl", "tf_configure") tf_configure( name = "local_config_tf", diff --git a/tf/BUILD b/tf_dependency/BUILD similarity index 100% rename from tf/BUILD rename to tf_dependency/BUILD diff --git a/tf/BUILD.tpl b/tf_dependency/BUILD.tpl similarity index 100% rename from tf/BUILD.tpl rename to tf_dependency/BUILD.tpl diff --git a/tf/tf_configure.bzl b/tf_dependency/tf_configure.bzl similarity index 99% rename from tf/tf_configure.bzl rename to tf_dependency/tf_configure.bzl index 7ddd739b81..c28ae4e91c 100644 --- a/tf/tf_configure.bzl +++ b/tf_dependency/tf_configure.bzl @@ -8,7 +8,7 @@ def _tpl(repository_ctx, tpl, substitutions = {}, out = None): out = tpl repository_ctx.template( out, - Label("//tf:%s.tpl" % tpl), + Label("//tf_dependency:%s.tpl" % tpl), substitutions, )