From 2861ac2a5136c065ec38cfc24bf9f979d5b7ae07 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Jun 2016 02:31:23 +0000 Subject: [PATCH 01/20] Add vectorized Orc reader support. --- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../ql/io/orc/SparkOrcNewRecordReader.java | 4 +- .../io/orc/SparkOrcNewRecordReaderBase.java | 24 +++ .../orc/SparkVectorizedOrcRecordReader.java | 175 ++++++++++++++++++ .../VectorizedSparkOrcNewRecordReader.java | 161 ++++++++++++++++ .../spark/sql/hive/orc/OrcFileFormat.scala | 26 ++- .../spark/sql/hive/orc/OrcQuerySuite.scala | 2 +- 7 files changed, 393 insertions(+), 7 deletions(-) create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 27b1fffe27a7..5c16b0c545f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -240,6 +240,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_ENABLED = + SQLConfigBuilder("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc reader.") + .booleanConf + .createWithDefault(true) + val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -586,6 +592,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java index f093637d412f..cdd50cb668bf 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java @@ -33,7 +33,8 @@ * NameNode calls in OrcRelation. */ public class SparkOrcNewRecordReader extends - org.apache.hadoop.mapreduce.RecordReader { + org.apache.hadoop.mapreduce.RecordReader + implements SparkOrcNewRecordReaderBase { private final org.apache.hadoop.hive.ql.io.orc.RecordReader reader; private final int numColumns; OrcStruct value; @@ -88,6 +89,7 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } } + @Override public ObjectInspector getObjectInspector() { return objectInspector; } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java new file mode 100644 index 000000000000..c9af60d67b05 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.orc; + +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; + +public interface SparkOrcNewRecordReaderBase { + public ObjectInspector getObjectInspector(); +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java new file mode 100644 index 000000000000..415a004e2961 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -0,0 +1,175 @@ +package org.apache.hadoop.hive.ql.io.orc; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.FileSplit; +import org.apache.hadoop.mapred.RecordReader; + +/** + * This is based on + * {@link org.apache.hadoop.hive.ql.io.orc.VectorizedOrcInputFormat.VectorizedOrcRecordReader}. + */ +public class SparkVectorizedOrcRecordReader + implements RecordReader { + private final org.apache.hadoop.hive.ql.io.orc.RecordReader reader; + private final long offset; + private final long length; + private float progress = 0.0f; + private ObjectInspector objectInspector; + + SparkVectorizedOrcRecordReader(Reader file, Configuration conf, + FileSplit fileSplit) throws IOException { + this.offset = fileSplit.getStart(); + this.length = fileSplit.getLength(); + this.objectInspector = file.getObjectInspector(); + this.reader = OrcInputFormat.createReaderFromFile(file, conf, this.offset, + this.length); + this.progress = reader.getProgress(); + } + + /** + * Create a ColumnVector based on given ObjectInspector's type info. + * + * @param inspector ObjectInspector + */ + private ColumnVector createColumnVector(ObjectInspector inspector) { + switch(inspector.getCategory()) { + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = + (PrimitiveTypeInfo) ((PrimitiveObjectInspector)inspector).getTypeInfo(); + switch(primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + case BYTE: + case SHORT: + case INT: + case LONG: + case DATE: + case INTERVAL_YEAR_MONTH: + return new LongColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case FLOAT: + case DOUBLE: + return new DoubleColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case BINARY: + case STRING: + case CHAR: + case VARCHAR: + BytesColumnVector column = new BytesColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + column.initBuffer(); + return column; + case DECIMAL: + DecimalTypeInfo tInfo = (DecimalTypeInfo) primitiveTypeInfo; + return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, + tInfo.precision(), tInfo.scale()); + default: + throw new RuntimeException("Vectorizaton is not supported for datatype:" + + primitiveTypeInfo.getPrimitiveCategory()); + } + } + default: + throw new RuntimeException("Vectorization is not supported for datatype:" + + inspector.getCategory()); + } + } + + /** + * Walk through the object inspector and add column vectors + * + * @param oi StructObjectInspector + * @param cvList ColumnVectors are populated in this list + */ + private void allocateColumnVector(StructObjectInspector oi, + List cvList) throws HiveException { + if (cvList == null) { + throw new HiveException("Null columnvector list"); + } + if (oi == null) { + return; + } + final List fields = oi.getAllStructFieldRefs(); + for(StructField field : fields) { + ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); + cvList.add(createColumnVector(fieldObjectInspector)); + } + } + + /** + * Create VectorizedRowBatch from ObjectInspector + * + * @param oi + * @return + * @throws HiveException + */ + private VectorizedRowBatch constructVectorizedRowBatch( + StructObjectInspector oi) throws HiveException { + final List cvList = new LinkedList(); + allocateColumnVector(oi, cvList); + final VectorizedRowBatch result = new VectorizedRowBatch(cvList.size()); + int i = 0; + for(ColumnVector cv : cvList) { + result.cols[i++] = cv; + } + return result; + } + + @Override + public boolean next(NullWritable key, VectorizedRowBatch value) throws IOException { + try { + reader.nextBatch(value); + if (value == null || value.endOfFile || value.size == 0) { + return false; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + progress = reader.getProgress(); + return true; + } + + @Override + public NullWritable createKey() { + return NullWritable.get(); + } + + @Override + public VectorizedRowBatch createValue() { + try { + return constructVectorizedRowBatch((StructObjectInspector)this.objectInspector); + } catch (HiveException e) { + } + return null; + } + + @Override + public long getPos() throws IOException { + return offset + (long) (progress * length); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @Override + public float getProgress() throws IOException { + return progress; + } + } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java new file mode 100644 index 000000000000..4b77d35ffd7f --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.orc; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.Reporter; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; + +import java.io.IOException; +import java.util.List; + +/** + * This is based on hive-exec-1.2.1 + * {@link org.apache.hadoop.hive.ql.io.orc.OrcNewInputFormat.OrcRecordReader}. + * This class exposes getObjectInspector which can be used for reducing + * NameNode calls in OrcRelation. + */ +public class VectorizedSparkOrcNewRecordReader + extends org.apache.hadoop.mapreduce.RecordReader + implements SparkOrcNewRecordReaderBase { + private final org.apache.hadoop.mapred.RecordReader reader; + private final int numColumns; + private VectorizedRowBatch internalValue; + OrcStruct value; + private float progress = 0.0f; + private ObjectInspector objectInspector; + private List columnIDs; + + private final VectorExpressionWriter [] valueWriters; + private long numRowsOfBatch = 0; + private int indexOfRow = 0; + + public VectorizedSparkOrcNewRecordReader( + Reader file, + JobConf conf, + FileSplit fileSplit, + List columnIDs) throws IOException { + List types = file.getTypes(); + numColumns = (types.size() == 0) ? 0 : types.get(0).getSubtypesCount(); + value = new OrcStruct(numColumns); + this.reader = new SparkVectorizedOrcRecordReader(file, conf, + new org.apache.hadoop.mapred.FileSplit(fileSplit)); + this.objectInspector = file.getObjectInspector(); + this.columnIDs = columnIDs; + this.internalValue = this.reader.createValue(); + + try { + valueWriters = VectorExpressionWriterFactory + .getExpressionWriters((StructObjectInspector) this.objectInspector); + } catch (HiveException e) { + throw new RuntimeException(e); + } + this.progress = reader.getProgress(); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + @Override + public NullWritable getCurrentKey() throws IOException, + InterruptedException { + return NullWritable.get(); + } + + @Override + public OrcStruct getCurrentValue() throws IOException, + InterruptedException { + if (indexOfRow >= numRowsOfBatch) { + return null; + } + try { + for (int p = 0; p < internalValue.numCols; p++) { + // Only when this column is a required column, we populate the data. + if (columnIDs.contains(p)) { + if (internalValue.cols[p].isRepeating) { + valueWriters[p].setValue(value, internalValue.cols[p], 0); + } else { + valueWriters[p].setValue(value, internalValue.cols[p], indexOfRow); + } + } + } + } catch (HiveException e) { + throw new RuntimeException(e); + } + indexOfRow++; + + return value; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return progress; + } + + @Override + public void initialize(InputSplit split, TaskAttemptContext context) + throws IOException, InterruptedException { + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (indexOfRow == numRowsOfBatch && progress < 1.0f) { + if (reader.next(NullWritable.get(), internalValue)) { + if (internalValue.endOfFile) { + progress = 1.0f; + numRowsOfBatch = 0; + indexOfRow = 0; + return false; + } else { + assert internalValue.numCols == numColumns : "Incorrect number of columns in OrcBatch"; + numRowsOfBatch = internalValue.count(); + indexOfRow = 0; + progress = reader.getProgress(); + } + return true; + } else { + return false; + } + } else { + if (indexOfRow < numRowsOfBatch) { + return true; + } else { + return false; + } + } + } + + @Override + public ObjectInspector getObjectInspector() { + return objectInspector; + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index a2c8092e01bb..b153f4336ecd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.net.URI import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -28,7 +30,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspec import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} -import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.{RecordReader => MReduceRecordReader, _} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.spark.internal.Logging @@ -40,7 +42,7 @@ import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{AtomicType, DateType, StructType, TimestampType} import org.apache.spark.util.SerializableConfiguration /** @@ -121,6 +123,11 @@ private[sql] class OrcFileFormat val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val enableVectorizedReader: Boolean = + sparkSession.sessionState.conf.orcVectorizedReaderEnabled && + dataSchema.forall(f => f.dataType.isInstanceOf[AtomicType] && + !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -134,7 +141,7 @@ private[sql] class OrcFileFormat val physicalSchema = maybePhysicalSchema.get OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) - val orcRecordReader = { + val orcRecordReader: MReduceRecordReader[_, org.apache.hadoop.hive.ql.io.orc.OrcStruct] = { val job = Job.getInstance(conf) FileInputFormat.setInputPaths(job, file.filePath) @@ -147,14 +154,23 @@ private[sql] class OrcFileFormat // Specifically would be helpful for partitioned datasets. val orcReader = OrcFile.createReader( new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) - new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + + if (enableVectorizedReader) { + val conf = job.getConfiguration.asInstanceOf[JobConf] + val columnIDs = + requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).asJava + new VectorizedSparkOrcNewRecordReader(orcReader, conf, fileSplit, columnIDs) + } else { + new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + } } // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, requiredSchema, - Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), + Some(orcRecordReader.asInstanceOf[SparkOrcNewRecordReaderBase] + .getObjectInspector.asInstanceOf[StructObjectInspector]), new RecordReaderIterator[OrcStruct](orcRecordReader)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index e6c9c5d4d9cc..fee8b3a4ad59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -52,7 +52,6 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { - test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) @@ -204,6 +203,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("simple select queries") { + val data = (0 until 10).map(i => (i, i.toString)) withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { checkAnswer( sql("SELECT `_1` FROM t where t.`_1` > 5"), From eee8eca70920d624becb43c8510d217ce4d9820b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 17 Jun 2016 09:44:11 +0000 Subject: [PATCH 02/20] import. --- .../VectorizedSparkOrcNewRecordReader.java | 158 ++++++++++++++++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 59 +++---- 2 files changed, 179 insertions(+), 38 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 4b77d35ffd7f..4ed99e4bb3c7 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -17,9 +17,19 @@ package org.apache.hadoop.hive.ql.io.orc; -import org.apache.hadoop.conf.Configuration; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; @@ -33,8 +43,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import java.io.IOException; -import java.util.List; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; /** * This is based on hive-exec-1.2.1 @@ -43,7 +58,7 @@ * NameNode calls in OrcRelation. */ public class VectorizedSparkOrcNewRecordReader - extends org.apache.hadoop.mapreduce.RecordReader + extends org.apache.hadoop.mapreduce.RecordReader implements SparkOrcNewRecordReaderBase { private final org.apache.hadoop.mapred.RecordReader reader; private final int numColumns; @@ -57,6 +72,8 @@ public class VectorizedSparkOrcNewRecordReader private long numRowsOfBatch = 0; private int indexOfRow = 0; + private final Row row; + public VectorizedSparkOrcNewRecordReader( Reader file, JobConf conf, @@ -78,6 +95,7 @@ public VectorizedSparkOrcNewRecordReader( throw new RuntimeException(e); } this.progress = reader.getProgress(); + this.row = new Row(this.internalValue.cols, columnIDs); } @Override @@ -92,12 +110,14 @@ public NullWritable getCurrentKey() throws IOException, } @Override - public OrcStruct getCurrentValue() throws IOException, + public InternalRow getCurrentValue() throws IOException, InterruptedException { if (indexOfRow >= numRowsOfBatch) { return null; } - try { + // try { + row.rowId = indexOfRow; + /* for (int p = 0; p < internalValue.numCols; p++) { // Only when this column is a required column, we populate the data. if (columnIDs.contains(p)) { @@ -108,12 +128,13 @@ public OrcStruct getCurrentValue() throws IOException, } } } - } catch (HiveException e) { - throw new RuntimeException(e); - } + */ + // } catch (HiveException e) { + // throw new RuntimeException(e); + // } indexOfRow++; - return value; + return row; // value; } @Override @@ -158,4 +179,121 @@ public boolean nextKeyValue() throws IOException, InterruptedException { public ObjectInspector getObjectInspector() { return objectInspector; } + + /** + * Adapter class to return an internal row. + */ + public static final class Row extends InternalRow { + protected int rowId; + private List columnIDs; + private final ColumnVector[] columns; + + private Row(ColumnVector[] columns, List columnIDs) { + this.columns = columns; + this.columnIDs = columnIDs; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public boolean anyNull() { + for (int i = 0; i < columns.length; i++) { + if (columnIDs.contains(i) && columns[i].isNull[rowId]) { + return true; + } + } + return false; + } + + @Override + public boolean isNullAt(int ordinal) { return columns[columnIDs.get(ordinal)].isNull[rowId]; } + + @Override + public boolean getBoolean(int ordinal) { + return ((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId] > 0; + } + + @Override + public byte getByte(int ordinal) { + return (byte)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public short getShort(int ordinal) { + return (short)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public int getInt(int ordinal) { + return (int)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public long getLong(int ordinal) { + return (long)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public float getFloat(int ordinal) { + return (float)((DoubleColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public double getDouble(int ordinal) { + return (double)((DoubleColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return Decimal.apply( + ((DecimalColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId].getHiveDecimal() + .bigDecimalValue(), + precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + BytesColumnVector bv = ((BytesColumnVector)columns[columnIDs.get(ordinal)]); + String str = new String(bv.vector[rowId], bv.start[rowId], bv.length[rowId], + StandardCharsets.UTF_8); + return UTF8String.fromString(str); + // new String(((BytesColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId])); + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[])((BytesColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + throw new NotImplementedException(); + } + + @Override + public ArrayData getArray(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + + @Override + public InternalRow copy() { + throw new NotImplementedException(); + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index b153f4336ecd..ab5e959983f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -141,37 +141,40 @@ private[sql] class OrcFileFormat val physicalSchema = maybePhysicalSchema.get OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) - val orcRecordReader: MReduceRecordReader[_, org.apache.hadoop.hive.ql.io.orc.OrcStruct] = { - val job = Job.getInstance(conf) - FileInputFormat.setInputPaths(job, file.filePath) - - val fileSplit = new FileSplit( - new Path(new URI(file.filePath)), file.start, file.length, Array.empty - ) - // Custom OrcRecordReader is used to get - // ObjectInspector during recordReader creation itself and can - // avoid NameNode call in unwrapOrcStructs per file. - // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) - - if (enableVectorizedReader) { - val conf = job.getConfiguration.asInstanceOf[JobConf] - val columnIDs = - requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).asJava + // val orcRecordReader: + // MReduceRecordReader[_, org.apache.hadoop.hive.ql.io.orc.OrcStruct] = { + val job = Job.getInstance(conf) + FileInputFormat.setInputPaths(job, file.filePath) + + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), file.start, file.length, Array.empty + ) + // Custom OrcRecordReader is used to get + // ObjectInspector during recordReader creation itself and can + // avoid NameNode call in unwrapOrcStructs per file. + // Specifically would be helpful for partitioned datasets. + val orcReader = OrcFile.createReader( + new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + + if (enableVectorizedReader) { + val conf = job.getConfiguration.asInstanceOf[JobConf] + val columnIDs = + requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava + val orcRecordReader = new VectorizedSparkOrcNewRecordReader(orcReader, conf, fileSplit, columnIDs) - } else { + new RecordReaderIterator[InternalRow](orcRecordReader) + } else { + val orcRecordReader = new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) - } + // Unwraps `OrcStruct`s to `UnsafeRow`s + OrcRelation.unwrapOrcStructs( + conf, + requiredSchema, + Some(orcRecordReader.asInstanceOf[SparkOrcNewRecordReaderBase] + .getObjectInspector.asInstanceOf[StructObjectInspector]), + new RecordReaderIterator[OrcStruct](orcRecordReader)) } - - // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( - conf, - requiredSchema, - Some(orcRecordReader.asInstanceOf[SparkOrcNewRecordReaderBase] - .getObjectInspector.asInstanceOf[StructObjectInspector]), - new RecordReaderIterator[OrcStruct](orcRecordReader)) + // } } } } From b753d09e3e369fc91a17d9632123dbe40d7d9dfb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 18 Jun 2016 18:00:00 +0800 Subject: [PATCH 03/20] If column is repeating, always using row id 0. --- .../VectorizedSparkOrcNewRecordReader.java | 123 +++++++++++------- 1 file changed, 78 insertions(+), 45 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 4ed99e4bb3c7..4a217c5ed1c4 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -63,12 +63,10 @@ public class VectorizedSparkOrcNewRecordReader private final org.apache.hadoop.mapred.RecordReader reader; private final int numColumns; private VectorizedRowBatch internalValue; - OrcStruct value; private float progress = 0.0f; private ObjectInspector objectInspector; private List columnIDs; - private final VectorExpressionWriter [] valueWriters; private long numRowsOfBatch = 0; private int indexOfRow = 0; @@ -81,19 +79,12 @@ public VectorizedSparkOrcNewRecordReader( List columnIDs) throws IOException { List types = file.getTypes(); numColumns = (types.size() == 0) ? 0 : types.get(0).getSubtypesCount(); - value = new OrcStruct(numColumns); this.reader = new SparkVectorizedOrcRecordReader(file, conf, new org.apache.hadoop.mapred.FileSplit(fileSplit)); + this.objectInspector = file.getObjectInspector(); this.columnIDs = columnIDs; this.internalValue = this.reader.createValue(); - - try { - valueWriters = VectorExpressionWriterFactory - .getExpressionWriters((StructObjectInspector) this.objectInspector); - } catch (HiveException e) { - throw new RuntimeException(e); - } this.progress = reader.getProgress(); this.row = new Row(this.internalValue.cols, columnIDs); } @@ -115,26 +106,10 @@ public InternalRow getCurrentValue() throws IOException, if (indexOfRow >= numRowsOfBatch) { return null; } - // try { row.rowId = indexOfRow; - /* - for (int p = 0; p < internalValue.numCols; p++) { - // Only when this column is a required column, we populate the data. - if (columnIDs.contains(p)) { - if (internalValue.cols[p].isRepeating) { - valueWriters[p].setValue(value, internalValue.cols[p], 0); - } else { - valueWriters[p].setValue(value, internalValue.cols[p], indexOfRow); - } - } - } - */ - // } catch (HiveException e) { - // throw new RuntimeException(e); - // } indexOfRow++; - return row; // value; + return row; } @Override @@ -199,71 +174,129 @@ private Row(ColumnVector[] columns, List columnIDs) { @Override public boolean anyNull() { for (int i = 0; i < columns.length; i++) { - if (columnIDs.contains(i) && columns[i].isNull[rowId]) { - return true; + if (columnIDs.contains(i)) { + if (columns[i].isRepeating && columns[i].isNull[0]) { + return true; + } else if (!columns[i].isRepeating && columns[i].isNull[rowId]) { + return true; + } } } return false; } @Override - public boolean isNullAt(int ordinal) { return columns[columnIDs.get(ordinal)].isNull[rowId]; } + public boolean isNullAt(int ordinal) { + ColumnVector col = columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return col.isNull[0]; + } else { + return col.isNull[rowId]; + } + } @Override public boolean getBoolean(int ordinal) { - return ((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId] > 0; + LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return col.vector[0] > 0; + } else { + return col.vector[rowId] > 0; + } } @Override public byte getByte(int ordinal) { - return (byte)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (byte)col.vector[0]; + } else { + return (byte)col.vector[rowId]; + } } @Override public short getShort(int ordinal) { - return (short)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (short)col.vector[0]; + } else { + return (short)col.vector[rowId]; + } } @Override public int getInt(int ordinal) { - return (int)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (int)col.vector[0]; + } else { + return (int)col.vector[rowId]; + } } @Override public long getLong(int ordinal) { - return (long)((LongColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (long)col.vector[0]; + } else { + return (long)col.vector[rowId]; + } } @Override public float getFloat(int ordinal) { - return (float)((DoubleColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + DoubleColumnVector col = (DoubleColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (float)col.vector[0]; + } else { + return (float)col.vector[rowId]; + } } @Override public double getDouble(int ordinal) { - return (double)((DoubleColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + DoubleColumnVector col = (DoubleColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (double)col.vector[0]; + } else { + return (double)col.vector[rowId]; + } } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - return Decimal.apply( - ((DecimalColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId].getHiveDecimal() - .bigDecimalValue(), - precision, scale); + DecimalColumnVector col = (DecimalColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return Decimal.apply(col.vector[0].getHiveDecimal().bigDecimalValue(), precision, scale); + } else { + return Decimal.apply(col.vector[rowId].getHiveDecimal().bigDecimalValue(), + precision, scale); + } } @Override public UTF8String getUTF8String(int ordinal) { BytesColumnVector bv = ((BytesColumnVector)columns[columnIDs.get(ordinal)]); - String str = new String(bv.vector[rowId], bv.start[rowId], bv.length[rowId], - StandardCharsets.UTF_8); + String str = null; + if (bv.isRepeating) { + str = new String(bv.vector[0], bv.start[0], bv.length[0], StandardCharsets.UTF_8); + } else { + str = new String(bv.vector[rowId], bv.start[rowId], bv.length[rowId], + StandardCharsets.UTF_8); + } return UTF8String.fromString(str); - // new String(((BytesColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId])); } @Override public byte[] getBinary(int ordinal) { - return (byte[])((BytesColumnVector)columns[columnIDs.get(ordinal)]).vector[rowId]; + BytesColumnVector col = (BytesColumnVector)columns[columnIDs.get(ordinal)]; + if (col.isRepeating) { + return (byte[])col.vector[0]; + } else { + return (byte[])col.vector[rowId]; + } } @Override From 7d26f5ed785269299b324df8bfc1c64c2d4a2b48 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 19 Jun 2016 12:16:49 +0800 Subject: [PATCH 04/20] Fix bugs of getBinary and numFields. --- .../ql/io/orc/SparkOrcNewRecordReader.java | 4 +-- .../io/orc/SparkOrcNewRecordReaderBase.java | 24 ------------- .../VectorizedSparkOrcNewRecordReader.java | 34 ++++++------------- .../spark/sql/hive/orc/OrcFileFormat.scala | 6 +--- 4 files changed, 13 insertions(+), 55 deletions(-) delete mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java index cdd50cb668bf..f093637d412f 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java @@ -33,8 +33,7 @@ * NameNode calls in OrcRelation. */ public class SparkOrcNewRecordReader extends - org.apache.hadoop.mapreduce.RecordReader - implements SparkOrcNewRecordReaderBase { + org.apache.hadoop.mapreduce.RecordReader { private final org.apache.hadoop.hive.ql.io.orc.RecordReader reader; private final int numColumns; OrcStruct value; @@ -89,7 +88,6 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } } - @Override public ObjectInspector getObjectInspector() { return objectInspector; } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java deleted file mode 100644 index c9af60d67b05..000000000000 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReaderBase.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.io.orc; - -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; - -public interface SparkOrcNewRecordReaderBase { - public ObjectInspector getObjectInspector(); -} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 4a217c5ed1c4..c7f320733bb5 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -24,21 +24,14 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; -import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; -import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.mapred.JobConf; -import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; @@ -52,19 +45,16 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This is based on hive-exec-1.2.1 - * {@link org.apache.hadoop.hive.ql.io.orc.OrcNewInputFormat.OrcRecordReader}. - * This class exposes getObjectInspector which can be used for reducing - * NameNode calls in OrcRelation. + * A RecordReader that returns InternalRow for Spark SQL execution. + * This reader uses an internal reader that returns Hive's VectorizedRowBatch. An adapter + * class is used to return internal row by directly accessing data in column vectors. */ public class VectorizedSparkOrcNewRecordReader - extends org.apache.hadoop.mapreduce.RecordReader - implements SparkOrcNewRecordReaderBase { + extends org.apache.hadoop.mapreduce.RecordReader { private final org.apache.hadoop.mapred.RecordReader reader; private final int numColumns; private VectorizedRowBatch internalValue; private float progress = 0.0f; - private ObjectInspector objectInspector; private List columnIDs; private long numRowsOfBatch = 0; @@ -82,7 +72,6 @@ public VectorizedSparkOrcNewRecordReader( this.reader = new SparkVectorizedOrcRecordReader(file, conf, new org.apache.hadoop.mapred.FileSplit(fileSplit)); - this.objectInspector = file.getObjectInspector(); this.columnIDs = columnIDs; this.internalValue = this.reader.createValue(); this.progress = reader.getProgress(); @@ -150,11 +139,6 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } } - @Override - public ObjectInspector getObjectInspector() { - return objectInspector; - } - /** * Adapter class to return an internal row. */ @@ -169,7 +153,7 @@ private Row(ColumnVector[] columns, List columnIDs) { } @Override - public int numFields() { return columns.length; } + public int numFields() { return columnIDs.size(); } @Override public boolean anyNull() { @@ -293,9 +277,13 @@ public UTF8String getUTF8String(int ordinal) { public byte[] getBinary(int ordinal) { BytesColumnVector col = (BytesColumnVector)columns[columnIDs.get(ordinal)]; if (col.isRepeating) { - return (byte[])col.vector[0]; + byte[] binary = new byte[col.length[0]]; + System.arraycopy(col.vector[0], col.start[0], binary, 0, binary.length); + return binary; } else { - return (byte[])col.vector[rowId]; + byte[] binary = new byte[col.length[rowId]]; + System.arraycopy(col.vector[rowId], col.start[rowId], binary, 0, binary.length); + return binary; } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index ab5e959983f6..711731963693 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -141,8 +141,6 @@ private[sql] class OrcFileFormat val physicalSchema = maybePhysicalSchema.get OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) - // val orcRecordReader: - // MReduceRecordReader[_, org.apache.hadoop.hive.ql.io.orc.OrcStruct] = { val job = Job.getInstance(conf) FileInputFormat.setInputPaths(job, file.filePath) @@ -170,11 +168,9 @@ private[sql] class OrcFileFormat OrcRelation.unwrapOrcStructs( conf, requiredSchema, - Some(orcRecordReader.asInstanceOf[SparkOrcNewRecordReaderBase] - .getObjectInspector.asInstanceOf[StructObjectInspector]), + Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), new RecordReaderIterator[OrcStruct](orcRecordReader)) } - // } } } } From 74fe936e522a827384461e445b9ba44f96ce29fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 10:44:07 +0800 Subject: [PATCH 05/20] Remove unnecessary change. --- .../test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index fee8b3a4ad59..2dc1cea31b3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -52,6 +52,7 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) From 7e7bb6c57860187f391f66ca82cdd715d0b2be43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 10:48:11 +0800 Subject: [PATCH 06/20] Remove unnecessary change. --- .../hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java | 3 +-- .../scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 415a004e2961..e742e887a58f 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -24,8 +24,7 @@ import org.apache.hadoop.mapred.RecordReader; /** - * This is based on - * {@link org.apache.hadoop.hive.ql.io.orc.VectorizedOrcInputFormat.VectorizedOrcRecordReader}. + * A mapred.RecordReader that returns VectorizedRowBatch. */ public class SparkVectorizedOrcRecordReader implements RecordReader { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 2dc1cea31b3b..e6c9c5d4d9cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -204,7 +204,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("simple select queries") { - val data = (0 until 10).map(i => (i, i.toString)) withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { checkAnswer( sql("SELECT `_1` FROM t where t.`_1` > 5"), From 20b832ee4e5ed4e794cc1bc8f2f67cce973759e0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Jun 2016 11:09:00 +0800 Subject: [PATCH 07/20] Add Apache license headers. --- .../io/orc/SparkVectorizedOrcRecordReader.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index e742e887a58f..8edc9e7b659e 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.hadoop.hive.ql.io.orc; import java.io.IOException; From 855bcfde2067af4bd88d95a6365f976ecf891de9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Jun 2016 22:38:50 +0800 Subject: [PATCH 08/20] Adjust exception. --- .../orc/SparkVectorizedOrcRecordReader.java | 86 +++++++++---------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 8edc9e7b659e..2fc8e54661b8 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -28,7 +28,6 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; -import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -68,41 +67,41 @@ public class SparkVectorizedOrcRecordReader */ private ColumnVector createColumnVector(ObjectInspector inspector) { switch(inspector.getCategory()) { - case PRIMITIVE: - { - PrimitiveTypeInfo primitiveTypeInfo = - (PrimitiveTypeInfo) ((PrimitiveObjectInspector)inspector).getTypeInfo(); - switch(primitiveTypeInfo.getPrimitiveCategory()) { - case BOOLEAN: - case BYTE: - case SHORT: - case INT: - case LONG: - case DATE: - case INTERVAL_YEAR_MONTH: - return new LongColumnVector(VectorizedRowBatch.DEFAULT_SIZE); - case FLOAT: - case DOUBLE: - return new DoubleColumnVector(VectorizedRowBatch.DEFAULT_SIZE); - case BINARY: - case STRING: - case CHAR: - case VARCHAR: - BytesColumnVector column = new BytesColumnVector(VectorizedRowBatch.DEFAULT_SIZE); - column.initBuffer(); - return column; - case DECIMAL: - DecimalTypeInfo tInfo = (DecimalTypeInfo) primitiveTypeInfo; - return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, - tInfo.precision(), tInfo.scale()); - default: - throw new RuntimeException("Vectorizaton is not supported for datatype:" - + primitiveTypeInfo.getPrimitiveCategory()); + case PRIMITIVE: + { + PrimitiveTypeInfo primitiveTypeInfo = + (PrimitiveTypeInfo) ((PrimitiveObjectInspector)inspector).getTypeInfo(); + switch(primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + case BYTE: + case SHORT: + case INT: + case LONG: + case DATE: + case INTERVAL_YEAR_MONTH: + return new LongColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case FLOAT: + case DOUBLE: + return new DoubleColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + case BINARY: + case STRING: + case CHAR: + case VARCHAR: + BytesColumnVector column = new BytesColumnVector(VectorizedRowBatch.DEFAULT_SIZE); + column.initBuffer(); + return column; + case DECIMAL: + DecimalTypeInfo tInfo = (DecimalTypeInfo) primitiveTypeInfo; + return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, + tInfo.precision(), tInfo.scale()); + default: + throw new RuntimeException("Vectorizaton is not supported for datatype:" + + primitiveTypeInfo.getPrimitiveCategory()); + } } - } - default: - throw new RuntimeException("Vectorization is not supported for datatype:" - + inspector.getCategory()); + default: + throw new RuntimeException("Vectorization is not supported for datatype:" + + inspector.getCategory()); } } @@ -113,9 +112,9 @@ private ColumnVector createColumnVector(ObjectInspector inspector) { * @param cvList ColumnVectors are populated in this list */ private void allocateColumnVector(StructObjectInspector oi, - List cvList) throws HiveException { + List cvList) { if (cvList == null) { - throw new HiveException("Null columnvector list"); + throw new RuntimeException("Null columnvector list"); } if (oi == null) { return; @@ -130,12 +129,11 @@ private void allocateColumnVector(StructObjectInspector oi, /** * Create VectorizedRowBatch from ObjectInspector * - * @param oi - * @return - * @throws HiveException + * @param oi StructObjectInspector + * @return VectorizedRowBatch */ private VectorizedRowBatch constructVectorizedRowBatch( - StructObjectInspector oi) throws HiveException { + StructObjectInspector oi) { final List cvList = new LinkedList(); allocateColumnVector(oi, cvList); final VectorizedRowBatch result = new VectorizedRowBatch(cvList.size()); @@ -167,11 +165,7 @@ public NullWritable createKey() { @Override public VectorizedRowBatch createValue() { - try { - return constructVectorizedRowBatch((StructObjectInspector)this.objectInspector); - } catch (HiveException e) { - } - return null; + return constructVectorizedRowBatch((StructObjectInspector)this.objectInspector); } @Override From 66ab632274674ae5b38c84bac8801feab3c9d2e0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 23 Jun 2016 10:07:54 +0800 Subject: [PATCH 09/20] Avoid creating String in getUTF8String. --- .../hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index c7f320733bb5..47ca4f7fc66b 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -263,14 +263,11 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { @Override public UTF8String getUTF8String(int ordinal) { BytesColumnVector bv = ((BytesColumnVector)columns[columnIDs.get(ordinal)]); - String str = null; if (bv.isRepeating) { - str = new String(bv.vector[0], bv.start[0], bv.length[0], StandardCharsets.UTF_8); + return UTF8String.fromBytes(bv.vector[0], bv.start[0], bv.length[0]); } else { - str = new String(bv.vector[rowId], bv.start[rowId], bv.length[rowId], - StandardCharsets.UTF_8); + return UTF8String.fromBytes(bv.vector[rowId], bv.start[rowId], bv.length[rowId]); } - return UTF8String.fromString(str); } @Override From b067658c53a3252f0a8a288e09b07feaf0ace8d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 10 Aug 2016 12:15:58 +0800 Subject: [PATCH 10/20] Address comment. --- .../hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 47ca4f7fc66b..019485aecad7 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; import org.apache.commons.lang.NotImplementedException; @@ -72,10 +73,10 @@ public VectorizedSparkOrcNewRecordReader( this.reader = new SparkVectorizedOrcRecordReader(file, conf, new org.apache.hadoop.mapred.FileSplit(fileSplit)); - this.columnIDs = columnIDs; + this.columnIDs = new ArrayList<>(columnIDs); this.internalValue = this.reader.createValue(); this.progress = reader.getProgress(); - this.row = new Row(this.internalValue.cols, columnIDs); + this.row = new Row(this.internalValue.cols, this.columnIDs); } @Override From 06066eb241eb97c4cf363adff2b0160b8a423ab8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 11 Aug 2016 12:57:27 +0800 Subject: [PATCH 11/20] Don't rely on progress to indicate last batch. --- .../orc/SparkVectorizedOrcRecordReader.java | 20 +++++++++++-------- .../VectorizedSparkOrcNewRecordReader.java | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 2fc8e54661b8..7a77da763863 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -146,16 +146,20 @@ private VectorizedRowBatch constructVectorizedRowBatch( @Override public boolean next(NullWritable key, VectorizedRowBatch value) throws IOException { - try { - reader.nextBatch(value); - if (value == null || value.endOfFile || value.size == 0) { - return false; + if (reader.hasNext()) { + try { + reader.nextBatch(value); + if (value == null || value.endOfFile || value.size == 0) { + return false; + } + } catch (Exception e) { + throw new RuntimeException(e); } - } catch (Exception e) { - throw new RuntimeException(e); + progress = reader.getProgress(); + return true; + } else { + return false; } - progress = reader.getProgress(); - return true; } @Override diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 019485aecad7..26a5629b089a 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -114,7 +114,7 @@ public void initialize(InputSplit split, TaskAttemptContext context) @Override public boolean nextKeyValue() throws IOException, InterruptedException { - if (indexOfRow == numRowsOfBatch && progress < 1.0f) { + if (indexOfRow == numRowsOfBatch) { if (reader.next(NullWritable.get(), internalValue)) { if (internalValue.endOfFile) { progress = 1.0f; From 7a47360895713c6e1eb8c4b0faef0870555a3be6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 Nov 2016 14:51:52 +0000 Subject: [PATCH 12/20] Address comments. --- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../ql/io/orc/SparkVectorizedOrcRecordReader.java | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9cdec11e58a4..40caf6afd245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -255,7 +255,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc reader.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 7a77da763863..1ec424dac1ad 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -50,7 +50,9 @@ public class SparkVectorizedOrcRecordReader private float progress = 0.0f; private ObjectInspector objectInspector; - SparkVectorizedOrcRecordReader(Reader file, Configuration conf, + SparkVectorizedOrcRecordReader( + Reader file, + Configuration conf, FileSplit fileSplit) throws IOException { this.offset = fileSplit.getStart(); this.length = fileSplit.getLength(); @@ -91,17 +93,19 @@ private ColumnVector createColumnVector(ObjectInspector inspector) { column.initBuffer(); return column; case DECIMAL: - DecimalTypeInfo tInfo = (DecimalTypeInfo) primitiveTypeInfo; + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, - tInfo.precision(), tInfo.scale()); + decimalTypeInfo.precision(), decimalTypeInfo.scale()); default: throw new RuntimeException("Vectorizaton is not supported for datatype:" - + primitiveTypeInfo.getPrimitiveCategory()); + + primitiveTypeInfo.getPrimitiveCategory() + ". " + + "Please disable spark.sql.orc.enableVectorizedReader."); } } default: throw new RuntimeException("Vectorization is not supported for datatype:" - + inspector.getCategory()); + + inspector.getCategory() + ". " + + "Please disable the config spark.sql.orc.enableVectorizedReader."); } } From 3895a980a2aae2dc7dedbf0797bb8a37d089e683 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Nov 2016 09:48:33 +0000 Subject: [PATCH 13/20] Address comments. --- .../orc/SparkVectorizedOrcRecordReader.java | 46 ++---- .../VectorizedSparkOrcNewRecordReader.java | 132 +++++++----------- .../spark/sql/hive/orc/OrcFileFormat.scala | 1 - 3 files changed, 64 insertions(+), 115 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 1ec424dac1ad..7f0e94db402f 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -109,41 +109,19 @@ private ColumnVector createColumnVector(ObjectInspector inspector) { } } - /** - * Walk through the object inspector and add column vectors - * - * @param oi StructObjectInspector - * @param cvList ColumnVectors are populated in this list - */ - private void allocateColumnVector(StructObjectInspector oi, - List cvList) { - if (cvList == null) { - throw new RuntimeException("Null columnvector list"); - } - if (oi == null) { - return; - } - final List fields = oi.getAllStructFieldRefs(); - for(StructField field : fields) { - ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); - cvList.add(createColumnVector(fieldObjectInspector)); - } - } - /** * Create VectorizedRowBatch from ObjectInspector * * @param oi StructObjectInspector * @return VectorizedRowBatch */ - private VectorizedRowBatch constructVectorizedRowBatch( - StructObjectInspector oi) { - final List cvList = new LinkedList(); - allocateColumnVector(oi, cvList); - final VectorizedRowBatch result = new VectorizedRowBatch(cvList.size()); + private VectorizedRowBatch constructVectorizedRowBatch(StructObjectInspector oi) { + List fields = oi.getAllStructFieldRefs(); + VectorizedRowBatch result = new VectorizedRowBatch(fields.size()); int i = 0; - for(ColumnVector cv : cvList) { - result.cols[i++] = cv; + for (StructField field : fields) { + ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); + result.cols[i++] = createColumnVector(fieldObjectInspector); } return result; } @@ -153,17 +131,13 @@ public boolean next(NullWritable key, VectorizedRowBatch value) throws IOExcepti if (reader.hasNext()) { try { reader.nextBatch(value); - if (value == null || value.endOfFile || value.size == 0) { - return false; - } + progress = reader.getProgress(); + return (value != null && !value.endOfFile && value.size > 0); } catch (Exception e) { throw new RuntimeException(e); } - progress = reader.getProgress(); - return true; - } else { - return false; } + return false; } @Override @@ -173,7 +147,7 @@ public NullWritable createKey() { @Override public VectorizedRowBatch createValue() { - return constructVectorizedRowBatch((StructObjectInspector)this.objectInspector); + return constructVectorizedRowBatch((StructObjectInspector) this.objectInspector); } @Override diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 26a5629b089a..f001bd393d5f 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -32,7 +32,6 @@ import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.io.NullWritable; -import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; @@ -65,7 +64,7 @@ public class VectorizedSparkOrcNewRecordReader public VectorizedSparkOrcNewRecordReader( Reader file, - JobConf conf, + Configuration conf, FileSplit fileSplit, List columnIDs) throws IOException { List types = file.getTypes(); @@ -85,14 +84,12 @@ public void close() throws IOException { } @Override - public NullWritable getCurrentKey() throws IOException, - InterruptedException { + public NullWritable getCurrentKey() throws IOException, InterruptedException { return NullWritable.get(); } @Override - public InternalRow getCurrentValue() throws IOException, - InterruptedException { + public InternalRow getCurrentValue() throws IOException, InterruptedException { if (indexOfRow >= numRowsOfBatch) { return null; } @@ -160,9 +157,10 @@ private Row(ColumnVector[] columns, List columnIDs) { public boolean anyNull() { for (int i = 0; i < columns.length; i++) { if (columnIDs.contains(i)) { - if (columns[i].isRepeating && columns[i].isNull[0]) { + boolean isRepeating = columns[i].isRepeating; + if (isRepeating && columns[i].isNull[0]) { return true; - } else if (!columns[i].isRepeating && columns[i].isNull[rowId]) { + } else if (!isRepeating && columns[i].isNull[rowId]) { return true; } } @@ -170,136 +168,114 @@ public boolean anyNull() { return false; } + private int getColIndex(ColumnVector col) { + return col.isRepeating ? 0 : rowId; + } + @Override public boolean isNullAt(int ordinal) { ColumnVector col = columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return col.isNull[0]; - } else { - return col.isNull[rowId]; - } + return col.isNull[getColIndex(col)]; } @Override public boolean getBoolean(int ordinal) { - LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return col.vector[0] > 0; - } else { - return col.vector[rowId] > 0; - } + LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; + return col.vector[getColIndex(col)] > 0; } @Override public byte getByte(int ordinal) { - LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (byte)col.vector[0]; - } else { - return (byte)col.vector[rowId]; - } + LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; + return (byte)col.vector[getColIndex(col)]; } @Override public short getShort(int ordinal) { - LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (short)col.vector[0]; - } else { - return (short)col.vector[rowId]; - } + LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; + return (short)col.vector[getColIndex(col)]; } @Override public int getInt(int ordinal) { - LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (int)col.vector[0]; - } else { - return (int)col.vector[rowId]; - } + LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; + return (int)col.vector[getColIndex(col)]; } @Override public long getLong(int ordinal) { - LongColumnVector col = (LongColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (long)col.vector[0]; - } else { - return (long)col.vector[rowId]; - } + LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; + return (long)col.vector[getColIndex(col)]; } @Override public float getFloat(int ordinal) { - DoubleColumnVector col = (DoubleColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (float)col.vector[0]; - } else { - return (float)col.vector[rowId]; - } + DoubleColumnVector col = (DoubleColumnVector) columns[columnIDs.get(ordinal)]; + return (float)col.vector[getColIndex(col)]; } @Override public double getDouble(int ordinal) { - DoubleColumnVector col = (DoubleColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return (double)col.vector[0]; - } else { - return (double)col.vector[rowId]; - } + DoubleColumnVector col = (DoubleColumnVector) columns[columnIDs.get(ordinal)]; + return (double)col.vector[getColIndex(col)]; } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - DecimalColumnVector col = (DecimalColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - return Decimal.apply(col.vector[0].getHiveDecimal().bigDecimalValue(), precision, scale); - } else { - return Decimal.apply(col.vector[rowId].getHiveDecimal().bigDecimalValue(), - precision, scale); - } + DecimalColumnVector col = (DecimalColumnVector) columns[columnIDs.get(ordinal)]; + int index = getColIndex(col); + return Decimal.apply(col.vector[index].getHiveDecimal().bigDecimalValue(), precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - BytesColumnVector bv = ((BytesColumnVector)columns[columnIDs.get(ordinal)]); - if (bv.isRepeating) { - return UTF8String.fromBytes(bv.vector[0], bv.start[0], bv.length[0]); - } else { - return UTF8String.fromBytes(bv.vector[rowId], bv.start[rowId], bv.length[rowId]); - } + BytesColumnVector col = ((BytesColumnVector) columns[columnIDs.get(ordinal)]); + int index = getColIndex(col); + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); } @Override public byte[] getBinary(int ordinal) { - BytesColumnVector col = (BytesColumnVector)columns[columnIDs.get(ordinal)]; - if (col.isRepeating) { - byte[] binary = new byte[col.length[0]]; - System.arraycopy(col.vector[0], col.start[0], binary, 0, binary.length); - return binary; - } else { - byte[] binary = new byte[col.length[rowId]]; - System.arraycopy(col.vector[rowId], col.start[rowId], binary, 0, binary.length); - return binary; - } + BytesColumnVector col = (BytesColumnVector) columns[columnIDs.get(ordinal)]; + int index = getColIndex(col); + byte[] binary = new byte[col.length[index]]; + System.arraycopy(col.vector[index], col.start[index], binary, 0, binary.length); + return binary; } + /** + * The data type CalendarInterval is not suppported due to the Hive version used by Spark + * internally. When we upgrade to newer Hive versions in the future, this is possibly to + * support. + */ @Override public CalendarInterval getInterval(int ordinal) { throw new NotImplementedException(); } + /** + * The data type CalendarInterval is not suppported due to the Hive version used by Spark + * internally. When we upgrade to newer Hive versions in the future, this is possibly to + * be supported. + */ @Override public InternalRow getStruct(int ordinal, int numFields) { throw new NotImplementedException(); } + /** + * The data type Array is not suppported due to the Hive version used by Spark internally. + * When we upgrade to newer Hive versions in the future, this is possibly to be supported. + */ @Override public ArrayData getArray(int ordinal) { throw new NotImplementedException(); } + /** + * The data type Map is not suppported due to the Hive version used by Spark internally. + * When we upgrade to newer Hive versions in the future, this is possibly to be supported. + */ @Override public MapData getMap(int ordinal) { throw new NotImplementedException(); diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 9de16f378ade..945e69a5794c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -152,7 +152,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) if (enableVectorizedReader) { - val conf = job.getConfiguration.asInstanceOf[JobConf] val columnIDs = requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava val orcRecordReader = From c24169d513c53eb9887f53749a4a6a4e51351667 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Nov 2016 11:02:08 +0000 Subject: [PATCH 14/20] Implement few newly added methods of InternalRow. --- .../ql/io/orc/VectorizedSparkOrcNewRecordReader.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index f001bd393d5f..3804ea69b06b 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -290,5 +290,15 @@ public Object get(int ordinal, DataType dataType) { public InternalRow copy() { throw new NotImplementedException(); } + + @Override + public void setNullAt(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public void update(int ordinal, Object value) { + throw new NotImplementedException(); + } } } From c2976788255588d66ad2527646e0719e32bdf182 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 Nov 2016 14:23:07 +0000 Subject: [PATCH 15/20] Support return Spark ColumnarBatch. --- .../execution/vectorized/ColumnVector.java | 14 +- .../execution/vectorized/ColumnarBatch.java | 12 + .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../hive/ql/io/orc/OrcColumnVector.java | 354 ++++++++++++++++++ .../orc/SparkVectorizedOrcRecordReader.java | 13 +- .../VectorizedSparkOrcNewRecordReader.java | 286 +++++--------- .../spark/sql/hive/orc/OrcFileFormat.scala | 57 ++- 7 files changed, 524 insertions(+), 214 deletions(-) create mode 100644 sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ff07940422a0..afbcd8710a9e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -588,7 +588,7 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public final Decimal getDecimal(int rowId, int precision, int scale) { + public Decimal getDecimal(int rowId, int precision, int scale) { if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -617,7 +617,7 @@ public final void putDecimal(int rowId, Decimal value, int precision) { /** * Returns the UTF8String for rowId. */ - public final UTF8String getUTF8String(int rowId) { + public UTF8String getUTF8String(int rowId) { if (dictionary == null) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); @@ -630,7 +630,7 @@ public final UTF8String getUTF8String(int rowId) { /** * Returns the byte array for rowId. */ - public final byte[] getBinary(int rowId) { + public byte[] getBinary(int rowId) { if (dictionary == null) { ColumnVector.Array array = getByteArray(rowId); byte[] bytes = new byte[array.length]; @@ -980,6 +980,14 @@ public ColumnVector getDictionaryIds() { return dictionaryIds; } + public ColumnVector() { + this.capacity = 0; + this.type = null; + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a6ce4c2edc23..a0e31e45cf9d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -466,6 +466,18 @@ public void filterNullsInColumn(int ordinal) { nullFilteredColumns.add(ordinal); } + /** + * A public Ctor which accepts allocated ColumnVectors. + */ + public ColumnarBatch(ColumnVector[] columns, int maxRows) { + this.columns = columns; + this.capacity = maxRows; + this.schema = null; + this.nullFilteredColumns = new HashSet<>(); + this.filteredRows = new boolean[maxRows]; + this.row = new Row(this); + } + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { this.schema = schema; this.capacity = maxRows; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5cd6ca6c3d6a..9776f8ff7679 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -252,7 +252,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc reader.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java new file mode 100644 index 000000000000..3bf9c52e9a19 --- /dev/null +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.orc; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column wrapping Hive's ColumnVector. This column vector is used to adapt Hive's ColumnVector + * with Spark ColumnBatch. + */ +public class OrcColumnVector extends org.apache.spark.sql.execution.vectorized.ColumnVector { + private ColumnVector col; + + public OrcColumnVector(ColumnVector col) { + this.col = col; + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return this.col.isRepeating ? 0 : rowId; + } + + @Override + public long valuesNativeAddress() { + throw new NotImplementedException(); + } + + @Override + public long nullsNativeAddress() { + throw new NotImplementedException(); + } + + @Override + public void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putNull(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putNulls(int rowId, int count) { + throw new NotImplementedException(); + } + + @Override + public void putNotNulls(int rowId, int count) { + throw new NotImplementedException(); + } + + @Override + public boolean isNullAt(int rowId) { + return col.isNull[getRowIndex(rowId)]; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new NotImplementedException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new NotImplementedException(); + } + + @Override + public boolean getBoolean(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return col.vector[getRowIndex(rowId)] > 0; + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new NotImplementedException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (byte) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new NotImplementedException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public short getShort(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (short) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new NotImplementedException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (int) col.vector[getRowIndex(rowId)]; + } + + /** + * Returns the dictionary Id for rowId. + */ + @Override + public int getDictId(int rowId) { + throw new NotImplementedException(); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new NotImplementedException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public long getLong(int rowId) { + LongColumnVector col = (LongColumnVector) this.col; + return (long) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public float getFloat(int rowId) { + DoubleColumnVector col = (DoubleColumnVector) this.col; + return (float) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int rowId) { + DoubleColumnVector col = (DoubleColumnVector) this.col; + return (double) col.vector[getRowIndex(rowId)]; + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + throw new NotImplementedException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new NotImplementedException(); + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + public void loadBytes(org.apache.spark.sql.execution.vectorized.ColumnVector.Array array) { + throw new NotImplementedException(); + } + + /** + * Returns the decimal for rowId. + */ + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + DecimalColumnVector col = (DecimalColumnVector) this.col; + int index = getRowIndex(rowId); + return Decimal.apply(col.vector[index].getHiveDecimal().bigDecimalValue(), precision, scale); + } + + /** + * Returns the UTF8String for rowId. + */ + @Override + public UTF8String getUTF8String(int rowId) { + BytesColumnVector col = (BytesColumnVector) this.col; + int index = getRowIndex(rowId); + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + /** + * Returns the byte array for rowId. + */ + @Override + public byte[] getBinary(int rowId) { + BytesColumnVector col = (BytesColumnVector) this.col; + int index = getRowIndex(rowId); + byte[] binary = new byte[col.length[index]]; + System.arraycopy(col.vector[index], col.start[index], binary, 0, binary.length); + return binary; + } + + // + // APIs dealing with Byte Arrays + // + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + throw new NotImplementedException(); + } + + @Override + protected void reserveInternal(int newCapacity) { + throw new NotImplementedException(); + } +} diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 7f0e94db402f..b840b0fab7a4 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -49,14 +49,17 @@ public class SparkVectorizedOrcRecordReader private final long length; private float progress = 0.0f; private ObjectInspector objectInspector; + private List columnIDs; SparkVectorizedOrcRecordReader( Reader file, Configuration conf, - FileSplit fileSplit) throws IOException { + FileSplit fileSplit, + List columnIDs) throws IOException { this.offset = fileSplit.getStart(); this.length = fileSplit.getLength(); this.objectInspector = file.getObjectInspector(); + this.columnIDs = columnIDs; this.reader = OrcInputFormat.createReaderFromFile(file, conf, this.offset, this.length); this.progress = reader.getProgress(); @@ -118,10 +121,10 @@ private ColumnVector createColumnVector(ObjectInspector inspector) { private VectorizedRowBatch constructVectorizedRowBatch(StructObjectInspector oi) { List fields = oi.getAllStructFieldRefs(); VectorizedRowBatch result = new VectorizedRowBatch(fields.size()); - int i = 0; - for (StructField field : fields) { - ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); - result.cols[i++] = createColumnVector(fieldObjectInspector); + for (int i = 0; i < columnIDs.size(); i++) { + int fieldIndex = columnIDs.get(i); + ObjectInspector fieldObjectInspector = fields.get(fieldIndex).getFieldObjectInspector(); + result.cols[fieldIndex] = createColumnVector(fieldObjectInspector); } return result; } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 3804ea69b06b..7b758c7fd8c1 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -25,7 +25,6 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; @@ -36,46 +35,88 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** - * A RecordReader that returns InternalRow for Spark SQL execution. - * This reader uses an internal reader that returns Hive's VectorizedRowBatch. An adapter - * class is used to return internal row by directly accessing data in column vectors. + * A RecordReader that returns ColumnarBatch for Spark SQL execution. + * This reader uses an internal reader that returns Hive's VectorizedRowBatch. */ public class VectorizedSparkOrcNewRecordReader - extends org.apache.hadoop.mapreduce.RecordReader { + extends org.apache.hadoop.mapreduce.RecordReader { private final org.apache.hadoop.mapred.RecordReader reader; private final int numColumns; - private VectorizedRowBatch internalValue; + private VectorizedRowBatch hiveBatch; private float progress = 0.0f; private List columnIDs; + private ColumnVector[] orcColumns; + private ColumnarBatch columnarBatch;; + + /** + * If true, this class returns batches instead of rows. + */ + private boolean returnColumnarBatch; + private long numRowsOfBatch = 0; private int indexOfRow = 0; - private final Row row; - public VectorizedSparkOrcNewRecordReader( Reader file, Configuration conf, FileSplit fileSplit, - List columnIDs) throws IOException { + List columnIDs, + StructType partitionColumns, + InternalRow partitionValues) throws IOException { List types = file.getTypes(); numColumns = (types.size() == 0) ? 0 : types.get(0).getSubtypesCount(); this.reader = new SparkVectorizedOrcRecordReader(file, conf, - new org.apache.hadoop.mapred.FileSplit(fileSplit)); + new org.apache.hadoop.mapred.FileSplit(fileSplit), columnIDs); + + this.hiveBatch = this.reader.createValue(); this.columnIDs = new ArrayList<>(columnIDs); - this.internalValue = this.reader.createValue(); + this.orcColumns = new ColumnVector[columnIDs.size() + partitionValues.numFields()]; + + // Allocate Spark ColumnVectors for data columns. + for (int i = 0; i < columnIDs.size(); i++) { + org.apache.hadoop.hive.ql.exec.vector.ColumnVector col = + this.hiveBatch.cols[columnIDs.get(i)]; + this.orcColumns[i] = new OrcColumnVector(col); + } + + // Allocate Spark ColumnVectors for partition columns. + if (partitionValues.numFields() > 0) { + int i = 0; + int base = columnIDs.size(); + for (StructField f : partitionColumns.fields()) { + // Use onheap for partition column vectors. + ColumnVector col = ColumnVector.allocate( + VectorizedRowBatch.DEFAULT_SIZE, + f.dataType(), + MemoryMode.ON_HEAP); + ColumnVectorUtils.populate(col, partitionValues, i); + col.setIsConstant(); + this.orcColumns[base + i] = col; + i++; + } + } + + // Allocate Spark ColumnBatch + this.columnarBatch = new ColumnarBatch(this.orcColumns, VectorizedRowBatch.DEFAULT_SIZE); + this.progress = reader.getProgress(); - this.row = new Row(this.internalValue.cols, this.columnIDs); } @Override @@ -83,20 +124,22 @@ public void close() throws IOException { reader.close(); } + /* + * Can be called before any rows are returned to enable returning columnar batches directly. + */ + public void enableReturningBatches() { + returnColumnarBatch = true; + } + @Override public NullWritable getCurrentKey() throws IOException, InterruptedException { return NullWritable.get(); } @Override - public InternalRow getCurrentValue() throws IOException, InterruptedException { - if (indexOfRow >= numRowsOfBatch) { - return null; - } - row.rowId = indexOfRow; - indexOfRow++; - - return row; + public Object getCurrentValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return this.columnarBatch; + return columnarBatch.getRow(indexOfRow - 1); } @Override @@ -111,194 +154,37 @@ public void initialize(InputSplit split, TaskAttemptContext context) @Override public boolean nextKeyValue() throws IOException, InterruptedException { - if (indexOfRow == numRowsOfBatch) { - if (reader.next(NullWritable.get(), internalValue)) { - if (internalValue.endOfFile) { - progress = 1.0f; - numRowsOfBatch = 0; - indexOfRow = 0; - return false; - } else { - assert internalValue.numCols == numColumns : "Incorrect number of columns in OrcBatch"; - numRowsOfBatch = internalValue.count(); - indexOfRow = 0; - progress = reader.getProgress(); - } - return true; - } else { - return false; - } + if (returnColumnarBatch) return nextBatch(); + + if (indexOfRow >= numRowsOfBatch) { + return nextBatch(); } else { - if (indexOfRow < numRowsOfBatch) { - return true; - } else { - return false; - } + indexOfRow++; + return true; } } /** - * Adapter class to return an internal row. + * Advances to the next batch of rows. Returns false if there are no more. */ - public static final class Row extends InternalRow { - protected int rowId; - private List columnIDs; - private final ColumnVector[] columns; - - private Row(ColumnVector[] columns, List columnIDs) { - this.columns = columns; - this.columnIDs = columnIDs; - } - - @Override - public int numFields() { return columnIDs.size(); } - - @Override - public boolean anyNull() { - for (int i = 0; i < columns.length; i++) { - if (columnIDs.contains(i)) { - boolean isRepeating = columns[i].isRepeating; - if (isRepeating && columns[i].isNull[0]) { - return true; - } else if (!isRepeating && columns[i].isNull[rowId]) { - return true; - } - } + public boolean nextBatch() throws IOException, InterruptedException { + if (reader.next(NullWritable.get(), hiveBatch)) { + if (hiveBatch.endOfFile) { + progress = 1.0f; + numRowsOfBatch = 0; + columnarBatch.setNumRows((int) numRowsOfBatch); + indexOfRow = 0; + return false; + } else { + assert hiveBatch.numCols == numColumns : "Incorrect number of columns in the current batch"; + numRowsOfBatch = hiveBatch.count(); + columnarBatch.setNumRows((int) numRowsOfBatch); + indexOfRow = 0; + progress = reader.getProgress(); + return true; } + } else { return false; } - - private int getColIndex(ColumnVector col) { - return col.isRepeating ? 0 : rowId; - } - - @Override - public boolean isNullAt(int ordinal) { - ColumnVector col = columns[columnIDs.get(ordinal)]; - return col.isNull[getColIndex(col)]; - } - - @Override - public boolean getBoolean(int ordinal) { - LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; - return col.vector[getColIndex(col)] > 0; - } - - @Override - public byte getByte(int ordinal) { - LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; - return (byte)col.vector[getColIndex(col)]; - } - - @Override - public short getShort(int ordinal) { - LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; - return (short)col.vector[getColIndex(col)]; - } - - @Override - public int getInt(int ordinal) { - LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; - return (int)col.vector[getColIndex(col)]; - } - - @Override - public long getLong(int ordinal) { - LongColumnVector col = (LongColumnVector) columns[columnIDs.get(ordinal)]; - return (long)col.vector[getColIndex(col)]; - } - - @Override - public float getFloat(int ordinal) { - DoubleColumnVector col = (DoubleColumnVector) columns[columnIDs.get(ordinal)]; - return (float)col.vector[getColIndex(col)]; - } - - @Override - public double getDouble(int ordinal) { - DoubleColumnVector col = (DoubleColumnVector) columns[columnIDs.get(ordinal)]; - return (double)col.vector[getColIndex(col)]; - } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - DecimalColumnVector col = (DecimalColumnVector) columns[columnIDs.get(ordinal)]; - int index = getColIndex(col); - return Decimal.apply(col.vector[index].getHiveDecimal().bigDecimalValue(), precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - BytesColumnVector col = ((BytesColumnVector) columns[columnIDs.get(ordinal)]); - int index = getColIndex(col); - return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); - } - - @Override - public byte[] getBinary(int ordinal) { - BytesColumnVector col = (BytesColumnVector) columns[columnIDs.get(ordinal)]; - int index = getColIndex(col); - byte[] binary = new byte[col.length[index]]; - System.arraycopy(col.vector[index], col.start[index], binary, 0, binary.length); - return binary; - } - - /** - * The data type CalendarInterval is not suppported due to the Hive version used by Spark - * internally. When we upgrade to newer Hive versions in the future, this is possibly to - * support. - */ - @Override - public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); - } - - /** - * The data type CalendarInterval is not suppported due to the Hive version used by Spark - * internally. When we upgrade to newer Hive versions in the future, this is possibly to - * be supported. - */ - @Override - public InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); - } - - /** - * The data type Array is not suppported due to the Hive version used by Spark internally. - * When we upgrade to newer Hive versions in the future, this is possibly to be supported. - */ - @Override - public ArrayData getArray(int ordinal) { - throw new NotImplementedException(); - } - - /** - * The data type Map is not suppported due to the Hive version used by Spark internally. - * When we upgrade to newer Hive versions in the future, this is possibly to be supported. - */ - @Override - public MapData getMap(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); - } - - @Override - public InternalRow copy() { - throw new NotImplementedException(); - } - - @Override - public void setNullAt(int ordinal) { - throw new NotImplementedException(); - } - - @Override - public void update(int ordinal, Object value) { - throw new NotImplementedException(); - } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 7fb212a69070..3d9e0b5e530c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -37,6 +37,7 @@ import org.apache.spark.TaskContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} @@ -109,6 +110,20 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable true } + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + // For Orc data source, `buildReader` already handles partition values appending. Here we + // simply delegate to `buildReader`. + buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + } + override def buildReader( sparkSession: SparkSession, dataSchema: StructType, @@ -128,11 +143,15 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.orcVectorizedReaderEnabled && - dataSchema.forall(f => f.dataType.isInstanceOf[AtomicType] && + resultSchema.forall(f => f.dataType.isInstanceOf[AtomicType] && !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sparkSession, resultSchema) + (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -163,11 +182,17 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val columnIDs = requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava val orcRecordReader = - new VectorizedSparkOrcNewRecordReader(orcReader, conf, fileSplit, columnIDs) - val recordsIterator = new RecordReaderIterator[InternalRow](orcRecordReader) + new VectorizedSparkOrcNewRecordReader( + orcReader, conf, fileSplit, columnIDs, partitionSchema, file.partitionValues) + + if (returningBatch) { + orcRecordReader.enableReturningBatches() + } + val recordsIterator = new RecordReaderIterator(orcRecordReader) Option(TaskContext.get()) .foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) - recordsIterator + // VectorizedSparkOrcNewRecordReader appends the columns internally to avoid another copy. + recordsIterator.asInstanceOf[Iterator[InternalRow]] } else { val orcRecordReader = new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) @@ -176,15 +201,37 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable .foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( + val iter = OrcRelation.unwrapOrcStructs( conf, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), recordsIterator) + + if (partitionSchema.length == 0) { + // There is no partition columns + iter + } else { + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + iter.map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) + } } } } } + + /** + * Returns whether the reader will return the rows as batch or not. + */ + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(f => f.dataType.isInstanceOf[AtomicType] && + !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) From 8638a0e2b98719770bff50804dcc0fc0e83674ad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Nov 2016 08:36:56 +0000 Subject: [PATCH 16/20] Add test for OrcColumnVector. --- .../hive/ql/io/orc/OrcColumnVector.java | 8 +- .../orc/SparkVectorizedOrcRecordReader.java | 2 +- .../orc/vectorized/OrcColumnVectorSuite.scala | 281 ++++++++++++++++++ 3 files changed, 287 insertions(+), 4 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java index 3bf9c52e9a19..643be2e126ee 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java @@ -17,8 +17,8 @@ package org.apache.hadoop.hive.ql.io.orc; import org.apache.commons.lang.NotImplementedException; -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; @@ -27,8 +27,10 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column wrapping Hive's ColumnVector. This column vector is used to adapt Hive's ColumnVector - * with Spark ColumnBatch. + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarBatch. This class inherits Spark's vectorized.ColumnVector class, but all data + * setter methods (e.g., putInt) in Spark vectorized.ColumnVector are not implemented. */ public class OrcColumnVector extends org.apache.spark.sql.execution.vectorized.ColumnVector { private ColumnVector col; diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index b840b0fab7a4..698a26292ff0 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -22,8 +22,8 @@ import java.util.List; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala new file mode 100644 index 000000000000..9a6c38a5ec82 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc.vectorized + +import scala.util.Random + +import org.apache.commons.lang.NotImplementedException +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector +import org.apache.hadoop.hive.ql.io.orc.OrcColumnVector +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class OrcColumnVectorSuite extends SparkFunSuite { + // This helper method access the internal vector of Hive's ColumnVector classes. + private def fillColumnVector[T](col: ColumnVector, values: Seq[T]): Unit = { + col match { + case lv: LongColumnVector => + assert(lv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + lv.vector(idx) = v.asInstanceOf[Long] + } + case bv: BytesColumnVector => + assert(bv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + val array = v.asInstanceOf[Seq[Byte]].toArray + bv.vector(idx) = array + bv.start(idx) = 0 + bv.length(idx) = array.length + } + case dv: DoubleColumnVector => + assert(dv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + dv.vector(idx) = v.asInstanceOf[Double] + } + case dv: DecimalColumnVector => + assert(dv.vector.length == values.length) + values.zipWithIndex.map { case (v, idx) => + val writable = new HiveDecimalWritable(v.asInstanceOf[HiveDecimal]) + dv.vector(idx) = writable + } + case _ => + assert(false, s"${col.getClass.getName} is not supported") + } + } + + private def dataGenerator[T](num: Int)(randomize: (Int) => T): Seq[T] = { + (0 until num).map { i => + randomize(i) + } + } + + private def getAllRowsFromColumn[T] + (rowNum: Int, col: OrcColumnVector)(accessor: (OrcColumnVector, Int) => T): Seq[T] = { + (0 until rowNum).map { rowId => + accessor(col, rowId) + } + } + + private def testLongColumnVector[T](num: Int) + (genExpected: (Seq[Long] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val data = dataGenerator(num) { _ => + random.nextLong() + } + + val lv = new LongColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv) + val actual = genActual(orcCol, num) + assert(actual === expected) + } + + private def testDoubleColumnVector[T](num: Int) + (genExpected: (Seq[Double] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val data = dataGenerator(num) { _ => + random.nextDouble() + } + + val lv = new DoubleColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv) + val actual = genActual(orcCol, num) + assert(actual === expected) + } + + private def testBytesColumnVector[T](num: Int) + (genExpected: (Seq[Seq[Byte]] => Seq[T])) + (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val schema = new StructType().add("binary", BinaryType, false) + val data = dataGenerator(num) { _ => + RandomDataGenerator.randomRow(random, schema).getAs[Array[Byte]](0).toSeq + } + + val lv = new BytesColumnVector(num) + fillColumnVector(lv, data) + assert(data === lv.vector) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv) + val actual = genActual(orcCol, num) + actual.zip(expected).foreach { case (a, e) => + assert(a === e) + } + } + + private def testDecimalColumnVector(num: Int) + (genExpected: (Seq[HiveDecimal] => Seq[java.math.BigDecimal])) + (genActual: (OrcColumnVector, Int, Int, Int) => Seq[java.math.BigDecimal]): Unit = { + val seed = System.currentTimeMillis() + val random = new Random(seed) + + val decimalTypes = Seq(DecimalType.ShortDecimal, DecimalType.IntDecimal, + DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal) + + decimalTypes.foreach { decimalType => + val schema = new StructType().add("decimal", decimalType, false) + val data = dataGenerator(num) { _ => + val javaDecimal = RandomDataGenerator.randomRow(random, schema).getDecimal(0) + HiveDecimal.create(javaDecimal) + } + + val lv = new DecimalColumnVector(num, decimalType.precision, decimalType.scale) + fillColumnVector(lv, data) + assert(data === lv.vector.map(_.getHiveDecimal(decimalType.precision, decimalType.scale))) + + val expected = genExpected(data) + + val orcCol = new OrcColumnVector(lv) + val actual = genActual(orcCol, num, decimalType.precision, decimalType.scale) + actual.zip(expected).foreach { case (a, e) => + assert(a.compareTo(e) == 0) + } + } + } + + test("Hive LongColumnVector: Boolean") { + val genExpected = (data: Seq[Long]) => data.map(_ > 0) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getBoolean(rowId) + } + } + testLongColumnVector(100)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Int") { + val genExpected = (data: Seq[Long]) => data.map(_.toInt) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getInt(rowId) + } + } + testLongColumnVector(100)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Byte") { + val genExpected = (data: Seq[Long]) => data.map(_.toByte) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getByte(rowId) + } + } + testLongColumnVector(100)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Short") { + val genExpected = (data: Seq[Long]) => data.map(_.toShort) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getShort(rowId) + } + } + testLongColumnVector(100)(genExpected)(genActual) + } + + test("Hive LongColumnVector: Long") { + val genExpected = (data: Seq[Long]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getLong(rowId) + } + } + testLongColumnVector(100)(genExpected)(genActual) + } + + test("Hive DoubleColumnVector: Float") { + val genExpected = (data: Seq[Double]) => data.map(_.toFloat) + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getFloat(rowId) + } + } + testDoubleColumnVector(100)(genExpected)(genActual) + } + + test("Hive DoubleColumnVector: Double") { + val genExpected = (data: Seq[Double]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getDouble(rowId) + } + } + testDoubleColumnVector(100)(genExpected)(genActual) + } + + test("Hive BytesColumnVector: Binary") { + val genExpected = (data: Seq[Seq[Byte]]) => data + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getBinary(rowId).toSeq + } + } + testBytesColumnVector(100)(genExpected)(genActual) + } + + test("Hive BytesColumnVector: String") { + val genExpected = (data: Seq[Seq[Byte]]) => { + data.map(bytes => UTF8String.fromBytes(bytes.toArray, 0, bytes.length)) + } + + val genActual = (orcCol: OrcColumnVector, num: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getUTF8String(rowId) + } + } + testBytesColumnVector(100)(genExpected)(genActual) + } + + test("Hive DecimalColumnVector") { + val genExpected = (data: Seq[HiveDecimal]) => data.map(_.bigDecimalValue()) + val genActual = (orcCol: OrcColumnVector, num: Int, precision: Int, scale: Int) => { + getAllRowsFromColumn(num, orcCol) { (col, rowId) => + col.getDecimal(rowId, precision, scale).toJavaBigDecimal + } + } + testDecimalColumnVector(100)(genExpected)(genActual) + } +} From 55bb19f91658767acf08e06ee7e64db27a7222aa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 24 Nov 2016 13:02:45 +0000 Subject: [PATCH 17/20] Expand OrcQuerySuite to test vectorized Orc reader. --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 40 ++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9776f8ff7679..5cd6ca6c3d6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -252,7 +252,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc reader.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index b8761e9de288..63f1022063d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets import java.sql.Timestamp +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} import org.scalatest.BeforeAndAfterAll @@ -54,7 +56,43 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) -class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { +class OrcQuerySuite extends OrcQueryBase { + override protected val value = "false" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } +} + +class OrcQueryVectorizedSuite extends OrcQueryBase { + override protected val value = "true" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } +} + +abstract class OrcQueryBase extends QueryTest with BeforeAndAfterAll with OrcTest { + protected val key = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + protected val value: String test("Read/write All Types") { val data = (0 to 255).map { i => From 160e92470136282ae3e94dc82ed41571a601017f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Nov 2016 05:38:39 +0000 Subject: [PATCH 18/20] Add test for VectorizedSparkOrcNewRecordReader. --- .../execution/vectorized/ColumnVector.java | 4 +- .../hive/ql/io/orc/OrcColumnVector.java | 4 +- .../VectorizedSparkOrcNewRecordReader.java | 3 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 28 ++- .../orc/vectorized/OrcColumnVectorSuite.scala | 32 ++-- ...ctorizedSparkOrcNewRecordReaderSuite.scala | 180 ++++++++++++++++++ 6 files changed, 223 insertions(+), 28 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index afbcd8710a9e..fa1f27a15ef1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -980,9 +980,9 @@ public ColumnVector getDictionaryIds() { return dictionaryIds; } - public ColumnVector() { + public ColumnVector(DataType type) { this.capacity = 0; - this.type = null; + this.type = type; this.childColumns = null; this.resultArray = null; this.resultStruct = null; diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java index 643be2e126ee..e80ae2ee8499 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; @@ -35,7 +36,8 @@ public class OrcColumnVector extends org.apache.spark.sql.execution.vectorized.ColumnVector { private ColumnVector col; - public OrcColumnVector(ColumnVector col) { + public OrcColumnVector(ColumnVector col, DataType type) { + super(type); this.col = col; } diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java index 7b758c7fd8c1..df4295d76bd6 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java @@ -77,6 +77,7 @@ public VectorizedSparkOrcNewRecordReader( Configuration conf, FileSplit fileSplit, List columnIDs, + StructType requiredSchema, StructType partitionColumns, InternalRow partitionValues) throws IOException { List types = file.getTypes(); @@ -93,7 +94,7 @@ public VectorizedSparkOrcNewRecordReader( for (int i = 0; i < columnIDs.size(); i++) { org.apache.hadoop.hive.ql.exec.vector.ColumnVector col = this.hiveBatch.cols[columnIDs.get(i)]; - this.orcColumns[i] = new OrcColumnVector(col); + this.orcColumns[i] = new OrcColumnVector(col, requiredSchema.fields()[i].dataType()); } // Allocate Spark ColumnVectors for partition columns. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3d9e0b5e530c..02946425ae01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -181,9 +181,14 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (enableVectorizedReader) { val columnIDs = requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava - val orcRecordReader = - new VectorizedSparkOrcNewRecordReader( - orcReader, conf, fileSplit, columnIDs, partitionSchema, file.partitionValues) + val orcRecordReader = new VectorizedSparkOrcNewRecordReader( + orcReader, + conf, + fileSplit, + columnIDs, + requiredSchema, + partitionSchema, + file.partitionValues) if (returningBatch) { orcRecordReader.enableReturningBatches() @@ -226,11 +231,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable * Returns whether the reader will return the rows as batch or not. */ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { - val conf = sparkSession.sessionState.conf - conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && - schema.length <= conf.wholeStageMaxNumFields && - schema.forall(f => f.dataType.isInstanceOf[AtomicType] && - !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + OrcRelation.supportBatch(sparkSession, schema) } } @@ -374,4 +375,15 @@ private[orc] object OrcRelation extends HiveInspectors { val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } + + /** + * Returns whether the reader will return the rows as batch or not. + */ + def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(f => f.dataType.isInstanceOf[AtomicType] && + !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala index 9a6c38a5ec82..c61956c80924 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala @@ -80,7 +80,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { } } - private def testLongColumnVector[T](num: Int) + private def testLongColumnVector[T](num: Int, dt: DataType) (genExpected: (Seq[Long] => Seq[T])) (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { val seed = System.currentTimeMillis() @@ -96,12 +96,12 @@ class OrcColumnVectorSuite extends SparkFunSuite { val expected = genExpected(data) - val orcCol = new OrcColumnVector(lv) + val orcCol = new OrcColumnVector(lv, dt) val actual = genActual(orcCol, num) assert(actual === expected) } - private def testDoubleColumnVector[T](num: Int) + private def testDoubleColumnVector[T](num: Int, dt: DataType) (genExpected: (Seq[Double] => Seq[T])) (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { val seed = System.currentTimeMillis() @@ -117,12 +117,12 @@ class OrcColumnVectorSuite extends SparkFunSuite { val expected = genExpected(data) - val orcCol = new OrcColumnVector(lv) + val orcCol = new OrcColumnVector(lv, dt) val actual = genActual(orcCol, num) assert(actual === expected) } - private def testBytesColumnVector[T](num: Int) + private def testBytesColumnVector[T](num: Int, dt: DataType) (genExpected: (Seq[Seq[Byte]] => Seq[T])) (genActual: (OrcColumnVector, Int) => Seq[T]): Unit = { val seed = System.currentTimeMillis() @@ -139,7 +139,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { val expected = genExpected(data) - val orcCol = new OrcColumnVector(lv) + val orcCol = new OrcColumnVector(lv, dt) val actual = genActual(orcCol, num) actual.zip(expected).foreach { case (a, e) => assert(a === e) @@ -168,7 +168,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { val expected = genExpected(data) - val orcCol = new OrcColumnVector(lv) + val orcCol = new OrcColumnVector(lv, decimalType) val actual = genActual(orcCol, num, decimalType.precision, decimalType.scale) actual.zip(expected).foreach { case (a, e) => assert(a.compareTo(e) == 0) @@ -183,7 +183,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getBoolean(rowId) } } - testLongColumnVector(100)(genExpected)(genActual) + testLongColumnVector(100, BooleanType)(genExpected)(genActual) } test("Hive LongColumnVector: Int") { @@ -193,7 +193,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getInt(rowId) } } - testLongColumnVector(100)(genExpected)(genActual) + testLongColumnVector(100, IntegerType)(genExpected)(genActual) } test("Hive LongColumnVector: Byte") { @@ -203,7 +203,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getByte(rowId) } } - testLongColumnVector(100)(genExpected)(genActual) + testLongColumnVector(100, ByteType)(genExpected)(genActual) } test("Hive LongColumnVector: Short") { @@ -213,7 +213,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getShort(rowId) } } - testLongColumnVector(100)(genExpected)(genActual) + testLongColumnVector(100, ShortType)(genExpected)(genActual) } test("Hive LongColumnVector: Long") { @@ -223,7 +223,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getLong(rowId) } } - testLongColumnVector(100)(genExpected)(genActual) + testLongColumnVector(100, LongType)(genExpected)(genActual) } test("Hive DoubleColumnVector: Float") { @@ -233,7 +233,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getFloat(rowId) } } - testDoubleColumnVector(100)(genExpected)(genActual) + testDoubleColumnVector(100, FloatType)(genExpected)(genActual) } test("Hive DoubleColumnVector: Double") { @@ -243,7 +243,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getDouble(rowId) } } - testDoubleColumnVector(100)(genExpected)(genActual) + testDoubleColumnVector(100, DoubleType)(genExpected)(genActual) } test("Hive BytesColumnVector: Binary") { @@ -253,7 +253,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getBinary(rowId).toSeq } } - testBytesColumnVector(100)(genExpected)(genActual) + testBytesColumnVector(100, BinaryType)(genExpected)(genActual) } test("Hive BytesColumnVector: String") { @@ -266,7 +266,7 @@ class OrcColumnVectorSuite extends SparkFunSuite { col.getUTF8String(rowId) } } - testBytesColumnVector(100)(genExpected)(genActual) + testBytesColumnVector(100, StringType)(genExpected)(genActual) } test("Hive DecimalColumnVector") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala new file mode 100644 index 000000000000..c020bfeb8e03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc.vectorized + +import java.io.File +import java.net.URI +import java.nio.charset.StandardCharsets +import java.sql.Timestamp + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader, VectorizedSparkOrcNewRecordReader} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} +import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} +import org.apache.spark.sql.hive.orc._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAfterAll with OrcTest { + val key = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + val value = "true" + private var currentValue: Option[String] = None + + override protected def beforeAll(): Unit = { + currentValue = Try(spark.conf.get(key)).toOption + spark.conf.set(key, value) + } + + override protected def afterAll(): Unit = { + currentValue match { + case Some(value) => spark.conf.set(key, value) + case None => spark.conf.unset(key) + } + } + + private def getVectorizedOrcReader( + filepath: String, + requiredSchema: StructType, + partitionSchema: StructType, + partitionValues: InternalRow): VectorizedSparkOrcNewRecordReader = { + val conf = new Configuration() + val physicalSchema = OrcFileOperator.readSchema(Seq(filepath), Some(conf)).get + OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + val orcReader = OrcFileOperator.getFileReader(filepath, Some(conf)).get + + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + + val file = new File(filepath) + val fileSplit = new FileSplit(new Path(new URI(filepath)), 0, file.length(), Array.empty) + val columnIDs = + requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava + val orcRecordReader = + new VectorizedSparkOrcNewRecordReader( + orcReader, conf, fileSplit, columnIDs, requiredSchema, partitionSchema, partitionValues) + + val returningBatch: Boolean = OrcRelation.supportBatch(spark, resultSchema) + if (returningBatch) { + orcRecordReader.enableReturningBatches() + } + orcRecordReader + } + + test("Read/write types: batch processing") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + s"$i".getBytes(StandardCharsets.UTF_8)) + } + val dataRows = data.map { x => + InternalRow(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9) + } + + withOrcFile(data) { file => + val requiredSchema = new StructType() + .add("_1", StringType) + .add("_2", IntegerType) + .add("_3", LongType) + .add("_4", FloatType) + .add("_5", DoubleType) + .add("_6", ShortType) + .add("_7", ByteType) + .add("_8", BooleanType) + .add("_9", BinaryType) + val partitionSchema = StructType(Nil) + val partitionValues = InternalRow.empty + val reader = getVectorizedOrcReader(file, requiredSchema, partitionSchema, partitionValues) + assert(reader.nextKeyValue()) + + // The schema is supported by ColumnarBatch. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch]) + + val batch = nextValue.asInstanceOf[ColumnarBatch] + + assert(batch.numCols() == 9) + assert(batch.numRows() == 256) + assert(batch.numValidRows() == 256) + assert(batch.capacity() > 0) + assert(batch.rowIterator().hasNext == true) + + assert(batch.column(0).getUTF8String(0).toString() == "0") + assert(batch.column(0).isNullAt(0) == false) + assert(batch.column(1).getInt(0) == 0) + assert(batch.column(1).isNullAt(0) == false) + assert(batch.column(4).getDouble(0) == 0.0) + assert(batch.column(4).isNullAt(0) == false) + + val it = batch.rowIterator() + dataRows.map { row => + assert(it.hasNext()) + assert(it.next().copy() == row) + } + } + } + + test("Read/write types: no batch processing") { + val colNum = spark.conf.get(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key).toInt + 1 + val data = (0 to 255).map { i => + Row.fromSeq((i to colNum + i - 1).toSeq) + } + + withTempPath { file => + val fields = (1 to colNum).map { idx => + StructField(s"_$idx", IntegerType) + } + val requiredSchema = StructType(fields.toArray) + spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) + .write.orc(file.getCanonicalPath) + val path = file.getCanonicalPath + + val partitionSchema = StructType(Nil) + val partitionValues = InternalRow.empty + val reader = getVectorizedOrcReader(path, requiredSchema, partitionSchema, partitionValues) + assert(reader.nextKeyValue()) + + // Column number exceeds SQLConf.WHOLESTAGE_MAX_NUM_FIELDS, + // so batch processing is not supported. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch.Row]) + + val batchRow = nextValue.asInstanceOf[ColumnarBatch.Row] + + assert(batchRow.numFields() == colNum) + + var idx = 0 + while (reader.nextKeyValue()) { + val row = data(idx) + val batchRow = reader.getCurrentValue().asInstanceOf[ColumnarBatch.Row].copy() + assert(batchRow.toSeq(requiredSchema) === row.toSeq) + idx += 1 + } + } + } +} From bd15842e7b146cf292d9c29b896362412b22b8c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Nov 2016 07:06:26 +0000 Subject: [PATCH 19/20] Add partition column test. --- ...ctorizedSparkOrcNewRecordReaderSuite.scala | 202 ++++++++++-------- 1 file changed, 116 insertions(+), 86 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala index c020bfeb8e03..b0300c61efe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala @@ -32,13 +32,9 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch -import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} import org.apache.spark.sql.hive.orc._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -87,93 +83,127 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft orcRecordReader } - test("Read/write types: batch processing") { - val data = (0 to 255).map { i => - (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, - s"$i".getBytes(StandardCharsets.UTF_8)) - } - val dataRows = data.map { x => - InternalRow(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9) - } + val partitionSchemas = Seq( + StructType(Nil), + new StructType().add("p1", IntegerType).add("p2", LongType)) + + val partitionValues = Seq( + InternalRow.empty, + InternalRow(1, 2L)) + + val partitionSettings = partitionSchemas.zip(partitionValues) - withOrcFile(data) { file => - val requiredSchema = new StructType() - .add("_1", StringType) - .add("_2", IntegerType) - .add("_3", LongType) - .add("_4", FloatType) - .add("_5", DoubleType) - .add("_6", ShortType) - .add("_7", ByteType) - .add("_8", BooleanType) - .add("_9", BinaryType) - val partitionSchema = StructType(Nil) - val partitionValues = InternalRow.empty - val reader = getVectorizedOrcReader(file, requiredSchema, partitionSchema, partitionValues) - assert(reader.nextKeyValue()) - - // The schema is supported by ColumnarBatch. - val nextValue = reader.getCurrentValue() - assert(nextValue.isInstanceOf[ColumnarBatch]) - - val batch = nextValue.asInstanceOf[ColumnarBatch] - - assert(batch.numCols() == 9) - assert(batch.numRows() == 256) - assert(batch.numValidRows() == 256) - assert(batch.capacity() > 0) - assert(batch.rowIterator().hasNext == true) - - assert(batch.column(0).getUTF8String(0).toString() == "0") - assert(batch.column(0).isNullAt(0) == false) - assert(batch.column(1).getInt(0) == 0) - assert(batch.column(1).isNullAt(0) == false) - assert(batch.column(4).getDouble(0) == 0.0) - assert(batch.column(4).isNullAt(0) == false) - - val it = batch.rowIterator() - dataRows.map { row => - assert(it.hasNext()) - assert(it.next().copy() == row) + partitionSettings.map { case (partitionSchema, partitionValue) => + val doPartition = partitionValue != InternalRow.empty + val partitionTitle = if (doPartition) "with partition" else "" + + test(s"Read/write types: batch processing $partitionTitle") { + val colNum = if (doPartition) 11 else 9 + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + s"$i".getBytes(StandardCharsets.UTF_8)) + } + val expectedRows = data.map { x => + val data = Seq(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9) + val dataWithPartition = if (doPartition) { + data ++ Seq(1, 2L) + } else { + data + } + InternalRow.fromSeq(dataWithPartition) } - } - } - test("Read/write types: no batch processing") { - val colNum = spark.conf.get(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key).toInt + 1 - val data = (0 to 255).map { i => - Row.fromSeq((i to colNum + i - 1).toSeq) + withOrcFile(data) { file => + val requiredSchema = new StructType() + .add("_1", StringType) + .add("_2", IntegerType) + .add("_3", LongType) + .add("_4", FloatType) + .add("_5", DoubleType) + .add("_6", ShortType) + .add("_7", ByteType) + .add("_8", BooleanType) + .add("_9", BinaryType) + val reader = getVectorizedOrcReader(file, requiredSchema, partitionSchema, partitionValue) + assert(reader.nextKeyValue()) + + // The schema is supported by ColumnarBatch. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch]) + + val batch = nextValue.asInstanceOf[ColumnarBatch] + + assert(batch.numCols() == colNum) + assert(batch.numRows() == 256) + assert(batch.numValidRows() == 256) + assert(batch.capacity() > 0) + assert(batch.rowIterator().hasNext == true) + + assert(batch.column(0).getUTF8String(0).toString() == "0") + assert(batch.column(0).isNullAt(0) == false) + assert(batch.column(1).getInt(0) == 0) + assert(batch.column(1).isNullAt(0) == false) + assert(batch.column(4).getDouble(0) == 0.0) + assert(batch.column(4).isNullAt(0) == false) + + val it = batch.rowIterator() + expectedRows.map { row => + assert(it.hasNext()) + assert(it.next().copy() == row) + } + } } - withTempPath { file => - val fields = (1 to colNum).map { idx => - StructField(s"_$idx", IntegerType) + test(s"Read/write types: no batch processing $partitionTitle") { + val dataColNum = spark.conf.get(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key).toInt + 1 + val colNum = if (doPartition) { + dataColNum + 2 + } else { + dataColNum + } + + val data = (0 to 255).map { i => + Row.fromSeq((i to dataColNum + i - 1).toSeq) } - val requiredSchema = StructType(fields.toArray) - spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) - .write.orc(file.getCanonicalPath) - val path = file.getCanonicalPath - - val partitionSchema = StructType(Nil) - val partitionValues = InternalRow.empty - val reader = getVectorizedOrcReader(path, requiredSchema, partitionSchema, partitionValues) - assert(reader.nextKeyValue()) - - // Column number exceeds SQLConf.WHOLESTAGE_MAX_NUM_FIELDS, - // so batch processing is not supported. - val nextValue = reader.getCurrentValue() - assert(nextValue.isInstanceOf[ColumnarBatch.Row]) - - val batchRow = nextValue.asInstanceOf[ColumnarBatch.Row] - - assert(batchRow.numFields() == colNum) - - var idx = 0 - while (reader.nextKeyValue()) { - val row = data(idx) - val batchRow = reader.getCurrentValue().asInstanceOf[ColumnarBatch.Row].copy() - assert(batchRow.toSeq(requiredSchema) === row.toSeq) - idx += 1 + + val expectedRows = data.map { x => + val data = x.toSeq + val dataWithPartition = if (doPartition) { + data ++ Seq(1, 2L) + } else { + data + } + InternalRow.fromSeq(dataWithPartition) + } + + withTempPath { file => + val fields = (1 to dataColNum).map { idx => + StructField(s"_$idx", IntegerType) + } + val requiredSchema = StructType(fields.toArray) + spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) + .write.orc(file.getCanonicalPath) + val path = file.getCanonicalPath + + val reader = getVectorizedOrcReader(path, requiredSchema, partitionSchema, partitionValue) + assert(reader.nextKeyValue()) + + // Column number exceeds SQLConf.WHOLESTAGE_MAX_NUM_FIELDS, + // so batch processing is not supported. + val nextValue = reader.getCurrentValue() + assert(nextValue.isInstanceOf[ColumnarBatch.Row]) + + val batchRow = nextValue.asInstanceOf[ColumnarBatch.Row] + + assert(batchRow.numFields() == colNum) + + var idx = 0 + while (reader.nextKeyValue()) { + val row = expectedRows(idx) + val batchRow = reader.getCurrentValue().asInstanceOf[ColumnarBatch.Row].copy() + assert(batchRow === row) + idx += 1 + } } } } From 0ac61b794146634887d184076aababfd25a22ff5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Nov 2016 08:59:51 +0000 Subject: [PATCH 20/20] Expand tests. --- .../orc/SparkVectorizedOrcRecordReader.java | 4 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 6 +- ...ctorizedSparkOrcNewRecordReaderSuite.scala | 131 +++++++++++++++--- 3 files changed, 119 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java index 698a26292ff0..f220d4d6e6fa 100644 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java +++ b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkVectorizedOrcRecordReader.java @@ -51,7 +51,7 @@ public class SparkVectorizedOrcRecordReader private ObjectInspector objectInspector; private List columnIDs; - SparkVectorizedOrcRecordReader( + public SparkVectorizedOrcRecordReader( Reader file, Configuration conf, FileSplit fileSplit, @@ -100,7 +100,7 @@ private ColumnVector createColumnVector(ObjectInspector inspector) { return new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, decimalTypeInfo.precision(), decimalTypeInfo.scale()); default: - throw new RuntimeException("Vectorizaton is not supported for datatype:" + throw new RuntimeException("Vectorization is not supported for datatype:" + primitiveTypeInfo.getPrimitiveCategory() + ". " + "Please disable spark.sql.orc.enableVectorizedReader."); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 02946425ae01..c572c7f8da9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.{AtomicType, DateType, StructType, TimestampType} +import org.apache.spark.sql.types.{AtomicType, StructType, TimestampType} import org.apache.spark.util.SerializableConfiguration /** @@ -147,7 +147,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.orcVectorizedReaderEnabled && resultSchema.forall(f => f.dataType.isInstanceOf[AtomicType] && - !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + !f.dataType.isInstanceOf[TimestampType]) // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -384,6 +384,6 @@ private[orc] object OrcRelation extends HiveInspectors { conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && schema.forall(f => f.dataType.isInstanceOf[AtomicType] && - !f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType]) + !f.dataType.isInstanceOf[TimestampType]) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala index b0300c61efe4..73ce68a8aeba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/VectorizedSparkOrcNewRecordReaderSuite.scala @@ -20,19 +20,22 @@ package org.apache.spark.sql.hive.orc.vectorized import java.io.File import java.net.URI import java.nio.charset.StandardCharsets -import java.sql.Timestamp +import java.sql.Date import scala.collection.JavaConverters._ -import scala.util.Try +import scala.util.{Random, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader, VectorizedSparkOrcNewRecordReader} +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch +import org.apache.hadoop.hive.ql.io.orc.{Reader, SparkVectorizedOrcRecordReader, VectorizedSparkOrcNewRecordReader} +import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.vectorized.ColumnarBatch import org.apache.spark.sql.hive.orc._ import org.apache.spark.sql.internal.SQLConf @@ -56,33 +59,57 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft } } - private def getVectorizedOrcReader( + private def prepareParametersForReader( filepath: String, - requiredSchema: StructType, - partitionSchema: StructType, - partitionValues: InternalRow): VectorizedSparkOrcNewRecordReader = { + requiredSchema: StructType): (Configuration, Reader, FileSplit, java.util.List[Integer]) = { val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(filepath), Some(conf)).get OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) val orcReader = OrcFileOperator.getFileReader(filepath, Some(conf)).get - val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) - val file = new File(filepath) val fileSplit = new FileSplit(new Path(new URI(filepath)), 0, file.length(), Array.empty) val columnIDs = requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava - val orcRecordReader = + + (conf, orcReader, fileSplit, columnIDs) + } + + private def getOrcRecordReader( + filepath: String, + requiredSchema: StructType): SparkVectorizedOrcRecordReader = { + val (conf, orcReader, fileSplit, columnIDs) = + prepareParametersForReader(filepath, requiredSchema) + new SparkVectorizedOrcRecordReader( + orcReader, + conf, + new org.apache.hadoop.mapred.FileSplit(fileSplit), + columnIDs) + } + + private def getVectorizedOrcReader( + filepath: String, + requiredSchema: StructType, + partitionSchema: StructType, + partitionValues: InternalRow): VectorizedSparkOrcNewRecordReader = { + val (conf, orcReader, fileSplit, columnIDs) = + prepareParametersForReader(filepath, requiredSchema) + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val reader = new VectorizedSparkOrcNewRecordReader( orcReader, conf, fileSplit, columnIDs, requiredSchema, partitionSchema, partitionValues) val returningBatch: Boolean = OrcRelation.supportBatch(spark, resultSchema) if (returningBatch) { - orcRecordReader.enableReturningBatches() + reader.enableReturningBatches() } - orcRecordReader + reader } + // Test data reading with VectorizedSparkOrcNewRecordReader: + // VectorizedSparkOrcNewRecordReader supports batch processing with Spark's ColumnarBatch. + // We test it with/without partitions. + val partitionSchemas = Seq( StructType(Nil), new StructType().add("p1", IntegerType).add("p2", LongType)) @@ -97,14 +124,18 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft val doPartition = partitionValue != InternalRow.empty val partitionTitle = if (doPartition) "with partition" else "" - test(s"Read/write types: batch processing $partitionTitle") { - val colNum = if (doPartition) 11 else 9 + test(s"Read types: batch processing $partitionTitle") { + val colNum = if (doPartition) 13 else 11 val data = (0 to 255).map { i => + val dateString = "2015-08-20" + val milliseconds = Date.valueOf(dateString).getTime + i * 3600 (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, - s"$i".getBytes(StandardCharsets.UTF_8)) + s"$i".getBytes(StandardCharsets.UTF_8), Decimal(i.toDouble).toJavaBigDecimal, + new Date(milliseconds)) } val expectedRows = data.map { x => - val data = Seq(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9) + val data = Seq(UTF8String.fromString(x._1), x._2, x._3, x._4, x._5, x._6, x._7, x._8, x._9, + Decimal(x._10), DateTimeUtils.fromJavaDate(x._11)) val dataWithPartition = if (doPartition) { data ++ Seq(1, 2L) } else { @@ -124,6 +155,8 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft .add("_7", ByteType) .add("_8", BooleanType) .add("_9", BinaryType) + .add("_10", DecimalType.LongDecimal) + .add("_11", DateType) val reader = getVectorizedOrcReader(file, requiredSchema, partitionSchema, partitionValue) assert(reader.nextKeyValue()) @@ -154,7 +187,7 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft } } - test(s"Read/write types: no batch processing $partitionTitle") { + test(s"Read types: no batch processing $partitionTitle") { val dataColNum = spark.conf.get(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key).toInt + 1 val colNum = if (doPartition) { dataColNum + 2 @@ -207,4 +240,68 @@ class VectorizedSparkOrcNewRecordReaderSuite extends QueryTest with BeforeAndAft } } } + + // Test SparkVectorizedOrcRecordReader: + // SparkVectorizedOrcRecordReader is only used by VectorizedSparkOrcNewRecordReader. + // We test it to see if it correctly constructs Hive's ColumnVector. + + test("Read Orc file with SparkVectorizedOrcRecordReader") { + val colNum = 9 + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + s"$i".getBytes(StandardCharsets.UTF_8)) + } + + withOrcFile(data) { file => + val requiredSchema = new StructType() + .add("_1", StringType) + .add("_2", IntegerType) + .add("_3", LongType) + .add("_4", FloatType) + .add("_5", DoubleType) + .add("_6", ShortType) + .add("_7", ByteType) + .add("_8", BooleanType) + .add("_9", BinaryType) + val reader = getOrcRecordReader(file, requiredSchema) + val hiveBatch = reader.createValue() + assert(hiveBatch.isInstanceOf[VectorizedRowBatch]) + assert(hiveBatch.cols.length == colNum) + + var allRowCount = 0L + while (reader.next(NullWritable.get(), hiveBatch)) { + allRowCount += hiveBatch.count() + } + assert(allRowCount == 256) + } + } + + val notSupportDataTypes = Seq( + ArrayType(IntegerType, true), + MapType(IntegerType, IntegerType, true), + new StructType().add("_1", IntegerType), + TimestampType) + + notSupportDataTypes.map { notSupportDataType => + val seed = System.currentTimeMillis() + val random = new Random(seed) + + test(s"SparkVectorizedOrcRecordReader does not support: $notSupportDataType") { + val requiredSchema = new StructType() + .add("_1", notSupportDataType) + val data = (0 to 255).map { i => + RandomDataGenerator.randomRow(random, requiredSchema) + } + withTempPath { file => + spark.createDataFrame(sparkContext.parallelize(data), requiredSchema) + .write.orc(file.getCanonicalPath) + val path = file.getCanonicalPath + val reader = getOrcRecordReader(path, requiredSchema) + val exception = intercept[RuntimeException] { + reader.createValue() + } + assert(exception.getMessage.contains("Vectorization is not supported for datatype")) + } + } + } }