Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
SortMaps::
ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Expand Down Expand Up @@ -2479,6 +2480,60 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
}
}

/**
* MapType expressions are not comparable.
*/
object SortMaps extends Rule[LogicalPlan] {
private def containsUnorderedMap(e: Expression): Boolean =
e.resolved && MapType.containsUnorderedMap(e.dataType)

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case expr: Expression with OrderSpecified =>
expr.mapChildren {
case child if containsUnorderedMap(child) => OrderMaps(child)
case child => child
}
} transform {
case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) =>
// Modify the top level grouping expressions
val replacements = a.groupingExpressions.collect {
case a: Attribute if containsUnorderedMap(a) =>
a -> Alias(OrderMaps(a), a.name)()
case e if containsUnorderedMap(e) =>
e -> OrderMaps(e)
}
a.transformExpressionsUp {
case e =>
replacements
.find(_._1.semanticEquals(e))
.map(_._2)
.getOrElse(e)
}
case distinct: Distinct =>
wrapOrderMaps(distinct)

case setOperation: SetOperation =>
wrapOrderMaps(setOperation)
}

private[this] def wrapOrderMaps(logicalPlan: LogicalPlan) = {
logicalPlan.mapChildren(child => {
if (child.resolved && child.output.exists(containsUnorderedMap)) {
val projectList = child.output.map { a =>
if (containsUnorderedMap(a)) {
Alias(OrderMaps(a), a.name)()
} else {
a
}
}
Project(projectList, child)
} else {
child
}
})
}
}

/**
* The aggregate expressions from subquery referencing outer query block are pushed
* down to the outer query block for evaluation. This rule below updates such outer references
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -202,7 +203,7 @@ trait CheckAnalysis extends PredicateHelper {
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
if (!TypeUtils.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.sql} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not an orderable " +
Expand All @@ -223,7 +224,7 @@ trait CheckAnalysis extends PredicateHelper {

case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
if (!TypeUtils.isOrderable(order.dataType)) {
failAnalysis(
s"sorting is not supported for columns of type ${order.dataType.simpleString}")
}
Expand Down Expand Up @@ -322,14 +323,6 @@ trait CheckAnalysis extends PredicateHelper {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
""".stripMargin)

// TODO: although map type is not orderable, technically map type should be able to be
// used in equality comparison, remove this type check once we support it.
case o if mapColumnInSetOperation(o).isDefined =>
val mapCol = mapColumnInSetOperation(o).get
failAnalysis("Cannot have map type columns in DataFrame which calls " +
s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " +
"is " + mapCol.dataType.simpleString)

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object RowEncoder {
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
case t @ MapType(kt, vt, valueNullable, _) =>
val keys =
Invoke(
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
Expand Down Expand Up @@ -307,7 +307,7 @@ object RowEncoder {
arrayData :: Nil,
returnNullable = false)

case MapType(kt, vt, valueNullable) =>
case MapType(kt, vt, valueNullable, _) =>
val keyArrayType = ArrayType(kt, false)
val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object Cast {
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
case (MapType(fromKey, fromValue, fn, _), MapType(toKey, toValue, tn, false)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
Expand Down Expand Up @@ -103,7 +103,7 @@ object Cast {
case (TimestampType, StringType) => true
case (TimestampType, DateType) => true
case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType)
case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
case (MapType(fromKey, fromValue, _, _), MapType(toKey, toValue, _, _)) =>
needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,9 @@ abstract class TernaryExpression extends Expression {
* and Hive function wrappers.
*/
trait UserDefinedExpression

/**
* An expression is marked as `OrderSpecified` if ordering judged from data
* type is used when calculate.
*/
trait OrderSpecified
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._

Expand Down Expand Up @@ -62,13 +63,13 @@ case class SortOrder(
direction: SortDirection,
nullOrdering: NullOrdering,
sameOrderExpressions: Set[Expression])
extends UnaryExpression with Unevaluable {
extends UnaryExpression with Unevaluable with OrderSpecified {

/** Sort order is not foldable because we don't have an eval for it. */
override def foldable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (RowOrdering.isOrderable(dataType)) {
if (TypeUtils.isOrderable(dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.simpleString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the maximum value of `expr`.")
case class Max(child: Expression) extends DeclarativeAggregate {
case class Max(child: Expression) extends DeclarativeAggregate with OrderSpecified {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the minimum value of `expr`.")
case class Min(child: Expression) extends DeclarativeAggregate {
case class Min(child: Expression) extends DeclarativeAggregate with OrderSpecified {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
> SELECT _FUNC_(10, 9, 2, 4, 3);
2
""")
case class Least(children: Seq[Expression]) extends Expression {
case class Least(children: Seq[Expression]) extends Expression with OrderSpecified {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand Down Expand Up @@ -633,7 +633,7 @@ case class Least(children: Seq[Expression]) extends Expression {
> SELECT _FUNC_(10, 9, 2, 4, 3);
10
""")
case class Greatest(children: Seq[Expression]) extends Expression {
case class Greatest(children: Seq[Expression]) extends Expression with OrderSpecified {

override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ class CodegenContext {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
Expand Down Expand Up @@ -687,6 +688,49 @@ class CodegenContext {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"

case mapType @ MapType(keyType, valueType, valueContainsNull, true) =>
val compareFunc = freshName("compareMap")
val funcCode: String =
s"""
public int $compareFunc(MapData a, MapData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
ArrayData aKeys = a.keyArray();
ArrayData aValues = a.valueArray();
ArrayData bKeys = b.keyArray();
ArrayData bValues = b.valueArray();
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < minLength; i++) {
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
int comp = ${genComp(keyType, "keyA", "keyB")};
if (comp != 0) {
return comp;
}
boolean isNullA = aValues.isNullAt(i);
boolean isNullB = bValues.isNullAt(i);
if (isNullA && isNullB) {
// Nothing
} else if (isNullA) {
return -1;
} else if (isNullB) {
return 1;
} else {
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
comp = ${genComp(valueType, "valueA", "valueB")};
if (comp != 0) {
return comp;
}
}
}
return lengthA - lengthB;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"

case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case MapType(keyType, valueType, _, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", "false", input)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case MapType(kt, vt, _, _) if canSupport(kt) && canSupport(vt) => true
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
Expand Down Expand Up @@ -133,7 +133,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
Expand Down Expand Up @@ -216,7 +216,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case m @ MapType(kt, vt, _) =>
case m @ MapType(kt, vt, _, _) =>
s"""
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
Expand Down
Loading