@@ -417,7 +417,7 @@ case class EqualTo(left: Expression, right: Expression)
417417 case TypeCheckResult .TypeCheckSuccess =>
418418 // TODO: although map type is not orderable, technically map type should be able to be used
419419 // in equality comparison, remove this type check once we support it.
420- if (hasMapType( left.dataType)) {
420+ if (left.dataType.existsRecursively(_. isInstanceOf [ MapType ] )) {
421421 TypeCheckResult .TypeCheckFailure (" Cannot use map type in EqualTo, but the actual " +
422422 s " input type is ${left.dataType.catalogString}. " )
423423 } else {
@@ -427,14 +427,6 @@ case class EqualTo(left: Expression, right: Expression)
427427 }
428428 }
429429
430- private def hasMapType (dt : DataType ): Boolean = dt match {
431- case _ : MapType => true
432- case st : StructType => st.map(_.dataType).exists(hasMapType)
433- case a : ArrayType => hasMapType(a.elementType)
434- case udt : UserDefinedType [_] => hasMapType(udt.sqlType)
435- case _ => false
436- }
437-
438430 override def symbol : String = " ="
439431
440432 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
@@ -468,7 +460,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
468460 case TypeCheckResult .TypeCheckSuccess =>
469461 // TODO: although map type is not orderable, technically map type should be able to be used
470462 // in equality comparison, remove this type check once we support it.
471- if (hasMapType( left.dataType)) {
463+ if (left.dataType.existsRecursively(_. isInstanceOf [ MapType ] )) {
472464 TypeCheckResult .TypeCheckFailure (" Cannot use map type in EqualNullSafe, but the actual " +
473465 s " input type is ${left.dataType.catalogString}. " )
474466 } else {
@@ -478,14 +470,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
478470 }
479471 }
480472
481- private def hasMapType (dt : DataType ): Boolean = dt match {
482- case _ : MapType => true
483- case st : StructType => st.map(_.dataType).exists(hasMapType)
484- case a : ArrayType => hasMapType(a.elementType)
485- case udt : UserDefinedType [_] => hasMapType(udt.sqlType)
486- case _ => false
487- }
488-
489473 override def symbol : String = " <=>"
490474
491475 override def nullable : Boolean = false
0 commit comments