diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b4b5f0a26593..051d8e38a980 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -93,28 +93,6 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } - /** - * Returns the array data as the java primitive array. - * For example, an array of IntegerType will return an int[]. - * Throws exceptions for unhandled schemas. - */ - public static Object toPrimitiveJavaArray(ColumnarArray array) { - DataType dt = array.data.dataType(); - if (dt instanceof IntegerType) { - int[] result = new int[array.length]; - ColumnVector data = array.data; - for (int i = 0; i < result.length; i++) { - if (data.isNullAt(array.offset + i)) { - throw new RuntimeException("Cannot handle NULL values."); - } - result[i] = data.getInt(array.offset + i); - } - return result; - } else { - throw new UnsupportedOperationException(); - } - } - private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/UnsafeColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/UnsafeColumnVector.java new file mode 100644 index 000000000000..f5db99f26889 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/UnsafeColumnVector.java @@ -0,0 +1,512 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteBuffer; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column backed by UnsafeArrayData on byte[]. + */ +public final class UnsafeColumnVector extends WritableColumnVector { + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + + // Array stored in byte array + private byte[] data; + private long offset; + + private int lastArrayRow; + private int lastArrayPos; + private UnsafeArrayData unsafeArray = new UnsafeArrayData(); + + public UnsafeColumnVector(int capacity, DataType type) { + super(capacity, type); + + reserveInternal(capacity); + reset(); + nulls = new byte[capacity]; + } + + @Override + public void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + nulls[rowId] = (byte)0; + } + + @Override + public void putNull(int rowId) { + nulls[rowId] = (byte)1; + ++numNulls; + anyNullsSet = true; + } + + @Override + public void putNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)1; + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)0; + } + } + + @Override + public boolean isNullAt(int rowId) { + if (nulls == null) return false; + if (data != null) { + return nulls[rowId] == 1; + } else { + return unsafeArray.isNullAt(rowId); + } + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new NotImplementedException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new NotImplementedException(); + } + + @Override + public boolean getBoolean(int rowId) { + assert(dictionary == null); + return unsafeArray.getBoolean(rowId); + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + boolean[] array = unsafeArray.toBooleanArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + boolean[] newArray = new boolean[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int rowId) { + assert(dictionary == null); + return unsafeArray.getByte(rowId); + } + + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + byte[] array = unsafeArray.toByteArray(); + if (array.length == count) { + return array; + } else { + assert(count < array.length); + byte[] newArray = new byte[count]; + System.arraycopy(array, 0, newArray, 0, count); + return newArray; + } + } + + @Override + protected UTF8String getBytesAsUTF8String(int rowId, int count) { + return UTF8String.fromAddress(null, unsafeArray.getBaseOffset() + rowId, count); + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public short getShort(int rowId) { + assert(dictionary == null); + return unsafeArray.getShort(rowId); + } + + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + short[] array = unsafeArray.toShortArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + short[] newArray = new short[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int rowId) { + assert(dictionary == null); + return unsafeArray.getInt(rowId); + } + + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + int[] array = unsafeArray.toIntArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + int[] newArray = new int[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + public int getDictId(int rowId) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public long getLong(int rowId) { + assert(dictionary == null); + return unsafeArray.getLong(rowId); + } + + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + long[] array = unsafeArray.toLongArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + long[] newArray = new long[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public float getFloat(int rowId) { + assert(dictionary == null); + return unsafeArray.getFloat(rowId); + } + + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + float[] array = unsafeArray.toFloatArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + float[] newArray = new float[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int rowId) { + assert(dictionary == null); + return unsafeArray.getDouble(rowId); + } + + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + double[] array = unsafeArray.toDoubleArray(); + if (rowId == 0 && array.length == count) { + return array; + } else { + assert(count < array.length); + double[] newArray = new double[count]; + System.arraycopy(array, rowId, newArray, 0, count); + return newArray; + } + } + + // + // APIs dealing with Arrays + // + + private void updateLastArrayPos(int rowId) { + int relative = rowId - lastArrayRow; + if (relative == 1 && !anyNullsSet()) { + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } else if (relative == 0) { + // return the same position + return; + } else if (relative > 0) { + for (int i = 0; i < relative; i++) { + if (isNullAt(lastArrayRow + i)) continue; + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } else { + // recalculate pos from the first entry + lastArrayPos = 0; + for (int i = 0; i < rowId; i++) { + if (isNullAt(i)) continue; + int totalBytesLastArray = Platform.getInt(data, offset + lastArrayPos); + lastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + } + } + lastArrayRow = rowId; + } + + private int setUnsafeArray(int rowId) { + assert(data != null); + int length; + if (rowId - lastArrayRow == 1 && !anyNullsSet()) { + // inlined frequently-executed path (access an array in the next row) + lastArrayRow = rowId; + long localOffset = offset; + int localLastArrayPos = lastArrayPos; + int totalBytesLastArray = Platform.getInt(data, localOffset + localLastArrayPos); + localLastArrayPos += totalBytesLastArray + 4; // 4 for totalbytes in UnsafeArrayData + length = Platform.getInt(data, localOffset + localLastArrayPos); + ((UnsafeColumnVector)(this.resultArray.data)) + .unsafeArray.pointTo(data, localOffset + localLastArrayPos + 4, length); + lastArrayPos = localLastArrayPos; + } else { + updateLastArrayPos(rowId); + length = Platform.getInt(data, offset + lastArrayPos); // inline getArrayLength() + ((UnsafeColumnVector)(this.resultArray.data)) + .unsafeArray.pointTo(data, offset + lastArrayPos + 4, length); + } + return ((UnsafeColumnVector)(this.resultArray.data)).unsafeArray.numElements(); + } + + @Override + public int getArrayLength(int rowId) { + return setUnsafeArray(rowId); + } + + @Override + public int getArrayOffset(int rowId) { + return 0; + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Byte Arrays + // + + // This method puts byte[] in value for the whole data, which is represented in + // UnsafeArrayData, into data without copying data to avoid inefficient copy. + // For UnsafeColumnVector, this method should be called only once at the beginning. + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + assert(this.resultArray != null); + data = value; + this.offset = Platform.BYTE_ARRAY_OFFSET + offset; + + lastArrayRow = Integer.MAX_VALUE; + lastArrayPos = 0; + + setIsConstant(); + + return value.length - offset; + } + + // Spilt this function out since it is the slow path. + @Override + protected void reserveInternal(int newCapacity) { + byte[] newNulls = new byte[newCapacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); + nulls = newNulls; + capacity = newCapacity; + } + + @Override + protected UnsafeColumnVector reserveNewColumn(int capacity, DataType type) { + return new UnsafeColumnVector(capacity, type); + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 85c36b7da949..05e23c3f219b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -156,6 +156,9 @@ private[sql] object ColumnAccessor { if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] nativeAccessor.decompress(columnVector, numRows) + } else if (columnAccessor.isInstanceOf[ArrayColumnAccessor]) { + val arrayAccessor = columnAccessor.asInstanceOf[ArrayColumnAccessor] + arrayAccessor.extract(columnVector, numRows) } else { throw new RuntimeException("Not support non-primitive type now") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b1285..ac555af244ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -37,15 +37,6 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - override def vectorTypes: Option[Seq[String]] = - Option(Seq.fill(attributes.length)( - if (!conf.offHeapColumnVectorEnabled) { - classOf[OnHeapColumnVector].getName - } else { - classOf[OffHeapColumnVector].getName - } - )) - /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. * If false, get data from UnsafeRow build from ColumnVector @@ -56,6 +47,8 @@ case class InMemoryTableScanExec( relation.schema.fields.forall(f => f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case ArrayType(dt, _) if (dt == BooleanType || dt == ByteType || dt == ShortType || + dt == IntegerType || dt == LongType || dt == FloatType || dt == DoubleType) => true case _ => false }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } @@ -67,14 +60,32 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) + override def vectorTypes: Option[Seq[String]] = { + val fields = columnarBatchSchema.fields + Option((0 until fields.length).map { i => + if (fields(i).dataType.isInstanceOf[ArrayType]) { + classOf[UnsafeColumnVector].getName + } else if (!conf.offHeapColumnVectorEnabled) { + classOf[OnHeapColumnVector].getName + } else { + classOf[OffHeapColumnVector].getName + } + }) + } + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val fields = columnarBatchSchema.fields val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { - OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) - } else { - OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) - } + val columnVectors = (0 until fields.length).map { i => + if (fields(i).dataType.isInstanceOf[ArrayType]) { + new UnsafeColumnVector(rowCount, fields(i).dataType) + } else if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + new OnHeapColumnVector(rowCount, fields(i).dataType) + } else { + new OffHeapColumnVector(rowCount, fields(i).dataType) + } + }.toArray val columnarBatch = new ColumnarBatch( columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) columnarBatch.setNumRows(rowCount) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 2f09757aa341..a947ca7e8b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.WritableColumnVector private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ @@ -56,4 +57,16 @@ private[columnar] trait NullableColumnAccessor extends ColumnAccessor { } abstract override def hasNext: Boolean = seenNulls < nullCount || super.hasNext + + def extract(columnVector: WritableColumnVector, capacity: Int): Unit = { + if (nextNullIndex != -1) { + columnVector.putNull(nextNullIndex) + for (_ <- 1 until nullCount) { + val ordinal = ByteBufferHelper.getInt(nullsBuffer) + columnVector.putNull(ordinal) + } + } + columnVector.putByteArray( + 0, underlyingBuffer.array, underlyingBuffer.position, underlyingBuffer.limit) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 0881212a64de..7d9760ba0a86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -76,9 +76,9 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("primitive data type accesses in persist data") { val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, - 31.25.toFloat, 63.75, null) + 31.25.toFloat, 63.75, null, Array(1.2, 2.3), Array[java.lang.Double](1.2, null)) val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, IntegerType) + FloatType, DoubleType, IntegerType, ArrayType(DoubleType, false), ArrayType(DoubleType, true)) val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, true) } @@ -109,4 +109,11 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { df.count assert(df.filter("d < 3").count == 1) } + + test("primitive array for Dataset") { + val ds = sparkContext.parallelize(Seq(Array(6, 7), Array(8, 9, 10)), 1).toDS.cache + ds.count + val ds1 = ds.map(p => p).collect + assert(ds1(0) === Array(6, 7) && ds1(1) === Array(8, 9, 10)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c4..7e6d55256145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -135,6 +135,20 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { ) assert(dsIntFilter.collect() === Array(1, 2)) + val dsIntArray = sparkContext.parallelize(Seq(Array(1, 2), Array(-1, -2)), 1).toDS.cache + dsIntArray.count + val dsIntArrayFilter = dsIntArray.filter(a => a(0) > 0) + val planIntArray = dsIntArrayFilter.queryExecution.executedPlan + assert(planIntArray.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntArrayFilter.collect() === Array(Array(1, 2))) + // cache for string type is not supported for InMemoryTableScanExec val dsString = spark.range(3).map(_.toString).cache dsString.count diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 3c76ca79f5dd..b8296e4da813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.vectorized import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper @@ -52,6 +52,13 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } + private def withColumnArrayVectors( + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + withVector(new UnsafeColumnVector(size, dt))(block) + } + testVectors("boolean", 10, BooleanType) { testVector => (0 until 10).foreach { i => testVector.appendBoolean(i % 2 == 0) @@ -396,5 +403,369 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } } + + test("CachedBatch boolean array Apis") { + val N = 16 + val dataType = ArrayType(BooleanType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Boolean]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i % 2 == 0) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toBooleanArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getBoolean(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toBooleanArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toBooleanArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch byte array Apis") { + val N = 16 + val dataType = ArrayType(ByteType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Byte]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i.toByte) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toByteArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getByte(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toByteArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toByteArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch short array Apis") { + val N = 16 + val dataType = ArrayType(ShortType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Short]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i.toShort) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toShortArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getShort(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toShortArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toShortArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch int array Apis") { + val N = 16 + val dataType = ArrayType(IntegerType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Int]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.range(0, i) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toIntArray() === data(i)) + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toIntArray() === data(i * 3)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getInt(j) == data(i)(j)) + } + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toIntArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch long array Apis") { + val N = 16 + val dataType = ArrayType(LongType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Long]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i.toLong) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toLongArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getLong(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toLongArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toLongArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch float array Apis") { + val N = 16 + val dataType = ArrayType(FloatType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Float]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i.toFloat) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toFloatArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getFloat(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toFloatArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toFloatArray() === data(N - i * 3)) + } + } + } + } + + test("CachedBatch double array Apis") { + val N = 16 + val dataType = ArrayType(DoubleType, false) + val columnBuilder = ColumnBuilderHelper(dataType, 4096, "col", true) + val row = new GenericInternalRow(N) + val data = new Array[Array[Double]](N) + val nulls = Seq(0, 6, 11) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + row.setNullAt(0) + } else { + data(i) = Array.tabulate(i)(i => i.toDouble) + row.update(0, UnsafeArrayData.fromPrimitiveArray(data(i))) + } + columnBuilder.appendFrom(row, 0) + } + + withColumnArrayVectors(N, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + for (i <- 0 until N) { + if (nulls.contains(i)) { + assert(testVector.isNullAt(i) == true) + } else { + assert(testVector.isNullAt(i) == false) + assert(testVector.getArray(i).toDoubleArray() === data(i)) + for (j <- 0 until data(i).length) { + assert(testVector.getArray(i).getDouble(j) == data(i)(j)) + } + } + } + for (i <- 0 to N / 3) { + if (nulls.contains(i * 3)) { + assert(testVector.isNullAt(i * 3) == true) + } else { + assert(testVector.isNullAt(i * 3) == false) + assert(testVector.getArray(i * 3).toDoubleArray() === data(i * 3)) + } + } + for (i <- 1 to N / 3) { + if (nulls.contains(N - i * 3)) { + assert(testVector.isNullAt(N - i * 3) == true) + } else { + assert(testVector.isNullAt(N - i * 3) == false) + assert(testVector.getArray(N - i * 3).toDoubleArray() === data(N - i * 3)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 80a50866aa50..8f9c9776b51e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -645,26 +645,26 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = column.getArray(0).toIntArray() + val a2 = column.getArray(1).toIntArray() + val a3 = column.getArray(2).toIntArray() + val a4 = column.getArray(3).toIntArray() assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -677,8 +677,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - === array) + assert(column.getArray(0).toIntArray === array) } test("toArray for primitive types") {