@@ -3309,6 +3309,75 @@ def testMutableHashTableFindHighRank(self):
33093309 result = self .evaluate (output )
33103310 self .assertAllEqual ([[0 , 1 ], [- 1 , - 1 ]], result )
33113311
3312+ def testMutableHashTableFindWithInvalidShapeDefaultValue (self ):
3313+ with self .cached_session ():
3314+ default_val = [- 1 , - 1 ]
3315+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3316+ default_val )
3317+
3318+ input_string = constant_op .constant ([["brain" , "salad" ],
3319+ ["tank" , "tarkus" ]])
3320+
3321+ raised_error = ValueError
3322+ if context .executing_eagerly ():
3323+ raised_error = errors_impl .InvalidArgumentError
3324+
3325+ invalid_default_val = constant_op .constant (
3326+ [[- 2 , - 3 ], [- 4 , - 5 ], [- 6 , - 7 ], [- 8 , - 9 ]], dtypes .int64 )
3327+
3328+ with self .assertRaises (raised_error ):
3329+ _ = table .lookup (input_string , invalid_default_val )
3330+
3331+ invalid_default_val = constant_op .constant ([[[- 2 , - 3 ], [- 4 , - 5 ]]],
3332+ dtypes .int64 )
3333+
3334+ with self .assertRaises (raised_error ):
3335+ _ = table .lookup (input_string , invalid_default_val )
3336+
3337+ def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue (self ):
3338+ with self .cached_session ():
3339+ default_val = - 1
3340+ keys = constant_op .constant (["brain" , "salad" , "surgery" ])
3341+ values = constant_op .constant ([0 , 1 , 2 ], dtypes .int64 )
3342+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3343+ default_val )
3344+
3345+ self .evaluate (table .insert (keys , values ))
3346+ self .assertAllEqual (3 , self .evaluate (table .size ()))
3347+
3348+ input_string = constant_op .constant ([["brain" , "salad" ],
3349+ ["tank" , "tarkus" ]])
3350+
3351+ dynamic_default_val = constant_op .constant ([[- 2 , - 3 ], [- 4 , - 5 ]],
3352+ dtypes .int64 )
3353+ output = table .lookup (input_string , dynamic_default_val )
3354+ self .assertAllEqual ([2 , 2 ], output .get_shape ())
3355+
3356+ result = self .evaluate (output )
3357+ self .assertAllEqual ([[0 , 1 ], [- 4 , - 5 ]], result )
3358+
3359+ def testMutableHashTableFindHighRankVactorWithDynamicDefaultValue (self ):
3360+ with self .cached_session ():
3361+ default_val = [- 1 , - 1 ]
3362+ keys = constant_op .constant (["brain" , "salad" , "surgery" ])
3363+ values = constant_op .constant ([[0 , 1 ], [2 , 3 ], [4 , 5 ]], dtypes .int64 )
3364+ table = lookup_ops .MutableHashTable (dtypes .string , dtypes .int64 ,
3365+ default_val )
3366+
3367+ self .evaluate (table .insert (keys , values ))
3368+ self .assertAllEqual (3 , self .evaluate (table .size ()))
3369+
3370+ input_string = constant_op .constant ([["brain" , "salad" ],
3371+ ["tank" , "tarkus" ]])
3372+
3373+ dynamic_default_val = constant_op .constant (
3374+ [[[- 2 , - 3 ], [- 4 , - 5 ]], [[- 6 , - 7 ], [- 8 , - 9 ]]], dtypes .int64 )
3375+ output = table .lookup (input_string , dynamic_default_val )
3376+ self .assertAllEqual ([2 , 2 , 2 ], output .get_shape ())
3377+
3378+ result = self .evaluate (output )
3379+ self .assertAllEqual ([[[0 , 1 ], [2 , 3 ]], [[- 6 , - 7 ], [- 8 , - 9 ]]], result )
3380+
33123381 def testMutableHashTableInsertHighRank (self ):
33133382 with self .cached_session ():
33143383 default_val = - 1
0 commit comments