diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/CachedBatchColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/CachedBatchColumnVector.java new file mode 100644 index 000000000000..fe172e308560 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/CachedBatchColumnVector.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteBuffer; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.execution.columnar.*; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector backed by data compressed thru ColumnAccessor + * this is a wrapper to read compressed data for table cache + */ +public final class CachedBatchColumnVector extends ReadOnlyColumnVector { + + // buffer for a column + private byte[] buffer; + + // accessor for a column + private ColumnAccessor columnAccessor; + + // a row where the compressed data is extracted + private UnsafeRow unsafeRow; + private BufferHolder bufferHolder; + private UnsafeRowWriter rowWriter; + private MutableUnsafeRow mutableRow; + + // an accessor uses only column 0 + private final int ORDINAL = 0; + + // Keep row id that was previously accessed + private int previousRowId = -1; + + + public CachedBatchColumnVector(byte[] buffer, int numRows, DataType type) { + super(numRows, type, MemoryMode.ON_HEAP); + this.buffer = buffer; + initialize(); + initializeRowAccessor(type); + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public void close() { + } + + // call extractTo() for rowId only once before getting actual data + private void prepareAccess(int rowId) { + if (previousRowId == rowId) { + // do nothing + } else if (previousRowId < rowId) { + for (; previousRowId < rowId; previousRowId++) { + assert (columnAccessor.hasNext()); + bufferHolder.reset(); + rowWriter.zeroOutNullBytes(); + columnAccessor.extractTo(mutableRow, ORDINAL); + } + } else { + throw new UnsupportedOperationException("Row access order must be equal or ascending." + + " Row " + rowId + " is accessed after row "+ previousRowId + " was accessed."); + } + } + + // + // APIs dealing with nulls + // + + @Override + public boolean isNullAt(int rowId) { + prepareAccess(rowId); + return unsafeRow.isNullAt(ORDINAL); + } + + // + // APIs dealing with Booleans + // + + @Override + public boolean getBoolean(int rowId) { + prepareAccess(rowId); + return unsafeRow.getBoolean(ORDINAL); + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + + // + // APIs dealing with Bytes + // + + @Override + public byte getByte(int rowId) { + prepareAccess(rowId); + return unsafeRow.getByte(ORDINAL); + } + + @Override + public byte[] getBytes(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Shorts + // + + @Override + public short getShort(int rowId) { + prepareAccess(rowId); + return unsafeRow.getShort(ORDINAL); + } + + @Override + public short[] getShorts(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Ints + // + + @Override + public int getInt(int rowId) { + prepareAccess(rowId); + return unsafeRow.getInt(ORDINAL); + } + + @Override + public int[] getInts(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + public int getDictId(int rowId) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Longs + // + + @Override + public long getLong(int rowId) { + prepareAccess(rowId); + return unsafeRow.getLong(ORDINAL); + } + + @Override + public long[] getLongs(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with floats + // + + @Override + public float getFloat(int rowId) { + prepareAccess(rowId); + return unsafeRow.getFloat(ORDINAL); + } + + @Override + public float[] getFloats(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with doubles + // + + @Override + public double getDouble(int rowId) { + prepareAccess(rowId); + return unsafeRow.getDouble(ORDINAL); + } + + @Override + public double[] getDoubles(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public void loadBytes(ColumnVector.Array array) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Byte Arrays + // + + public final UTF8String getUTF8String(int rowId) { + prepareAccess(rowId); + return unsafeRow.getUTF8String(ORDINAL); + } + + public void initialize() { + ByteBuffer byteBuffer = ByteBuffer.wrap(buffer); + columnAccessor = ColumnAccessor$.MODULE$.apply(type, byteBuffer); + previousRowId = -1; + } + + private void initializeRowAccessor(DataType type) { + unsafeRow = new UnsafeRow(1); + bufferHolder = new BufferHolder(unsafeRow); + rowWriter = new UnsafeRowWriter(bufferHolder, 1); + mutableRow = new MutableUnsafeRow(rowWriter); + + if (type instanceof ArrayType) { + throw new UnsupportedOperationException(); + } else if (type instanceof BinaryType) { + throw new UnsupportedOperationException(); + } else if (type instanceof StructType) { + throw new UnsupportedOperationException(); + } else if (type instanceof MapType) { + throw new UnsupportedOperationException(); + } else if (type instanceof DecimalType && ((DecimalType) type).precision() > Decimal.MAX_LONG_DIGITS()) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5e078f251375..310cb0be5f5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.types.AtomicType +import org.apache.spark.sql.types.{AtomicType, DataType} class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, @@ -42,3 +42,10 @@ object TestCompressibleColumnBuilder { builder } } + +object ColumnBuilderHelper { + def apply( + dataType: DataType, batchSize: Int, name: String, useCompression: Boolean): ColumnBuilder = { + ColumnBuilder(dataType, batchSize, name, useCompression) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ccf7aa7022a2..0429d94405ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,11 +27,14 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode +import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { @@ -1248,4 +1251,260 @@ class ColumnarBatchSuite extends SparkFunSuite { s"vectorized reader")) } } + + test("CachedBatch boolean Apis") { + val dataType = BooleanType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setBoolean(0, i % 2 == 0) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getBoolean(i) == (i % 2 == 0)) + } + } + column.close + } + + test("CachedBatch byte Apis") { + val dataType = ByteType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setByte(0, i.toByte) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getByte(i) == i) + } + } + column.close + } + + test("CachedBatch short Apis") { + val dataType = ShortType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setShort(0, i.toShort) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getShort(i) == i) + } + } + column.close + } + + test("CachedBatch int Apis") { + val dataType = IntegerType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setInt(0, i) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getInt(i) == i) + } + } + column.close + } + + test("CachedBatch long Apis") { + val dataType = LongType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setLong(0, i.toLong) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getLong(i) == i.toLong) + } + } + column.close + } + + test("CachedBatch float Apis") { + val dataType = FloatType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setFloat(0, i.toFloat) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getFloat(i) == i.toFloat) + } + } + column.close + } + + test("CachedBatch double Apis") { + val dataType = DoubleType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setDouble(0, i.toDouble) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getDouble(i) == i.toDouble) + } + } + column.close + } + + test("CachedBatch String type Apis") { + val dataType = StringType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.update(0, UTF8String.fromString((i % 4).toString)) + columnBuilder.appendFrom(row, 0) + } + + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + // reuse CachedBatchColumnVector + for (j <- 0 to 1) { + column.initialize + assert(column.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(column.isNullAt(i) == false) + assert(column.getUTF8String(i).toString == (i % 4).toString) + } + } + column.close + } + + test("CachedBatch access order") { + (BooleanType :: IntegerType :: DoubleType :: Nil).foreach { dataType => { + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + for (i <- 0 until 16) { + dataType match { + case _ : BooleanType => row.setBoolean(0, i % 2 == 0) + case _ : IntegerType => row.setInt(0, i) + case _ : DoubleType => row.setDouble(0, i) + } + columnBuilder.appendFrom(row, 0) + } + + /* check row access order */ + val column = new CachedBatchColumnVector( + JavaUtils.bufferToArray(columnBuilder.build), 1024, dataType) + + dataType match { + case _: BooleanType => + /* Row access may start with non-0 */ + assert(column.getBoolean(1) == false) + case _: IntegerType => + /* Row access order must be equal or ascending, but may not be sequential */ + assert(column.getInt(0) == 0) + assert(column.getInt(2) == 2) + assert(column.getInt(2) == 2) + assert(column.getInt(5) == 5) + case _: DoubleType => + /* Row access order must be ascending */ + assert(column.getDouble(0) == 0) + assert(column.getDouble(1) == 1.0) + val e = intercept[UnsupportedOperationException] { + column.getDouble(0) + } + assert(e.getMessage.startsWith("Row access order must be equal or ascending.")) + } + column.close + }} + } }