From 12b591ed18d5e4cbbeac142dc5b3bbe10ee05fa5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 26 Jun 2016 13:02:49 +0900 Subject: [PATCH 01/39] remove unboxing operations when an array is primitive type array --- .../expressions/complexTypeCreator.scala | 57 ++++++++++++++----- .../spark/sql/DataFrameComplexTypeSuite.scala | 17 ++++++ 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 599fb638db32a..e3c7b69c88f80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -58,26 +58,53 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") - - ev.copy(code = s""" - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" + val dt = dataType match { + case a @ ArrayType(et, _) => et + } + val isPrimitive = ctx.isPrimitiveType(dt) + val evals = children.map(e => e.genCode(ctx)) + val allNonNull = evals.find(_.isNull != "false").isEmpty + if (isPrimitive && allNonNull) { + val javaDataType = ctx.javaType(dt) + ctx.addMutableState(s"${javaDataType}[]", values, + s"this.$values = new ${javaDataType}[${children.size}];") + + ev.copy(code = + ctx.splitExpressions( + ctx.INPUT_ROW, + evals.zipWithIndex.map { case (eval, i) => + eval.code + + s"\n$values[$i] = ${eval.value};" + }) + + s""" + /* final ArrayData ${ev.value} = $arrayClass.allocate($values); */ + final ArrayData ${ev.value} = new $arrayClass($values); + """, + isNull = "false") + } else { + ctx.addMutableState("Object[]", values, s"this.$values = null;") + + ev.copy(code = s""" + final boolean ${ev.isNull} = false; + this.$values = new Object[${children.size}];""" + + ctx.splitExpressions( + ctx.INPUT_ROW, + children.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; } - """ - }) + - s""" - final ArrayData ${ev.value} = new $arrayClass($values); - this.$values = null; - """, isNull = "false") + """ + }) + + s""" + /* final ArrayData ${ev.value} = $arrayClass.allocate($values); */ + final ArrayData ${ev.value} = new $arrayClass($values); + this.$values = null; + """) + } } override def prettyName: String = "array" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 1230b921aa279..256233d5fc856 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,6 +27,23 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("primitive type on array") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + val resDF = df.selectExpr("Array(v + 2, v + 3)") + checkAnswer(resDF, + Seq(Row(Array(3, 4)), Row(Array(4, 5)))) + } + + test("primitive array or null on array") { + val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + val resDF = df.selectExpr("Array(Array(v, v + 1, v + 2)," + + "null," + + "Array(v, v - 1, v - 2))") + QueryTest.checkAnswer(resDF, + Seq(Row(Array(Array(1, 2, 3), null, Array(1, 0, -1))), + Row(Array(Array(2, 3, 4), null, Array(2, 1, 0))))) + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") From 909d21030f1e82f706165f6a092944ca246c5fc2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 27 Jun 2016 12:38:59 +0900 Subject: [PATCH 02/39] addressed comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e3c7b69c88f80..00b2cae1c9dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -58,12 +58,10 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - val dt = dataType match { - case a @ ArrayType(et, _) => et - } + val ArrayType(dt, _) = dataType val isPrimitive = ctx.isPrimitiveType(dt) val evals = children.map(e => e.genCode(ctx)) - val allNonNull = evals.find(_.isNull != "false").isEmpty + val allNonNull = evals.forall(_.isNull == "false") if (isPrimitive && allNonNull) { val javaDataType = ctx.javaType(dt) ctx.addMutableState(s"${javaDataType}[]", values, From d481cb067719581fe173ffdf74dc28b7ddc51ed4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 27 Jun 2016 13:35:04 +0900 Subject: [PATCH 03/39] fix test failure --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 00b2cae1c9dc0..ec821c18d6d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -87,8 +87,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { this.$values = new Object[${children.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, - children.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) + evals.zipWithIndex.map { case (eval, i) => eval.code + s""" if (${eval.isNull}) { $values[$i] = null; From 66454f310d7240f39c48876bf814a72412783d3f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2016 04:04:56 +0900 Subject: [PATCH 04/39] revert miss operation of git --- .../expressions/complexTypeCreator.scala | 43 +++++++------- .../expressions/ComplexTypeSuite.scala | 56 ++++++++++--------- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ec821c18d6d0f..33bbc14e94cfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -56,32 +56,15 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val unsafeArrayClass = classOf[UnsafeArrayData].getName val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") + ctx.addMutableState("Object[]", values, s"this.$values = null;") + val ArrayType(dt, _) = dataType - val isPrimitive = ctx.isPrimitiveType(dt) val evals = children.map(e => e.genCode(ctx)) - val allNonNull = evals.forall(_.isNull == "false") - if (isPrimitive && allNonNull) { - val javaDataType = ctx.javaType(dt) - ctx.addMutableState(s"${javaDataType}[]", values, - s"this.$values = new ${javaDataType}[${children.size}];") - - ev.copy(code = - ctx.splitExpressions( - ctx.INPUT_ROW, - evals.zipWithIndex.map { case (eval, i) => - eval.code + - s"\n$values[$i] = ${eval.value};" - }) + - s""" - /* final ArrayData ${ev.value} = $arrayClass.allocate($values); */ - final ArrayData ${ev.value} = new $arrayClass($values); - """, - isNull = "false") - } else { - ctx.addMutableState("Object[]", values, s"this.$values = null;") - + val isPrimitiveArray = ctx.isPrimitiveType(dt) && evals.forall(_.isNull == "false") + if (!isPrimitiveArray) { ev.copy(code = s""" final boolean ${ev.isNull} = false; this.$values = new Object[${children.size}];""" + @@ -97,10 +80,24 @@ case class CreateArray(children: Seq[Expression]) extends Expression { """ }) + s""" - /* final ArrayData ${ev.value} = $arrayClass.allocate($values); */ final ArrayData ${ev.value} = new $arrayClass($values); this.$values = null; """) + } else { + val javaDataType = ctx.javaType(dt) + ctx.addMutableState(s"${javaDataType}[]", values, + s"this.$values = new ${javaDataType}[${children.size}];") + ev.copy(code = + ctx.splitExpressions( + ctx.INPUT_ROW, + evals.zipWithIndex.map { case (eval, i) => + eval.code + + s"\n$values[$i] = ${eval.value};" + }) + + s""" + final ArrayData ${ev.value} = $unsafeArrayClass.fromPrimitiveArray($values); + """, + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index c21c6de32c0ba..9e25f13905881 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -118,19 +118,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("CreateArray") { - val intSeq = Seq(5, 10, 15, 20, 25) - val longSeq = intSeq.map(_.toLong) - val strSeq = intSeq.map(_.toString) - checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) - checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) - checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) - - val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) - val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) - val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) - checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) - checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) - checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + // Array is required to pass Array(_, containsNull = false) as type information + val intArray = Array(5, 10, 15, 20, 25) + val longArray = intArray.map(_.toLong) + val strArray = intArray.map(_.toString) + checkEvaluation(CreateArray(intArray.map(Literal(_))), intArray, EmptyRow) + checkEvaluation(CreateArray(longArray.map(Literal(_))), longArray, EmptyRow) + checkEvaluation(CreateArray(strArray.map(Literal(_))), strArray, EmptyRow) + + val intWithNull = intArray.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longArray.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strArray.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intArray :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longArray :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strArray :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } @@ -144,32 +145,33 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { scala.collection.immutable.ListMap(keys.zip(values): _*) } - val intSeq = Seq(5, 10, 15, 20, 25) - val longSeq = intSeq.map(_.toLong) - val strSeq = intSeq.map(_.toString) + // Array is required to pass Array(_, containsNull = false) as type information + val intArray = Array(5, 10, 15, 20, 25) + val longArray = intArray.map(_.toLong) + val strArray = intArray.map(_.toString) checkEvaluation(CreateMap(Nil), Map.empty) checkEvaluation( - CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(intSeq, longSeq)) + CreateMap(interlace(intArray.map(Literal(_)), longArray.map(Literal(_)))), + createMap(intArray, longArray)) checkEvaluation( - CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), - createMap(strSeq, longSeq)) + CreateMap(interlace(strArray.map(Literal(_)), longArray.map(Literal(_)))), + createMap(strArray, longArray)) checkEvaluation( - CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), - createMap(longSeq, strSeq)) + CreateMap(interlace(longArray.map(Literal(_)), strArray.map(Literal(_)))), + createMap(longArray, strArray)) - val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) + val strWithNull = strArray.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) checkEvaluation( - CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), - createMap(intSeq, strWithNull.map(_.value))) + CreateMap(interlace(intArray.map(Literal(_)), strWithNull)), + createMap(intArray, strWithNull.map(_.value))) intercept[RuntimeException] { checkEvaluationWithoutCodegen( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + CreateMap(interlace(strWithNull, intArray.map(Literal(_)))), null, null) } intercept[RuntimeException] { checkEvalutionWithUnsafeProjection( - CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + CreateMap(interlace(strWithNull, intArray.map(Literal(_)))), null, null) } } From 03e0cfa1247f42b067e7729e2a5d3b231cf605a6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2016 04:05:18 +0900 Subject: [PATCH 05/39] addressed review comments --- .../expressions/ExpressionEvalHelper.scala | 51 ++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f83650424a964..83b8c072dd14b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -21,15 +21,14 @@ import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.MapData -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.util.Utils /** @@ -42,15 +41,55 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } + protected def convertToCatalystUnsafe(a: Any): Any = a match { + case arr: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Short] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Int] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Long] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Float] => UnsafeArrayData.fromPrimitiveArray(arr) + case arr: Array[Double] => UnsafeArrayData.fromPrimitiveArray(arr) + case other => CatalystTypeConverters.convertToCatalyst(other) + } + protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance val expr: Expression = serializer.deserialize(serializer.serialize(expression)) - val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + // No codegen version expects GenericArrayData + val catalystValue = expected match { + case arr: Array[Byte] if expression.dataType == BinaryType => arr + case arr: Array[_] => new GenericArrayData(arr.map(CatalystTypeConverters.convertToCatalyst)) + case _ => CatalystTypeConverters.convertToCatalyst(expected) + } + // Codegen version expects UnsafeArrayData for array expect Array(Binarytype) + val catalystValueUnsafe = expected match { + case arr: Array[Byte] if expression.dataType == BinaryType => arr + case _ => convertToCatalystUnsafe(expected) + } + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValueUnsafe, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvalutionWithUnsafeProjection(expr, catalystValueUnsafe, inputRow) + } + checkEvaluationWithOptimization(expr, catalystValue, inputRow) + } + + protected def checkEvaluationMap(expression: => Expression, expectedMap: Any, + expectedKey: Any, expectedValue: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + // No codegen version expects GenericArrayData for map + val catalystValue = CatalystTypeConverters.convertToCatalyst(expectedMap) + // Codegen version expects UnsafeArrayData for map + val catalystValueUnsafe = new ArrayBasedMapData( + convertToCatalystUnsafe(expectedKey).asInstanceOf[ArrayData], + convertToCatalystUnsafe(expectedValue).asInstanceOf[ArrayData]) + checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValueUnsafe, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvalutionWithUnsafeProjection(expr, catalystValueUnsafe, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } From 562073357058848793bf2ae6acf9176a45669b81 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2016 04:10:06 +0900 Subject: [PATCH 06/39] fix scala style error --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 83b8c072dd14b..3ed7e607a548e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -21,6 +21,7 @@ import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.GeneratorDrivenPropertyChecks + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} From 2906c74363edcfe0382c5cac20f00de295403484 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2016 11:17:27 +0900 Subject: [PATCH 07/39] support CreateMap --- .../expressions/complexTypeCreator.scala | 90 ++++++++++++------- .../expressions/CodeGenerationSuite.scala | 7 +- .../expressions/ComplexTypeSuite.scala | 16 ++-- 3 files changed, 73 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 33bbc14e94cfa..a8d914ab89314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -153,50 +153,80 @@ case class CreateMap(children: Seq[Expression]) extends Expression { new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } + private def getAccessors(ctx: CodegenContext, dt: DataType, array: String, + isPrimitive : Boolean, size: Int): (String, String, String) = { + if (!isPrimitive) { + val arrayClass = classOf[GenericArrayData].getName + ctx.addMutableState("Object[]", array, s"this.$array = null;") + (s"new $arrayClass($array)", + s"$array = new Object[${size}];", s"this.$array = null;") + } else { + val unsafeArrayClass = classOf[UnsafeArrayData].getName + val javaDataType = ctx.javaType(dt) + ctx.addMutableState(s"${javaDataType}[]", array, + s"this.$array = new ${javaDataType}[${size}];") + (s"$unsafeArrayClass.fromPrimitiveArray($array)", "", "") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;") - ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;") - val keyData = s"new $arrayClass($keyArray)" - val valueData = s"new $arrayClass($valueArray)" + val MapType(keyDt, valueDt, _) = dataType + val evalKeys = keys.map(e => e.genCode(ctx)) + val isPrimitiveArrayKey = ctx.isPrimitiveType(keyDt) + val isNonNullKey = evalKeys.forall(_.isNull == "false") + val evalValues = values.map(e => e.genCode(ctx)) + val isPrimitiveArrayValue = + ctx.isPrimitiveType(valueDt) && evalValues.forall(_.isNull == "false") + val (keyData, keyArrayAllocate, keyArrayNullify) = + getAccessors(ctx, keyDt, keyArray, isPrimitiveArrayKey, keys.size) + val (valueData, valueArrayAllocate, valueArrayNullify) = + getAccessors(ctx, valueDt, valueArray, isPrimitiveArrayValue, values.size) + ev.copy(code = s""" - $keyArray = new Object[${keys.size}]; - $valueArray = new Object[${values.size}];""" + + final boolean ${ev.isNull} = false; + $keyArrayAllocate + $valueArrayAllocate""" + ctx.splitExpressions( ctx.INPUT_ROW, - keys.zipWithIndex.map { case (key, i) => - val eval = key.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ + evalKeys.zipWithIndex.map { case (eval, i) => + eval.code + + (if (isNonNullKey) { + s"$keyArray[$i] = ${eval.value};" + } else { + s""" + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + }) }) + ctx.splitExpressions( ctx.INPUT_ROW, - values.zipWithIndex.map { case (value, i) => - val eval = value.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ + evalValues.zipWithIndex.map { case (eval, i) => + eval.code + + (if (isPrimitiveArrayValue) { + s"$valueArray[$i] = ${eval.value};" + } else { + s""" + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + }) }) + s""" final MapData ${ev.value} = new $mapClass($keyData, $valueData); - this.$keyArray = null; - this.$valueArray = null; - """, isNull = "false") + $keyArrayNullify + $valueArrayNullify; + """) } override def prettyName: String = "map" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index ee5d1f637374e..1cbb1b3a36059 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -118,7 +118,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(new GenericArrayData(Seq.fill(length)(true))) + val expected = Seq(UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))) if (!checkResult(actual, expected)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") @@ -133,7 +133,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { })) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { - case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) + case m: ArrayBasedMapData => + val keys = m.keyArray.asInstanceOf[UnsafeArrayData].toIntArray + val values = m.valueArray.asInstanceOf[UnsafeArrayData].toBooleanArray + keys.zip(values).toMap } val expected = (0 until length).map((_, true)).toMap :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9e25f13905881..e0d9756080fad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -150,20 +150,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val longArray = intArray.map(_.toLong) val strArray = intArray.map(_.toString) checkEvaluation(CreateMap(Nil), Map.empty) - checkEvaluation( + checkEvaluationMap( CreateMap(interlace(intArray.map(Literal(_)), longArray.map(Literal(_)))), - createMap(intArray, longArray)) - checkEvaluation( + createMap(intArray, longArray), intArray, longArray) + checkEvaluationMap( CreateMap(interlace(strArray.map(Literal(_)), longArray.map(Literal(_)))), - createMap(strArray, longArray)) - checkEvaluation( + createMap(strArray, longArray), strArray, longArray) + checkEvaluationMap( CreateMap(interlace(longArray.map(Literal(_)), strArray.map(Literal(_)))), - createMap(longArray, strArray)) + createMap(longArray, strArray), longArray, strArray) val strWithNull = strArray.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) - checkEvaluation( + checkEvaluationMap( CreateMap(interlace(intArray.map(Literal(_)), strWithNull)), - createMap(intArray, strWithNull.map(_.value))) + createMap(intArray, strWithNull.map(_.value)), intArray, strWithNull.map(_.value)) intercept[RuntimeException] { checkEvaluationWithoutCodegen( CreateMap(interlace(strWithNull, intArray.map(Literal(_)))), From d5b3a8ac806763998b8dbafc11ac4b3e88a65271 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 15 Nov 2016 21:35:10 +0900 Subject: [PATCH 08/39] addressed review comments --- .../spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e0d9756080fad..3d3947a26af2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -118,7 +118,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("CreateArray") { - // Array is required to pass Array(_, containsNull = false) as type information + // Array is required to pass Array[primitiveType] as type information for expected val intArray = Array(5, 10, 15, 20, 25) val longArray = intArray.map(_.toLong) val strArray = intArray.map(_.toString) @@ -145,7 +145,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { scala.collection.immutable.ListMap(keys.zip(values): _*) } - // Array is required to pass Array(_, containsNull = false) as type information + // Array is required to pass Array[primitiveType] as type information val intArray = Array(5, 10, 15, 20, 25) val longArray = intArray.map(_.toLong) val strArray = intArray.map(_.toString) From d29bb97e621c23aa6a6c8187f559ae6d8c1d7028 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 Nov 2016 22:43:54 +0900 Subject: [PATCH 09/39] addressed review comments --- .../catalyst/expressions/complexTypeCreator.scala | 6 +++--- .../catalyst/expressions/ExpressionEvalHelper.scala | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a8d914ab89314..1a40fc232f981 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val ArrayType(dt, _) = dataType val evals = children.map(e => e.genCode(ctx)) - val isPrimitiveArray = ctx.isPrimitiveType(dt) && evals.forall(_.isNull == "false") + val isPrimitiveArray = ctx.isPrimitiveType(dt) && children.forall(!_.nullable) if (!isPrimitiveArray) { ev.copy(code = s""" final boolean ${ev.isNull} = false; @@ -177,10 +177,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val isPrimitiveArrayKey = ctx.isPrimitiveType(keyDt) - val isNonNullKey = evalKeys.forall(_.isNull == "false") + val isNonNullKey = keys.forall(!_.nullable) val evalValues = values.map(e => e.genCode(ctx)) val isPrimitiveArrayValue = - ctx.isPrimitiveType(valueDt) && evalValues.forall(_.isNull == "false") + ctx.isPrimitiveType(valueDt) && values.forall(!_.nullable) val (keyData, keyArrayAllocate, keyArrayNullify) = getAccessors(ctx, keyDt, keyArray, isPrimitiveArrayKey, keys.size) val (valueData, valueArrayAllocate, valueArrayNullify) = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3ed7e607a548e..1a9b4668eb4d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -64,14 +64,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { case _ => CatalystTypeConverters.convertToCatalyst(expected) } // Codegen version expects UnsafeArrayData for array expect Array(Binarytype) - val catalystValueUnsafe = expected match { + val catalystValueForCodegen = expected match { case arr: Array[Byte] if expression.dataType == BinaryType => arr case _ => convertToCatalystUnsafe(expected) } checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValueUnsafe, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValueForCodegen, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValueUnsafe, inputRow) + checkEvalutionWithUnsafeProjection(expr, catalystValueForCodegen, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -83,14 +83,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { // No codegen version expects GenericArrayData for map val catalystValue = CatalystTypeConverters.convertToCatalyst(expectedMap) // Codegen version expects UnsafeArrayData for map - val catalystValueUnsafe = new ArrayBasedMapData( + val catalystValueForCodegen = new ArrayBasedMapData( convertToCatalystUnsafe(expectedKey).asInstanceOf[ArrayData], convertToCatalystUnsafe(expectedValue).asInstanceOf[ArrayData]) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValueUnsafe, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValueForCodegen, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValueUnsafe, inputRow) + checkEvalutionWithUnsafeProjection(expr, catalystValueForCodegen, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } From 88daf42c8a816fb1a246f4a3d84564419d73d63a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 00:08:55 +0900 Subject: [PATCH 10/39] addressed review comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- .../org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1a40fc232f981..35a50469649e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -59,12 +59,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val unsafeArrayClass = classOf[UnsafeArrayData].getName val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") val ArrayType(dt, _) = dataType val evals = children.map(e => e.genCode(ctx)) val isPrimitiveArray = ctx.isPrimitiveType(dt) && children.forall(!_.nullable) if (!isPrimitiveArray) { + ctx.addMutableState("Object[]", values, s"this.$values = null;") ev.copy(code = s""" final boolean ${ev.isNull} = false; this.$values = new Object[${children.size}];""" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 256233d5fc856..292f873cb876b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -28,14 +28,14 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("primitive type on array") { - val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + val df = sparkContext.parallelize(Seq(1, 2)).toDF("v") val resDF = df.selectExpr("Array(v + 2, v + 3)") checkAnswer(resDF, Seq(Row(Array(3, 4)), Row(Array(4, 5)))) } test("primitive array or null on array") { - val df = sparkContext.parallelize(Seq(1, 2), 1).toDF("v") + val df = sparkContext.parallelize(Seq(1, 2)).toDF("v") val resDF = df.selectExpr("Array(Array(v, v + 1, v + 2)," + "null," + "Array(v, v - 1, v - 2))") From da82efeac4b106a12e77248e259341b2a2d5b4d2 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 19 Nov 2016 00:44:04 +0900 Subject: [PATCH 11/39] addressed review comment --- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1cbb1b3a36059..a68b1e173e12f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -134,8 +134,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { case m: ArrayBasedMapData => - val keys = m.keyArray.asInstanceOf[UnsafeArrayData].toIntArray - val values = m.valueArray.asInstanceOf[UnsafeArrayData].toBooleanArray + val keys = m.keyArray.asInstanceOf[ArrayData].toIntArray + val values = m.valueArray.asInstanceOf[ArrayData].toBooleanArray keys.zip(values).toMap } val expected = (0 until length).map((_, true)).toMap :: Nil From 79142300451006bf676cf31d9c5bb6b6061afbd4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Nov 2016 23:45:54 +0800 Subject: [PATCH 12/39] A patch for expected catalystValue. (#1) Thanks. I confirmed `LiteralExpressionSuite`, `DataFrameComplexSuite`, `ExpressionToSQLSuite`, `ComplexTypeSuite`, and `CodeGenerationSuite` can pass. --- .../apache/spark/sql/catalyst/CatalystTypeConverters.scala | 3 ++- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a7af..5b435174796e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -412,7 +412,8 @@ object CatalystTypeConverters { case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) + case arr: Array[Byte] => arr + case arr: Array[_] => new GenericArrayData(arr.map(convertToCatalyst)) case map: Map[_, _] => ArrayBasedMapData( map, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1a9b4668eb4d8..0a6bc4310efe5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -58,11 +58,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val serializer = new JavaSerializer(new SparkConf()).newInstance val expr: Expression = serializer.deserialize(serializer.serialize(expression)) // No codegen version expects GenericArrayData - val catalystValue = expected match { - case arr: Array[Byte] if expression.dataType == BinaryType => arr - case arr: Array[_] => new GenericArrayData(arr.map(CatalystTypeConverters.convertToCatalyst)) - case _ => CatalystTypeConverters.convertToCatalyst(expected) - } + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) // Codegen version expects UnsafeArrayData for array expect Array(Binarytype) val catalystValueForCodegen = expected match { case arr: Array[Byte] if expression.dataType == BinaryType => arr From 69e0eed64232e0b9fa03fff51cf02c5ae0dc6326 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 22 Nov 2016 19:50:29 +0900 Subject: [PATCH 13/39] address review comment --- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 0a6bc4310efe5..ca483a07d74cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,8 +43,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } protected def convertToCatalystUnsafe(a: Any): Any = a match { + case arr: Array[Byte] => arr case arr: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(arr) case arr: Array[Short] => UnsafeArrayData.fromPrimitiveArray(arr) case arr: Array[Int] => UnsafeArrayData.fromPrimitiveArray(arr) case arr: Array[Long] => UnsafeArrayData.fromPrimitiveArray(arr) @@ -60,10 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { // No codegen version expects GenericArrayData val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) // Codegen version expects UnsafeArrayData for array expect Array(Binarytype) - val catalystValueForCodegen = expected match { - case arr: Array[Byte] if expression.dataType == BinaryType => arr - case _ => convertToCatalystUnsafe(expected) - } + val catalystValueForCodegen = convertToCatalystUnsafe(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValueForCodegen, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { From 597dc72fed9eeadcbec736f25887a6e202184acc Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 3 Dec 2016 17:30:20 +0900 Subject: [PATCH 14/39] add a test for Byte array to ComplexTypeSuite make CheckResult type-aware for UnsafeArrayData and ArrayBasedMap --- .../sql/catalyst/CatalystTypeConverters.scala | 3 +- .../expressions/ComplexTypeSuite.scala | 70 +++++++++--------- .../expressions/ExpressionEvalHelper.scala | 73 ++++++++----------- 3 files changed, 69 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b435174796e1..5b9161551a7af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -412,8 +412,7 @@ object CatalystTypeConverters { case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Byte] => arr - case arr: Array[_] => new GenericArrayData(arr.map(convertToCatalyst)) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case map: Map[_, _] => ArrayBasedMapData( map, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3d3947a26af2c..abe1d2b2c99e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -118,20 +118,23 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("CreateArray") { - // Array is required to pass Array[primitiveType] as type information for expected - val intArray = Array(5, 10, 15, 20, 25) - val longArray = intArray.map(_.toLong) - val strArray = intArray.map(_.toString) - checkEvaluation(CreateArray(intArray.map(Literal(_))), intArray, EmptyRow) - checkEvaluation(CreateArray(longArray.map(Literal(_))), longArray, EmptyRow) - checkEvaluation(CreateArray(strArray.map(Literal(_))), strArray, EmptyRow) - - val intWithNull = intArray.map(Literal(_)) :+ Literal.create(null, IntegerType) - val longWithNull = longArray.map(Literal(_)) :+ Literal.create(null, LongType) - val strWithNull = strArray.map(Literal(_)) :+ Literal.create(null, StringType) - checkEvaluation(CreateArray(intWithNull), intArray :+ null, EmptyRow) - checkEvaluation(CreateArray(longWithNull), longArray :+ null, EmptyRow) - checkEvaluation(CreateArray(strWithNull), strArray :+ null, EmptyRow) + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val byteSeq = intSeq.map(_.toByte) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(byteSeq.map(Literal(_))), byteSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val byteWithNull = byteSeq.map(Literal(_)) :+ Literal.create(null, ByteType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } @@ -145,33 +148,32 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { scala.collection.immutable.ListMap(keys.zip(values): _*) } - // Array is required to pass Array[primitiveType] as type information - val intArray = Array(5, 10, 15, 20, 25) - val longArray = intArray.map(_.toLong) - val strArray = intArray.map(_.toString) + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) checkEvaluation(CreateMap(Nil), Map.empty) - checkEvaluationMap( - CreateMap(interlace(intArray.map(Literal(_)), longArray.map(Literal(_)))), - createMap(intArray, longArray), intArray, longArray) - checkEvaluationMap( - CreateMap(interlace(strArray.map(Literal(_)), longArray.map(Literal(_)))), - createMap(strArray, longArray), strArray, longArray) - checkEvaluationMap( - CreateMap(interlace(longArray.map(Literal(_)), strArray.map(Literal(_)))), - createMap(longArray, strArray), longArray, strArray) - - val strWithNull = strArray.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) - checkEvaluationMap( - CreateMap(interlace(intArray.map(Literal(_)), strWithNull)), - createMap(intArray, strWithNull.map(_.value)), intArray, strWithNull.map(_.value)) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(intSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(strSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), + createMap(longSeq, strSeq)) + + val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), + createMap(intSeq, strWithNull.map(_.value))) intercept[RuntimeException] { checkEvaluationWithoutCodegen( - CreateMap(interlace(strWithNull, intArray.map(Literal(_)))), + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } intercept[RuntimeException] { checkEvalutionWithUnsafeProjection( - CreateMap(interlace(strWithNull, intArray.map(Literal(_)))), + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ca483a07d74cb..d95f5ac9552b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} -import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -42,48 +42,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def convertToCatalystUnsafe(a: Any): Any = a match { - case arr: Array[Byte] => arr - case arr: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Short] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Int] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Long] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Float] => UnsafeArrayData.fromPrimitiveArray(arr) - case arr: Array[Double] => UnsafeArrayData.fromPrimitiveArray(arr) - case other => CatalystTypeConverters.convertToCatalyst(other) - } - protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance val expr: Expression = serializer.deserialize(serializer.serialize(expression)) - // No codegen version expects GenericArrayData val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) - // Codegen version expects UnsafeArrayData for array expect Array(Binarytype) - val catalystValueForCodegen = convertToCatalystUnsafe(expected) - checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValueForCodegen, inputRow) - if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValueForCodegen, inputRow) - } - checkEvaluationWithOptimization(expr, catalystValue, inputRow) - } - - protected def checkEvaluationMap(expression: => Expression, expectedMap: Any, - expectedKey: Any, expectedValue: Any, inputRow: InternalRow = EmptyRow): Unit = { - val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) - // No codegen version expects GenericArrayData for map - val catalystValue = CatalystTypeConverters.convertToCatalyst(expectedMap) - // Codegen version expects UnsafeArrayData for map - val catalystValueForCodegen = new ArrayBasedMapData( - convertToCatalystUnsafe(expectedKey).asInstanceOf[ArrayData], - convertToCatalystUnsafe(expectedValue).asInstanceOf[ArrayData]) - checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) - checkEvaluationWithGeneratedMutableProjection(expr, catalystValueForCodegen, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValueForCodegen, inputRow) + checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -92,12 +59,36 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], and MapData. */ - protected def checkResult(result: Any, expected: Any): Boolean = { + protected def checkResult(result: Any, expected: Any, expr: Any = null): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: UnsafeArrayData, expected: GenericArrayData) => + val dataType = if (expr.isInstanceOf[DataType]) expr.asInstanceOf[DataType] + else expr.asInstanceOf[Expression].dataType + dataType match { + case ArrayType(BooleanType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toBooleanArray()) + case ArrayType(ByteType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toByteArray()) + case ArrayType(ShortType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toShortArray()) + case ArrayType(IntegerType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toIntArray()) + case ArrayType(LongType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toLongArray()) + case ArrayType(FloatType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toFloatArray()) + case ArrayType(DoubleType, false) => + result == UnsafeArrayData.fromPrimitiveArray(expected.toDoubleArray()) + case _ => result == expected + } + case (result: ArrayBasedMapData, expected: ArrayBasedMapData) => + val MapType(keyType, valueType, containsNull) = expr.asInstanceOf[Expression].dataType + checkResult(result.keyArray, expected.keyArray, ArrayType(keyType, false)) && + checkResult(result.valueArray, expected.valueArray, ArrayType(valueType, containsNull)) case (result: MapData, expected: MapData) => result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray() case (result: Double, expected: Double) => @@ -141,7 +132,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -160,7 +151,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expression)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -221,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -229,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected)) + assert(checkResult(actual, expected, expression)) } /** From a97133668bf94f21d6c9da47b321a7f98cb2f77f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 6 Dec 2016 18:27:52 +0900 Subject: [PATCH 15/39] address review comments --- .../expressions/complexTypeCreator.scala | 99 ++++++++----------- .../expressions/CodeGenerationSuite.scala | 35 ++++--- .../expressions/ExpressionEvalHelper.scala | 48 ++++----- 3 files changed, 78 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 35a50469649e3..43ad53083afa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") - override def dataType: DataType = { + override def dataType: ArrayType = { ArrayType( children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) @@ -56,49 +56,37 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val unsafeArrayClass = classOf[UnsafeArrayData].getName - val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - val ArrayType(dt, _) = dataType + val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) - val isPrimitiveArray = ctx.isPrimitiveType(dt) && children.forall(!_.nullable) - if (!isPrimitiveArray) { - ctx.addMutableState("Object[]", values, s"this.$values = null;") - ev.copy(code = s""" - final boolean ${ev.isNull} = false; - this.$values = new Object[${children.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - evals.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }) + - s""" - final ArrayData ${ev.value} = new $arrayClass($values); - this.$values = null; - """) + val isPrimitiveArray = ctx.isPrimitiveType(et) && children.forall(!_.nullable) + val (assigns, allocate) = if (!isPrimitiveArray) { + val arrayClass = classOf[GenericArrayData].getName + ctx.addMutableState("Object[]", values, + s"this.$values = new Object[${children.size}];") + (evals.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }, + s"final ArrayData ${ev.value} = new $arrayClass($values);") } else { - val javaDataType = ctx.javaType(dt) + val unsafeArrayClass = classOf[UnsafeArrayData].getName + val javaDataType = ctx.javaType(et) ctx.addMutableState(s"${javaDataType}[]", values, - s"this.$values = new ${javaDataType}[${children.size}];") - ev.copy(code = - ctx.splitExpressions( - ctx.INPUT_ROW, - evals.zipWithIndex.map { case (eval, i) => - eval.code + - s"\n$values[$i] = ${eval.value};" - }) + - s""" - final ArrayData ${ev.value} = $unsafeArrayClass.fromPrimitiveArray($values); - """, - isNull = "false") + s"this.$values = new ${javaDataType}[${children.size}];") + (evals.zipWithIndex.map { case (eval, i) => + eval.code + + s"\n$values[$i] = ${eval.value};" + }, + s"final ArrayData ${ev.value} = $unsafeArrayClass.fromPrimitiveArray($values);") } + ev.copy(code = ctx.splitExpressions(ctx.INPUT_ROW, assigns) + allocate, isNull = "false") } override def prettyName: String = "array" @@ -153,19 +141,25 @@ case class CreateMap(children: Seq[Expression]) extends Expression { new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } - private def getAccessors(ctx: CodegenContext, dt: DataType, array: String, - isPrimitive : Boolean, size: Int): (String, String, String) = { + // This function returns Java code pieces based on DataType and isPrimitive + // for allocation of ArrayData class + private def getArrayData( + ctx: CodegenContext, + dt: DataType, + array: String, + isPrimitive : Boolean, + size: Int): String = { if (!isPrimitive) { val arrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", array, s"this.$array = null;") - (s"new $arrayClass($array)", - s"$array = new Object[${size}];", s"this.$array = null;") + ctx.addMutableState("Object[]", array, + s"this.$array = new Object[${size}];") + s"new $arrayClass($array)" } else { val unsafeArrayClass = classOf[UnsafeArrayData].getName val javaDataType = ctx.javaType(dt) ctx.addMutableState(s"${javaDataType}[]", array, s"this.$array = new ${javaDataType}[${size}];") - (s"$unsafeArrayClass.fromPrimitiveArray($array)", "", "") + s"$unsafeArrayClass.fromPrimitiveArray($array)" } } @@ -181,15 +175,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val evalValues = values.map(e => e.genCode(ctx)) val isPrimitiveArrayValue = ctx.isPrimitiveType(valueDt) && values.forall(!_.nullable) - val (keyData, keyArrayAllocate, keyArrayNullify) = - getAccessors(ctx, keyDt, keyArray, isPrimitiveArrayKey, keys.size) - val (valueData, valueArrayAllocate, valueArrayNullify) = - getAccessors(ctx, valueDt, valueArray, isPrimitiveArrayValue, values.size) + val keyData = getArrayData(ctx, keyDt, keyArray, isPrimitiveArrayKey, keys.size) + val valueData = getArrayData(ctx, valueDt, valueArray, isPrimitiveArrayValue, values.size) - ev.copy(code = s""" - final boolean ${ev.isNull} = false; - $keyArrayAllocate - $valueArrayAllocate""" + + ev.copy(code = s"final boolean ${ev.isNull} = false;" + ctx.splitExpressions( ctx.INPUT_ROW, evalKeys.zipWithIndex.map { case (eval, i) => @@ -222,11 +211,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { """ }) }) + - s""" - final MapData ${ev.value} = new $mapClass($keyData, $valueData); - $keyArrayNullify - $valueArrayNullify; - """) + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);") } override def prettyName: String = "map" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index a68b1e173e12f..587022f0a2275 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -71,7 +71,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -106,9 +106,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(null).toSeq(expressions.map(_.dataType)) - val expected = Seq(UTF8String.fromString("abc")) + assert(actual.length == 1) + val expected = UTF8String.fromString("abc") - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -118,9 +119,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true))) + assert(actual.length == 1) + val expected = UnsafeArrayData.fromPrimitiveArray(Array.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -132,15 +134,11 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)).map { - case m: ArrayBasedMapData => - val keys = m.keyArray.asInstanceOf[ArrayData].toIntArray - val values = m.valueArray.asInstanceOf[ArrayData].toBooleanArray - keys.zip(values).toMap - } - val expected = (0 until length).map((_, true)).toMap :: Nil + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + assert(actual.length == 1) + val expected = ArrayBasedMapData((0 until length).toArray, Array.fill(length)(true)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -152,7 +150,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) - if (!checkResult(actual, expected)) { + if (!checkResult(actual, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -165,9 +163,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { })) val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(InternalRow(Seq.fill(length)(true): _*)) + assert(actual.length == 1) + val expected = InternalRow(Seq.fill(length)(true): _*) - if (!checkResult(actual, expected)) { + if (!checkResult(actual.head, expected, expressions.head.dataType)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -180,7 +179,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } @@ -197,7 +196,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expected = Seq.fill(length)( DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-07-24 07:00:00"))) - if (!checkResult(actual, expected)) { + if (actual != expected) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index d95f5ac9552b2..1ba6dd1c5e8ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -59,38 +59,28 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], and MapData. */ - protected def checkResult(result: Any, expected: Any, expr: Any = null): Boolean = { + protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) - case (result: UnsafeArrayData, expected: GenericArrayData) => - val dataType = if (expr.isInstanceOf[DataType]) expr.asInstanceOf[DataType] - else expr.asInstanceOf[Expression].dataType - dataType match { - case ArrayType(BooleanType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toBooleanArray()) - case ArrayType(ByteType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toByteArray()) - case ArrayType(ShortType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toShortArray()) - case ArrayType(IntegerType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toIntArray()) - case ArrayType(LongType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toLongArray()) - case ArrayType(FloatType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toFloatArray()) - case ArrayType(DoubleType, false) => - result == UnsafeArrayData.fromPrimitiveArray(expected.toDoubleArray()) - case _ => result == expected + case (result: ArrayData, expected: ArrayData) => + result.numElements == expected.numElements && { + val et = dataType.asInstanceOf[ArrayType].elementType + var isSame = true + var i = 0 + while (isSame && i < result.numElements) { + isSame = checkResult(result.get(i, et), expected.get(i, et), et) + i += 1 + } + isSame } - case (result: ArrayBasedMapData, expected: ArrayBasedMapData) => - val MapType(keyType, valueType, containsNull) = expr.asInstanceOf[Expression].dataType - checkResult(result.keyArray, expected.keyArray, ArrayType(keyType, false)) && - checkResult(result.valueArray, expected.valueArray, ArrayType(valueType, containsNull)) case (result: MapData, expected: MapData) => - result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray() + val kt = dataType.asInstanceOf[MapType].keyType + val vt = dataType.asInstanceOf[MapType].valueType + checkResult(result.keyArray, expected.keyArray, ArrayType(kt)) && + checkResult(result.valueArray, expected.valueArray, ArrayType(vt)) case (result: Double, expected: Double) => if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => @@ -132,7 +122,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (!checkResult(actual, expected, expression)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -151,7 +141,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression)) { + if (!checkResult(actual, expected, expression.dataType)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -212,7 +202,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) - assert(checkResult(actual, expected, expression)) + assert(checkResult(actual, expected, expression.dataType)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -220,7 +210,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) - assert(checkResult(actual, expected, expression)) + assert(checkResult(actual, expected, expression.dataType)) } /** From be01d913a39908f1273261d899ef0e39459bfecf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 6 Dec 2016 20:32:55 +0900 Subject: [PATCH 16/39] commit a file --- .../spark/sql/hive/execution/ObjectHashAggregateSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9a8d4498bba2f..9eaf44c043c71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -411,8 +411,8 @@ class ObjectHashAggregateSuite actual.zip(expected).foreach { case (lhs: Row, rhs: Row) => assert(lhs.length == rhs.length) lhs.toSeq.zip(rhs.toSeq).foreach { - case (a: Double, b: Double) => checkResult(a, b +- tolerance) - case (a, b) => checkResult(a, b) + case (a: Double, b: Double) => checkResult(a, b +- tolerance, DoubleType) + case (a, b) => a == b } } } From 1c7f972c1da166fe85c873189090c5770ce1dc7c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 11 Dec 2016 02:45:03 +0900 Subject: [PATCH 17/39] remove two test suites that I newly added --- .../spark/sql/DataFrameComplexTypeSuite.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 292f873cb876b..1230b921aa279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,23 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("primitive type on array") { - val df = sparkContext.parallelize(Seq(1, 2)).toDF("v") - val resDF = df.selectExpr("Array(v + 2, v + 3)") - checkAnswer(resDF, - Seq(Row(Array(3, 4)), Row(Array(4, 5)))) - } - - test("primitive array or null on array") { - val df = sparkContext.parallelize(Seq(1, 2)).toDF("v") - val resDF = df.selectExpr("Array(Array(v, v + 1, v + 2)," + - "null," + - "Array(v, v - 1, v - 2))") - QueryTest.checkAnswer(resDF, - Seq(Row(Array(Array(1, 2, 3), null, Array(1, 0, -1))), - Row(Array(Array(2, 3, 4), null, Array(2, 1, 0))))) - } - test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") From 360d13952bf8d5e2dd9620c2b163294996822306 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 11 Dec 2016 02:45:56 +0900 Subject: [PATCH 18/39] Create UnsafeArrayData by using UnsafeArrayWriter --- .../expressions/codegen/BufferHolder.java | 35 ++++++++++--- .../expressions/complexTypeCreator.scala | 49 ++++++++++++++----- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 0e4264fe8dfb5..66d16ca1a1ef5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; @@ -38,7 +39,9 @@ public class BufferHolder { public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; + private final UnsafeArrayData array; private final int fixedSize; + private long numElements; public BufferHolder(UnsafeRow row) { this(row, 64); @@ -55,6 +58,15 @@ public BufferHolder(UnsafeRow row, int initialSize) { this.buffer = new byte[fixedSize + initialSize]; this.row = row; this.row.pointTo(buffer, buffer.length); + this.array = null; + } + + public BufferHolder(UnsafeArrayData array, long numElements) { + this.fixedSize = 0; + this.buffer = null; + this.array = array; + this.numElements = numElements; + this.row = null; } /** @@ -67,18 +79,25 @@ public void grow(int neededSize) { "exceeds size limitation " + Integer.MAX_VALUE); } final int length = totalSize() + neededSize; - if (buffer.length < length) { + if (buffer == null || buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < Integer.MAX_VALUE / 2 ? length * 2 : Integer.MAX_VALUE; final byte[] tmp = new byte[newLength]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); + if (buffer != null) { + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + } buffer = tmp; - row.pointTo(buffer, buffer.length); + if (row != null) + row.pointTo(buffer, buffer.length); + else { + Platform.putLong(buffer, Platform.BYTE_ARRAY_OFFSET, numElements); + array.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET, buffer.length); + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 43ad53083afa6..24d4fda4d25fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String /** @@ -60,12 +61,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) - val isPrimitiveArray = ctx.isPrimitiveType(et) && children.forall(!_.nullable) - val (assigns, allocate) = if (!isPrimitiveArray) { + val isPrimitiveArray = ctx.isPrimitiveType(et) + val (preprocess, assigns, postprocess) = if (!isPrimitiveArray) { val arrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", values, s"this.$values = new Object[${children.size}];") - (evals.zipWithIndex.map { case (eval, i) => + ("", + evals.zipWithIndex.map { case (eval, i) => eval.code + s""" if (${eval.isNull}) { $values[$i] = null; @@ -74,19 +76,42 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } """ }, - s"final ArrayData ${ev.value} = new $arrayClass($values);") + s"\nfinal ArrayData ${ev.value} = new $arrayClass($values);\n") } else { + val holder = ctx.freshName("holder") + val arrayWriter = ctx.freshName("createArrayWriter") val unsafeArrayClass = classOf[UnsafeArrayData].getName - val javaDataType = ctx.javaType(et) - ctx.addMutableState(s"${javaDataType}[]", values, - s"this.$values = new ${javaDataType}[${children.size}];") - (evals.zipWithIndex.map { case (eval, i) => - eval.code + - s"\n$values[$i] = ${eval.value};" + val holderClass = classOf[BufferHolder].getName + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + ctx.addMutableState(unsafeArrayClass, ev.value, + s"${ev.value} = new $unsafeArrayClass();") + ctx.addMutableState(holderClass, holder, + s"$holder = new $holderClass(${ev.value}, ${children.size});") + ctx.addMutableState(arrayWriterClass, arrayWriter, + s"$arrayWriter = new $arrayWriterClass();") + val primitiveTypeName = ctx.primitiveTypeName(et) + + (s""" + $holder.reset(); + $arrayWriter.initialize($holder, ${children.size}, ${et.defaultSize}); + """, + evals.zipWithIndex.map { case (eval, i) => + eval.code + (if (!children(i).nullable) { + s"\n$arrayWriter.write($i, ${eval.value});" + } else { + s""" + if (${eval.isNull}) { + $arrayWriter.setNull$primitiveTypeName($i); + } else { + $arrayWriter.write($i, ${eval.value}); + } + """ + }) }, - s"final ArrayData ${ev.value} = $unsafeArrayClass.fromPrimitiveArray($values);") + "") } - ev.copy(code = ctx.splitExpressions(ctx.INPUT_ROW, assigns) + allocate, isNull = "false") + ev.copy(code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, + isNull = "false") } override def prettyName: String = "array" From b66d0f6c018226b35247feb654653ea0586bfa74 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 11 Dec 2016 19:15:11 +0900 Subject: [PATCH 19/39] fix test failure - DataFrameSuite.Star Expansion --- .../catalyst/expressions/complexTypeCreator.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 24d4fda4d25fd..25b549e3b761f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -83,15 +83,15 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val unsafeArrayClass = classOf[UnsafeArrayData].getName val holderClass = classOf[BufferHolder].getName val arrayWriterClass = classOf[UnsafeArrayWriter].getName - ctx.addMutableState(unsafeArrayClass, ev.value, - s"${ev.value} = new $unsafeArrayClass();") - ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass(${ev.value}, ${children.size});") - ctx.addMutableState(arrayWriterClass, arrayWriter, - s"$arrayWriter = new $arrayWriterClass();") + ctx.addMutableState(unsafeArrayClass, ev.value, "") + ctx.addMutableState(holderClass, holder, "") + ctx.addMutableState(arrayWriterClass, arrayWriter, "") val primitiveTypeName = ctx.primitiveTypeName(et) (s""" + ${ev.value} = new $unsafeArrayClass(); + $holder = new $holderClass(${ev.value}, ${children.size}); + $arrayWriter = new $arrayWriterClass(); $holder.reset(); $arrayWriter.initialize($holder, ${children.size}, ${et.defaultSize}); """, From 08262b1f6abb3f8b85b276006a36d5fcbadf504f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 12 Dec 2016 00:36:28 +0900 Subject: [PATCH 20/39] support createMap address review comments --- .../expressions/codegen/BufferHolder.java | 39 ++-- .../codegen/UnsafeArrayWriter.java | 11 ++ .../expressions/complexTypeCreator.scala | 182 ++++++++++-------- 3 files changed, 131 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 66d16ca1a1ef5..db51d3452e325 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions.codegen; -import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.Platform; +import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; + /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and * automatically re-point the unsafe row to it. @@ -39,9 +41,7 @@ public class BufferHolder { public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; - private final UnsafeArrayData array; private final int fixedSize; - private long numElements; public BufferHolder(UnsafeRow row) { this(row, 64); @@ -58,14 +58,14 @@ public BufferHolder(UnsafeRow row, int initialSize) { this.buffer = new byte[fixedSize + initialSize]; this.row = row; this.row.pointTo(buffer, buffer.length); - this.array = null; } - public BufferHolder(UnsafeArrayData array, long numElements) { + public BufferHolder(int numElements, int elementSize) { + int headerInBytes = calculateHeaderPortionInBytes(numElements); + int fixedPartInBytes = + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements); this.fixedSize = 0; - this.buffer = null; - this.array = array; - this.numElements = numElements; + this.buffer = new byte[headerInBytes + fixedPartInBytes]; this.row = null; } @@ -79,25 +79,18 @@ public void grow(int neededSize) { "exceeds size limitation " + Integer.MAX_VALUE); } final int length = totalSize() + neededSize; - if (buffer == null || buffer.length < length) { + if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. int newLength = length < Integer.MAX_VALUE / 2 ? length * 2 : Integer.MAX_VALUE; final byte[] tmp = new byte[newLength]; - if (buffer != null) { - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); - } + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); buffer = tmp; - if (row != null) - row.pointTo(buffer, buffer.length); - else { - Platform.putLong(buffer, Platform.BYTE_ARRAY_OFFSET, numElements); - array.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET, buffer.length); - } + row.pointTo(buffer, buffer.length); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index afea4676893ed..b458ae6c80255 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -73,6 +73,17 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.cursor += (headerInBytes + fixedPartInBytes); } + public void initialize(BufferHolder holder, int numElements) { + this.numElements = numElements; + this.headerInBytes = calculateHeaderPortionInBytes(numElements); + + this.holder = holder; + this.startingOffset = holder.cursor; + + Platform.putLong(holder.buffer, startingOffset, numElements); + /* avoid to fill 0 since we ensure all elements in holder.buffer are 0 */ + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 25b549e3b761f..14f2f4c2a19a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -57,64 +57,88 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val values = ctx.freshName("values") + val array = ctx.freshName("array") val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) val isPrimitiveArray = ctx.isPrimitiveType(et) - val (preprocess, assigns, postprocess) = if (!isPrimitiveArray) { + val primitiveTypeName = if (isPrimitiveArray) ctx.primitiveTypeName(et) else "" + val (preprocess, arrayData, arrayWriter) = + genArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) + + ev.copy(code = + preprocess + + ctx.splitExpressions( + ctx.INPUT_ROW, + evals.zipWithIndex.map { case (eval, i) => + eval.code + + (if (isPrimitiveArray) { + (if (!children(i).nullable) { + s"\n$arrayWriter.write($i, ${eval.value});" + } else { + s""" + if (${eval.isNull}) { + $arrayWriter.setNull$primitiveTypeName($i); + } else { + $arrayWriter.write($i, ${eval.value}); + } + """ + }) + } else { + s""" + if (${eval.isNull}) { + $array[$i] = null; + } else { + $array[$i] = ${eval.value}; + } + """ + }) + }) + + s"\nfinal ArrayData ${ev.value} = $arrayData;\n", + isNull = "false") + } + + override def prettyName: String = "array" +} + +private [sql] object genArrayData { + // This function returns Java code pieces based on DataType and isPrimitive + // for allocation of ArrayData class + def getCodeArrayData( + ctx: CodegenContext, + dt: DataType, + size: Int, + isPrimitive : Boolean, + array: String): (String, String, String) = { + if (!isPrimitive) { val arrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", values, - s"this.$values = new Object[${children.size}];") - ("", - evals.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - } - """ - }, - s"\nfinal ArrayData ${ev.value} = new $arrayClass($values);\n") + ctx.addMutableState("Object[]", array, + s"this.$array = new Object[${size}];") + ("", s"new $arrayClass($array)", null) } else { val holder = ctx.freshName("holder") val arrayWriter = ctx.freshName("createArrayWriter") val unsafeArrayClass = classOf[UnsafeArrayData].getName val holderClass = classOf[BufferHolder].getName val arrayWriterClass = classOf[UnsafeArrayWriter].getName - ctx.addMutableState(unsafeArrayClass, ev.value, "") + ctx.addMutableState(unsafeArrayClass, array, "") ctx.addMutableState(holderClass, holder, "") ctx.addMutableState(arrayWriterClass, arrayWriter, "") - val primitiveTypeName = ctx.primitiveTypeName(et) + val baseOffset = Platform.BYTE_ARRAY_OFFSET (s""" - ${ev.value} = new $unsafeArrayClass(); - $holder = new $holderClass(${ev.value}, ${children.size}); + $array = new $unsafeArrayClass(); + $holder = new $holderClass(${size}, ${dt.defaultSize}); $arrayWriter = new $arrayWriterClass(); $holder.reset(); - $arrayWriter.initialize($holder, ${children.size}, ${et.defaultSize}); - """, - evals.zipWithIndex.map { case (eval, i) => - eval.code + (if (!children(i).nullable) { - s"\n$arrayWriter.write($i, ${eval.value});" - } else { - s""" - if (${eval.isNull}) { - $arrayWriter.setNull$primitiveTypeName($i); - } else { - $arrayWriter.write($i, ${eval.value}); - } - """ - }) - }, - "") + $arrayWriter.initialize($holder, ${size}); + $array.pointTo($holder.buffer, $baseOffset, $holder.buffer.length); + """, + array, + arrayWriter + ) } - ev.copy(code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, - isNull = "false") } - - override def prettyName: String = "array" } /** @@ -166,28 +190,6 @@ case class CreateMap(children: Seq[Expression]) extends Expression { new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } - // This function returns Java code pieces based on DataType and isPrimitive - // for allocation of ArrayData class - private def getArrayData( - ctx: CodegenContext, - dt: DataType, - array: String, - isPrimitive : Boolean, - size: Int): String = { - if (!isPrimitive) { - val arrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", array, - s"this.$array = new Object[${size}];") - s"new $arrayClass($array)" - } else { - val unsafeArrayClass = classOf[UnsafeArrayData].getName - val javaDataType = ctx.javaType(dt) - ctx.addMutableState(s"${javaDataType}[]", array, - s"this.$array = new ${javaDataType}[${size}];") - s"$unsafeArrayClass.fromPrimitiveArray($array)" - } - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapClass = classOf[ArrayBasedMapData].getName val keyArray = ctx.freshName("keyArray") @@ -196,36 +198,60 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val isPrimitiveArrayKey = ctx.isPrimitiveType(keyDt) - val isNonNullKey = keys.forall(!_.nullable) + val primitiveKeyTypeName = if (isPrimitiveArrayKey) ctx.primitiveTypeName(keyDt) else "" val evalValues = values.map(e => e.genCode(ctx)) - val isPrimitiveArrayValue = - ctx.isPrimitiveType(valueDt) && values.forall(!_.nullable) - val keyData = getArrayData(ctx, keyDt, keyArray, isPrimitiveArrayKey, keys.size) - val valueData = getArrayData(ctx, valueDt, valueArray, isPrimitiveArrayValue, values.size) + val isPrimitiveArrayValue = ctx.isPrimitiveType(valueDt) + val primitiveValueTypeName = if (isPrimitiveArrayKey) ctx.primitiveTypeName(keyDt) else "" + val (preprocessKeyData, keyData, keyDataArrayWriter) = + genArrayData.getCodeArrayData(ctx, keyDt, keys.size, isPrimitiveArrayKey, keyArray) + val (preprocessValueData, valueData, valueDataArrayWriter) = + genArrayData.getCodeArrayData(ctx, valueDt, values.size, isPrimitiveArrayValue, valueArray) ev.copy(code = s"final boolean ${ev.isNull} = false;" + + preprocessKeyData + ctx.splitExpressions( ctx.INPUT_ROW, evalKeys.zipWithIndex.map { case (eval, i) => eval.code + - (if (isNonNullKey) { - s"$keyArray[$i] = ${eval.value};" - } else { - s""" - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ - }) + (if (isPrimitiveArrayKey) { + (if (!keys(i).nullable) { + s"\n$keyDataArrayWriter.write($i, ${eval.value});" + } else { + s""" + if (${eval.isNull}) { + $keyDataArrayWriter.setNull$primitiveKeyTypeName($i); + } else { + $keyDataArrayWriter.write($i, ${eval.value}); + } + """ + }) + } else { + s""" + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + }) }) + + preprocessValueData + ctx.splitExpressions( ctx.INPUT_ROW, evalValues.zipWithIndex.map { case (eval, i) => eval.code + (if (isPrimitiveArrayValue) { - s"$valueArray[$i] = ${eval.value};" + (if (!values(i).nullable) { + s"\n$valueDataArrayWriter.write($i, ${eval.value});" + } else { + s""" + if (${eval.isNull}) { + $valueDataArrayWriter.setNull$primitiveValueTypeName($i); + } else { + $valueDataArrayWriter.write($i, ${eval.value}); + } + """ + }) } else { s""" if (${eval.isNull}) { From 438944b0cc79d824898d44032674cb77395b59fb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 12 Dec 2016 04:13:13 +0900 Subject: [PATCH 21/39] calculate initial buffer size at compilation time --- .../sql/catalyst/expressions/codegen/BufferHolder.java | 7 ++----- .../sql/catalyst/expressions/complexTypeCreator.scala | 6 +++++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index db51d3452e325..1e36dfca03407 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -60,12 +60,9 @@ public BufferHolder(UnsafeRow row, int initialSize) { this.row.pointTo(buffer, buffer.length); } - public BufferHolder(int numElements, int elementSize) { - int headerInBytes = calculateHeaderPortionInBytes(numElements); - int fixedPartInBytes = - ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements); + public BufferHolder(int initialSizeInBytes) { this.fixedSize = 0; - this.buffer = new byte[headerInBytes + fixedPartInBytes]; + this.buffer = new byte[initialSizeInBytes]; this.row = null; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 14f2f4c2a19a9..f7a4d8118115f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String /** @@ -125,10 +126,13 @@ private [sql] object genArrayData { ctx.addMutableState(holderClass, holder, "") ctx.addMutableState(arrayWriterClass, arrayWriter, "") val baseOffset = Platform.BYTE_ARRAY_OFFSET + val unsafeArraySizeInBytes = + UnsafeArrayData.calculateHeaderPortionInBytes(size) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(dt.defaultSize * size) (s""" $array = new $unsafeArrayClass(); - $holder = new $holderClass(${size}, ${dt.defaultSize}); + $holder = new $holderClass($unsafeArraySizeInBytes); $arrayWriter = new $arrayWriterClass(); $holder.reset(); $arrayWriter.initialize($holder, ${size}); From f418062e8c54732c4b78716d27b8c699ac9df980 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 12 Dec 2016 16:36:00 +0900 Subject: [PATCH 22/39] addressed review comments --- .../sql/catalyst/expressions/codegen/BufferHolder.java | 5 ++--- .../catalyst/expressions/codegen/UnsafeArrayWriter.java | 3 +++ .../sql/catalyst/expressions/complexTypeCreator.scala | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 1e36dfca03407..df8bbf26c59b8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -18,11 +18,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.Platform; -import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; - /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and * automatically re-point the unsafe row to it. @@ -60,6 +57,8 @@ public BufferHolder(UnsafeRow row, int initialSize) { this.row.pointTo(buffer, buffer.length); } + // This is a special constructor for writing data to UnsafeArray for a primitive array + // that do not require to grow buffer (not to call grow()) during write operations public BufferHolder(int initialSizeInBytes) { this.fixedSize = 0; this.buffer = new byte[initialSizeInBytes]; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index b458ae6c80255..69842218699a9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -73,6 +73,9 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.cursor += (headerInBytes + fixedPartInBytes); } + // This is a special constructor for writing data to UnsafeArray for a primitive array + // that writes regions only for null bits and values + // This assumes that all elements in holder.buffer are 0 public void initialize(BufferHolder holder, int numElements) { this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f7a4d8118115f..f9b23703867d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -65,7 +65,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val isPrimitiveArray = ctx.isPrimitiveType(et) val primitiveTypeName = if (isPrimitiveArray) ctx.primitiveTypeName(et) else "" val (preprocess, arrayData, arrayWriter) = - genArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) + GenArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) ev.copy(code = preprocess + @@ -102,7 +102,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def prettyName: String = "array" } -private [sql] object genArrayData { +private [sql] object GenArrayData { // This function returns Java code pieces based on DataType and isPrimitive // for allocation of ArrayData class def getCodeArrayData( @@ -207,9 +207,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val isPrimitiveArrayValue = ctx.isPrimitiveType(valueDt) val primitiveValueTypeName = if (isPrimitiveArrayKey) ctx.primitiveTypeName(keyDt) else "" val (preprocessKeyData, keyData, keyDataArrayWriter) = - genArrayData.getCodeArrayData(ctx, keyDt, keys.size, isPrimitiveArrayKey, keyArray) + GenArrayData.getCodeArrayData(ctx, keyDt, keys.size, isPrimitiveArrayKey, keyArray) val (preprocessValueData, valueData, valueDataArrayWriter) = - genArrayData.getCodeArrayData(ctx, valueDt, values.size, isPrimitiveArrayValue, valueArray) + GenArrayData.getCodeArrayData(ctx, valueDt, values.size, isPrimitiveArrayValue, valueArray) ev.copy(code = s"final boolean ${ev.isNull} = false;" + preprocessKeyData + From d24c7b1e13c00b722cafb7230e1e59021f42eee4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 19 Dec 2016 10:45:53 +0900 Subject: [PATCH 23/39] address review comment --- .../expressions/complexTypeCreator.scala | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f9b23703867d9..bd84fb5bb8ac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -74,17 +74,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression { evals.zipWithIndex.map { case (eval, i) => eval.code + (if (isPrimitiveArray) { - (if (!children(i).nullable) { - s"\n$arrayWriter.write($i, ${eval.value});" + s""" + if (${eval.isNull}) { + $arrayWriter.setNull$primitiveTypeName($i); } else { - s""" - if (${eval.isNull}) { - $arrayWriter.setNull$primitiveTypeName($i); - } else { - $arrayWriter.write($i, ${eval.value}); - } - """ - }) + $arrayWriter.write($i, ${eval.value}); + } + """ } else { s""" if (${eval.isNull}) { From c159f0361c472fd4e0e780a46f1b95a6a5f4f4d7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 22 Dec 2016 02:00:08 +0900 Subject: [PATCH 24/39] Create UnsafeArrayData by using UnsafeRow and UnsafeArrayWriter --- .../expressions/codegen/BufferHolder.java | 8 --- .../codegen/UnsafeArrayWriter.java | 14 ---- .../expressions/complexTypeCreator.scala | 67 +++++++++++-------- 3 files changed, 39 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index df8bbf26c59b8..0e4264fe8dfb5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -57,14 +57,6 @@ public BufferHolder(UnsafeRow row, int initialSize) { this.row.pointTo(buffer, buffer.length); } - // This is a special constructor for writing data to UnsafeArray for a primitive array - // that do not require to grow buffer (not to call grow()) during write operations - public BufferHolder(int initialSizeInBytes) { - this.fixedSize = 0; - this.buffer = new byte[initialSizeInBytes]; - this.row = null; - } - /** * Grows the buffer by at least neededSize and points the row to the buffer. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 69842218699a9..afea4676893ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -73,20 +73,6 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.cursor += (headerInBytes + fixedPartInBytes); } - // This is a special constructor for writing data to UnsafeArray for a primitive array - // that writes regions only for null bits and values - // This assumes that all elements in holder.buffer are 0 - public void initialize(BufferHolder holder, int numElements) { - this.numElements = numElements; - this.headerInBytes = calculateHeaderPortionInBytes(numElements); - - this.holder = holder; - this.startingOffset = holder.cursor; - - Platform.putLong(holder.buffer, startingOffset, numElements); - /* avoid to fill 0 since we ensure all elements in holder.buffer are 0 */ - } - private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index bd84fb5bb8ac9..e9d02daf7fbed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -67,30 +67,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, arrayData, arrayWriter) = GenArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) + val assigns = if (isPrimitiveArray) { + evals.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $arrayWriter.setNull$primitiveTypeName($i); + } else { + $arrayWriter.write($i, ${eval.value}); + } + """ + } + } else { + evals.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $array[$i] = null; + } else { + $array[$i] = ${eval.value}; + } + """ + } + } ev.copy(code = preprocess + - ctx.splitExpressions( - ctx.INPUT_ROW, - evals.zipWithIndex.map { case (eval, i) => - eval.code + - (if (isPrimitiveArray) { - s""" - if (${eval.isNull}) { - $arrayWriter.setNull$primitiveTypeName($i); - } else { - $arrayWriter.write($i, ${eval.value}); - } - """ - } else { - s""" - if (${eval.isNull}) { - $array[$i] = null; - } else { - $array[$i] = ${eval.value}; - } - """ - }) - }) + + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + s"\nfinal ArrayData ${ev.value} = $arrayData;\n", isNull = "false") } @@ -113,27 +113,38 @@ private [sql] object GenArrayData { s"this.$array = new Object[${size}];") ("", s"new $arrayClass($array)", null) } else { + val row = ctx.freshName("row") val holder = ctx.freshName("holder") + val rowWriter = ctx.freshName("createRowWriter") val arrayWriter = ctx.freshName("createArrayWriter") + val unsafeRowClass = classOf[UnsafeRow].getName val unsafeArrayClass = classOf[UnsafeArrayData].getName val holderClass = classOf[BufferHolder].getName + val rowWriterClass = classOf[UnsafeRowWriter].getName val arrayWriterClass = classOf[UnsafeArrayWriter].getName + ctx.addMutableState(unsafeRowClass, row, "") ctx.addMutableState(unsafeArrayClass, array, "") ctx.addMutableState(holderClass, holder, "") + ctx.addMutableState(rowWriterClass, rowWriter, "") ctx.addMutableState(arrayWriterClass, arrayWriter, "") - val baseOffset = Platform.BYTE_ARRAY_OFFSET val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(size) + ByteArrayMethods.roundNumberOfBytesToNearestWord(dt.defaultSize * size) + // To write data to UnsafeArrayData, we create UnsafeRow with a single array field + // and then prepare BufferHolder for the array. + // In summary, this does not use UnsafeRow and wastes some bits in an byte array (s""" - $array = new $unsafeArrayClass(); - $holder = new $holderClass($unsafeArraySizeInBytes); + $row = new $unsafeRowClass(1); + $holder = new $holderClass($row, $unsafeArraySizeInBytes); + $rowWriter = new $rowWriterClass($holder, 1); $arrayWriter = new $arrayWriterClass(); + $rowWriter.reset(); + $rowWriter.setOffsetAndSize(0, $unsafeArraySizeInBytes); $holder.reset(); - $arrayWriter.initialize($holder, ${size}); - $array.pointTo($holder.buffer, $baseOffset, $holder.buffer.length); - """, + $arrayWriter.initialize($holder, ${size}, ${dt.defaultSize}); + $array = $row.getArray(0); + """, array, arrayWriter ) From 0af08282f4f1d72d205442ba66d6964cd1ac0599 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 23 Dec 2016 14:37:30 +0900 Subject: [PATCH 25/39] Create UnsafeArrayData by making ArrayData mutable refectoring createMap fix bugs in createMap --- .../catalyst/expressions/UnsafeArrayData.java | 51 ++++++ .../expressions/complexTypeCreator.scala | 153 ++++++++---------- .../spark/sql/catalyst/util/ArrayData.scala | 13 ++ .../sql/catalyst/util/GenericArrayData.scala | 4 + .../execution/vectorized/ColumnVector.java | 6 + 5 files changed, 138 insertions(+), 89 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index e8c33871f97bc..c402275c1096a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -287,6 +287,57 @@ public UnsafeMapData getMap(int ordinal) { return map; } + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + public void setNullAt(int ordinal) { + assertIndexIsValid(ordinal); + BitSetMethods.set(baseObject, baseOffset + 8, ordinal); + + /* we assume the corrresponding column was already 0 */ + } + + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + Platform.putBoolean(baseObject, getElementOffset(ordinal, 1), value); + } + + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + Platform.putByte(baseObject, getElementOffset(ordinal, 1), value); + } + + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + Platform.putShort(baseObject, getElementOffset(ordinal, 2), value); + } + + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + Platform.putInt(baseObject, getElementOffset(ordinal, 4), value); + } + + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + Platform.putLong(baseObject, getElementOffset(ordinal, 8), value); + } + + public void setFloat(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + assertIndexIsValid(ordinal); + Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value); + } + + public void setDouble(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + assertIndexIsValid(ordinal); + Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value); + } + // This `hashCode` computation could consume much processor time for large data. // If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes // are used to compute `hashCode` (See `Vector.hashCode`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e9d02daf7fbed..36d1bdb876778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,17 +63,17 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) val isPrimitiveArray = ctx.isPrimitiveType(et) - val primitiveTypeName = if (isPrimitiveArray) ctx.primitiveTypeName(et) else "" - val (preprocess, arrayData, arrayWriter) = + val (preprocess, arrayData) = GenArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) val assigns = if (isPrimitiveArray) { + val primitiveTypeName = ctx.primitiveTypeName(et) evals.zipWithIndex.map { case (eval, i) => eval.code + s""" if (${eval.isNull}) { - $arrayWriter.setNull$primitiveTypeName($i); + $arrayData.setNullAt($i); } else { - $arrayWriter.write($i, ${eval.value}); + $arrayData.set$primitiveTypeName($i, ${eval.value}); } """ } @@ -106,47 +106,28 @@ private [sql] object GenArrayData { dt: DataType, size: Int, isPrimitive : Boolean, - array: String): (String, String, String) = { + array: String): (String, String) = { if (!isPrimitive) { val arrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", array, s"this.$array = new Object[${size}];") - ("", s"new $arrayClass($array)", null) + ("", s"new $arrayClass($array)") } else { - val row = ctx.freshName("row") - val holder = ctx.freshName("holder") - val rowWriter = ctx.freshName("createRowWriter") - val arrayWriter = ctx.freshName("createArrayWriter") - val unsafeRowClass = classOf[UnsafeRow].getName + val baseArray = ctx.freshName("baseArray") val unsafeArrayClass = classOf[UnsafeArrayData].getName - val holderClass = classOf[BufferHolder].getName - val rowWriterClass = classOf[UnsafeRowWriter].getName - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - ctx.addMutableState(unsafeRowClass, row, "") ctx.addMutableState(unsafeArrayClass, array, "") - ctx.addMutableState(holderClass, holder, "") - ctx.addMutableState(rowWriterClass, rowWriter, "") - ctx.addMutableState(arrayWriterClass, arrayWriter, "") val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(size) + ByteArrayMethods.roundNumberOfBytesToNearestWord(dt.defaultSize * size) + val baseOffset = Platform.BYTE_ARRAY_OFFSET - // To write data to UnsafeArrayData, we create UnsafeRow with a single array field - // and then prepare BufferHolder for the array. - // In summary, this does not use UnsafeRow and wastes some bits in an byte array (s""" - $row = new $unsafeRowClass(1); - $holder = new $holderClass($row, $unsafeArraySizeInBytes); - $rowWriter = new $rowWriterClass($holder, 1); - $arrayWriter = new $arrayWriterClass(); - $rowWriter.reset(); - $rowWriter.setOffsetAndSize(0, $unsafeArraySizeInBytes); - $holder.reset(); - $arrayWriter.initialize($holder, ${size}, ${dt.defaultSize}); - $array = $row.getArray(0); + byte[] $baseArray = new byte[$unsafeArraySizeInBytes]; + $array = new $unsafeArrayClass(); + Platform.putLong($baseArray, $baseOffset, $size); + $array.pointTo($baseArray, $baseOffset, $unsafeArraySizeInBytes); """, - array, - arrayWriter + array ) } } @@ -209,71 +190,65 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val isPrimitiveArrayKey = ctx.isPrimitiveType(keyDt) - val primitiveKeyTypeName = if (isPrimitiveArrayKey) ctx.primitiveTypeName(keyDt) else "" val evalValues = values.map(e => e.genCode(ctx)) val isPrimitiveArrayValue = ctx.isPrimitiveType(valueDt) - val primitiveValueTypeName = if (isPrimitiveArrayKey) ctx.primitiveTypeName(keyDt) else "" - val (preprocessKeyData, keyData, keyDataArrayWriter) = + val (preprocessKeyData, keyDataArray) = GenArrayData.getCodeArrayData(ctx, keyDt, keys.size, isPrimitiveArrayKey, keyArray) - val (preprocessValueData, valueData, valueDataArrayWriter) = + val (preprocessValueData, valueDataArray) = GenArrayData.getCodeArrayData(ctx, valueDt, values.size, isPrimitiveArrayValue, valueArray) + val assignKeys = if (isPrimitiveArrayKey) { + val primitiveKeyTypeName = ctx.primitiveTypeName(keyDt) + evalKeys.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $keyDataArray.setNullAt($i); + } else { + $keyDataArray.set$primitiveKeyTypeName($i, ${eval.value}); + } + """ + } + } else { + evalKeys.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + } + } + + val assignValues = if (isPrimitiveArrayValue) { + val primitiveValueTypeName = ctx.primitiveTypeName(valueDt) + evalValues.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $valueDataArray.setNullAt($i); + } else { + $valueDataArray.set$primitiveValueTypeName($i, ${eval.value}); + } + """ + } + } else { + evalValues.zipWithIndex.map { case (eval, i) => + eval.code + s""" + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + } + } + ev.copy(code = s"final boolean ${ev.isNull} = false;" + preprocessKeyData + - ctx.splitExpressions( - ctx.INPUT_ROW, - evalKeys.zipWithIndex.map { case (eval, i) => - eval.code + - (if (isPrimitiveArrayKey) { - (if (!keys(i).nullable) { - s"\n$keyDataArrayWriter.write($i, ${eval.value});" - } else { - s""" - if (${eval.isNull}) { - $keyDataArrayWriter.setNull$primitiveKeyTypeName($i); - } else { - $keyDataArrayWriter.write($i, ${eval.value}); - } - """ - }) - } else { - s""" - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ - }) - }) + + ctx.splitExpressions(ctx.INPUT_ROW, assignKeys) + preprocessValueData + - ctx.splitExpressions( - ctx.INPUT_ROW, - evalValues.zipWithIndex.map { case (eval, i) => - eval.code + - (if (isPrimitiveArrayValue) { - (if (!values(i).nullable) { - s"\n$valueDataArrayWriter.write($i, ${eval.value});" - } else { - s""" - if (${eval.isNull}) { - $valueDataArrayWriter.setNull$primitiveValueTypeName($i); - } else { - $valueDataArrayWriter.write($i, ${eval.value}); - } - """ - }) - } else { - s""" - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ - }) - }) + - s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);") + ctx.splitExpressions(ctx.INPUT_ROW, assignValues) + + s"final MapData ${ev.value} = new $mapClass($keyDataArray, $valueDataArray);") } override def prettyName: String = "map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 140e86d670a5b..9beef41d639f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -42,6 +42,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def setNullAt(i: Int): Unit + + def update(i: Int, value: Any): Unit + + // default implementation (slow) + def setBoolean(i: Int, value: Boolean): Unit = update(i, value) + def setByte(i: Int, value: Byte): Unit = update(i, value) + def setShort(i: Int, value: Short): Unit = update(i, value) + def setInt(i: Int, value: Int): Unit = update(i, value) + def setLong(i: Int, value: Long): Unit = update(i, value) + def setFloat(i: Int, value: Float): Unit = update(i, value) + def setDouble(i: Int, value: Double): Unit = update(i, value) + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 7ee9581b63af5..07f4acdea7ccb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -71,6 +71,10 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def setNullAt(ordinal: Int): Unit = throw new UnsupportedOperationException(); + + override def update(ordinal: Int, value: Any): Unit = throw new UnsupportedOperationException(); + override def toString(): String = array.mkString("[", ",", "]") override def equals(o: Any): Boolean = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ff07940422a0b..354c878aca000 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -246,6 +246,12 @@ public MapData getMap(int ordinal) { public Object get(int ordinal, DataType dataType) { throw new UnsupportedOperationException(); } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } } /** From f6e9a832a983a805f91e5d65d28cb95cb6d89d99 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 24 Dec 2016 04:16:29 +0900 Subject: [PATCH 26/39] addressed review comments --- .../spark/sql/catalyst/expressions/UnsafeArrayData.java | 3 ++- .../org/apache/spark/sql/catalyst/util/GenericArrayData.scala | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index c402275c1096a..64ab01ca57403 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -294,7 +294,8 @@ public void setNullAt(int ordinal) { assertIndexIsValid(ordinal); BitSetMethods.set(baseObject, baseOffset + 8, ordinal); - /* we assume the corrresponding column was already 0 */ + /* we assume the corrresponding column was already 0 or + will be set to 0 later by the caller side */ } public void setBoolean(int ordinal, boolean value) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 07f4acdea7ccb..dd660c80a9c3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -71,9 +71,9 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def setNullAt(ordinal: Int): Unit = throw new UnsupportedOperationException(); + override def setNullAt(ordinal: Int): Unit = array(ordinal) = null - override def update(ordinal: Int, value: Any): Unit = throw new UnsupportedOperationException(); + override def update(ordinal: Int, value: Any): Unit = array(ordinal) = value override def toString(): String = array.mkString("[", ",", "]") From 327c8acc3045b3a96b893ad221e8379d0403b3a9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 24 Dec 2016 17:36:51 +0900 Subject: [PATCH 27/39] addressed review comments --- .../expressions/complexTypeCreator.scala | 237 +++++++++--------- 1 file changed, 123 insertions(+), 114 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 36d1bdb876778..74bb7e177bc6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -58,40 +58,24 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val array = ctx.freshName("array") - val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) val isPrimitiveArray = ctx.isPrimitiveType(et) - val (preprocess, arrayData) = - GenArrayData.getCodeArrayData(ctx, et, children.size, isPrimitiveArray, array) - - val assigns = if (isPrimitiveArray) { - val primitiveTypeName = ctx.primitiveTypeName(et) - evals.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $arrayData.setNullAt($i); - } else { - $arrayData.set$primitiveTypeName($i, ${eval.value}); - } - """ - } - } else { - evals.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $array[$i] = null; - } else { - $array[$i] = ${eval.value}; - } - """ - } - } - ev.copy(code = - preprocess + - ctx.splitExpressions(ctx.INPUT_ROW, assigns) + - s"\nfinal ArrayData ${ev.value} = $arrayData;\n", + val (preprocess, postprocess, arrayData, array) = + GenArrayData.genCodeToCreateArrayData(ctx, et, children.size, isPrimitiveArray) + val assigns = GenArrayData.genCodeToAssignArrayElements( + ctx, evals, et, isPrimitiveArray, arrayData, array, true) + /* + TODO: When we generate simpler code, we have to solve the following exception + https://github.com/apache/spark/pull/13909/files#r93813725 + ev.copy( + code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess + value = arrayData, + isNull = "false") + */ + ev.copy( + code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess + + s"\nfinal ArrayData ${ev.value} = $arrayData;\n", isNull = "false") } @@ -99,36 +83,101 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } private [sql] object GenArrayData { - // This function returns Java code pieces based on DataType and isPrimitive - // for allocation of ArrayData class - def getCodeArrayData( + /** + * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class + * + * @param ctx a [[CodegenContext]] + * @param elementType data type of an underlying array + * @param numElements the number of array elements + * @param isPrimitive Are all of the elements of an underlying array primitive type + * @return (code pre-assignments, code post-assignments, underlying array name, arrayData name) + */ + def genCodeToCreateArrayData( ctx: CodegenContext, - dt: DataType, - size: Int, - isPrimitive : Boolean, - array: String): (String, String) = { + elementType: DataType, + numElements: Int, + isPrimitive : Boolean): (String, String, String, String) = { + val arrayName = ctx.freshName("array") + val arrayDataName = ctx.freshName("arrayData") if (!isPrimitive) { val arrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", array, - s"this.$array = new Object[${size}];") - ("", s"new $arrayClass($array)") + ctx.addMutableState("Object[]", arrayName, + s"this.$arrayName = new Object[${numElements}];") + ("", + s"$arrayClass $arrayDataName = new $arrayClass($arrayName);", + arrayDataName, + arrayName) } else { - val baseArray = ctx.freshName("baseArray") val unsafeArrayClass = classOf[UnsafeArrayData].getName - ctx.addMutableState(unsafeArrayClass, array, "") + val baseObject = ctx.freshName("baseObject") val unsafeArraySizeInBytes = - UnsafeArrayData.calculateHeaderPortionInBytes(size) + - ByteArrayMethods.roundNumberOfBytesToNearestWord(dt.defaultSize * size) + UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET (s""" - byte[] $baseArray = new byte[$unsafeArraySizeInBytes]; - $array = new $unsafeArrayClass(); - Platform.putLong($baseArray, $baseOffset, $size); - $array.pointTo($baseArray, $baseOffset, $unsafeArraySizeInBytes); - """, - array - ) + byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; + $unsafeArrayClass $arrayDataName = new $unsafeArrayClass(); + Platform.putLong($arrayName, $baseOffset, $numElements); + $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); + """, + "", + arrayDataName, + arrayName) + } + } + + /** + * Return Java code pieces to assign values to each element of an array + * + * @param ctx a [[CodegenContext]] + * @param evals a set of [[ExprCode]] for each element of an underlying array + * @param elementType data type of an underlying array + * @param isPrimitive Are all of the elements of an underlying array primitive type + * @param arrayDataName arrayData name + * @param arrayName underlying array name + * @param allowNull Is an assignment of null to an array element allowed + * @return a set of Strings for assignments to each element of an array + */ + def genCodeToAssignArrayElements( + ctx: CodegenContext, + evals: Seq[ExprCode], + elementType: DataType, + isPrimitive: Boolean, + arrayDataName: String, + arrayName: String, + allowNull: Boolean): Seq[String] = { + if (isPrimitive) { + val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + evals.zipWithIndex.map { case (eval, i) => + val isNullAssignment = if (allowNull) { + s"$arrayDataName.setNullAt($i);" + } else { + "throw new RuntimeException(\"Cannot use null as map key!\");" + } + eval.code + s""" + if (${eval.isNull}) { + $isNullAssignment + } else { + $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); + } + """ + } + } else { + evals.zipWithIndex.map { case (eval, i) => + val isNullAssignment = if (allowNull) { + s"$arrayName[$i] = null;" + } else { + "throw new RuntimeException(\"Cannot use null as map key!\");" + } + eval.code + s""" + if (${eval.isNull}) { + $isNullAssignment + } else { + $arrayName[$i] = ${eval.value}; + } + """ + } } } } @@ -184,71 +233,31 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapClass = classOf[ArrayBasedMapData].getName - val keyArray = ctx.freshName("keyArray") - val valueArray = ctx.freshName("valueArray") - val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) - val isPrimitiveArrayKey = ctx.isPrimitiveType(keyDt) + val isPrimitiveKey = ctx.isPrimitiveType(keyDt) val evalValues = values.map(e => e.genCode(ctx)) - val isPrimitiveArrayValue = ctx.isPrimitiveType(valueDt) - val (preprocessKeyData, keyDataArray) = - GenArrayData.getCodeArrayData(ctx, keyDt, keys.size, isPrimitiveArrayKey, keyArray) - val (preprocessValueData, valueDataArray) = - GenArrayData.getCodeArrayData(ctx, valueDt, values.size, isPrimitiveArrayValue, valueArray) - - val assignKeys = if (isPrimitiveArrayKey) { - val primitiveKeyTypeName = ctx.primitiveTypeName(keyDt) - evalKeys.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $keyDataArray.setNullAt($i); - } else { - $keyDataArray.set$primitiveKeyTypeName($i, ${eval.value}); - } - """ - } - } else { - evalKeys.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - throw new RuntimeException("Cannot use null as map key!"); - } else { - $keyArray[$i] = ${eval.value}; - } - """ - } - } - - val assignValues = if (isPrimitiveArrayValue) { - val primitiveValueTypeName = ctx.primitiveTypeName(valueDt) - evalValues.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $valueDataArray.setNullAt($i); - } else { - $valueDataArray.set$primitiveValueTypeName($i, ${eval.value}); - } - """ - } - } else { - evalValues.zipWithIndex.map { case (eval, i) => - eval.code + s""" - if (${eval.isNull}) { - $valueArray[$i] = null; - } else { - $valueArray[$i] = ${eval.value}; - } - """ - } - } - - ev.copy(code = s"final boolean ${ev.isNull} = false;" + - preprocessKeyData + - ctx.splitExpressions(ctx.INPUT_ROW, assignKeys) + - preprocessValueData + - ctx.splitExpressions(ctx.INPUT_ROW, assignValues) + - s"final MapData ${ev.value} = new $mapClass($keyDataArray, $valueDataArray);") + val isPrimitiveValue = ctx.isPrimitiveType(valueDt) + val (preprocessKeyData, postprocessKeyData, keyArrayData, keyArray) = + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys.size, isPrimitiveKey) + val (preprocessValueData, postprocessValueData, valueArrayData, valueArray) = + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values.size, isPrimitiveValue) + val assignKeys = GenArrayData.genCodeToAssignArrayElements( + ctx, evalKeys, keyDt, isPrimitiveKey, keyArrayData, keyArray, false) + val assignValues = GenArrayData.genCodeToAssignArrayElements( + ctx, evalValues, valueDt, isPrimitiveValue, valueArrayData, valueArray, true) + val code = + s""" + final boolean ${ev.isNull} = false; + $preprocessKeyData + ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)} + $postprocessKeyData + $preprocessValueData + ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)} + $postprocessValueData + final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); + """ + ev.copy(code = code) } override def prettyName: String = "map" From 7a7e9c3c0fcdd76de4536bf8ff14e2243da06c93 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 24 Dec 2016 18:02:50 +0900 Subject: [PATCH 28/39] addressed review comment --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 74bb7e177bc6f..88e1e14b85593 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -69,7 +69,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { TODO: When we generate simpler code, we have to solve the following exception https://github.com/apache/spark/pull/13909/files#r93813725 ev.copy( - code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess + code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, value = arrayData, isNull = "false") */ @@ -90,7 +90,7 @@ private [sql] object GenArrayData { * @param elementType data type of an underlying array * @param numElements the number of array elements * @param isPrimitive Are all of the elements of an underlying array primitive type - * @return (code pre-assignments, code post-assignments, underlying array name, arrayData name) + * @return (code pre-assignments, code post-assignments, arrayData name, underlying array name) */ def genCodeToCreateArrayData( ctx: CodegenContext, From ee237b405d9065fe571ff401657da560afb88cbb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 24 Dec 2016 23:10:37 +0900 Subject: [PATCH 29/39] addressed review comments --- .../expressions/complexTypeCreator.scala | 139 +++++++----------- 1 file changed, 54 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 88e1e14b85593..a9eefd04199c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -60,22 +59,11 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) - val isPrimitiveArray = ctx.isPrimitiveType(et) - val (preprocess, postprocess, arrayData, array) = - GenArrayData.genCodeToCreateArrayData(ctx, et, children.size, isPrimitiveArray) - val assigns = GenArrayData.genCodeToAssignArrayElements( - ctx, evals, et, isPrimitiveArray, arrayData, array, true) - /* - TODO: When we generate simpler code, we have to solve the following exception - https://github.com/apache/spark/pull/13909/files#r93813725 - ev.copy( - code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, - value = arrayData, - isNull = "false") - */ + val (preprocess, assigns, postprocess, arrayData, array) = + GenArrayData.genCodeToCreateArrayData(ctx, et, evals, true) ev.copy( - code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess + - s"\nfinal ArrayData ${ev.value} = $arrayData;\n", + code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, + value = arrayData, isNull = "false") } @@ -88,70 +76,28 @@ private [sql] object GenArrayData { * * @param ctx a [[CodegenContext]] * @param elementType data type of an underlying array - * @param numElements the number of array elements - * @param isPrimitive Are all of the elements of an underlying array primitive type - * @return (code pre-assignments, code post-assignments, arrayData name, underlying array name) + * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @return (code pre-assignments, assignments to each array elements, code post-assignments, + * arrayData name, underlying array name) */ def genCodeToCreateArrayData( ctx: CodegenContext, elementType: DataType, - numElements: Int, - isPrimitive : Boolean): (String, String, String, String) = { + elementsCode: Seq[ExprCode], + allowNull: Boolean): (String, Seq[String], String, String, String) = { val arrayName = ctx.freshName("array") val arrayDataName = ctx.freshName("arrayData") - if (!isPrimitive) { - val arrayClass = classOf[GenericArrayData].getName + val numElements = elementsCode.length + + if (!ctx.isPrimitiveType(elementType)) { + val arrayClass = classOf[ArrayData].getName + val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, s"this.$arrayName = new Object[${numElements}];") - ("", - s"$arrayClass $arrayDataName = new $arrayClass($arrayName);", - arrayDataName, - arrayName) - } else { - val unsafeArrayClass = classOf[UnsafeArrayData].getName - val baseObject = ctx.freshName("baseObject") - val unsafeArraySizeInBytes = - UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + - ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) - val baseOffset = Platform.BYTE_ARRAY_OFFSET - (s""" - byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $unsafeArrayClass $arrayDataName = new $unsafeArrayClass(); - Platform.putLong($arrayName, $baseOffset, $numElements); - $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); - """, - "", - arrayDataName, - arrayName) - } - } - - /** - * Return Java code pieces to assign values to each element of an array - * - * @param ctx a [[CodegenContext]] - * @param evals a set of [[ExprCode]] for each element of an underlying array - * @param elementType data type of an underlying array - * @param isPrimitive Are all of the elements of an underlying array primitive type - * @param arrayDataName arrayData name - * @param arrayName underlying array name - * @param allowNull Is an assignment of null to an array element allowed - * @return a set of Strings for assignments to each element of an array - */ - def genCodeToAssignArrayElements( - ctx: CodegenContext, - evals: Seq[ExprCode], - elementType: DataType, - isPrimitive: Boolean, - arrayDataName: String, - arrayName: String, - allowNull: Boolean): Seq[String] = { - if (isPrimitive) { - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) - evals.zipWithIndex.map { case (eval, i) => + val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (allowNull) { - s"$arrayDataName.setNullAt($i);" + s"$arrayName[$i] = null;" } else { "throw new RuntimeException(\"Cannot use null as map key!\");" } @@ -159,14 +105,32 @@ private [sql] object GenArrayData { if (${eval.isNull}) { $isNullAssignment } else { - $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); + $arrayName[$i] = ${eval.value}; } """ } + + /* + TODO: When we declare arrayDataName as GenericArrayData, + we have to solve the following exception + https://github.com/apache/spark/pull/13909/files#r93813725 + */ + ("", + assignments, + s"final $arrayClass $arrayDataName = new $genericArrayClass($arrayName);", + arrayDataName, + arrayName) } else { - evals.zipWithIndex.map { case (eval, i) => + val unsafeArrayClass = classOf[UnsafeArrayData].getName + val unsafeArraySizeInBytes = + UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (allowNull) { - s"$arrayName[$i] = null;" + s"$arrayDataName.setNullAt($i);" } else { "throw new RuntimeException(\"Cannot use null as map key!\");" } @@ -174,10 +138,21 @@ private [sql] object GenArrayData { if (${eval.isNull}) { $isNullAssignment } else { - $arrayName[$i] = ${eval.value}; + $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); } """ } + + (s""" + byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; + final $unsafeArrayClass $arrayDataName = new $unsafeArrayClass(); + Platform.putLong($arrayName, $baseOffset, $numElements); + $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); + """, + assignments, + "", + arrayDataName, + arrayName) } } } @@ -235,17 +210,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val mapClass = classOf[ArrayBasedMapData].getName val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) - val isPrimitiveKey = ctx.isPrimitiveType(keyDt) val evalValues = values.map(e => e.genCode(ctx)) - val isPrimitiveValue = ctx.isPrimitiveType(valueDt) - val (preprocessKeyData, postprocessKeyData, keyArrayData, keyArray) = - GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys.size, isPrimitiveKey) - val (preprocessValueData, postprocessValueData, valueArrayData, valueArray) = - GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values.size, isPrimitiveValue) - val assignKeys = GenArrayData.genCodeToAssignArrayElements( - ctx, evalKeys, keyDt, isPrimitiveKey, keyArrayData, keyArray, false) - val assignValues = GenArrayData.genCodeToAssignArrayElements( - ctx, evalValues, valueDt, isPrimitiveValue, valueArrayData, valueArray, true) + val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData, keyArray) = + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, false) + val (preprocessValueData, assignValues, postprocessValueData, valueArrayData, valueArray) = + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, true) val code = s""" final boolean ${ev.isNull} = false; From 28df09fb569149470dcdb45e36e2d8b993c99a17 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 24 Dec 2016 23:18:27 +0900 Subject: [PATCH 30/39] fixed test failure --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9eefd04199c6..893af13f5dbff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -126,6 +126,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET + ctx.addMutableState(unsafeArrayClass, arrayDataName, ""); val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -145,7 +146,7 @@ private [sql] object GenArrayData { (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - final $unsafeArrayClass $arrayDataName = new $unsafeArrayClass(); + $arrayDataName = new $unsafeArrayClass(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, From 293b344e761bc4b9c04891c02c702a374472345a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Dec 2016 13:38:53 +0900 Subject: [PATCH 31/39] addressed review comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 893af13f5dbff..471538c2a11b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -59,7 +59,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) - val (preprocess, assigns, postprocess, arrayData, array) = + val (preprocess, assigns, postprocess, arrayData, _) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, true) ev.copy( code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, @@ -75,8 +75,9 @@ private [sql] object GenArrayData { * Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class * * @param ctx a [[CodegenContext]] - * @param elementType data type of an underlying array + * @param elementType data type of underlying array elements * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @param allowNull if to assign null value to an array element is allowed * @return (code pre-assignments, assignments to each array elements, code post-assignments, * arrayData name, underlying array name) */ From 69d5e33d2035fc5f6f4dfec65bde60c7dfc39548 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Dec 2016 16:21:19 +0900 Subject: [PATCH 32/39] addressed review comments --- .../expressions/complexTypeCreator.scala | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 471538c2a11b2..e37530218ec13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -59,7 +59,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) - val (preprocess, assigns, postprocess, arrayData, _) = + val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, true) ev.copy( code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, @@ -85,13 +85,12 @@ private [sql] object GenArrayData { ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - allowNull: Boolean): (String, Seq[String], String, String, String) = { + allowNull: Boolean): (String, Seq[String], String, String) = { val arrayName = ctx.freshName("array") val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length if (!ctx.isPrimitiveType(elementType)) { - val arrayClass = classOf[ArrayData].getName val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, s"this.$arrayName = new Object[${numElements}];") @@ -118,16 +117,14 @@ private [sql] object GenArrayData { */ ("", assignments, - s"final $arrayClass $arrayDataName = new $genericArrayClass($arrayName);", - arrayDataName, - arrayName) + s"final ArrayClass $arrayDataName = new $genericArrayClass($arrayName);", + arrayDataName) } else { - val unsafeArrayClass = classOf[UnsafeArrayData].getName val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState(unsafeArrayClass, arrayDataName, ""); + ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -147,14 +144,13 @@ private [sql] object GenArrayData { (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $arrayDataName = new $unsafeArrayClass(); + $arrayDataName = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, assignments, "", - arrayDataName, - arrayName) + arrayDataName) } } } @@ -213,9 +209,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val MapType(keyDt, valueDt, _) = dataType val evalKeys = keys.map(e => e.genCode(ctx)) val evalValues = values.map(e => e.genCode(ctx)) - val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData, keyArray) = + val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, false) - val (preprocessValueData, assignValues, postprocessValueData, valueArrayData, valueArray) = + val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, true) val code = s""" From 2556ba5482c516e496676cc3a649a3c5f7f75d41 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Dec 2016 23:13:13 +0900 Subject: [PATCH 33/39] fix a test failure --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e37530218ec13..9e9de8fb126bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -117,7 +117,7 @@ private [sql] object GenArrayData { */ ("", assignments, - s"final ArrayClass $arrayDataName = new $genericArrayClass($arrayName);", + s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", arrayDataName) } else { val unsafeArraySizeInBytes = From 34bff15d09e1200660141f104c225356590bc860 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 26 Dec 2016 23:38:18 +0900 Subject: [PATCH 34/39] fix a test failure --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 9e9de8fb126bd..b6dac12e6cade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -79,7 +79,7 @@ private [sql] object GenArrayData { * @param elementsCode a set of [[ExprCode]] for each element of an underlying array * @param allowNull if to assign null value to an array element is allowed * @return (code pre-assignments, assignments to each array elements, code post-assignments, - * arrayData name, underlying array name) + * arrayData name) */ def genCodeToCreateArrayData( ctx: CodegenContext, From 4a0409a42e784cbf132cabeb81786c3d13963599 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Dec 2016 00:05:18 +0900 Subject: [PATCH 35/39] addressed review comments --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6dac12e6cade..0f1a89a00c695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -99,7 +99,7 @@ private [sql] object GenArrayData { val isNullAssignment = if (allowNull) { s"$arrayName[$i] = null;" } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" + "throw new RuntimeException(\"Cannot use null!\");" } eval.code + s""" if (${eval.isNull}) { @@ -131,7 +131,7 @@ private [sql] object GenArrayData { val isNullAssignment = if (allowNull) { s"$arrayDataName.setNullAt($i);" } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" + "throw new RuntimeException(\"Cannot use null!\");" } eval.code + s""" if (${eval.isNull}) { From dcce4c57be03df1b263b77600876398962490748 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Dec 2016 09:41:08 +0900 Subject: [PATCH 36/39] addressed review comment --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0f1a89a00c695..fa352771c3c03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -110,11 +110,6 @@ private [sql] object GenArrayData { """ } - /* - TODO: When we declare arrayDataName as GenericArrayData, - we have to solve the following exception - https://github.com/apache/spark/pull/13909/files#r93813725 - */ ("", assignments, s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", From 2f67ac2805fc59ef872be8a923ab6a877812309d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Dec 2016 23:10:13 +0900 Subject: [PATCH 37/39] address a review comment --- .../catalyst/expressions/complexTypeCreator.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index fa352771c3c03..474627fd1c8e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -60,7 +60,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val et = dataType.elementType val evals = children.map(e => e.genCode(ctx)) val (preprocess, assigns, postprocess, arrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, et, evals, true) + GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, value = arrayData, @@ -77,7 +77,7 @@ private [sql] object GenArrayData { * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements * @param elementsCode a set of [[ExprCode]] for each element of an underlying array - * @param allowNull if to assign null value to an array element is allowed + * @param isMapKey if throw an exception when to assign a null value to an array element * @return (code pre-assignments, assignments to each array elements, code post-assignments, * arrayData name) */ @@ -85,7 +85,7 @@ private [sql] object GenArrayData { ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - allowNull: Boolean): (String, Seq[String], String, String) = { + isMapKey: Boolean): (String, Seq[String], String, String) = { val arrayName = ctx.freshName("array") val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length @@ -96,7 +96,7 @@ private [sql] object GenArrayData { s"this.$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => - val isNullAssignment = if (allowNull) { + val isNullAssignment = if (!isMapKey) { s"$arrayName[$i] = null;" } else { "throw new RuntimeException(\"Cannot use null!\");" @@ -123,7 +123,7 @@ private [sql] object GenArrayData { val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => - val isNullAssignment = if (allowNull) { + val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" } else { "throw new RuntimeException(\"Cannot use null!\");" @@ -205,9 +205,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val evalKeys = keys.map(e => e.genCode(ctx)) val evalValues = values.map(e => e.genCode(ctx)) val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, false) + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true) val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, true) + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = s""" final boolean ${ev.isNull} = false; From c986361da6a5edaf0d3aef24340c2d3b2d1fdd42 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Dec 2016 23:32:34 +0900 Subject: [PATCH 38/39] address a review comment --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 474627fd1c8e9..a86b5e8d7f754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -77,7 +77,7 @@ private [sql] object GenArrayData { * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements * @param elementsCode a set of [[ExprCode]] for each element of an underlying array - * @param isMapKey if throw an exception when to assign a null value to an array element + * @param isMapKey if true, throw an exception when the element is null * @return (code pre-assignments, assignments to each array elements, code post-assignments, * arrayData name) */ From cfe2e3d9defd7c04e921e886dc129725ce06fc67 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 28 Dec 2016 16:00:30 +0900 Subject: [PATCH 39/39] revert a change of an exception message --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a86b5e8d7f754..22277ad8d56ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -99,7 +99,7 @@ private [sql] object GenArrayData { val isNullAssignment = if (!isMapKey) { s"$arrayName[$i] = null;" } else { - "throw new RuntimeException(\"Cannot use null!\");" + "throw new RuntimeException(\"Cannot use null as map key!\");" } eval.code + s""" if (${eval.isNull}) { @@ -126,7 +126,7 @@ private [sql] object GenArrayData { val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" } else { - "throw new RuntimeException(\"Cannot use null!\");" + "throw new RuntimeException(\"Cannot use null as map key!\");" } eval.code + s""" if (${eval.isNull}) {