From 6b7a9f0c2bde4cb67152c89b9abc5e61bda170ca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 May 2016 13:57:20 -0700 Subject: [PATCH 1/3] support null object in outer join --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++-- .../catalyst/encoders/ExpressionEncoder.scala | 33 ++++++++++-- .../sql/catalyst/plans/logical/object.scala | 17 ++----- .../org/apache/spark/sql/types/Metadata.scala | 2 + .../encoders/ExpressionEncoderSuite.scala | 2 +- .../EliminateSerializationSuite.scala | 4 +- .../scala/org/apache/spark/sql/Dataset.scala | 51 +++++++++++++++---- .../spark/sql/KeyValueGroupedDataset.scala | 8 ++- .../aggregate/TypedAggregateExpression.scala | 8 +-- .../sources/JavaDatasetAggregatorSuite.java | 24 ++++++--- .../org/apache/spark/sql/DatasetSuite.scala | 16 ++++++ .../execution/WholeStageCodegenSuite.scala | 2 +- 12 files changed, 131 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bf221e0d7cfc4..ba6c6b1bacd7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance @@ -1887,7 +1887,12 @@ class Analyzer( val unbound = deserializer transform { case b: BoundReference => inputs(b.ordinal) } - resolveExpression(unbound, LocalRelation(inputs), throws = true) + + val resolved = resolveExpression(unbound, LocalRelation(inputs), throws = true) + + inputs.find(ExpressionEncoder.isNullFlagColumn).map { a => + If(Or(IsNull(a), a), Literal.create(null, resolved.dataType), resolved) + }.getOrElse(resolved) } } } @@ -2129,4 +2134,3 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index f21a39a2d4730..a28bc0a5270f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, ObjectType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -107,7 +107,10 @@ object ExpressionEncoder { val serializer = encoders.map { case e if e.flat => e.serializer.head - case other => CreateStruct(other.serializer) + case other => + val inputObject = other.serializer.head.find(_.isInstanceOf[BoundReference]).get + val struct = CreateStruct(other.serializer) + If(IsNull(inputObject), Literal.create(null, struct.dataType), struct) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t, _) => @@ -125,12 +128,13 @@ object ExpressionEncoder { } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.deserializer.transformUp { + val deserializer = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } + If(IsNull(input), Literal.create(null, deserializer.dataType), deserializer) } } @@ -170,6 +174,17 @@ object ExpressionEncoder { e4: ExpressionEncoder[T4], e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + + private val nullFlagName = "is_null_obj" + private val nullFlagMeta = new MetadataBuilder().putNull(nullFlagName).build() + + def nullFlagColumn(inputObject: Expression): NamedExpression = { + Alias(IsNull(inputObject), nullFlagName)(explicitMetadata = Some(nullFlagMeta)) + } + + def isNullFlagColumn(f: StructField): Boolean = f.metadata.contains(nullFlagName) + + def isNullFlagColumn(a: Attribute): Boolean = a.metadata.contains(nullFlagName) } /** @@ -209,16 +224,24 @@ case class ExpressionEncoder[T]( resolve(attrs, OuterScopes.outerScopes).bind(attrs) } - /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. */ - def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { + def namedSerializer: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { case (_, ne: NamedExpression) => ne.newInstance() case (name, e) => Alias(e, name)() } + def serializerWithNullFlag: Seq[NamedExpression] = { + if (flat) { + namedSerializer + } else { + val inputObject = serializer.head.find(_.isInstanceOf[BoundReference]).get + namedSerializer :+ ExpressionEncoder.nullFlagColumn(inputObject) + } + } + /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 98ce5dd2efd91..8192b6d671cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,17 +30,8 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } - def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { - val deserializer = UnresolvedDeserializer(encoder.deserializer) - DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) - } - def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { - SerializeFromObject(encoderFor[T].namedExpressions, child) - } - - def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { - SerializeFromObject(encoder.namedExpressions, child) + SerializeFromObject(encoderFor[T].serializerWithNullFlag, child) } def generateObjAttr[T : Encoder]: Attribute = { @@ -128,7 +119,7 @@ object MapPartitionsInR { schema: StructType, encoder: ExpressionEncoder[Row], child: LogicalPlan): LogicalPlan = { - val deserialized = CatalystSerde.deserialize(child, encoder) + val deserialized = CatalystSerde.deserialize(child)(encoder) val mapped = MapPartitionsInR( func, packageNames, @@ -137,7 +128,7 @@ object MapPartitionsInR { schema, CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), deserialized) - CatalystSerde.serialize(mapped, RowEncoder(schema)) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) } } @@ -185,7 +176,7 @@ object AppendColumns { new AppendColumns( func.asInstanceOf[Any => Any], UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, + encoderFor[U].serializerWithNullFlag, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 1fb2e2404cc42..89e0e921bf0d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -185,6 +185,7 @@ object Metadata { JString(x) case x: Metadata => toJsonValue(x.map) + case null => JNull case other => throw new RuntimeException(s"Do not support type ${other.getClass}.") } @@ -208,6 +209,7 @@ object Metadata { x.## case x: Metadata => hash(x.map) + case null => 0 case other => throw new RuntimeException(s"Do not support type ${other.getClass}.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 3d97113b52e39..8d98ed638075e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -354,7 +354,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val inputPlan = LocalRelation(attr) val plan = Project(Alias(encoder.deserializer, "obj")() :: Nil, - Project(encoder.namedExpressions, + Project(encoder.serializerWithNullFlag, inputPlan)) assertAnalysisSuccess(plan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 3c033ddc374cf..c2eed80410c99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -62,8 +62,8 @@ class EliminateSerializationSuite extends PlanTest { val expected = AppendColumnsWithObject( func.asInstanceOf[Any => Any], - productEncoder[(Int, Int)].namedExpressions, - intEncoder.namedExpressions, + productEncoder[(Int, Int)].serializerWithNullFlag, + intEncoder.serializerWithNullFlag, input).analyze comparePlans(optimized, expected) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 85f0cf8a60415..e98fd81223dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -198,7 +198,7 @@ class Dataset[T] private[sql]( * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) - unresolvedTEncoder.validate(logicalPlan.output) + unresolvedTEncoder.validate(logicalPlan.output.filterNot(ExpressionEncoder.isNullFlagColumn)) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = @@ -391,7 +391,8 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def schema: StructType = queryExecution.analyzed.schema + def schema: StructType = + StructType(queryExecution.analyzed.schema.filterNot(ExpressionEncoder.isNullFlagColumn)) /** * Prints the schema to the console in a nice tree format. @@ -753,16 +754,41 @@ class Dataset[T] private[sql]( val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType = JoinType(joinType), Some(condition.expr))) - val leftOutput = joined.analyzed.output.take(left.output.length) - val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val (leftNullColumn, leftOutput) = joined.analyzed.output.take(left.output.length) + .partition(ExpressionEncoder.isNullFlagColumn) + val (rightNullColumn, rightOutput) = joined.analyzed.output.takeRight(right.output.length) + .partition(ExpressionEncoder.isNullFlagColumn) val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(leftOutput.head, "_1")() - case _ => Alias(CreateStruct(leftOutput), "_1")() + case e if e.flat => + assert(leftNullColumn.isEmpty) + assert(leftOutput.length == 1) + Alias(leftOutput.head, "_1")() + case _ => + if (leftNullColumn.isEmpty) { + Alias(CreateStruct(leftOutput), "_1")() + } else { + assert(leftNullColumn.length == 1) + val struct = CreateStruct(leftOutput) + val isObjectNull = Or(IsNull(leftNullColumn.head), leftNullColumn.head) + Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "_1")() + } + } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(rightOutput.head, "_2")() - case _ => Alias(CreateStruct(rightOutput), "_2")() + case e if e.flat => + assert(rightNullColumn.isEmpty) + assert(rightOutput.length == 1) + Alias(rightOutput.head, "_2")() + case _ => + if (rightNullColumn.isEmpty) { + Alias(CreateStruct(rightOutput), "_2")() + } else { + assert(rightNullColumn.length == 1) + val struct = CreateStruct(rightOutput) + val isObjectNull = Or(IsNull(rightNullColumn.head), rightNullColumn.head) + Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "_2")() + } } implicit val tuple2Encoder: Encoder[(T, U)] = @@ -1910,7 +1936,14 @@ class Dataset[T] private[sql]( val function = Literal.create(func, ObjectType(classOf[T => Boolean])) val condition = Invoke(function, "apply", BooleanType, deserialized.output) val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + + val serializer = if (logicalPlan.output.exists(ExpressionEncoder.isNullFlagColumn)) { + unresolvedTEncoder.serializerWithNullFlag + } else { + unresolvedTEncoder.namedSerializer + } + + withTypedPlan(SerializeFromObject(serializer, filter)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a5ea19b8ad14..219f8d8232065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -214,7 +214,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( assert(groupingAttributes.length == 1) groupingAttributes.head } else { - Alias(CreateStruct(groupingAttributes), "key")() + val (nullColumn, groupAttr) = groupingAttributes.partition(ExpressionEncoder.isNullFlagColumn) + assert(nullColumn.length == 1) + val struct = CreateStruct(groupAttr) + val isObjectNull = Or(IsNull(nullColumn.head), nullColumn.head) + Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 8f94184764c0f..6e05ac4d693a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -32,10 +32,10 @@ object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] - val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = bufferEncoder.deserializer.transform { - case b: BoundReference => bufferSerializer(b.ordinal).toAttribute - } + val bufferSerializer = bufferEncoder.serializerWithNullFlag + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index f9842e130b5d0..fc048b65c1b31 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -18,6 +18,8 @@ package test.org.apache.spark.sql.sources; import java.util.Arrays; +import java.util.List; +import java.util.HashSet; import scala.Tuple2; @@ -36,16 +38,26 @@ * Suite for testing the aggregate functionality of Datasets in Java. */ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { + private void checkResult(List expected, List actual) { + HashSet s1 = new HashSet(); + s1.addAll(expected); + HashSet s2 = new HashSet(); + s2.addAll(actual); + + Assert.assertEquals(s1, s2); + } + @Test public void testTypedAggregationAnonClass() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); Dataset> agged = grouped.agg(new IntSumOf().toColumn()); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + checkResult(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); Dataset> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); - Assert.assertEquals( + + checkResult( Arrays.asList( new Tuple2<>("a", 3), new Tuple2<>("b", 3)), @@ -93,7 +105,7 @@ public Double call(Tuple2 value) throws Exception { return (double)(value._2() * 2); } })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + checkResult(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); } @Test @@ -105,7 +117,7 @@ public Object call(Tuple2 value) throws Exception { return value; } })); - Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + checkResult(Arrays.asList(tuple2("a", 2L), tuple2("b", 1L)), agged.collectAsList()); } @Test @@ -117,7 +129,7 @@ public Double call(Tuple2 value) throws Exception { return (double)value._2(); } })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + checkResult(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); } @Test @@ -129,6 +141,6 @@ public Long call(Tuple2 value) throws Exception { return (long)value._2(); } })); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + checkResult(Arrays.asList(tuple2("a", 3L), tuple2("b", 3L)), agged.collectAsList()); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 32320a6435acb..9cb96285c9c0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -763,6 +763,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkShowString(ds, expected) } + + ignore("SPARK-15140: encoder should support null input object") { + val ds = Seq(1 -> "a", null).toDS() + val result = ds.collect() + assert(result.length == 2) + assert(result(0) == 1 -> "a") + assert(result(1) == null) + } + + test("SPARK-15441: Dataset outer join") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + joined.explain(true) + joined.show() + } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f86955e5a5bc4..58e65806dff97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -111,6 +111,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[TungstenAggregate]).isDefined) - assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + assert(ds.collect().toSet === Set(("a", 10.0), ("b", 3.0), ("c", 1.0))) } } From bd76b18b4f672c10d56d053c8ff17ef42560d60d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 May 2016 12:14:11 -0700 Subject: [PATCH 2/3] improve --- .../sql/catalyst/analysis/Analyzer.scala | 5 +- .../catalyst/encoders/ExpressionEncoder.scala | 35 ++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 83 ++++++++----------- .../spark/sql/KeyValueGroupedDataset.scala | 7 +- .../org/apache/spark/sql/SparkSession.scala | 10 ++- .../org/apache/spark/sql/DatasetSuite.scala | 12 +-- 6 files changed, 77 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba6c6b1bacd7e..068ca2b65bd8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1889,10 +1889,7 @@ class Analyzer( } val resolved = resolveExpression(unbound, LocalRelation(inputs), throws = true) - - inputs.find(ExpressionEncoder.isNullFlagColumn).map { a => - If(Or(IsNull(a), a), Literal.create(null, resolved.dataType), resolved) - }.getOrElse(resolved) + ExpressionEncoder.checkNullFlag(inputs, resolved) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index a28bc0a5270f9..51390b82f4e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{MetadataBuilder, ObjectType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -176,15 +176,31 @@ object ExpressionEncoder { tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] private val nullFlagName = "is_null_obj" - private val nullFlagMeta = new MetadataBuilder().putNull(nullFlagName).build() + private val nullFlagMeta = new MetadataBuilder().putBoolean(nullFlagName, true).build() def nullFlagColumn(inputObject: Expression): NamedExpression = { Alias(IsNull(inputObject), nullFlagName)(explicitMetadata = Some(nullFlagMeta)) } - def isNullFlagColumn(f: StructField): Boolean = f.metadata.contains(nullFlagName) + def isNullFlagColumn(f: StructField): Boolean = { + f.dataType == BooleanType && f.metadata.contains(nullFlagName) + } + + def isNullFlagColumn(a: Attribute): Boolean = { + a.dataType == BooleanType && a.metadata.contains(nullFlagName) + } - def isNullFlagColumn(a: Attribute): Boolean = a.metadata.contains(nullFlagName) + def checkNullFlag(input: Seq[Attribute], output: Expression): Expression = { + val nullFlag = input.filter(isNullFlagColumn) + if (nullFlag.isEmpty) { + output + } else if (nullFlag.length == 1) { + val objIsNull = Or(IsNull(nullFlag.head), nullFlag.head) + If(objIsNull, Literal.create(null, output.dataType), output) + } else { + throw new IllegalStateException("more than one null flag columns are found.") + } + } } /** @@ -209,6 +225,9 @@ case class ExpressionEncoder[T]( @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) + @transient + private lazy val serProjWithNullFlag = GenerateUnsafeProjection.generate(serializerWithNullFlag) + @transient private lazy val inputRow = new GenericMutableRow(1) @@ -256,6 +275,11 @@ case class ExpressionEncoder[T]( s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e) } + def toRowWithNullFlag(t: T): InternalRow = { + inputRow(0) = t + serProjWithNullFlag(inputRow) + } + /** * Returns an object of type `T`, extracting the required values from the provided row. Note that * you must `resolve` and `bind` an encoder to a specific schema before you can call this @@ -349,7 +373,8 @@ case class ExpressionEncoder[T]( LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) + val newDeserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head + copy(deserializer = ExpressionEncoder.checkNullFlag(schema, newDeserializer)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e98fd81223dbe..5f7492af53983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -172,7 +172,7 @@ class Dataset[T] private[sql]( this(sqlContext.sparkSession, logicalPlan, encoder) } - @transient private[sql] val logicalPlan: LogicalPlan = { + @transient private val withSideEffects: LogicalPlan = { def hasSideEffects(plan: LogicalPlan): Boolean = plan match { case _: Command | _: InsertIntoTable | @@ -192,13 +192,22 @@ class Dataset[T] private[sql]( } } + @transient private[sql] val logicalPlan: LogicalPlan = { + val output = withSideEffects.output + if (output.exists(ExpressionEncoder.isNullFlagColumn)) { + Project(output.filterNot(ExpressionEncoder.isNullFlagColumn), withSideEffects) + } else { + withSideEffects + } + } + /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) - unresolvedTEncoder.validate(logicalPlan.output.filterNot(ExpressionEncoder.isNullFlagColumn)) + unresolvedTEncoder.validate(logicalPlan.output) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = @@ -357,7 +366,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, withSideEffects) /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -391,8 +400,7 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def schema: StructType = - StructType(queryExecution.analyzed.schema.filterNot(ExpressionEncoder.isNullFlagColumn)) + def schema: StructType = logicalPlan.schema /** * Prints the schema to the console in a nice tree format. @@ -749,46 +757,27 @@ class Dataset[T] private[sql]( */ @Experimental def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan + val left = this.withSideEffects + val right = other.withSideEffects val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType = JoinType(joinType), Some(condition.expr))) - val (leftNullColumn, leftOutput) = joined.analyzed.output.take(left.output.length) - .partition(ExpressionEncoder.isNullFlagColumn) - val (rightNullColumn, rightOutput) = joined.analyzed.output.takeRight(right.output.length) - .partition(ExpressionEncoder.isNullFlagColumn) - - val leftData = this.unresolvedTEncoder match { - case e if e.flat => - assert(leftNullColumn.isEmpty) - assert(leftOutput.length == 1) - Alias(leftOutput.head, "_1")() - case _ => - if (leftNullColumn.isEmpty) { - Alias(CreateStruct(leftOutput), "_1")() - } else { - assert(leftNullColumn.length == 1) - val struct = CreateStruct(leftOutput) - val isObjectNull = Or(IsNull(leftNullColumn.head), leftNullColumn.head) - Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "_1")() - } + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = if (this.unresolvedTEncoder.flat) { + assert(leftOutput.length == 1) + Alias(leftOutput.head, "_1")() + } else { + val struct = CreateStruct(leftOutput.filterNot(ExpressionEncoder.isNullFlagColumn)) + Alias(ExpressionEncoder.checkNullFlag(leftOutput, struct), "_1")() } - val rightData = other.unresolvedTEncoder match { - case e if e.flat => - assert(rightNullColumn.isEmpty) - assert(rightOutput.length == 1) - Alias(rightOutput.head, "_2")() - case _ => - if (rightNullColumn.isEmpty) { - Alias(CreateStruct(rightOutput), "_2")() - } else { - assert(rightNullColumn.length == 1) - val struct = CreateStruct(rightOutput) - val isObjectNull = Or(IsNull(rightNullColumn.head), rightNullColumn.head) - Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "_2")() - } + val rightData = if (other.unresolvedTEncoder.flat) { + assert(rightOutput.length == 1) + Alias(rightOutput.head, "_2")() + } else { + val struct = CreateStruct(rightOutput.filterNot(ExpressionEncoder.isNullFlagColumn)) + Alias(ExpressionEncoder.checkNullFlag(rightOutput, struct), "_2")() } implicit val tuple2Encoder: Encoder[(T, U)] = @@ -925,7 +914,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + SubqueryAlias(alias, withSideEffects) } /** @@ -1281,7 +1270,7 @@ class Dataset[T] private[sql]( */ @Experimental def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val inputPlan = logicalPlan + val inputPlan = withSideEffects val withGroupingKey = AppendColumns(func, inputPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -1932,12 +1921,12 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: T => Boolean): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserialized = CatalystSerde.deserialize[T](withSideEffects) val function = Literal.create(func, ObjectType(classOf[T => Boolean])) val condition = Invoke(function, "apply", BooleanType, deserialized.output) val filter = Filter(condition, deserialized) - val serializer = if (logicalPlan.output.exists(ExpressionEncoder.isNullFlagColumn)) { + val serializer = if (withSideEffects.output.exists(ExpressionEncoder.isNullFlagColumn)) { unresolvedTEncoder.serializerWithNullFlag } else { unresolvedTEncoder.namedSerializer @@ -1973,7 +1962,7 @@ class Dataset[T] private[sql]( */ @Experimental def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + MapElements[T, U](func, withSideEffects) } /** @@ -1987,7 +1976,7 @@ class Dataset[T] private[sql]( @Experimental def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - withTypedPlan(MapElements[T, U](func, logicalPlan)) + withTypedPlan(MapElements[T, U](func, withSideEffects)) } /** @@ -2002,7 +1991,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sparkSession, - MapPartitions[T, U](func, logicalPlan), + MapPartitions[T, U](func, withSideEffects), implicitly[Encoder[U]]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 219f8d8232065..7d7bcc21582e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -214,11 +214,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( assert(groupingAttributes.length == 1) groupingAttributes.head } else { - val (nullColumn, groupAttr) = groupingAttributes.partition(ExpressionEncoder.isNullFlagColumn) - assert(nullColumn.length == 1) - val struct = CreateStruct(groupAttr) - val isObjectNull = Or(IsNull(nullColumn.head), nullColumn.head) - Alias(If(isObjectNull, Literal.create(null, struct.dataType), struct), "key")() + val struct = CreateStruct(groupingAttributes.filterNot(ExpressionEncoder.isNullFlagColumn)) + Alias(ExpressionEncoder.checkNullFlag(groupingAttributes, struct), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5dabe0e83c1cf..9457bbb7bf17e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -378,16 +378,18 @@ class SparkSession private( def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] - val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d).copy()) + val attributes = enc.serializerWithNullFlag.map(_.toAttribute) + val encoded = data.map(d => enc.toRowWithNullFlag(d).copy()) val plan = new LocalRelation(attributes, encoded) Dataset[T](self, plan) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { val enc = encoderFor[T] - val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d)) + val attributes = enc.serializerWithNullFlag.map(_.toAttribute) + val encoded = data.mapPartitions { it => + it.map(enc.toRowWithNullFlag) + } val plan = LogicalRDD(attributes, encoded)(self) Dataset[T](self, plan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9cb96285c9c0b..a0636b17e85d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -764,20 +764,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkShowString(ds, expected) } - ignore("SPARK-15140: encoder should support null input object") { - val ds = Seq(1 -> "a", null).toDS() - val result = ds.collect() - assert(result.length == 2) - assert(result(0) == 1 -> "a") - assert(result(1) == null) - } - test("SPARK-15441: Dataset outer join") { val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") val joined = left.joinWith(right, $"left.b" === $"right.b", "left") - joined.explain(true) - joined.show() + val result = joined.collect().toSet + assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) } } From d63912218a050bd87ab265bb8910f0e77b211244 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 31 May 2016 11:52:37 -0700 Subject: [PATCH 3/3] update --- .../spark/sql/catalyst/ScalaReflection.scala | 4 +-- .../catalyst/encoders/ExpressionEncoder.scala | 6 +---- .../scala/org/apache/spark/sql/Dataset.scala | 22 ++++++++-------- .../org/apache/spark/sql/DatasetSuite.scala | 25 +++++++------------ .../apache/spark/sql/test/SQLTestData.scala | 5 ++-- 5 files changed, 26 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bdd40f340235b..052cc486e8cb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -568,7 +568,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -577,7 +577,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 51390b82f4e8e..2c7ca13d7a397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -52,7 +52,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) val serializer = ScalaReflection.serializerFor[T](inputObject) val deserializer = ScalaReflection.deserializerFor[T] @@ -182,10 +182,6 @@ object ExpressionEncoder { Alias(IsNull(inputObject), nullFlagName)(explicitMetadata = Some(nullFlagMeta)) } - def isNullFlagColumn(f: StructField): Boolean = { - f.dataType == BooleanType && f.metadata.contains(nullFlagName) - } - def isNullFlagColumn(a: Attribute): Boolean = { a.dataType == BooleanType && a.metadata.contains(nullFlagName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c31d456f6b6f5..b25ea2aa15e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -209,13 +209,13 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + unresolvedTEncoder.resolve(withSideEffects.output, OuterScopes.outerScopes) /** * The encoder where the expressions used to construct an object from an input row have been * bound to the ordinals of this [[Dataset]]'s output schema. */ - private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + private[sql] val boundTEncoder = resolvedTEncoder.bind(withSideEffects.output) private implicit def classTag = unresolvedTEncoder.clsTag @@ -223,7 +223,7 @@ class Dataset[T] private[sql]( @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext private[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) + logicalPlan.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) .getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") @@ -232,7 +232,7 @@ class Dataset[T] private[sql]( private[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + logicalPlan.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get } } @@ -344,7 +344,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = Dataset.ofRows(sparkSession, logicalPlan) /** * :: Experimental :: @@ -899,7 +899,7 @@ class Dataset[T] private[sql]( */ def col(colName: String): Column = colName match { case "*" => - Column(ResolvedStar(queryExecution.analyzed.output)) + Column(ResolvedStar(logicalPlan.output)) case _ => val expr = resolve(colName) Column(expr) @@ -1655,7 +1655,7 @@ class Dataset[T] private[sql]( */ def withColumn(colName: String, col: Column): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val output = queryExecution.analyzed.output + val output = logicalPlan.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { val columns = output.map { field => @@ -1676,7 +1676,7 @@ class Dataset[T] private[sql]( */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val output = queryExecution.analyzed.output + val output = logicalPlan.output val shouldReplace = output.exists(f => resolver(f.name, colName)) if (shouldReplace) { val columns = output.map { field => @@ -1701,7 +1701,7 @@ class Dataset[T] private[sql]( */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val output = queryExecution.analyzed.output + val output = logicalPlan.output val shouldRename = output.exists(f => resolver(f.name, existingName)) if (shouldRename) { val columns = output.map { col => @@ -1759,11 +1759,11 @@ class Dataset[T] private[sql]( def drop(col: Column): DataFrame = { val expression = col match { case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted( + logicalPlan.resolveQuoted( u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u) case Column(expr: Expression) => expr } - val attrs = this.logicalPlan.output + val attrs = logicalPlan.output val colsAfterDrop = attrs.filter { attr => attr != expression }.map(attr => Column(attr)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d4468fe6a102a..9b7d37df8caa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1), (2, 2)) } - test("joinWith, expression condition, outer join") { - val nullInteger = null.asInstanceOf[Integer] - val nullString = null.asInstanceOf[String] - val ds1 = Seq(ClassNullableData("a", 1), - ClassNullableData("c", 3)).toDS() - val ds2 = Seq(("a", new Integer(1)), - ("b", new Integer(2))).toDS() - - checkDataset( - ds1.joinWith(ds2, $"_1" === $"a", "outer"), - (ClassNullableData("a", 1), ("a", new Integer(1))), - (ClassNullableData("c", 3), (nullString, nullInteger)), - (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) - } - test("joinWith tuple with primitive, expression") { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() @@ -506,7 +491,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.as[ClassData2] } - assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage) + assert(e.getMessage.contains("cannot resolve '`c`'")) } test("runtime nullability check") { @@ -784,6 +769,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-15140: encoder should support null object") { + val ds = Seq(1 -> "a", null).toDS() + val result = ds.collect() + assert(result.length == 2) + assert(result(0) == 1 -> "a") + assert(result(1) == null) + } + test("SPARK-15441: Dataset outer join") { val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52152..eca6a73d8adcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -220,9 +220,10 @@ private[sql] trait SQLTestData { self => } protected lazy val person: DataFrame = { - val df = spark.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() + Person(1, "jim", 20) :: Nil) + val df = spark.createDataFrame(rdd) df.createOrReplaceTempView("person") df }