Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public static MapType createMapType(DataType keyType, DataType valueType) {
if (valueType == null) {
throw new IllegalArgumentException("valueType should not be null.");
}
return new MapType(keyType, valueType, true);
return new MapType(keyType, valueType, true, false);
}

/**
Expand All @@ -159,7 +159,7 @@ public static MapType createMapType(
if (valueType == null) {
throw new IllegalArgumentException("valueType should not be null.");
}
return new MapType(keyType, valueType, valueContainsNull);
return new MapType(keyType, valueType, valueContainsNull, false);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
SortMaps ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
Expand Down Expand Up @@ -2329,3 +2330,49 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
CreateNamedStruct(children.toList)
}
}

/**
* MapType expressions are not comparable.
*/
object SortMaps extends Rule[LogicalPlan] {
private def containsUnorderedMap(e: Expression): Boolean =
e.resolved && MapType.containsUnorderedMap(e.dataType)

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(left) =>
cmp.withNewChildren(OrderMaps(left) :: right :: Nil)
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(right) =>
cmp.withNewChildren(left :: OrderMaps(right) :: Nil)
case sort: SortOrder if containsUnorderedMap(sort.child) =>
sort.copy(child = OrderMaps(sort.child))
} transform {
case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) =>
// Modify the top level grouping expressions
val replacements = a.groupingExpressions.collect {
case a: Attribute if containsUnorderedMap(a) =>
a -> Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
case e if containsUnorderedMap(e) =>
e -> OrderMaps(e)
}

// Tranform the expression tree.
a.transformExpressionsUp {
case e =>
// TODO create an expression map!
replacements
.find(_._1.semanticEquals(e))
.map(_._2)
.getOrElse(e)
}

case Distinct(child) if child.resolved && child.output.exists(containsUnorderedMap) =>
val projectList = child.output.map { a =>
if (containsUnorderedMap(a)) {
Alias(OrderMaps(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
} else {
a
}
}
Distinct(Project(projectList, child))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -244,7 +245,7 @@ trait CheckAnalysis extends PredicateHelper {

def checkValidGroupingExprs(expr: Expression): Unit = {
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
if (!TypeUtils.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not an orderable " +
Expand All @@ -265,7 +266,7 @@ trait CheckAnalysis extends PredicateHelper {

case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
if (!TypeUtils.isOrderable(order.dataType)) {
failAnalysis(
s"sorting is not supported for columns of type ${order.dataType.simpleString}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ object RowEncoder {
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
case t @ MapType(kt, vt, valueNullable, _) =>
val keys =
Invoke(
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
Expand Down Expand Up @@ -279,7 +279,7 @@ object RowEncoder {
"make",
arrayData :: Nil)

case MapType(kt, vt, valueNullable) =>
case MapType(kt, vt, valueNullable, _) =>
val keyArrayType = ArrayType(kt, false)
val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object Cast {
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, false)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._

Expand Down Expand Up @@ -61,7 +62,7 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering:
override def foldable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (RowOrdering.isOrderable(dataType)) {
if (TypeUtils.isOrderable(dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ class CodegenContext {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case _ =>
throw new IllegalArgumentException(
Expand Down Expand Up @@ -554,6 +555,47 @@ class CodegenContext {
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case MapType(keyType, valueType, _, true) =>
val compareFunc = freshName("compareMap")
val funcCode: String =
s"""
public int $compareFunc(MapData a, MapData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
ArrayData aKeys = a.keyArray();
ArrayData aValues = a.valueArray();
ArrayData bKeys = b.keyArray();
ArrayData bValues = b.valueArray();
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < minLength; i++) {
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
int comp = ${genComp(keyType, "keyA", "keyB")};
if (comp != 0) {
return comp;
}
boolean isNullA = aValues.isNullAt(i);
boolean isNullB = bValues.isNullAt(i);
if (isNullA && isNullB) {
// Nothing
} else if (isNullA) {
return -1;
} else if (isNullB) {
return 1;
} else {
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
comp = ${genComp(valueType, "valueA", "valueB")};
if (comp != 0) {
return comp;
}
}
}
return lengthA - lengthB;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
INPUT_ROW = "i"
val comparisons = GenerateOrdering.genComparisons(this, schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case MapType(keyType, valueType, _, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => ExprCode("", "false", s"$input.clone()")
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case MapType(kt, vt, _, _) if canSupport(kt) && canSupport(vt) => true
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
Expand Down Expand Up @@ -126,7 +126,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
Expand Down Expand Up @@ -209,7 +209,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
Expand Down
Loading