From 4e2c8ed7414c2781c47fd7fe41457a575e29cf87 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Mon, 14 Jan 2019 21:50:55 -0500 Subject: [PATCH 1/2] Issue #12: Fix skip_gram input_tensor check --- .../text/cc/kernels/skip_gram_kernels.cc | 9 +++++++-- .../text/python/skip_gram_ops_test.py | 14 +++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc index c75b98a924..57da725c92 100644 --- a/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc +++ b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc @@ -47,11 +47,16 @@ class SkipGramGenerateCandidatesOp : public OpKernel { OP_REQUIRES_OK(context, context->input("max_skips", &max_skips_tensor)); const int max_skips = *(max_skips_tensor->scalar().data()); + const Tensor& input_check = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_check.shape()), + errors::InvalidArgument("input_tensor must be of rank 1")); + OP_REQUIRES( context, min_skips >= 0 && max_skips >= 0, errors::InvalidArgument("Both min_skips and max_skips must be >= 0.")); - OP_REQUIRES(context, min_skips <= max_skips, - errors::InvalidArgument("min_skips must be <= max_skips.")); + OP_REQUIRES( + context, min_skips <= max_skips, + errors::InvalidArgument("min_skips must be <= max_skips.")); const Tensor* start_tensor; OP_REQUIRES_OK(context, context->input("start", &start_tensor)); diff --git a/tensorflow_addons/text/python/skip_gram_ops_test.py b/tensorflow_addons/text/python/skip_gram_ops_test.py index 8f3a578c55..e15e29b0a2 100644 --- a/tensorflow_addons/text/python/skip_gram_ops_test.py +++ b/tensorflow_addons/text/python/skip_gram_ops_test.py @@ -265,15 +265,11 @@ def test_skip_gram_sample_errors(self): text.skip_gram_sample(input_tensor, min_skips=min_skips, max_skips=max_skips) - ######################################### - - # FIXME: Why is this not failing? - # with self.assertRaises(ValueError): - # invalid_tensor = constant_op.constant([[b"the"], [b"quick"], - # [b"brown"]]) - # text.skip_gram_sample(invalid_tensor) - - ######################################### + # Eager tensor must be rank 1 + with self.assertRaises(errors.InvalidArgumentError): + invalid_tensor = constant_op.constant([[b"the"], [b"quick"], + [b"brown"]]) + text.skip_gram_sample(invalid_tensor) # vocab_freq_table must be provided if vocab_min_count, # vocab_subsampling, or corpus_size is specified. From 8075a392363e0b1406e3ffa35360af849d463788 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Mon, 14 Jan 2019 21:57:17 -0500 Subject: [PATCH 2/2] Update skip_gram_kernels.cc --- tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc index 57da725c92..7480177985 100644 --- a/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc +++ b/tensorflow_addons/text/cc/kernels/skip_gram_kernels.cc @@ -48,8 +48,9 @@ class SkipGramGenerateCandidatesOp : public OpKernel { const int max_skips = *(max_skips_tensor->scalar().data()); const Tensor& input_check = context->input(0); - OP_REQUIRES(context, TensorShapeUtils::IsVector(input_check.shape()), - errors::InvalidArgument("input_tensor must be of rank 1")); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(input_check.shape()), + errors::InvalidArgument("input_tensor must be of rank 1")); OP_REQUIRES( context, min_skips >= 0 && max_skips >= 0,