diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e5c93b5f0e059..f59b325c1fc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -150,6 +150,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + SortMaps:: ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), @@ -2479,6 +2480,60 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { } } +/** + * MapType expressions are not comparable. + */ +object SortMaps extends Rule[LogicalPlan] { + private def containsUnorderedMap(e: Expression): Boolean = + e.resolved && MapType.containsUnorderedMap(e.dataType) + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case expr: Expression with OrderSpecified => + expr.mapChildren { + case child if containsUnorderedMap(child) => OrderMaps(child) + case child => child + } + } transform { + case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) => + // Modify the top level grouping expressions + val replacements = a.groupingExpressions.collect { + case a: Attribute if containsUnorderedMap(a) => + a -> Alias(OrderMaps(a), a.name)() + case e if containsUnorderedMap(e) => + e -> OrderMaps(e) + } + a.transformExpressionsUp { + case e => + replacements + .find(_._1.semanticEquals(e)) + .map(_._2) + .getOrElse(e) + } + case distinct: Distinct => + wrapOrderMaps(distinct) + + case setOperation: SetOperation => + wrapOrderMaps(setOperation) + } + + private[this] def wrapOrderMaps(logicalPlan: LogicalPlan) = { + logicalPlan.mapChildren(child => { + if (child.resolved && child.output.exists(containsUnorderedMap)) { + val projectList = child.output.map { a => + if (containsUnorderedMap(a)) { + Alias(OrderMaps(a), a.name)() + } else { + a + } + } + Project(projectList, child) + } else { + child + } + }) + } +} + /** * The aggregate expressions from subquery referencing outer query block are pushed * down to the outer query block for evaluation. This rule below updates such outer references diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b5e8bdd79869e..64fbcf196645b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -202,7 +203,7 @@ trait CheckAnalysis extends PredicateHelper { } // Check if the data type of expr is orderable. - if (!RowOrdering.isOrderable(expr.dataType)) { + if (!TypeUtils.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + s"because its data type ${expr.dataType.simpleString} is not an orderable " + @@ -223,7 +224,7 @@ trait CheckAnalysis extends PredicateHelper { case Sort(orders, _, _) => orders.foreach { order => - if (!RowOrdering.isOrderable(order.dataType)) { + if (!TypeUtils.isOrderable(order.dataType)) { failAnalysis( s"sorting is not supported for columns of type ${order.dataType.simpleString}") } @@ -322,14 +323,6 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) - // TODO: although map type is not orderable, technically map type should be able to be - // used in equality comparison, remove this type check once we support it. - case o if mapColumnInSetOperation(o).isDefined => - val mapCol = mapColumnInSetOperation(o).get - failAnalysis("Cannot have map type columns in DataFrame which calls " + - s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + - "is " + mapCol.dataType.simpleString) - case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..2925bfe6cb832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -146,7 +146,7 @@ object RowEncoder { ObjectType(classOf[Object])) } - case t @ MapType(kt, vt, valueNullable) => + case t @ MapType(kt, vt, valueNullable, _) => val keys = Invoke( Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), @@ -307,7 +307,7 @@ object RowEncoder { arrayData :: Nil, returnNullable = false) - case MapType(kt, vt, valueNullable) => + case MapType(kt, vt, valueNullable, _) => val keyArrayType = ArrayType(kt, false) val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bc809f559d586..a1310a5ecdb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -67,7 +67,7 @@ object Cast { canCast(fromType, toType) && resolvableNullability(fn || forceNullable(fromType, toType), tn) - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, false)) => canCast(fromKey, toKey) && (!forceNullable(fromKey, toKey)) && canCast(fromValue, toValue) && @@ -103,7 +103,7 @@ object Cast { case (TimestampType, StringType) => true case (TimestampType, DateType) => true case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + case (MapType(fromKey, fromValue, _, _), MapType(toKey, toValue, _, _)) => needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3b722a47d688..7f33c6396035e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -642,3 +642,9 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * An expression is marked as `OrderSpecified` if ordering judged from data + * type is used when calculate. + */ +trait OrderSpecified diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..88498c9971f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -62,13 +63,13 @@ case class SortOrder( direction: SortDirection, nullOrdering: NullOrdering, sameOrderExpressions: Set[Expression]) - extends UnaryExpression with Unevaluable { + extends UnaryExpression with Unevaluable with OrderSpecified { /** Sort order is not foldable because we don't have an eval for it. */ override def foldable: Boolean = false override def checkInputDataTypes(): TypeCheckResult = { - if (RowOrdering.isOrderable(dataType)) { + if (TypeUtils.isOrderable(dataType)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 58fd1d8620e16..347fd03a09241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ @ExpressionDescription( usage = "_FUNC_(expr) - Returns the maximum value of `expr`.") -case class Max(child: Expression) extends DeclarativeAggregate { +case class Max(child: Expression) extends DeclarativeAggregate with OrderSpecified { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index b2724ee76827c..9ddb2ee88e202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ @ExpressionDescription( usage = "_FUNC_(expr) - Returns the minimum value of `expr`.") -case class Min(child: Expression) extends DeclarativeAggregate { +case class Min(child: Expression) extends DeclarativeAggregate with OrderSpecified { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7559852a2ac45..db98a49263c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -567,7 +567,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends Expression with OrderSpecified { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -633,7 +633,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends Expression with OrderSpecified { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3dc3f8e4adac0..89e6e5f5248bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -617,6 +617,7 @@ class CodegenContext { case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" + case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case NullType => "false" case _ => @@ -687,6 +688,49 @@ class CodegenContext { } """ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + + case mapType @ MapType(keyType, valueType, valueContainsNull, true) => + val compareFunc = freshName("compareMap") + val funcCode: String = + s""" + public int $compareFunc(MapData a, MapData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + ArrayData aKeys = a.keyArray(); + ArrayData aValues = a.valueArray(); + ArrayData bKeys = b.keyArray(); + ArrayData bValues = b.valueArray(); + int minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < minLength; i++) { + ${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")}; + ${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")}; + int comp = ${genComp(keyType, "keyA", "keyB")}; + if (comp != 0) { + return comp; + } + boolean isNullA = aValues.isNullAt(i); + boolean isNullB = bValues.isNullAt(i); + if (isNullA && isNullB) { + // Nothing + } else if (isNullA) { + return -1; + } else if (isNullB) { + return 1; + } else { + ${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")}; + ${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")}; + comp = ${genComp(valueType, "valueA", "valueB")}; + if (comp != 0) { + return comp; + } + } + } + return lengthA - lengthB; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" + case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1e4ac3f2afd52..49895826896be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -130,7 +130,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) - case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + case MapType(keyType, valueType, _, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => ExprCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4bd50aee05514..a2b8726723450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -38,7 +38,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true - case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case MapType(kt, vt, _, _) if canSupport(kt) && canSupport(vt) => true case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -133,7 +133,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, _, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. @@ -216,7 +216,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case m @ MapType(kt, vt, _) => + case m @ MapType(kt, vt, _, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..cea4a5692c270 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ /** @@ -129,7 +129,7 @@ case class MapValues(child: Expression) """) // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + extends BinaryExpression with ExpectsInputTypes with CodegenFallback with OrderSpecified { def this(e: Expression) = this(e, Literal(true)) @@ -139,7 +139,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + case ArrayType(dt, _) if TypeUtils.isOrderable(dt) => ascendingOrder match { case Literal(_: Boolean, BooleanType) => TypeCheckResult.TypeCheckSuccess @@ -159,6 +159,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(m: MapType, _) => m.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -182,6 +183,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(m: MapType, _) => m.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -287,3 +289,143 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * This expression orders all maps in an expression's result. This expression enables the use of + * maps in comparisons and equality operations. + */ +case class OrderMaps(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(ArrayType, MapType, StructType)) + + /** Create a data type in which all maps are ordered. */ + private[this] def createDataType(dataType: DataType): DataType = dataType match { + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = createDataType(field.dataType)) + }) + case ArrayType(elementType, containsNull) => + ArrayType(createDataType(elementType), containsNull) + case MapType(keyType, valueType, valueContainsNull, false) => + MapType( + createDataType(keyType), + createDataType(valueType), + valueContainsNull, + ordered = true) + case _ => + dataType + } + + override lazy val dataType: DataType = createDataType(child.dataType) + + private[this] val identity = (id: Any) => id + + /** + * Create a function that transforms a Spark SQL datum to a new datum for which all MapData + * elements have been ordered. + */ + private[this] def createTransform(dataType: DataType): Option[Any => Any] = { + dataType match { + case m @ MapType(keyType, valueType, _, false) => + val keyTransform = createTransform(keyType).getOrElse(identity) + val valueTransform = createTransform(valueType).getOrElse(identity) + val ordering = Ordering.Tuple2(m.interpretedKeyOrdering, m.interpretedValueOrdering) + Option((data: Any) => { + val input = data.asInstanceOf[MapData] + val length = input.numElements() + val buffer = Array.ofDim[(Any, Any)](length) + + // Move the entries into a temporary buffer. + var i = 0 + val keys = input.keyArray() + val values = input.valueArray() + while (i < length) { + val key = keyTransform(keys.get(i, keyType)) + val value = if (!values.isNullAt(i)) { + valueTransform(values.get(i, valueType)) + } else { + null + } + buffer(i) = key -> value + i += 1 + } + + // Sort the buffer. + java.util.Arrays.sort(buffer, ordering) + + // Recreate the map data. + i = 0 + val sortedKeys = Array.ofDim[Any](length) + val sortedValues = Array.ofDim[Any](length) + while (i < length) { + sortedKeys(i) = buffer(i)._1 + sortedValues(i) = buffer(i)._2 + i += 1 + } + ArrayBasedMapData(sortedKeys, sortedValues) + }) + case ArrayType(dt, _) => + createTransform(dt).map { transform => + data: Any => { + val input = data.asInstanceOf[ArrayData] + val length = input.numElements() + val output = Array.ofDim[Any](length) + var i = 0 + while (i < length) { + if (!input.isNullAt(i)) { + output(i) = transform(input.get(i, dt)) + } + i += 1 + } + new GenericArrayData(output) + } + } + case StructType(fields) => + val transformOpts = fields.map { field => + createTransform(field.dataType) + } + // Only transform a struct if a meaningful transformation has been defined. + if (transformOpts.exists(_.isDefined)) { + val transforms = transformOpts.zip(fields).map { case (opt, field) => + val dataType = field.dataType + val transform = opt.getOrElse(identity) + (input: InternalRow, i: Int) => { + transform(input.get(i, dataType)) + } + } + val length = fields.length + val tf = (data: Any) => { + val input = data.asInstanceOf[InternalRow] + val output = Array.ofDim[Any](length) + var i = 0 + while (i < length) { + if (!input.isNullAt(i)) { + output(i) = transforms(i)(input, i) + } + i += 1 + } + new GenericInternalRow(output) + } + Some(tf) + } else { + None + } + case _ => None + } + } + + @transient private[this] lazy val transform = { + createTransform(child.dataType).getOrElse(identity) + } + + override protected def nullSafeEval(input: Any): Any = transform(input) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // TODO we should code generate this. + val tf = ctx.addReferenceObj("transform", transform, classOf[Any => Any].getCanonicalName) + nullSafeCodeGen(ctx, ev, eval => { + s"${ev.value} = (${ctx.boxedType(dataType)})$tf.apply($eval);" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4b6574a31424e..52c4235090c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -205,7 +205,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapClass = classOf[ArrayBasedMapData].getName - val MapType(keyDt, valueDt, _) = dataType + val MapType(keyDt, valueDt, _, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val evalValues = values.map(e => e.genCode(ctx)) val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908905..16506bf791f2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -61,7 +61,7 @@ object ExtractValue { case (_: ArrayType, _) => GetArrayItem(child, extraction) - case (MapType(kt, _, _), _) => GetMapValue(child, extraction) + case (MapType(kt, _, _, _), _) => GetMapValue(child, extraction) case (otherType, _) => val errorMsg = otherType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 8618f49086077..7e46fcdb7a1c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -263,7 +263,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with new StructType() .add("col", et, containsNull) } - case MapType(kt, vt, valueContainsNull) => + case MapType(kt, vt, valueContainsNull, _) => if (position) { new StructType() .add("pos", IntegerType, nullable = false) @@ -289,7 +289,7 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with }) rows } - case MapType(kt, vt, _) => + case MapType(kt, vt, _, _) => val inputMap = child.eval(input).asInstanceOf[MapData] if (inputMap == null) { Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index eb3c49f5cf30e..a8e8f9107cddc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -413,7 +413,7 @@ abstract class HashExpression[E] extends Expression { case BinaryType => genHashBytes(input, result) case StringType => genHashString(input, result) case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) - case MapType(kt, vt, valueContainsNull) => + case MapType(kt, vt, valueContainsNull, _) => genHashForMap(ctx, input, result, kt, vt, valueContainsNull) case StructType(fields) => genHashForStruct(ctx, input, result, fields) case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) @@ -484,7 +484,7 @@ abstract class InterpretedHashFunction { case udt: UserDefinedType[_] => val mapType = udt.sqlType.asInstanceOf[MapType] mapType.keyType -> mapType.valueType - case MapType(kt, vt, _) => kt -> vt + case MapType(kt, vt, _, _) => kt -> vt } val keys = map.keyArray() val values = map.valueArray() @@ -859,7 +859,7 @@ object HiveHashFunction extends InterpretedHashFunction { case udt: UserDefinedType[_] => val mapType = udt.sqlType.asInstanceOf[MapType] mapType.keyType -> mapType.valueType - case MapType(_kt, _vt, _) => _kt -> _vt + case MapType(_kt, _vt, _, _) => _kt -> _vt } val keys = map.keyArray() val values = map.valueArray() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f2eee991c9865..83fc3c86c5dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -639,7 +639,7 @@ case class MapObjects private( val genFunctionValue = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) - case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case MapType(_, _, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) case _ => genFunction.value } @@ -837,7 +837,7 @@ case class CatalystToExternalMap private( lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) - case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case MapType(_, _, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) case _ => genFunction.value } val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de3cfdbe..42ef6bc5fa10d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -18,22 +18,32 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(orders: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) + @transient private[this] lazy val orderings = orders.toIndexedSeq.map { order => + val ordering = TypeUtils.getInterpretedOrdering(order.dataType) + if (order.direction == Ascending) { + ordering + } else { + ordering.reverse + } + } + def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 - val size = ordering.size + val size = orders.size while (i < size) { - val order = ordering(i) + val order = orders(i) val left = order.child.eval(a) val right = order.child.eval(b) @@ -44,29 +54,14 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow } else if (right == null) { return if (order.nullOrdering == NullsFirst) 1 else -1 } else { - val comparison = order.dataType match { - case dt: AtomicType if order.direction == Ascending => - dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case dt: AtomicType if order.direction == Descending => - dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case a: ArrayType if order.direction == Ascending => - a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) - case a: ArrayType if order.direction == Descending => - a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case s: StructType if order.direction == Ascending => - s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) - case s: StructType if order.direction == Descending => - s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) - case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") - } + val comparison = orderings(i).compare(left, right) if (comparison != 0) { return comparison } } i += 1 } - return 0 + 0 } } @@ -81,23 +76,3 @@ object InterpretedOrdering { }) } } - -object RowOrdering { - - /** - * Returns true iff the data type can be ordered (i.e. can be sorted). - */ - def isOrderable(dataType: DataType): Boolean = dataType match { - case NullType => true - case dt: AtomicType => true - case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) - case array: ArrayType => isOrderable(array.elementType) - case udt: UserDefinedType[_] => isOrderable(udt.sqlType) - case _ => false - } - - /** - * Returns true iff outputs from the expressions can be ordered. - */ - def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 61df5e053a374..8725af952de94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -152,7 +152,7 @@ case class Not(child: Expression) true """) // scalastyle:on line.size.limit -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate with OrderSpecified { require(list != null, "list should not be null") @@ -270,7 +270,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { +case class InSet(child: Expression, hset: Set[Any]) + extends UnaryExpression + with Predicate + with OrderSpecified { require(hset != null, "hset could not be null") @@ -541,7 +544,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } -abstract class BinaryComparison extends BinaryOperator with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate with OrderSpecified { // Note that we need to give a superset of allowable input types since orderable types are not // finitely enumerable. The allowable types are checked below by checkInputDataTypes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 523b53b39d6b5..78e934befd7b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -575,8 +575,8 @@ object SimplifyCasts extends Rule[LogicalPlan] { case Cast(e, dataType, _) if e.dataType == dataType => e case c @ Cast(e, dataType, _) => (e.dataType, dataType) match { case (ArrayType(from, false), ArrayType(to, true)) if from == to => e - case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true)) - if fromKey == toKey && fromValue == toValue => e + case (MapType(fromKey, fromValue, false, fromOrder), MapType(toKey, toValue, true, toOrder)) + if fromKey == toKey && fromValue == toValue && (!toOrder || fromOrder) => e case _ => c } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..55c9f0c7f0c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ /** @@ -34,7 +34,7 @@ object TypeUtils { } def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { - if (RowOrdering.isOrderable(dt)) { + if (isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") @@ -64,11 +64,30 @@ object TypeUtils { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case m: MapType if m.ordered => m.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType) + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") } } + /** + * Returns true iff the data type can be ordered (i.e. can be sorted). + */ + def isOrderable(dataType: DataType): Boolean = dataType match { + case NullType => true + case dt: AtomicType => true + case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case MapType(keyType, valueType, _, true) => isOrderable(keyType) && isOrderable(valueType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) + case _ => false + } + + + def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val v1 = x(i) & 0xff diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218ad..91e24bd9d63f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -222,7 +222,8 @@ object DataType { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + case (MapType(leftKeyType, leftValueType, _, _), + MapType(rightKeyType, rightValueType, _, _)) => equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) case (StructType(leftFields), StructType(rightFields)) => @@ -253,7 +254,7 @@ object DataType { case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, _)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromKey, toKey) && equalsIgnoreCompatibleNullability(fromValue, toValue) @@ -279,7 +280,7 @@ object DataType { case (ArrayType(fromElement, _), ArrayType(toElement, _)) => equalsIgnoreCaseAndNullability(fromElement, toElement) - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + case (MapType(fromKey, fromValue, _, _), MapType(toKey, toValue, _, _)) => equalsIgnoreCaseAndNullability(fromKey, toKey) && equalsIgnoreCaseAndNullability(fromValue, toValue) @@ -307,7 +308,8 @@ object DataType { case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + left.valueContainsNull == right.valueContainsNull && + left.ordered == right.ordered case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala index e0bca937d1d84..596136bae6cda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -44,7 +44,7 @@ object HiveStringType { def replaceCharType(dt: DataType): DataType = dt match { case ArrayType(et, nullable) => ArrayType(replaceCharType(et), nullable) - case MapType(kt, vt, nullable) => + case MapType(kt, vt, nullable, _) => MapType(replaceCharType(kt), replaceCharType(vt), nullable) case StructType(fields) => StructType(fields.map { field => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 6691b81dcea8d..2a9d709e85173 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import scala.language.existentials + import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.catalyst.util.{MapData, TypeUtils} /** * The data type for Maps. Keys in a map are not allowed to have `null` values. @@ -30,12 +33,17 @@ import org.apache.spark.annotation.InterfaceStability * @param keyType The data type of map keys. * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. + * @param ordered Indicates if two maps can be compared. */ @InterfaceStability.Stable case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean, + ordered: Boolean = false) extends DataType { + + def this(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) = + this(keyType, valueType, valueContainsNull, false) /** No-arg constructor for kryo. */ def this() = this(null, null, false) @@ -68,11 +76,71 @@ case class MapType( override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" override private[spark] def asNullable: MapType = - MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true, ordered) override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + @transient + private[sql] lazy val interpretedKeyOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + + @transient + private[sql] lazy val interpretedValueOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(valueType) + + @transient + private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] { + assert(ordered) + val keyOrdering = interpretedKeyOrdering + val valueOrdering = interpretedValueOrdering + + def compare(left: MapData, right: MapData): Int = { + val leftKeys = left.keyArray() + val leftValues = left.valueArray() + val rightKeys = right.keyArray() + val rightValues = right.valueArray() + val minLength = scala.math.min(leftKeys.numElements(), rightKeys.numElements()) + var i = 0 + while (i < minLength) { + val keyComp = keyOrdering.compare(leftKeys.get(i, keyType), rightKeys.get(i, keyType)) + if (keyComp != 0) { + return keyComp + } + // TODO this has been taken from ArrayData. Perhaps we should factor out the common code. + val isNullLeft = leftValues.isNullAt(i) + val isNullRight = rightValues.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = valueOrdering.compare( + leftValues.get(i, valueType), + rightValues.get(i, valueType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + val diff = left.numElements() - right.numElements() + if (diff < 0) { + -1 + } else if (diff > 0) { + 1 + } else { + 0 + } + } + } + + override def toString: String = { + s"MapType(${keyType.toString},${valueType.toString},${valueContainsNull.toString})" + } } /** @@ -94,5 +162,15 @@ object MapType extends AbstractDataType { * The `valueContainsNull` is true. */ def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) + new MapType(keyType, valueType, valueContainsNull = true, ordered = false) + + /** + * Check if a dataType contains an unordered map. + */ + private[sql] def containsUnorderedMap(dataType: DataType): Boolean = { + dataType.existsRecursively { + case m: MapType => !m.ordered + case _ => false + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a84..5013ea76b2e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -454,12 +454,13 @@ object StructType extends AbstractDataType { merge(leftElementType, rightElementType), leftContainsNull || rightContainsNull) - case (MapType(leftKeyType, leftValueType, leftContainsNull), - MapType(rightKeyType, rightValueType, rightContainsNull)) => + case (MapType(leftKeyType, leftValueType, leftContainsNull, leftOrdered), + MapType(rightKeyType, rightValueType, rightContainsNull, rightOrdered)) => MapType( merge(leftKeyType, rightKeyType), merge(leftValueType, rightValueType), - leftContainsNull || rightContainsNull) + leftContainsNull || rightContainsNull, + leftOrdered && rightOrdered) case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 8ae3ff5043e68..17a6f9aa982b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -200,7 +200,7 @@ object RandomDataGenerator { forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, false) => for ( keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..8f925ef4c8b85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -27,8 +27,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval @BeanInfo private[sql] case class GroupableData(@BeanProperty data: Int) @@ -51,26 +52,23 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { } @BeanInfo -private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) +private[sql] case class UngroupableData(@BeanProperty data: CalendarInterval) private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { - override def sqlType: DataType = MapType(IntegerType, IntegerType) + override def sqlType: DataType = CalendarIntervalType - override def serialize(ungroupableData: UngroupableData): MapData = { - val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) - new ArrayBasedMapData(keyArray, valueArray) + override def serialize(ungroupableData: UngroupableData): GenericArrayData = { + val output = new Array[Any](2) + output(0) = ungroupableData.data.months + output(1) = ungroupableData.data.microseconds + new GenericArrayData(output) } override def deserialize(datum: Any): UngroupableData = { datum match { - case data: MapData => - val keyArray = data.keyArray().array - val valueArray = data.valueArray().array - assert(keyArray.length == valueArray.length) - val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] - UngroupableData(mapData) + case data: ArrayData => + UngroupableData(new CalendarInterval(data.getInt(0), data.getLong(1))) } } @@ -220,11 +218,6 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation.select(Literal(1).cast(BinaryType).as('badCast)), "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) - errorTest( - "sorting by unsupported column types", - mapRelation.orderBy('map.asc), - "sort" :: "type" :: "map" :: Nil) - errorTest( "sorting by attributes are not from grouping expressions", testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc), @@ -462,6 +455,7 @@ class AnalysisErrorSuite extends AnalysisTest { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), + MapType(IntegerType, StringType), new StructType() .add("f1", FloatType, nullable = true) .add("f2", StringType, nullable = true), @@ -474,10 +468,10 @@ class AnalysisErrorSuite extends AnalysisTest { } val unsupportedDataTypes = Seq( - MapType(StringType, LongType), + CalendarIntervalType, new StructType() .add("f1", FloatType, nullable = true) - .add("f2", MapType(StringType, LongType), nullable = true), + .add("f2", CalendarIntervalType, nullable = true), new UngroupableUDT()) unsupportedDataTypes.foreach { dataType => checkDataType(dataType, shouldSuccess = false) @@ -499,7 +493,7 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can work on binary types but can't work on map types") { + test("Join can work on binary types and map types") { val left = LocalRelation('a.binary, 'b.map(StringType, StringType)) val right = LocalRelation('c.binary, 'd.map(StringType, StringType)) @@ -514,7 +508,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisSuccess(plan2) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..bebd20f5bd92b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -40,8 +40,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { val e = intercept[AnalysisException] { assertSuccess(expr) } - assert(e.getMessage.contains( - s"cannot resolve '${expr.sql}' due to data type mismatch:")) + assert(e.getMessage.contains("cannot resolve")) + assert(e.getMessage.contains("due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -52,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.sql}'") + s"differing types in") } test("check types for unary arithmetic") { @@ -102,6 +102,22 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) + // Array type is supported for binary comparison. + assertSuccess(EqualTo('arrayField, 'arrayField)) + assertSuccess(EqualNullSafe('arrayField, 'arrayField)) + assertSuccess(LessThan('arrayField, 'arrayField)) + assertSuccess(LessThanOrEqual('arrayField, 'arrayField)) + assertSuccess(GreaterThan('arrayField, 'arrayField)) + assertSuccess(GreaterThanOrEqual('arrayField, 'arrayField)) + + // Map type is supported for binary comparison. + assertSuccess(EqualTo('mapField, 'mapField)) + assertSuccess(EqualNullSafe('mapField, 'mapField)) + assertSuccess(LessThan('mapField, 'mapField)) + assertSuccess(LessThanOrEqual('mapField, 'mapField)) + assertSuccess(GreaterThan('mapField, 'mapField)) + assertSuccess(GreaterThanOrEqual('mapField, 'mapField)) + assertErrorForDifferingTypes(EqualTo('intField, 'mapField)) assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) @@ -109,18 +125,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") - assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") - assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") - assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") - assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") - assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") - assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -144,9 +148,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(Min('mapField)) + assertSuccess(Max('mapField)) - assertError(Min('mapField), "min does not support ordering on type") - assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } @@ -210,7 +214,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { assertError(operator(Seq('booleanField)), "requires at least two arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") - assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index e12e272aedffe..85e64fa10c9d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -61,7 +61,4 @@ object TestRelations { val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) - - val mapRelation = LocalRelation( - AttributeReference("map", MapType(IntegerType, IntegerType))()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 793e04f66f0f9..bbd7491dafcc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -85,7 +85,7 @@ class TypeCoercionSuite extends AnalysisTest { private def default(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.default(internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => + case MapType(keyDataType: DataType, valueDataType: DataType, _, _) => CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) case _ => Literal.default(dataType) } @@ -93,7 +93,7 @@ class TypeCoercionSuite extends AnalysisTest { private def createNull(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.create(null, internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => + case MapType(keyDataType: DataType, valueDataType: DataType, _, _) => CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) case _ => Literal.create(null, dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..235955c153913 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -62,6 +62,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(Seq(4, 5, 6), Seq(1, 2, 3)), ArrayType(ArrayType(IntegerType))) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) @@ -78,6 +79,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Literal.create(null, ArrayType(StringType)), null) checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(Seq(1, 2, 3), Seq(4, 5, 6))) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index aa61ba2bff2bb..0dde4daf898fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -29,62 +29,95 @@ import org.apache.spark.sql.types._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { - def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = { - test(s"compare two arrays: a = $a, b = $b") { - val dataType = ArrayType(IntegerType) - val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil) - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] - val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] - Seq(Ascending, Descending).foreach { direction => - val sortOrder = direction match { - case Ascending => BoundReference(0, dataType, nullable = true).asc - case Descending => BoundReference(0, dataType, nullable = true).desc - } - val expectedCompareResult = direction match { - case Ascending => signum(expected) - case Descending => -1 * signum(expected) - } + def compareDatum(a: Any, b: Any, dataType: DataType, expected: Int): Unit = { + val rowType = StructType(StructField("data", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] + val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => + if (dataType.isInstanceOf[MapType]) { + OrderMaps(BoundReference(0, dataType, nullable = true)).asc + } else { + BoundReference(0, dataType, nullable = true).asc + } + case Descending => + if (dataType.isInstanceOf[MapType]) { + OrderMaps(BoundReference(0, dataType, nullable = true)).desc + } else { + BoundReference(0, dataType, nullable = true).desc + } + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } - val kryo = new KryoSerializer(new SparkConf).newInstance() - val intOrdering = new InterpretedOrdering(sortOrder :: Nil) - val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil) - val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering)) - val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering)) - - Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering => - assert(ordering.compare(rowA, rowA) === 0) - assert(ordering.compare(rowB, rowB) === 0) - assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) - assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) - } + val kryo = new KryoSerializer(new SparkConf).newInstance() + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil) + val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering)) + val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering)) + + val orderings = if (dataType.isInstanceOf[MapType]) { + Seq(genOrdering, kryoGenOrdering) + } else { + Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering) + } + orderings.foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) } } } + def compareArrays(a: Seq[Integer], b: Seq[Integer], expected: Int): Unit = { + test(s"compare two arrays: a = $a, b = $b, expected = $expected") { + compareDatum(a, b, ArrayType(IntegerType), expected) + } + } + + def compareMaps(a: Map[Integer, Integer], b: Map[Integer, Integer], expected: Int): Unit = { + test(s"compare two maps: a = $a, b = $b, expected = $expected") { + compareDatum(a, b, MapType(IntegerType, IntegerType), expected) + } + } + // Two arrays have the same size. - compareArrays(Seq[Any](), Seq[Any](), 0) - compareArrays(Seq[Any](1), Seq[Any](1), 0) - compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0) - compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1) + compareArrays(Seq[Integer](), Seq[Integer](), 0) + compareArrays(Seq[Integer](1), Seq[Integer](1), 0) + compareArrays(Seq[Integer](1, 2), Seq[Integer](1, 2), 0) + compareArrays(Seq[Integer](1, 2, 2), Seq[Integer](1, 2, 3), -1) // Two arrays have different sizes. - compareArrays(Seq[Any](), Seq[Any](1), -1) - compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1) - compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1) - compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1) + compareArrays(Seq[Integer](), Seq[Integer](1), -1) + compareArrays(Seq[Integer](1, 2, 3), Seq[Integer](1, 2, 3, 4), -1) + compareArrays(Seq[Integer](1, 2, 3), Seq[Integer](1, 2, 3, 2), -1) + compareArrays(Seq[Integer](1, 2, 3), Seq[Integer](1, 2, 2, 2), 1) // Arrays having nulls. - compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1) - compareArrays(Seq[Any](), Seq[Any](null), -1) - compareArrays(Seq[Any](null), Seq[Any](null), 0) - compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0) - compareArrays(Seq[Any](null), Seq[Any](null, null), -1) - compareArrays(Seq[Any](null), Seq[Any](1), -1) - compareArrays(Seq[Any](null), Seq[Any](null, 1), -1) - compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1) - compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) - compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + compareArrays(Seq[Integer](1, 2, 3), Seq[Integer](1, 2, 3, null), -1) + compareArrays(Seq[Integer](), Seq[Integer](null), -1) + compareArrays(Seq[Integer](null), Seq[Integer](null), 0) + compareArrays(Seq[Integer](null, null), Seq[Integer](null, null), 0) + compareArrays(Seq[Integer](null), Seq[Integer](null, null), -1) + compareArrays(Seq[Integer](null), Seq[Integer](1), -1) + compareArrays(Seq[Integer](null), Seq[Integer](null, 1), -1) + compareArrays(Seq[Integer](null, 1), Seq[Integer](1, 1), -1) + compareArrays(Seq[Integer](1, null, 1), Seq[Integer](1, null, 1), 0) + compareArrays(Seq[Integer](1, null, 1), Seq[Integer](1, null, 2), -1) + + + // Comparing maps. + compareMaps(null, Map((1, 2)), -1) + compareMaps(Map((1, 2)), Map((1, 2), (0, 4)), 1) + compareMaps(Map((1, 2)), Map((1, 2), (3, 4)), -1) + compareMaps(Map((1, 2), (3, 4)), Map((1, 2), (3, 5)), -1) + compareMaps(Map((1, 2), (3, 4)), Map((1, 2), (3, null)), 1) + compareMaps(Map((1, 2), (3, 4)), Map((1, 2), (4, 4)), -1) // Test GenerateOrdering for all common types. For each type, we construct random input rows that // contain two columns of that type, then for pairs of randomly-generated rows we check that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1620ab3aa2094..b033bb5bdf2be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.execution.command._ @@ -1955,7 +1956,7 @@ class Dataset[T] private[sql]( // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out // from the sort order. val sortOrder = logicalPlan.output - .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .filter(attr => TypeUtils.isOrderable(attr.dataType)) .map(SortOrder(_, Ascending)) val plan = if (sortOrder.nonEmpty) { Sort(sortOrder, global = false, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index c142d3b5ed4f2..8f1b212b1c0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -199,7 +199,7 @@ case class GenerateExec( case ArrayType(dataType, nullable) => ("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks))) - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, _) => // Materialize the key and the value arrays before we enter the loop. val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621399cea..b6c48782c53cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -161,7 +161,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -179,7 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 19b858faba6ea..f023b28e00415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -168,21 +169,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || - !RowOrdering.isOrderable(leftKeys) => + !TypeUtils.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || - !RowOrdering.isOrderable(leftKeys) => + !TypeUtils.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if RowOrdering.isOrderable(leftKeys) => + if TypeUtils.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..41bd1bd7c60d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -494,7 +494,7 @@ class SparkToParquetSchemaConverter( // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. - case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat => + case MapType(keyType, valueType, valueContainsNull, _) if writeLegacyParquetFormat => // group (MAP) { // repeated group map (MAP_KEY_VALUE) { // required key; @@ -525,7 +525,7 @@ class SparkToParquetSchemaConverter( .named("list")) .named(field.name) - case MapType(keyType, valueType, valueContainsNull) => + case MapType(keyType, valueType, valueContainsNull, _) => // group (MAP) { // repeated group key_value { // required key; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 60c430bcfece2..e90996f67cc3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,9 +22,10 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation @@ -296,7 +297,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi sparkSession.sessionState.conf.resolver) normalizedBucketSpec.sortColumnNames.map(schema(_)).map(_.dataType).foreach { - case dt if RowOrdering.isOrderable(dt) => // OK + case dt if TypeUtils.isOrderable(dt) => // OK case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba77..df008a5b05d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -39,7 +39,7 @@ object EvaluatePython { case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => needConversionInPython(keyType) || needConversionInPython(valueType) case _ => false } @@ -136,7 +136,7 @@ object EvaluatePython { case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => + case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _, _)) => ArrayBasedMapData( javaMap, (key: Any) => fromJava(key, keyType), diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..fe0525bc1d22a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -161,7 +161,7 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); - MapType mapType = new MapType(DataTypes.StringType, valueType, true); + MapType mapType = new MapType(DataTypes.StringType, valueType, true, false); Assert.assertEquals( new StructField("c", mapType, true, Metadata.empty()), schema.apply("c")); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b297..0d83f758a2d73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -666,4 +666,48 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + test("Do grouping on field of map type.") { + withTempView("aggV") { + val data = Seq[(Map[Integer, Integer], Integer, Integer)]( + (Map((3, 4)), 10, -10), + (Map((1, 2)), -1, null), + (Map((3, 4)), 1, 1), + (Map((5, 6)), null, 1), + (Map((1, null)), 100, -10), + (Map((7, 8)), null, null), + (Map((7, 8)), null, 1), + (Map((9, 99), (10, 100)), 10, -10), + (Map((10, 100), (9, 99)), 10, -10)).toDF("key", "value1", "value2") + .createTempView("aggV") + + checkAnswer( + spark.sql( + """ + |SELECT DISTINCT key + |FROM aggV + """.stripMargin), + Row(Map(3 -> 4)) :: + Row(Map(1 -> 2)) :: + Row(Map(5 -> 6)) :: + Row(Map(1 -> null)) :: + Row(Map(7 -> 8)) :: + Row(Map(9 -> 99, 10 -> 100)) :: Nil) + + checkAnswer( + spark.sql( + """ + |SELECT value1, key + |FROM aggV + |GROUP BY value1, key + """.stripMargin), + Row(10, Map(3 -> 4)) :: + Row(-1, Map(1 -> 2)) :: + Row(1, Map(3 -> 4)) :: + Row(null, Map(5 -> 6)) :: + Row(100, Map(1 -> null)) :: + Row(null, Map(7 -> 8)) :: + Row(10, Map((9, 99), (10, 100))) :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 1230b921aa279..5207f00c889ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -64,6 +64,123 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val ds100_5 = Seq(S100_5()).toDS() ds100_5.rdd.count } + + test("Order by fieled of map type.") { + withTempView("v") { + Seq[Map[Integer, Integer]]( + null, + Map((1, 2)), + Map((1, 2), (0, 4)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 5)), + Map(((1, 2)), (3, null)), + Map((1, 2), (4, 5))).toDF("a").createTempView("v") + checkAnswer( + spark.sql( + """ + |SELECT a + |FROM v + |ORDER BY a + """.stripMargin), + Row(null) :: + Row(Map(1 -> 2, 0 -> 4)) :: + Row(Map(1 -> 2)) :: + Row(Map(1 -> 2, 3 -> null)) :: + Row(Map(1 -> 2, 3 -> 4)) :: + Row(Map(1 -> 2, 3 -> 5)) :: + Row(Map(1 -> 2, 4 -> 5)) :: Nil) + } + } + + test("Binary comparison on fields of map type") { + withTempView("vx", "vy") { + val smallValues = Seq[Map[Integer, Integer]]( + Map((1, 2), (0, 4)), + Map((1, 2)), + Map((1, 2), (3, null)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 5))) + val bigValues = Seq[Map[Integer, Integer]]( + Map((1, 2)), + Map((1, 2), (3, null)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 5)), + Map((1, 2), (4, 5))) + val equalValues0 = Seq[Map[Integer, Integer]]( + Map((1, 2)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, null))) + val equalValues1 = Seq[Map[Integer, Integer]]( + Map((1, 2)), + Map((1, 2), (3, 4)), + Map((3, 4), (1, 2)), + Map((1, 2), (3, null))) + + assert(smallValues.length === bigValues.length) + smallValues.zip(bigValues).toDF("a", "b").createTempView("vx") + checkAnswer( + spark.sql( + """ + |SELECT a < b, b < a, a > b, b > a + |FROM vx + """.stripMargin), + Array.fill(smallValues.length)(Row(true, false, false, true))) + equalValues0.zip(equalValues1).toDF("a", "b").createTempView("vy") + assert(equalValues0.length === equalValues1.length) + checkAnswer( + spark.sql( + """ + |SELECT a = b, a != b + |FROM vy + """.stripMargin), + Array.fill(equalValues0.length)(Row(true, false))) + } + } + + test("Run set operations with map type.") { + withTempView("vx, vy") { + Seq[Map[Integer, Integer]]( + Map((1, 2)), + Map((1, 2), (3, 4)), + Map((1, 2), (33, 44)), + Map((1, 2), (3, null)), + Map((5, 6), (7, 8))).toDF("a").createTempView("vx") + Seq[Map[Integer, Integer]]( + Map((1, 2)), + Map((1, 2), (3, 4)), + Map((33, 44), (1, 2)), + Map((1, 2), (3, null))).toDF("a").createTempView("vy") + + checkAnswer( + spark.sql( + """ + |SELECT a + |FROM vx + |INTERSECT + |SELECT a + |FROM vy + """.stripMargin + ), + Row(Map(1 -> 2)) :: + Row(Map(1 -> 2, 3 -> 4)) :: + Row(Map(1 -> 2, 33 -> 44)) :: + Row(Map(1 -> 2, 3 -> null)) :: Nil + ) + checkAnswer( + spark.sql( + """ + |SELECT a + |FROM vx + |EXCEPT + |SELECT a + |FROM vy + """.stripMargin + ), + Row(Map(5 -> 6, 7 -> 8)) :: Nil + ) + } + } } class S100( @@ -97,5 +214,3 @@ extends DefinedByConstructorParams case class S100_5( s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(), s4: S100 = new S100(), s5: S100 = new S100()) - - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..17369aa4afd5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -320,9 +320,41 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) ) - val df3 = Seq(("xxx", "x")).toDF("a", "b") + Array[Map[Integer, Integer]](null, Map[Integer, Integer]((1, 2))) + val df3 = Seq((Array[Map[Integer, Integer]]( + null, + Map((1, 2)), + Map((1, 2), (0, 4)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 5)), + Map((1, 2), (3, null)), + Map((1, 2), (4, 5))), + "x")).toDF("a", "b") + + checkAnswer( + df3.selectExpr("sort_array(a, true)", "sort_array(a, false)"), + Seq(Row( + Seq[Map[Integer, Integer]]( + null, + Map((1, 2), (0, 4)), + Map((1, 2)), + Map((1, 2), (3, null)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, 5)), + Map((1, 2), (4, 5))), + Seq[Map[Integer, Integer]]( + Map((1, 2), (4, 5)), + Map((1, 2), (3, 5)), + Map((1, 2), (3, 4)), + Map((1, 2), (3, null)), + Map((1, 2)), + Map((1, 2), (0, 4)), + null) + ))) + + val df4 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { - df3.selectExpr("sort_array(a)").collect() + df4.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 644e72c893ceb..6535e8d4a4bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2068,25 +2068,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(BigDecimal(0)) :: Nil) } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } - } - test("SPARK-20359: catalyst outer join optimization should not throw npe") { val df1 = Seq("a", "b", "c").toDF("x") .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 226cc3028b135..b3f2362db010f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -208,6 +208,33 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("inner join ON, key of map type") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val data1 = spark.sparkContext.parallelize( + Map(1 -> "a", 2 -> "b") :: + Map(3 -> "c", 4 -> "d") :: + Map(5 -> "e", 6 -> "d") :: + Map(7 -> "e") :: + Map(8 -> "f") :: + Map(9 -> "g") :: + Nil).toDF("X").as("A") + val data2 = spark.sparkContext.parallelize( + Map(1 -> "a", 2 -> "b") :: + Map(4 -> "d", 3 -> "c") :: + Map(5 -> "e") :: + Map(7 -> "ee") :: + Map(88 -> "f") :: + Map(9 -> null) :: + Nil).toDF("Y").as("B") + checkAnswer( + data1.join(data2, $"A.X" === $"B.Y"), + Seq( + Row(Map(1 -> "a", 2 -> "b"), Map(1 -> "a", 2 -> "b")), + Row(Map(3 -> "c", 4 -> "d"), Map(4 -> "d", 3 -> "c")) + )) + } + } + test("big inner join, 4 matches per row") { val bigData = testData.union(testData).union(testData).union(testData) val bigDataX = bigData.as("x") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4dec2f71b8a50..746e54970fa10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -765,7 +765,7 @@ private[hive] trait HiveInspectors { def toInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector @@ -832,7 +832,7 @@ private[hive] trait HiveInspectors { list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } - case Literal(value, MapType(keyType, valueType, _)) => + case Literal(value, MapType(keyType, valueType, _, _)) => val keyOI = toInspector(keyType) val valueOI = toInspector(valueType) if (value == null) { @@ -1032,7 +1032,7 @@ private[hive] trait HiveInspectors { getStructTypeInfo( java.util.Arrays.asList(fields.map(_.name) : _*), java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index f5e6720f6a510..dd08383cfdfae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -446,7 +446,7 @@ private[spark] object HiveUtils extends Logging { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -469,7 +469,7 @@ private[spark] object HiveUtils extends Logging { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index c300660458fdd..1f67f4504df79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -107,7 +107,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) - case MapType(keyType, valueType, _) => + case MapType(keyType, valueType, _, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toWritableInspector(keyType), toWritableInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.writableStringObjectInspector