From 16e8c655d868d1606bff8f26cc24d1325c273600 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 25 Aug 2016 11:24:40 +0200 Subject: [PATCH 1/2] Backport of SPARK-17061 & SPARK-17093 --- .../sql/catalyst/expressions/objects/objects.scala | 12 +++++++++++- .../catalyst/expressions/ExpressionEvalHelper.scala | 2 +- .../catalyst/expressions/ObjectExpressionSuite.scala | 8 ++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 37ec1a63394c..1cdda53708bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -481,6 +481,16 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" + val genFunctionValue = lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val loopNullCheck = inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. @@ -508,7 +518,7 @@ case class MapObjects private( if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - $convertedArray[$loopIndex] = ${genFunction.value}; + $convertedArray[$loopIndex] = $genFunctionValue; } $loopIndex += 1; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index d6a9672d1f18..668543a28bd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -136,7 +136,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. val plan = generateProject( - GenerateUnsafeProjection.generate( + UnsafeProjection.create( Alias(expression, s"Optimized($expression)1")() :: Alias(expression, s"Optimized($expression)2")() :: Nil), expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala new file mode 100644 index 000000000000..5769e0169f56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala @@ -0,0 +1,8 @@ +package org.apache.spark.sql.catalyst.expressions + +/** + * Created by hvanhovell on 8/25/16. + */ +class ObjectExpressionSuite { + +} From 0f245d08abb065efded174ebb2eed38a775b8a08 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 25 Aug 2016 12:16:10 +0200 Subject: [PATCH 2/2] Add test --- .../expressions/ObjectExpressionSuite.scala | 56 +++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala index 5769e0169f56..b6263e77c141 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala @@ -1,8 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.catalyst.expressions -/** - * Created by hvanhovell on 8/25/16. - */ -class ObjectExpressionSuite { +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} + +class ObjectExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("MapObjects should make copies of unsafe-backed data") { + // test UnsafeRow-backed data + val structEncoder = ExpressionEncoder[Array[(java.lang.Integer, java.lang.Integer)]]() + val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) + val structExpected = new GenericArrayData( + Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) + checkEvalutionWithUnsafeProjection( + structEncoder.serializer.head, structExpected, structInputRow) + + // test UnsafeArray-backed data + val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]() + val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) + val arrayExpected = new GenericArrayData( + Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) + checkEvalutionWithUnsafeProjection( + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + // test UnsafeMap-backed data + val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]() + val mapInputRow = InternalRow.fromSeq(Seq(Array( + Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400)))) + val mapExpected = new GenericArrayData(Seq( + new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(100, 200))), + new ArrayBasedMapData( + new GenericArrayData(Array(3, 4)), + new GenericArrayData(Array(300, 400))))) + checkEvalutionWithUnsafeProjection( + mapEncoder.serializer.head, mapExpected, mapInputRow) + } }