@@ -3309,6 +3309,71 @@ def testMutableHashTableFindHighRank(self):
33093309 result = self .evaluate (output )
33103310 self .assertAllEqual ([[0 , 1 ], [- 1 , - 1 ]], result )
33113311
3312+ def testMutableHashTableFindWithInvalidShapeDefaultValue (self ):
3313+ default_val = [- 1 , - 1 ]
3314+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3315+ default_val )
3316+
3317+ input_string = constant_op .constant ([["brain" , "salad" ],
3318+ ["tank" , "tarkus" ]])
3319+
3320+ invalid_default_val = constant_op .constant (
3321+ [[- 2 , - 3 ], [- 4 , - 5 ], [- 6 , - 7 ], [- 8 , - 9 ]], dtypes .int64 )
3322+
3323+ with self .assertRaisesRegex (
3324+ (ValueError , errors_impl .InvalidArgumentError ),
3325+ "Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]" ):
3326+ self .evaluate (table .lookup (input_string , invalid_default_val ))
3327+
3328+ invalid_default_val = constant_op .constant ([[[- 2 , - 3 ], [- 4 , - 5 ]]],
3329+ dtypes .int64 )
3330+ with self .assertRaisesRegex (
3331+ (ValueError , errors_impl .InvalidArgumentError ),
3332+ "Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]" ):
3333+ self .evaluate (table .lookup (input_string , invalid_default_val ))
3334+
3335+ def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue (self ):
3336+ default_val = - 1
3337+ keys = constant_op .constant (["brain" , "salad" , "surgery" ])
3338+ values = constant_op .constant ([0 , 1 , 2 ], dtypes .int64 )
3339+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3340+ default_val )
3341+
3342+ self .evaluate (table .insert (keys , values ))
3343+ self .assertAllEqual (3 , self .evaluate (table .size ()))
3344+
3345+ input_string = constant_op .constant ([["brain" , "salad" ],
3346+ ["tank" , "tarkus" ]])
3347+
3348+ dynamic_default_val = constant_op .constant ([[- 2 , - 3 ], [- 4 , - 5 ]],
3349+ dtypes .int64 )
3350+ output = table .lookup (input_string , dynamic_default_val )
3351+ self .assertAllEqual ([2 , 2 ], output .get_shape ())
3352+
3353+ result = self .evaluate (output )
3354+ self .assertAllEqual ([[0 , 1 ], [- 4 , - 5 ]], result )
3355+
3356+ def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue (self ):
3357+ default_val = [- 1 , - 1 ]
3358+ keys = constant_op .constant (["brain" , "salad" , "surgery" ])
3359+ values = constant_op .constant ([[0 , 1 ], [2 , 3 ], [4 , 5 ]], dtypes .int64 )
3360+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3361+ default_val )
3362+
3363+ self .evaluate (table .insert (keys , values ))
3364+ self .assertAllEqual (3 , self .evaluate (table .size ()))
3365+
3366+ input_string = constant_op .constant ([["brain" , "salad" ],
3367+ ["tank" , "tarkus" ]])
3368+
3369+ dynamic_default_val = constant_op .constant (
3370+ [[[- 2 , - 3 ], [- 4 , - 5 ]], [[- 6 , - 7 ], [- 8 , - 9 ]]], dtypes .int64 )
3371+ output = table .lookup (input_string , dynamic_default_val )
3372+ self .assertAllEqual ([2 , 2 , 2 ], output .get_shape ())
3373+
3374+ result = self .evaluate (output )
3375+ self .assertAllEqual ([[[0 , 1 ], [2 , 3 ]], [[- 6 , - 7 ], [- 8 , - 9 ]]], result )
3376+
33123377 def testMutableHashTableInsertHighRank (self ):
33133378 with self .cached_session ():
33143379 default_val = - 1
0 commit comments