Skip to content

Commit 9343aea

Browse files
committed
Mutablehashtable lookup support full size dynamic default values.
This PR is one part of RFC:tensorflow/community#237
1 parent 5c42efe commit 9343aea

File tree

6 files changed

+116
-15
lines changed

6 files changed

+116
-15
lines changed

tensorflow/core/framework/lookup_interface.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
8383
const Tensor& default_value) {
8484
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
8585
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
86-
if (default_value.shape() != value_shape()) {
86+
TensorShape fullsize_value_shape = key.shape();
87+
for (int i = 0; i < key_shape().dims(); ++i) {
88+
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
89+
}
90+
fullsize_value_shape.AppendShape(value_shape());
91+
if (default_value.shape() != value_shape() &&
92+
default_value.shape() != fullsize_value_shape) {
8793
return errors::InvalidArgument(
88-
"Expected shape ", value_shape().DebugString(),
89-
" for default value, got ", default_value.shape().DebugString());
94+
"Expected shape ", value_shape().DebugString(), " or ",
95+
fullsize_value_shape.DebugString(), " for default value, got ",
96+
default_value.shape().DebugString());
9097
}
9198
return Status::OK();
9299
}

tensorflow/core/framework/lookup_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
128128
// requirements are satisfied, otherwise it returns InvalidArgument:
129129
// - DataType of the tensor keys equals to the table key_dtype
130130
// - DataType of the tensor default_value equals to the table value_dtype
131-
// - the default_value tensor shape matches the table's value shape.
131+
// - the default_value tensor has the required shape given keys and the
132+
// tables's value shape.
132133
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
133134

