diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index afea4676893ed..3af08617a6784 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -259,4 +260,158 @@ public void write(int ordinal, CalendarInterval input) { // move the cursor forward. holder.cursor += 16; } + + private void writePrimitiveArray(Object input, int offset, int elementSize, int length) { + Platform.copyMemory(input, offset, holder.buffer, startingOffset + headerInBytes, elementSize * length); + } + + public void writePrimitiveBooleanArray(ArrayData arrayData) { + boolean[] input = arrayData.toBooleanArray(); + int length = input.length; + int offset = Platform.BYTE_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 1, length); + } + + public void writePrimitiveByteArray(ArrayData arrayData) { + byte[] input = arrayData.toByteArray(); + int length = input.length; + int offset = Platform.BYTE_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 1, length); + } + + public void writePrimitiveShortArray(ArrayData arrayData) { + short[] input = arrayData.toShortArray(); + int length = input.length; + int offset = Platform.SHORT_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 2, length); + } + + public void writePrimitiveIntArray(ArrayData arrayData) { + int[] input = arrayData.toIntArray(); + int length = input.length; + int offset = Platform.INT_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 4, length); + } + + public void writePrimitiveLongArray(ArrayData arrayData) { + long[] input = arrayData.toLongArray(); + int length = input.length; + int offset = Platform.LONG_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 8, length); + } + + public void writePrimitiveFloatArray(ArrayData arrayData) { + float[] input = arrayData.toFloatArray(); + int length = input.length; + int offset = Platform.FLOAT_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 4, length); + } + + public void writePrimitiveDoubleArray(ArrayData arrayData) { + double[] input = arrayData.toDoubleArray(); + int length = input.length; + int offset = Platform.DOUBLE_ARRAY_OFFSET; + writePrimitiveArray(input, offset, 8, length); + } + +/** uncomment this if SPARK-16043 is merged + + public void writePrimitiveBooleanArray(ArrayData arrayData) { + if (arrayData instanceof GenericBooleanArrayData) { + boolean[] input = ((GenericBooleanArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.BOOLEAN_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putBoolean(holder.buffer, holder.cursor + i, arrayData.getBoolean(i)); + } + } + } + + public void writePrimitiveByteArray(ArrayData arrayData) { + if (arrayData instanceof GenericByteArrayData) { + byte[] input = ((GenericByteArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putByte(holder.buffer, holder.cursor + i, arrayData.getByte(i)); + } + } + } + + public void writePrimitiveShortArray(ArrayData arrayData) { + if (arrayData instanceof GenericShortArrayData) { + short[] input = ((GenericShortArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.SHORT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putShort(holder.buffer, holder.cursor + i, arrayData.getShort(i)); + } + } + } + + public void writePrimitiveIntArray(ArrayData arrayData) { + if (arrayData instanceof GenericIntArrayData) { + int[] input = ((GenericIntArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.INT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putInt(holder.buffer, holder.cursor + i, arrayData.getInt(i)); + } + } + } + + public void writePrimitiveLongArray(ArrayData arrayData) { + if (arrayData instanceof GenericLongArrayData) { + long[] input = ((GenericLongArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.LONG_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putLong(holder.buffer, holder.cursor + i, arrayData.getLong(i)); + } + } + } + + public void writePrimitiveFloatArray(ArrayData arrayData) { + if (arrayData instanceof GenericFloatArrayData) { + float[] input = ((GenericFloatArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.FLOAT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putFloat(holder.buffer, holder.cursor + i, arrayData.getFloat(i)); + } + } + } + + public void writePrimitiveDoubleArray(ArrayData arrayData) { + if (arrayData instanceof GenericDoubleArrayData) { + double[] input = ((GenericDoubleArrayData)arrayData).primitiveArray(); + int length = input.length; + Platform.copyMemory(input, Platform.DOUBLE_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putFloat(holder.buffer, holder.cursor + i, arrayData.getFloat(i)); + } + } + } +*/ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2cb9..aa0a052f45d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -117,12 +117,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} + ${writeArrayToBuffer(ctx, input.value, et, cn, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ @@ -171,6 +171,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, bufferHolder: String): String = { val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") @@ -202,10 +203,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + ${writeArrayToBuffer(ctx, element, et, cn, bufferHolder)} $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ @@ -224,7 +225,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val typeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val storeElements = if (containsNull) { + s""" + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNull${typeName}($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } + } + """ + } else { + if (ctx.isPrimitiveType(jt)) { + s"$arrayWriter.writePrimitive${typeName}Array($input);" + } else { + s""" + for (int $index = 0; $index < $numElements; $index++) { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } + """ + } + } + s""" if ($input instanceof UnsafeArrayData) { ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} @@ -232,14 +257,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $numElements = $input.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); - for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); - } else { - final $jt $element = ${ctx.getValue(input, et, index)}; - $writeElement - } - } + $storeElements } """ } @@ -271,11 +289,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, keys, keyType, false, bufferHolder)} // Write the numBytes of key array into the first 8 bytes. Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, values, valueType, true, bufferHolder)} } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838a..e4a968e2eec8d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.unsafe.Platform class BufferHolderSuite extends SparkFunSuite { @@ -36,4 +38,190 @@ class BufferHolderSuite extends SparkFunSuite { } assert(e.getMessage.contains("exceeds size limitation")) } + + def performUnsafeArrayWriter(length: Int, elementSize: Int, f: (UnsafeArrayWriter) => Unit): + UnsafeArrayData = { + val unsafeRow = new UnsafeRow(1) + val unsafeArrayWriter = new UnsafeArrayWriter + val bufferHolder = new BufferHolder(unsafeRow, 32) + bufferHolder.reset() + val cursor = bufferHolder.cursor + unsafeArrayWriter.initialize(bufferHolder, length, elementSize) + // execute UnsafeArrayWriter.foo() in f() + f(unsafeArrayWriter) + + val unsafeArray = new UnsafeArrayData + unsafeArray.pointTo(bufferHolder.buffer, cursor.toLong, bufferHolder.cursor - cursor) + assert(unsafeArray.numElements() == length) + unsafeArray + } + + def initializeUnsafeArrayData(data: Seq[Any], elementSize: Int): + UnsafeArrayData = { + val length = data.length + val unsafeArray = new UnsafeArrayData + val headerSize = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val size = headerSize + elementSize * length + val buffer = new Array[Byte](size) + Platform.putInt(buffer, Platform.BYTE_ARRAY_OFFSET, length) + unsafeArray.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET, size) + assert(unsafeArray.numElements == length) + data.zipWithIndex.map { case (e, i) => + val offset = Platform.BYTE_ARRAY_OFFSET + headerSize + elementSize * i + e match { + case _ : Boolean => Platform.putBoolean(buffer, offset, e.asInstanceOf[Boolean]) + case _ : Byte => Platform.putByte(buffer, offset, e.asInstanceOf[Byte]) + case _ : Short => Platform.putShort(buffer, offset, e.asInstanceOf[Short]) + case _ : Int => Platform.putInt(buffer, offset, e.asInstanceOf[Int]) + case _ : Long => Platform.putLong(buffer, offset, e.asInstanceOf[Long]) + case _ : Float => Platform.putFloat(buffer, offset, e.asInstanceOf[Float]) + case _ : Double => Platform.putDouble(buffer, offset, e.asInstanceOf[Double]) + case _ => throw new UnsupportedOperationException() + } + } + unsafeArray + } + + val booleanData = Seq(true, false) + val byteData = Seq(0.toByte, 1.toByte, Byte.MaxValue, Byte.MinValue) + val shortData = Seq(0.toShort, 1.toShort, Short.MaxValue, Short.MinValue) + val intData = Seq(0, 1, -1, Int.MaxValue, Int.MinValue) + val longData = Seq(0.toLong, 1.toLong, -1.toLong, Long.MaxValue, Long.MinValue) + val floatData = Seq(0.toFloat, 1.1.toFloat, -1.1.toFloat, Float.MaxValue, Float.MinValue) + val doubleData = Seq(0.toDouble, 1.1.toDouble, -1.1.toDouble, Double.MaxValue, Double.MinValue) + + test("UnsafeArrayDataWriter write") { + val boolUnsafeArray = performUnsafeArrayWriter(booleanData.length, 1, + (writer: UnsafeArrayWriter) => booleanData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + booleanData.zipWithIndex.map { case (e, i) => assert(boolUnsafeArray.getBoolean(i) == e) } + + val byteUnsafeArray = performUnsafeArrayWriter(byteData.length, 1, + (writer: UnsafeArrayWriter) => byteData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + byteData.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortUnsafeArray = performUnsafeArrayWriter(shortData.length, 2, + (writer: UnsafeArrayWriter) => shortData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + shortData.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intUnsafeArray = performUnsafeArrayWriter(intData.length, 4, + (writer: UnsafeArrayWriter) => intData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + intData.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longUnsafeArray = performUnsafeArrayWriter(longData.length, 8, + (writer: UnsafeArrayWriter) => longData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + longData.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatUnsafeArray = performUnsafeArrayWriter(floatData.length, 8, + (writer: UnsafeArrayWriter) => floatData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + floatData.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleUnsafeArray = performUnsafeArrayWriter(doubleData.length, 8, + (writer: UnsafeArrayWriter) => doubleData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + doubleData.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } + + test("toPrimitiveArray") { + val booleanUnsafeArray = initializeUnsafeArrayData(booleanData, 1) + booleanUnsafeArray.toBooleanArray(). + zipWithIndex.map { case (e, i) => assert(e == booleanData(i)) } + + val byteUnsafeArray = initializeUnsafeArrayData(byteData, 1) + byteUnsafeArray.toByteArray().zipWithIndex.map { case (e, i) => assert(e == byteData(i)) } + + val shortUnsafeArray = initializeUnsafeArrayData(shortData, 2) + shortUnsafeArray.toShortArray().zipWithIndex.map { case (e, i) => assert(e == shortData(i)) } + + val intUnsafeArray = initializeUnsafeArrayData(intData, 4) + intUnsafeArray.toIntArray().zipWithIndex.map { case (e, i) => assert(e == intData(i)) } + + val longUnsafeArray = initializeUnsafeArrayData(longData, 8) + longUnsafeArray.toLongArray().zipWithIndex.map { case (e, i) => assert(e == longData(i)) } + + val floatUnsafeArray = initializeUnsafeArrayData(floatData, 4) + floatUnsafeArray.toFloatArray().zipWithIndex.map { case (e, i) => assert(e == floatData(i)) } + + val doubleUnsafeArray = initializeUnsafeArrayData(doubleData, 8) + doubleUnsafeArray.toDoubleArray(). + zipWithIndex.map { case (e, i) => assert(e == doubleData(i)) } + } + + test("fromPrimitiveArray") { + val booleanArray = booleanData.toArray + val booleanUnsafeArray = UnsafeArrayData.fromPrimitiveArray(booleanArray) + booleanArray.zipWithIndex.map { case (e, i) => assert(booleanUnsafeArray.getBoolean(i) == e) } + + val byteArray = byteData.toArray + val byteUnsafeArray = UnsafeArrayData.fromPrimitiveArray(byteArray) + byteArray.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortArray = shortData.toArray + val shortUnsafeArray = UnsafeArrayData.fromPrimitiveArray(shortArray) + shortArray.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intArray = intData.toArray + val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) + intArray.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longArray = longData.toArray + val longUnsafeArray = UnsafeArrayData.fromPrimitiveArray(longArray) + longArray.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatArray = floatData.toArray + val floatUnsafeArray = UnsafeArrayData.fromPrimitiveArray(floatArray) + floatArray.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleArray = doubleData.toArray + val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray) + doubleArray.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } + + test("writePrimitiveArray") { + val booleanArray = booleanData.toArray + val booleanUnsafeArray = performUnsafeArrayWriter(booleanArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveBooleanArray(new GenericArrayData(booleanArray))) + booleanArray.zipWithIndex.map { case (e, i) => assert(booleanUnsafeArray.getBoolean(i) == e) } + + val byteArray = byteData.toArray + val byteUnsafeArray = performUnsafeArrayWriter(byteArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveByteArray(new GenericArrayData(byteArray))) + byteArray.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortArray = shortData.toArray + val shortUnsafeArray = performUnsafeArrayWriter(shortArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveShortArray(new GenericArrayData(shortArray))) + shortArray.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intArray = intData.toArray + val intUnsafeArray = performUnsafeArrayWriter(intArray.length, 4, + (writer: UnsafeArrayWriter) => writer.writePrimitiveIntArray(new GenericArrayData(intArray))) + intArray.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longArray = longData.toArray + val longUnsafeArray = performUnsafeArrayWriter(longArray.length, 8, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveLongArray(new GenericArrayData(longArray))) + longArray.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatArray = floatData.toArray + val floatUnsafeArray = performUnsafeArrayWriter(floatArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveFloatArray(new GenericArrayData(floatArray))) + floatArray.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleArray = doubleData.toArray + val doubleUnsafeArray = performUnsafeArrayWriter(doubleArray.length, 8, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveDoubleArray(new GenericArrayData(doubleArray))) + doubleArray.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 1230b921aa279..a5c201949748d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,6 +27,25 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("primitive type on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(v + 2, v + 3)") + checkAnswer(rows, Seq(Row(Array(3, 4)), Row(Array(4, 5)))) + } + + test("primitive type and null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(v + 2, null, v + 3)") + checkAnswer(rows, Seq(Row(Array(3, null, 4)), Row(Array(4, null, 5)))) + } + + test("array with null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(Array(v, v + 1)," + + "null," + + "Array(v, v - 1))").collect + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")