Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ case class ArrayContains(left: Expression, right: Expression)

override def dataType: DataType = BooleanType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then in checkInputDataTypes we should check if there is Ordering for right.dataType. Otherwise for example MapType will throw a match error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks


override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq.empty
case _ => left.dataType match {
Expand All @@ -520,7 +523,7 @@ case class ArrayContains(left: Expression, right: Expression)
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
}
}

Expand All @@ -533,7 +536,7 @@ case class ArrayContains(left: Expression, right: Expression)
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (v == value) {
} else if (ordering.equiv(v, value)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously does this work for Map? No?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MapType is not supported in comparison, even =

return true
}
)
Expand Down Expand Up @@ -582,11 +585,7 @@ case class ArraysOverlap(left: Expression, right: Expression)

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (RowOrdering.isOrderable(elementType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
}
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a general suggestion. For these refactoring, we should do it in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll keep this in mind for the future, thanks.

case failure => failure
}

Expand Down Expand Up @@ -1238,13 +1237,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
case class ArrayPosition(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)

override def dataType: DataType = LongType
override def inputTypes: Seq[AbstractDataType] =
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
}
}

override def nullSafeEval(arr: Any, value: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == value) {
if (v != null && ordering.equiv(v, value)) {
return (i + 1).toLong
}
)
Expand Down Expand Up @@ -1293,6 +1303,9 @@ case class ArrayPosition(left: Expression, right: Expression)
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)

override def dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
Expand All @@ -1307,6 +1320,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
)
}

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
TypeUtils.checkForOrderingExpr(
left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
}
}

override def nullable: Boolean = true

override def nullSafeEval(value: Any, ordinal: Any): Any = {
Expand All @@ -1331,7 +1354,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}
case _: MapType =>
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
Expand All @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
var i = 0
var found = false
while (i < length && !found) {
if (keys.get(i, keyType) == ordinal) {
if (ordering.equiv(keys.get(i, keyType), ordinal)) {
found = true
} else {
i += 1
Expand Down Expand Up @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
case class GetMapValue(child: Expression, key: Expression)
extends GetMapValueUtil with ExtractValue with NullIntolerant {

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(keyType)

private def keyType = child.dataType.asInstanceOf[MapType].keyType

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
}
}

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)

Expand All @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression)

// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
getValueEval(value, ordinal, keyType)
getValueEval(value, ordinal, keyType, ordering)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)

// binary
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
ArrayType(BinaryType))
val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
ArrayType(BinaryType))
val be = Literal.create(Array[Byte](1, 2), BinaryType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, binary type is not complex type

val nullBinary = Literal.create(null, BinaryType)

checkEvaluation(ArrayContains(b0, be), true)
checkEvaluation(ArrayContains(b1, be), false)
checkEvaluation(ArrayContains(b0, nullBinary), null)
checkEvaluation(ArrayContains(b2, be), null)
checkEvaluation(ArrayContains(b3, be), true)

// complex data types
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
checkEvaluation(ArrayContains(aa0, aae), true)
checkEvaluation(ArrayContains(aa1, aae), false)
}

test("ArraysOverlap") {
Expand Down Expand Up @@ -349,6 +376,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)

val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
ArrayType(ArrayType(IntegerType)))
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
checkEvaluation(ArrayPosition(aa0, aae), 1L)
checkEvaluation(ArrayPosition(aa1, aae), 0L)
}

test("elementAt") {
Expand Down Expand Up @@ -386,7 +421,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(ElementAt(m0, Literal(1.0)), null)
assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure)

checkEvaluation(ElementAt(m0, Literal("d")), null)

Expand All @@ -397,6 +432,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(m0, Literal("c")), null)

checkEvaluation(ElementAt(m2, Literal("a")), null)

// test binary type as keys
val mb0 = Literal.create(
Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
MapType(BinaryType, StringType))
val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))

checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null)

checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null)
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}

test("Concat") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select('c as 'sCol2, 'a as 'sCol1)
checkRule(originalQuery, correctAnswer)
}

test("SPARK-24313: support binary type as map keys in GetMapValue") {
val mb0 = Literal.create(
Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
MapType(BinaryType, StringType))
val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))

checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)

checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
}

test("SPARK-24313: access map with binary keys") {
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}
}