134135
string DebugString() const override {

tensorflow/core/kernels/lookup_table_op.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,19 @@ class MutableHashTableOfScalars final : public LookupInterface {
5656

5757
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
5858
const Tensor& default_value) override {
59-
const V default_val = default_value.flat<V>()(0);
6059
const auto key_values = key.flat<K>();
6160
auto value_values = value->flat<V>();
61+
const auto default_flat = default_value.flat<V>();
62+
63+
int64 total = value_values.size();
64+
int64 default_total = default_flat.size();
65+
bool is_full_default = (total == default_total);
6266

6367
tf_shared_lock l(mu_);
6468
for (int64 i = 0; i < key_values.size(); ++i) {
6569
value_values(i) = gtl::FindWithDefault(
66-
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
70+
table_, SubtleMustCopyIfIntegral(key_values(i)),
71+
is_full_default ? default_flat(i) : default_flat(0));
6772
}
6873

6974
return Status::OK();
@@ -173,11 +178,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
173178

174179
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
175180
const Tensor& default_value) override {
176-
const auto default_flat = default_value.flat<V>();
181+
const auto default_flat = default_value.flat_inner_dims<V, 2>();
177182
const auto key_values = key.flat<K>();
178183
auto value_values = value->flat_inner_dims<V, 2>();
179184
int64 value_dim = value_shape_.dim_size(0);
180185

186+
int64 total = value_values.size();
187+
int64 default_total = default_flat.size();
188+
bool is_full_default = (total == default_total);
189+
181190
tf_shared_lock l(mu_);
182191
for (int64 i = 0; i < key_values.size(); ++i) {
183192
ValueArray* value_vec =
@@ -188,7 +197,8 @@ class MutableHashTableOfTensors final : public LookupInterface {
188197
}
189198
} else {
190199
for (int64 j = 0; j < value_dim; j++) {
191-
value_values(i, j) = default_flat(j);
200+
value_values(i, j) =
201+
is_full_default ? default_flat(i, j) : default_flat(0, j);
192202
}
193203
}
194204
}

tensorflow/core/ops/lookup_ops.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ REGISTER_OP("LookupTableFindV2")
166166
ShapeHandle handle;
167167
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
168168

169-
// Default value must be scalar or vector.
170-
ShapeHandle keys;
171-
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
172-
173169
ShapeAndType value_shape_and_type;
174170
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
175171
c,

tensorflow/python/kernel_tests/lookup_ops_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,6 +3309,75 @@ def testMutableHashTableFindHighRank(self):
33093309
result = self.evaluate(output)
33103310
self.assertAllEqual([[0, 1], [-1, -1]], result)
33113311

3312+
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
3313+
with self.cached_session():
3314+
default_val = [-1, -1]
3315+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3316+
default_val)
3317+
3318+
input_string = constant_op.constant([["brain", "salad"],
3319+
["tank", "tarkus"]])
3320+
3321+
raised_error = ValueError
3322+
if context.executing_eagerly():
3323+
raised_error = errors_impl.InvalidArgumentError
3324+
3325+
invalid_default_val = constant_op.constant(
3326+
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
3327+
3328+
with self.assertRaises(raised_error):
3329+
_ = table.lookup(input_string, invalid_default_val)
3330+
3331+
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
3332+
dtypes.int64)
3333+
3334+
with self.assertRaises(raised_error):
3335+
_ = table.lookup(input_string, invalid_default_val)
3336+
3337+
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
3338+
with self.cached_session():
3339+
default_val = -1
3340+
keys = constant_op.constant(["brain", "salad", "surgery"])
3341+
values = constant_op.constant([0, 1, 2], dtypes.int64)
3342+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3343+
default_val)
3344+
3345+
self.evaluate(table.insert(keys, values))
3346+
self.assertAllEqual(3, self.evaluate(table.size()))
3347+
3348+
input_string = constant_op.constant([["brain", "salad"],
3349+
["tank", "tarkus"]])
3350+
3351+
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
3352+
dtypes.int64)
3353+
output = table.lookup(input_string, dynamic_default_val)
3354+
self.assertAllEqual([2, 2], output.get_shape())
3355+
3356+
result = self.evaluate(output)
3357+
self.assertAllEqual([[0, 1], [-4, -5]], result)
3358+
3359+
def testMutableHashTableFindHighRankVactorWithDynamicDefaultValue(self):
3360+
with self.cached_session():
3361+
default_val = [-1, -1]
3362+
keys = constant_op.constant(["brain", "salad", "surgery"])
3363+
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3364+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3365+
default_val)
3366+
3367+
self.evaluate(table.insert(keys, values))
3368+
self.assertAllEqual(3, self.evaluate(table.size()))
3369+
3370+
input_string = constant_op.constant([["brain", "salad"],
3371+
["tank", "tarkus"]])
3372+
3373+
dynamic_default_val = constant_op.constant(
3374+
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
3375+
output = table.lookup(input_string, dynamic_default_val)
3376+
self.assertAllEqual([2, 2, 2], output.get_shape())
3377+
3378+
result = self.evaluate(output)
3379+
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
3380+
33123381
def testMutableHashTableInsertHighRank(self):
33133382
with self.cached_session():
33143383
default_val = -1

tensorflow/python/ops/lookup_ops.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,14 +1810,31 @@ def remove(self, keys, name=None):
18101810

18111811
return op
18121812

1813-
def lookup(self, keys, name=None):
1813+
def lookup(self, keys, dynamic_default_values=None, name=None):
18141814
"""Looks up `keys` in a table, outputs the corresponding values.
18151815
18161816
The `default_value` is used for keys not present in the table.
18171817
18181818
Args:
18191819
keys: Keys to look up. Can be a tensor of any shape. Must match the
18201820
table's key_dtype.
1821+
dynamic_default_values: The values to use if a key is missing in the
1822+
table. If None (by default), the `self._default_value` will be used.
1823+
Shape of `dynamic_default_values` must be same with
1824+
`self._default_value` or the lookup result tensor.
1825+
In the latter case, each key will have a different default value.
1826+
1827+
For example:
1828+
1829+
```python
1830+
keys = [0, 1, 3]
1831+
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1832+
1833+
# The key '0' will use [1, 3, 4] as default value.
1834+
# The key '1' will use [2, 3, 9] as default value.
1835+
# The key '3' will use [8, 3, 0] as default value.
1836+
```
1837+
18211838
name: A name for the operation (optional).
18221839
18231840
Returns:
@@ -1831,8 +1848,9 @@ def lookup(self, keys, name=None):
18311848
(self.resource_handle, keys, self._default_value)):
18321849
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
18331850
with ops.colocate_with(self.resource_handle):
1834-
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
1835-
self._default_value)
1851+
values = gen_lookup_ops.lookup_table_find_v2(
1852+
self.resource_handle, keys, dynamic_default_values
1853+
if dynamic_default_values is not None else self._default_value)
18361854
return values
18371855

18381856
def insert(self, keys, values, name=None):

0 commit comments

Comments
 (0)