Skip to content

Commit ee3c445

Browse files
committed
Mutablehashtable lookup support full size dynamic default values.
This PR is one part of RFC:tensorflow/community#237
1 parent bd8733d commit ee3c445

File tree

5 files changed

+109
-10
lines changed

5 files changed

+109
-10
lines changed

tensorflow/core/framework/lookup_interface.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
8383
const Tensor& default_value) {
8484
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
8585
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
86-
if (default_value.shape() != value_shape()) {
86+
TensorShape fullsize_value_shape = key.shape();
87+
for (int i = 0; i < key_shape().dims(); ++i) {
88+
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
89+
}
90+
fullsize_value_shape.AppendShape(value_shape());
91+
if (default_value.shape() != value_shape() &&
92+
default_value.shape() != fullsize_value_shape) {
8793
return errors::InvalidArgument(
88-
"Expected shape ", value_shape().DebugString(),
89-
" for default value, got ", default_value.shape().DebugString());
94+
"Expected shape ", value_shape().DebugString(), " or ",
95+
fullsize_value_shape.DebugString(), " for default value, got ",
96+
default_value.shape().DebugString());
9097
}
9198
return Status::OK();
9299
}

tensorflow/core/framework/lookup_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
128128
// requirements are satisfied, otherwise it returns InvalidArgument:
129129
// - DataType of the tensor keys equals to the table key_dtype
130130
// - DataType of the tensor default_value equals to the table value_dtype
131-
// - the default_value tensor shape matches the table's value shape.
131+
// - the default_value tensor has the required shape given keys and the
132+
// tables's value shape.
132133
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
133134

134135
string DebugString() const override {

tensorflow/core/kernels/lookup_table_op.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,19 @@ class MutableHashTableOfScalars final : public LookupInterface {
6262

6363
int64 total = value_values.size();
6464
int64 default_total = default_flat.size();
65-
bool is_full_default = (total == default_total);
65+
bool is_full_size_default = (total == default_total);
6666

6767
tf_shared_lock l(mu_);
6868
for (int64 i = 0; i < key_values.size(); ++i) {
69+
// is_full_size_default is true:
70+
// Each key has an independent default value, key_values(i)
71+
// corresponding uses default_flat(i) as its default value.
72+
//
73+
// is_full_size_default is false:
74+
// All keys will share the default_flat(0) as default value.
6975
value_values(i) = gtl::FindWithDefault(
7076
table_, SubtleMustCopyIfIntegral(key_values(i)),
71-
is_full_default ? default_flat(i) : default_flat(0));
77+
is_full_size_default ? default_flat(i) : default_flat(0));
7278
}
7379

7480
return Status::OK();
@@ -185,7 +191,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
185191

186192
int64 total = value_values.size();
187193
int64 default_total = default_flat.size();
188-
bool is_full_default = (total == default_total);
194+
bool is_full_size_default = (total == default_total);
189195

190196
tf_shared_lock l(mu_);
191197
for (int64 i = 0; i < key_values.size(); ++i) {
@@ -196,9 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
196202
value_values(i, j) = value_vec->at(j);
197203
}
198204
} else {
205+
// is_full_size_default is true:
206+
// Each key has an independent default value, key_values(i)
207+
// corresponding uses default_flat(i) as its default value.
208+
//
209+
// is_full_size_default is false:
210+
// All keys will share the default_flat(0) as default value.
199211
for (int64 j = 0; j < value_dim; j++) {
200212
value_values(i, j) =
201-
is_full_default ? default_flat(i, j) : default_flat(0, j);
213+
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
202214
}
203215
}
204216
}

tensorflow/python/kernel_tests/lookup_ops_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,6 +3309,71 @@ def testMutableHashTableFindHighRank(self):
33093309
result = self.evaluate(output)
33103310
self.assertAllEqual([[0, 1], [-1, -1]], result)
33113311

3312+
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
3313+
default_val = [-1, -1]
3314+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3315+
default_val)
3316+
3317+
input_string = constant_op.constant([["brain", "salad"],
3318+
["tank", "tarkus"]])
3319+
3320+
invalid_default_val = constant_op.constant(
3321+
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
3322+
3323+
with self.assertRaisesRegex(
3324+
(ValueError, errors_impl.InvalidArgumentError),
3325+
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
3326+
self.evaluate(table.lookup(input_string, invalid_default_val))
3327+
3328+
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
3329+
dtypes.int64)
3330+
with self.assertRaisesRegex(
3331+
(ValueError, errors_impl.InvalidArgumentError),
3332+
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
3333+
self.evaluate(table.lookup(input_string, invalid_default_val))
3334+
3335+
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
3336+
default_val = -1
3337+
keys = constant_op.constant(["brain", "salad", "surgery"])
3338+
values = constant_op.constant([0, 1, 2], dtypes.int64)
3339+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3340+
default_val)
3341+
3342+
self.evaluate(table.insert(keys, values))
3343+
self.assertAllEqual(3, self.evaluate(table.size()))
3344+
3345+
input_string = constant_op.constant([["brain", "salad"],
3346+
["tank", "tarkus"]])
3347+
3348+
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
3349+
dtypes.int64)
3350+
output = table.lookup(input_string, dynamic_default_val)
3351+
self.assertAllEqual([2, 2], output.get_shape())
3352+
3353+
result = self.evaluate(output)
3354+
self.assertAllEqual([[0, 1], [-4, -5]], result)
3355+
3356+
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
3357+
default_val = [-1, -1]
3358+
keys = constant_op.constant(["brain", "salad", "surgery"])
3359+
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3360+
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
3361+
default_val)
3362+
3363+
self.evaluate(table.insert(keys, values))
3364+
self.assertAllEqual(3, self.evaluate(table.size()))
3365+
3366+
input_string = constant_op.constant([["brain", "salad"],
3367+
["tank", "tarkus"]])
3368+
3369+
dynamic_default_val = constant_op.constant(
3370+
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
3371+
output = table.lookup(input_string, dynamic_default_val)
3372+
self.assertAllEqual([2, 2, 2], output.get_shape())
3373+
3374+
result = self.evaluate(output)
3375+
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
3376+
33123377
def testMutableHashTableInsertHighRank(self):
33133378
with self.cached_session():
33143379
default_val = -1

tensorflow/python/ops/lookup_ops.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,8 +1819,22 @@ def lookup(self, keys, dynamic_default_values=None, name=None):
18191819
keys: Keys to look up. Can be a tensor of any shape. Must match the
18201820
table's key_dtype.
18211821
dynamic_default_values: The values to use if a key is missing in the
1822-
table. If None (by default), the static default_value
1823-
`self._default_value` will be used.
1822+
table. If None (by default), the `table.default_value` will be used.
1823+
Shape of `dynamic_default_values` must be same with
1824+
`table.default_value` or the lookup result tensor.
1825+
In the latter case, each key will have a different default value.
1826+
1827+
For example:
1828+
1829+
```python
1830+
keys = [0, 1, 3]
1831+
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1832+
1833+
# The key '0' will use [1, 3, 4] as default value.
1834+
# The key '1' will use [2, 3, 9] as default value.
1835+
# The key '3' will use [8, 3, 0] as default value.
1836+
```
1837+
18241838
name: A name for the operation (optional).
18251839
18261840
Returns:

0 commit comments

Comments
 (0)