diff --git a/assembly/pom.xml b/assembly/pom.xml
index 464af16e46f6..cd8366a17552 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -220,6 +220,12 @@
provided
+
+ orc-provided
+
+ provided
+
+
parquet-provided
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 9287bd47cf11..b8535b6d9222 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -2,6 +2,7 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
+aircompressor-0.3.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.3.jar
@@ -143,6 +144,8 @@ netty-3.9.9.Final.jar
netty-all-4.0.43.Final.jar
objenesis-2.1.jar
opencsv-2.3.jar
+orc-core-1.4.0-nohive.jar
+orc-mapreduce-1.4.0-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.6.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index ab1de3d3dd8a..6db54ef74e61 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -2,6 +2,7 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
+aircompressor-0.3.jar
antlr-2.7.7.jar
antlr-runtime-3.4.jar
antlr4-runtime-4.5.3.jar
@@ -144,6 +145,8 @@ netty-3.9.9.Final.jar
netty-all-4.0.43.Final.jar
objenesis-2.1.jar
opencsv-2.3.jar
+orc-core-1.4.0-nohive.jar
+orc-mapreduce-1.4.0-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.6.jar
diff --git a/pom.xml b/pom.xml
index 0533a8dcf2e0..6a6252b8fc44 100644
--- a/pom.xml
+++ b/pom.xml
@@ -131,6 +131,8 @@
1.2.1
10.12.1.1
1.8.2
+ 1.4.0
+ nohive
1.6.0
9.3.11.v20160721
3.1.0
@@ -205,6 +207,7 @@
compile
compile
compile
+ compile
compile
test
@@ -1665,6 +1668,44 @@
+
+ org.apache.orc
+ orc-core
+ ${orc.version}
+ ${orc.classifier}
+ ${orc.deps.scope}
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.apache.hive
+ hive-storage-api
+
+
+
+
+ org.apache.orc
+ orc-mapreduce
+ ${orc.version}
+ ${orc.classifier}
+ ${orc.deps.scope}
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.apache.orc
+ orc-core
+
+
+ org.apache.hive
+ hive-storage-api
+
+
+
org.apache.parquet
parquet-column
@@ -2701,6 +2742,9 @@
hive-provided
+
+ orc-provided
+
parquet-provided
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b24419a41edb..61f8e06774f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -272,6 +272,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ORC_VECTORIZED_READER_ENABLED =
+ buildConf("spark.sql.orc.vectorizedReader.enabled")
+ .doc("Enables vectorized orc decoding.")
+ .booleanConf
+ .createWithDefault(true)
+
val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown")
.doc("When true, enable filter pushdown for ORC files.")
.booleanConf
@@ -867,6 +873,8 @@ class SQLConf extends Serializable with Logging {
def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED)
+ def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED)
+
def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE)
def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index fe4be963e818..c6af15006d17 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -86,6 +86,16 @@
test
+
+ org.apache.orc
+ orc-core
+ ${orc.classifier}
+
+
+ org.apache.orc
+ orc-mapreduce
+ ${orc.classifier}
+
org.apache.parquet
parquet-column
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.scala
new file mode 100644
index 000000000000..3171acdae8ab
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.scala
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.orc._
+import org.apache.orc.mapred.OrcInputFormat
+import org.apache.orc.storage.ql.exec.vector._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils}
+import org.apache.spark.sql.types._
+
+
+/**
+ * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch.
+ */
+private[orc] class OrcColumnarBatchReader extends RecordReader[Void, ColumnarBatch] with Logging {
+ import OrcColumnarBatchReader._
+
+ /**
+ * ORC File Reader.
+ */
+ private var reader: Reader = _
+
+ /**
+ * ORC Data Schema.
+ */
+ private var schema: TypeDescription = _
+
+ /**
+ * Vectorized Row Batch.
+ */
+ private var batch: VectorizedRowBatch = _
+
+ /**
+ * Record reader from row batch.
+ */
+ private var rows: org.apache.orc.RecordReader = _
+
+ /**
+ * Spark Schema.
+ */
+ private var sparkSchema: StructType = _
+
+ /**
+ * Required Schema.
+ */
+ private var requiredSchema: StructType = _
+
+ /**
+ * Partition Column.
+ */
+ private var partitionColumns: StructType = _
+
+ private var useIndex: Boolean = false
+
+ /**
+ * Full Schema: requiredSchema + partition schema.
+ */
+ private var fullSchema: StructType = _
+
+ /**
+ * ColumnarBatch for vectorized execution by whole-stage codegen.
+ */
+ private var columnarBatch: ColumnarBatch = _
+
+ /**
+ * The number of rows read and considered to be returned.
+ */
+ private var rowsReturned: Long = 0L
+
+ /**
+ * Total number of rows.
+ */
+ private var totalRowCount: Long = 0L
+
+ override def getCurrentKey: Void = null
+
+ override def getCurrentValue: ColumnarBatch = columnarBatch
+
+ override def getProgress: Float = rowsReturned.toFloat / totalRowCount
+
+ override def nextKeyValue(): Boolean = nextBatch()
+
+ override def close(): Unit = {
+ if (columnarBatch != null) {
+ columnarBatch.close()
+ columnarBatch = null
+ }
+ if (rows != null) {
+ rows.close()
+ rows = null
+ }
+ }
+
+ /**
+ * Initialize ORC file reader and batch record reader.
+ * Please note that `setRequiredSchema` is needed to be called after this.
+ */
+ override def initialize(inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): Unit = {
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+ val conf = taskAttemptContext.getConfiguration
+ reader = OrcFile.createReader(
+ fileSplit.getPath,
+ OrcFile.readerOptions(conf)
+ .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
+ .filesystem(fileSplit.getPath.getFileSystem(conf)))
+ schema = reader.getSchema
+ sparkSchema = CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]
+
+ batch = schema.createRowBatch(DEFAULT_SIZE)
+ totalRowCount = reader.getNumberOfRows
+ logDebug(s"totalRowCount = $totalRowCount")
+
+ val options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart, fileSplit.getLength)
+ rows = reader.rows(options)
+ }
+
+ /**
+ * Set required schema and partition information.
+ * With this information, this creates ColumnarBatch with the full schema.
+ */
+ def setRequiredSchema(
+ requiredSchema: StructType,
+ partitionColumns: StructType,
+ partitionValues: InternalRow,
+ useIndex: Boolean): Unit = {
+ this.requiredSchema = requiredSchema
+ this.partitionColumns = partitionColumns
+ this.useIndex = useIndex
+ fullSchema = new StructType(requiredSchema.fields ++ partitionColumns.fields)
+
+ columnarBatch = ColumnarBatch.allocate(fullSchema, DEFAULT_MEMORY_MODE, DEFAULT_SIZE)
+ if (partitionColumns != null) {
+ val partitionIdx = requiredSchema.fields.length
+ for (i <- partitionColumns.fields.indices) {
+ ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i)
+ columnarBatch.column(i + partitionIdx).setIsConstant()
+ }
+ }
+ }
+
+ /**
+ * Return true if there exists more data in the next batch. If exists, prepare the next batch
+ * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns.
+ */
+ private def nextBatch(): Boolean = {
+ if (rowsReturned >= totalRowCount) {
+ return false
+ }
+
+ rows.nextBatch(batch)
+ val batchSize = batch.size
+ if (batchSize == 0) {
+ return false
+ }
+ rowsReturned += batchSize
+ columnarBatch.reset()
+ columnarBatch.setNumRows(batchSize)
+
+ for (i <- 0 until requiredSchema.length) {
+ val field = requiredSchema(i)
+ val schemaIndex = if (useIndex) i else schema.getFieldNames.indexOf(field.name)
+ assert(schemaIndex >= 0)
+
+ val fromColumn = batch.cols(schemaIndex)
+ val toColumn = columnarBatch.column(i)
+
+ if (fromColumn.isRepeating) {
+ if (fromColumn.isNull(0)) {
+ toColumn.appendNulls(batchSize)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0) == 1
+ toColumn.appendBooleans(batchSize, data)
+
+ case ByteType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0).toByte
+ toColumn.appendBytes(batchSize, data)
+ case ShortType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0).toShort
+ toColumn.appendShorts(batchSize, data)
+ case IntegerType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0).toInt
+ toColumn.appendInts(batchSize, data)
+ case LongType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0)
+ toColumn.appendLongs(batchSize, data)
+
+ case DateType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(0).toInt
+ toColumn.appendInts(batchSize, data)
+
+ case TimestampType =>
+ val data = fromColumn.asInstanceOf[TimestampColumnVector].getTimestampAsLong(0)
+ toColumn.appendLongs(batchSize, data)
+
+ case FloatType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector(0).toFloat
+ toColumn.appendFloats(batchSize, data)
+ case DoubleType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector(0)
+ toColumn.appendDoubles(batchSize, data)
+
+ case StringType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ for (index <- 0 until batchSize) {
+ toColumn.appendByteArray(data.vector(0), data.start(0), data.length(0))
+ }
+ case BinaryType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ for (index <- 0 until batchSize) {
+ toColumn.appendByteArray(data.vector(0), data.start(0), data.length(0))
+ }
+
+ case DecimalType.Fixed(precision, _) =>
+ val d = fromColumn.asInstanceOf[DecimalColumnVector].vector(0)
+ val value = Decimal(d.getHiveDecimal.bigDecimalValue, d.precision(), d.scale)
+ if (precision <= Decimal.MAX_INT_DIGITS) {
+ toColumn.appendInts(batchSize, value.toUnscaledLong.toInt)
+ } else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ toColumn.appendLongs(batchSize, value.toUnscaledLong)
+ } else {
+ val bytes = value.toJavaBigDecimal.unscaledValue.toByteArray
+ for (index <- 0 until batchSize) {
+ toColumn.appendByteArray(bytes, 0, bytes.length)
+ }
+ }
+
+ case dt =>
+ throw new UnsupportedOperationException(s"Unsupported Data Type: $dt")
+ }
+ }
+ } else if (!field.nullable || fromColumn.noNulls) {
+ field.dataType match {
+ case BooleanType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ data.foreach { x => toColumn.appendBoolean(x == 1) }
+
+ case ByteType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ toColumn.appendBytes(batchSize, data.map(_.toByte), 0)
+ case ShortType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ toColumn.appendShorts(batchSize, data.map(_.toShort), 0)
+ case IntegerType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ toColumn.appendInts(batchSize, data.map(_.toInt), 0)
+ case LongType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ toColumn.appendLongs(batchSize, data, 0)
+
+ case DateType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector
+ toColumn.appendInts(batchSize, data.map(_.toInt), 0)
+
+ case TimestampType =>
+ val data = fromColumn.asInstanceOf[TimestampColumnVector]
+ for (index <- 0 until batchSize) {
+ toColumn.appendLong(data.getTimestampAsLong(index))
+ }
+
+ case FloatType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector
+ toColumn.appendFloats(batchSize, data.map(_.toFloat), 0)
+ case DoubleType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector
+ toColumn.appendDoubles(batchSize, data, 0)
+
+ case StringType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ for (index <- 0 until batchSize) {
+ toColumn.appendByteArray(data.vector(index), data.start(index), data.length(index))
+ }
+ case BinaryType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ for (index <- 0 until batchSize) {
+ toColumn.appendByteArray(data.vector(index), data.start(index), data.length(index))
+ }
+
+ case DecimalType.Fixed(precision, _) =>
+ val data = fromColumn.asInstanceOf[DecimalColumnVector]
+ for (index <- 0 until batchSize) {
+ val d = data.vector(index)
+ val value = Decimal(d.getHiveDecimal.bigDecimalValue, d.precision(), d.scale)
+ if (precision <= Decimal.MAX_INT_DIGITS) {
+ toColumn.appendInt(value.toUnscaledLong.toInt)
+ } else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ toColumn.appendLong(value.toUnscaledLong)
+ } else {
+ val bytes = value.toJavaBigDecimal.unscaledValue.toByteArray
+ toColumn.appendByteArray(bytes, 0, bytes.length)
+ }
+ }
+
+ case dt =>
+ throw new UnsupportedOperationException(s"Unsupported Data Type: $dt")
+ }
+ } else {
+ for (index <- 0 until batchSize) {
+ if (fromColumn.isNull(index)) {
+ toColumn.appendNull()
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index) == 1
+ toColumn.appendBoolean(data)
+ case ByteType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index).toByte
+ toColumn.appendByte(data)
+ case ShortType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index).toShort
+ toColumn.appendShort(data)
+ case IntegerType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index).toInt
+ toColumn.appendInt(data)
+ case LongType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index)
+ toColumn.appendLong(data)
+
+ case DateType =>
+ val data = fromColumn.asInstanceOf[LongColumnVector].vector(index).toInt
+ toColumn.appendInt(data)
+
+ case TimestampType =>
+ val data = fromColumn.asInstanceOf[TimestampColumnVector]
+ .getTimestampAsLong(index)
+ toColumn.appendLong(data)
+
+ case FloatType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector(index).toFloat
+ toColumn.appendFloat(data)
+ case DoubleType =>
+ val data = fromColumn.asInstanceOf[DoubleColumnVector].vector(index)
+ toColumn.appendDouble(data)
+
+ case StringType =>
+ val v = fromColumn.asInstanceOf[BytesColumnVector]
+ toColumn.appendByteArray(v.vector(index), v.start(index), v.length(index))
+
+ case BinaryType =>
+ val v = fromColumn.asInstanceOf[BytesColumnVector]
+ toColumn.appendByteArray(v.vector(index), v.start(index), v.length(index))
+
+ case DecimalType.Fixed(precision, _) =>
+ val d = fromColumn.asInstanceOf[DecimalColumnVector].vector(index)
+ val value = Decimal(d.getHiveDecimal.bigDecimalValue, d.precision(), d.scale)
+ if (precision <= Decimal.MAX_INT_DIGITS) {
+ toColumn.appendInt(value.toUnscaledLong.toInt)
+ } else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ toColumn.appendLong(value.toUnscaledLong)
+ } else {
+ val bytes = value.toJavaBigDecimal.unscaledValue.toByteArray
+ toColumn.appendByteArray(bytes, 0, bytes.length)
+ }
+
+ case dt =>
+ throw new UnsupportedOperationException(s"Unsupported Data Type: $dt")
+ }
+ }
+ }
+ }
+ }
+ true
+ }
+}
+
+/**
+ * Constants for OrcColumnarBatchReader.
+ */
+object OrcColumnarBatchReader {
+ /**
+ * Default memory mode for ColumnarBatch.
+ */
+ val DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP
+
+ /**
+ * The default size of batch. We use this value for both ORC and Spark consistently
+ * because they have different default values like the following.
+ *
+ * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024
+ * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024
+ */
+ val DEFAULT_SIZE: Int = 4 * 1024
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
new file mode 100644
index 000000000000..97d2ac5f3843
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -0,0 +1,502 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io._
+import java.net.URI
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.io._
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc._
+import org.apache.orc.OrcFile.ReaderOptions
+import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp}
+import org.apache.orc.mapreduce._
+import org.apache.orc.storage.common.`type`.HiveDecimal
+import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.SerializableConfiguration
+
+class DefaultSource extends OrcFileFormat
+
+/**
+ * New ORC File Format based on Apache ORC 1.4.x and above.
+ */
+class OrcFileFormat
+ extends FileFormat
+ with DataSourceRegister
+ with Logging
+ with Serializable {
+
+ override def shortName(): String = "orc"
+
+ override def toString: String = "ORC"
+
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat]
+
+ override def inferSchema(
+ spark: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ val conf = spark.sparkContext.hadoopConfiguration
+ val fs = FileSystem.getLocal(conf)
+ val schema = OrcFileFormat.readSchema(
+ files.map(_.getPath).head,
+ OrcFile.readerOptions(conf).filesystem(fs))
+ logDebug(s"Reading schema from file $files, got Hive schema string: $schema")
+ Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])
+ }
+
+ override def prepareWrite(
+ sparkSession: SparkSession,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+
+ val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
+
+ val conf = job.getConfiguration
+
+ val writerOptions = OrcFile.writerOptions(conf)
+
+ conf.set(
+ OrcConf.MAPRED_OUTPUT_SCHEMA.getAttribute,
+ OrcFileFormat.getSchemaString(dataSchema))
+
+ conf.set(
+ OrcConf.COMPRESS.getAttribute,
+ orcOptions.compressionCodecClassName)
+
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new OrcOutputWriter(path, dataSchema, context)
+ }
+
+ override def getFileExtension(context: TaskAttemptContext): String = {
+ val compressionExtension: String = {
+ val name = context.getConfiguration.get(OrcConf.COMPRESS.getAttribute)
+ OrcOptions.extensionsForCompressionCodecNames.getOrElse(name, "")
+ }
+
+ compressionExtension + ".orc"
+ }
+ }
+ }
+
+ /**
+ * Returns whether the reader will return the rows as batch or not.
+ */
+ override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
+ val conf = sparkSession.sessionState.conf
+ conf.orcVectorizedReaderEnabled &&
+ conf.wholeStageEnabled &&
+ schema.length <= conf.wholeStageMaxNumFields &&
+ schema.forall(_.dataType.isInstanceOf[AtomicType])
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ true
+ }
+
+ override def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+
+ // Predicate Push Down
+ if (sparkSession.sessionState.conf.orcFilterPushDown) {
+ val sarg = OrcFilters.createFilter(dataSchema, filters)
+ if (sarg.isDefined) {
+ OrcInputFormat.setSearchArgument(
+ hadoopConf,
+ sarg.get,
+ dataSchema.fieldNames)
+ }
+ }
+
+ // Column Selection
+ if (requiredSchema.nonEmpty) {
+ hadoopConf.set(
+ OrcConf.INCLUDE_COLUMNS.getAttribute,
+ requiredSchema.map(f => dataSchema.fieldIndex(f.name)).mkString(","))
+ logDebug(s"${OrcConf.INCLUDE_COLUMNS.getAttribute}=" +
+ s"${requiredSchema.map(f => dataSchema.fieldIndex(f.name)).mkString(",")}")
+ }
+
+ val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
+ val enableVectorizedReader =
+ sparkSession.sessionState.conf.orcVectorizedReaderEnabled &&
+ resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
+ val useColumnarBatchReader = supportBatch(sparkSession, resultSchema)
+
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+
+ (file: PartitionedFile) => {
+ assert(file.partitionValues.numFields == partitionSchema.size)
+
+ val hdfsPath = new Path(file.filePath)
+ val conf = broadcastedConf.value.value
+ val fs = hdfsPath.getFileSystem(conf)
+ val orcSchema = OrcFileFormat.readSchema(hdfsPath, OrcFile.readerOptions(conf).filesystem(fs))
+ val useIndex = requiredSchema.fieldNames.zipWithIndex.forall { case (name, index) =>
+ name.equals(orcSchema.getFieldNames.get(index))
+ }
+
+ if (orcSchema.getFieldNames.isEmpty) {
+ Iterator.empty
+ } else {
+ val split =
+ new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty)
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val partitionValues = file.partitionValues
+
+ createIterator(
+ split,
+ taskAttemptContext,
+ orcSchema,
+ requiredSchema,
+ partitionSchema,
+ partitionValues,
+ useIndex,
+ useColumnarBatchReader,
+ enableVectorizedReader)
+ }
+ }
+ }
+
+ /**
+ * Create one of the following iterators.
+ *
+ * - An iterator with ColumnarBatch.
+ * This is used when supportBatch(sparkSession, resultSchema) is true.
+ * Whole-stage codegen understands ColumnarBatch which offers significant
+ * performance gains.
+ *
+ * - An iterator with InternalRow based on ORC RowBatch.
+ * This is used when ORC_VECTORIZED_READER_ENABLED is true and
+ * the schema has only atomic fields.
+ *
+ * - An iterator with InternalRow based on ORC OrcMapreduceRecordReader.
+ * This is the default iterator for the other cases.
+ */
+ private def createIterator(
+ split: FileSplit,
+ taskAttemptContext: TaskAttemptContext,
+ orcSchema: TypeDescription,
+ requiredSchema: StructType,
+ partitionSchema: StructType,
+ partitionValues: InternalRow,
+ useIndex: Boolean,
+ columnarBatch: Boolean,
+ enableVectorizedReader: Boolean): Iterator[InternalRow] = {
+ val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
+
+ if (columnarBatch) {
+ assert(enableVectorizedReader)
+ val reader = new OrcColumnarBatchReader
+ reader.initialize(split, taskAttemptContext)
+ reader.setRequiredSchema(requiredSchema, partitionSchema, partitionValues, useIndex)
+ val iter = new RecordReaderIterator(reader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+ iter.asInstanceOf[Iterator[InternalRow]]
+ } else if (enableVectorizedReader) {
+ val iter = new OrcRecordIterator
+ iter.initialize(
+ split,
+ taskAttemptContext,
+ orcSchema,
+ requiredSchema,
+ partitionSchema,
+ partitionValues,
+ useIndex)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+
+ val unsafeProjection = UnsafeProjection.create(resultSchema)
+ iter.map(unsafeProjection)
+ } else {
+ val orcRecordReader = OrcFileFormat.ORC_INPUT_FORMAT
+ .createRecordReader(split, taskAttemptContext)
+ val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+
+ val mutableRow = new SpecificInternalRow(resultSchema.map(_.dataType))
+ val unsafeProjection = UnsafeProjection.create(resultSchema)
+
+ // Initialize the partition column values once.
+ for (i <- requiredSchema.length until resultSchema.length) {
+ val value = partitionValues.get(i - requiredSchema.length, resultSchema(i).dataType)
+ mutableRow.update(i, value)
+ }
+
+ iter.map { value =>
+ unsafeProjection(OrcFileFormat.convertOrcStructToInternalRow(
+ value, requiredSchema, useIndex, Some(mutableRow)))
+ }
+ }
+ }
+}
+
+/**
+ * New ORC File Format companion object.
+ */
+object OrcFileFormat extends Logging {
+
+ lazy val ORC_INPUT_FORMAT = new OrcInputFormat[OrcStruct]
+
+ /**
+ * Read ORC file schema as a string.
+ */
+ private[orc] def readSchema(file: Path, conf: ReaderOptions): TypeDescription = {
+ val reader = OrcFile.createReader(file, conf)
+ reader.getSchema
+ }
+
+ /**
+ * Convert Apache ORC OrcStruct to Apache Spark InternalRow.
+ * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it.
+ */
+ private[orc] def convertOrcStructToInternalRow(
+ orcStruct: OrcStruct,
+ schema: StructType,
+ useIndex: Boolean = false,
+ internalRow: Option[InternalRow] = None): InternalRow = {
+
+ val mutableRow = internalRow.getOrElse(new SpecificInternalRow(schema.map(_.dataType)))
+
+ for (schemaIndex <- 0 until schema.length) {
+ val writable = if (useIndex) {
+ orcStruct.getFieldValue(schemaIndex)
+ } else {
+ orcStruct.getFieldValue(schema(schemaIndex).name)
+ }
+ if (writable == null) {
+ mutableRow.setNullAt(schemaIndex)
+ } else {
+ mutableRow(schemaIndex) = getCatalystValue(writable, schema(schemaIndex).dataType)
+ }
+ }
+
+ mutableRow
+ }
+
+
+ private[orc] def getTypeDescription(dataType: DataType) = dataType match {
+ case st: StructType => TypeDescription.fromString(getSchemaString(st))
+ case _ => TypeDescription.fromString(dataType.catalogString)
+ }
+
+ /**
+ * Return a ORC schema string for ORCStruct.
+ */
+ private[orc] def getSchemaString(schema: StructType): String = {
+ schema.fields.map(f => s"${f.name}:${f.dataType.catalogString}").mkString("struct<", ",", ">")
+ }
+
+ /**
+ * Return a Orc value object for the given Spark schema.
+ */
+ private[orc] def createOrcValue(dataType: DataType) =
+ OrcStruct.createValue(getTypeDescription(dataType))
+
+ /**
+ * Convert Apache Spark InternalRow to Apache ORC OrcStruct.
+ */
+ private[orc] def convertInternalRowToOrcStruct(
+ row: InternalRow,
+ schema: StructType,
+ struct: Option[OrcStruct] = None): OrcStruct = {
+
+ val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct])
+
+ for (schemaIndex <- 0 until schema.length) {
+ val fieldType = schema(schemaIndex).dataType
+ val fieldValue = if (row.isNullAt(schemaIndex)) {
+ null
+ } else {
+ getWritable(row.get(schemaIndex, fieldType), fieldType)
+ }
+ orcStruct.setFieldValue(schemaIndex, fieldValue)
+ }
+
+ orcStruct
+ }
+
+ /**
+ * Return WritableComparable from Spark catalyst values.
+ */
+ private[orc] def getWritable(value: Object, dataType: DataType): WritableComparable[_] = {
+ if (value == null) {
+ null
+ } else {
+ dataType match {
+ case NullType => null
+
+ case BooleanType => new BooleanWritable(value.asInstanceOf[Boolean])
+
+ case ByteType => new ByteWritable(value.asInstanceOf[Byte])
+ case ShortType => new ShortWritable(value.asInstanceOf[Short])
+ case IntegerType => new IntWritable(value.asInstanceOf[Int])
+ case LongType => new LongWritable(value.asInstanceOf[Long])
+
+ case FloatType => new FloatWritable(value.asInstanceOf[Float])
+ case DoubleType => new DoubleWritable(value.asInstanceOf[Double])
+
+ case StringType => new Text(value.asInstanceOf[UTF8String].getBytes)
+
+ case BinaryType => new BytesWritable(value.asInstanceOf[Array[Byte]])
+
+ case DateType => new DateWritable(DateTimeUtils.toJavaDate(value.asInstanceOf[Int]))
+ case TimestampType => new OrcTimestamp(value.asInstanceOf[Long])
+
+ case _: DecimalType =>
+ new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal))
+
+ case st: StructType =>
+ convertInternalRowToOrcStruct(value.asInstanceOf[InternalRow], st)
+
+ case ArrayType(et, _) =>
+ val data = value.asInstanceOf[ArrayData]
+ val list = createOrcValue(dataType)
+ for (i <- 0 until data.numElements()) {
+ list.asInstanceOf[OrcList[WritableComparable[_]]]
+ .add(getWritable(data.get(i, et), et))
+ }
+ list
+
+ case MapType(keyType, valueType, _) =>
+ val data = value.asInstanceOf[MapData]
+ val map = createOrcValue(dataType)
+ .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ data.foreach(keyType, valueType, { case (k, v) =>
+ map.put(
+ getWritable(k.asInstanceOf[Object], keyType),
+ getWritable(v.asInstanceOf[Object], valueType))
+ })
+ map
+
+ case udt: UserDefinedType[_] =>
+ val udtRow = new SpecificInternalRow(Seq(udt.sqlType))
+ udtRow(0) = value
+ convertInternalRowToOrcStruct(udtRow,
+ StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+ }
+
+ }
+
+ /**
+ * Return Spark Catalyst value from WritableComparable object.
+ */
+ private[orc] def getCatalystValue(value: WritableComparable[_], dataType: DataType): Any = {
+ if (value == null) {
+ null
+ } else {
+ dataType match {
+ case NullType => null
+
+ case BooleanType => value.asInstanceOf[BooleanWritable].get
+
+ case ByteType => value.asInstanceOf[ByteWritable].get
+ case ShortType => value.asInstanceOf[ShortWritable].get
+ case IntegerType => value.asInstanceOf[IntWritable].get
+ case LongType => value.asInstanceOf[LongWritable].get
+
+ case FloatType => value.asInstanceOf[FloatWritable].get
+ case DoubleType => value.asInstanceOf[DoubleWritable].get
+
+ case StringType => UTF8String.fromBytes(value.asInstanceOf[Text].getBytes)
+
+ case BinaryType =>
+ val binary = value.asInstanceOf[BytesWritable]
+ val bytes = new Array[Byte](binary.getLength)
+ System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength)
+ bytes
+
+ case DateType => DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)
+ case TimestampType => DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])
+
+ case _: DecimalType =>
+ val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal()
+ Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale())
+
+ case _: StructType =>
+ val structValue = convertOrcStructToInternalRow(
+ value.asInstanceOf[OrcStruct],
+ dataType.asInstanceOf[StructType])
+ structValue
+
+ case ArrayType(elementType, _) =>
+ val data = new scala.collection.mutable.ArrayBuffer[Any]
+ value.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x =>
+ data += getCatalystValue(x, elementType)
+ }
+ new GenericArrayData(data.toArray)
+
+ case MapType(keyType, valueType, _) =>
+ val map = new java.util.TreeMap[Any, Any]
+ value
+ .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ .entrySet().asScala.foreach { entry =>
+ val k = getCatalystValue(entry.getKey, keyType)
+ val v = getCatalystValue(entry.getValue, valueType)
+ map.put(k, v)
+ }
+ ArrayBasedMapData(map.asScala)
+
+ case udt: UserDefinedType[_] =>
+ getCatalystValue(value, udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
new file mode 100644
index 000000000000..6a8d3739892b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory}
+import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder
+import org.apache.orc.storage.serde2.io.HiveDecimalWritable
+
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types._
+
+/**
+ * Utility functions to convert Spark data source filters to ORC filters.
+ */
+private[orc] object OrcFilters {
+
+ /**
+ * Create ORC filter as a SearchArgument instance.
+ */
+ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
+ val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+
+ val convertibleFilters = for {
+ filter <- filters
+ _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder())
+ } yield filter
+
+ for {
+ conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And)
+ builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder())
+ } yield builder.build()
+ }
+
+ /**
+ * Return true if this is a searchable type in ORC.
+ */
+ private def isSearchableType(dataType: DataType) = dataType match {
+ case ByteType | ShortType | FloatType | DoubleType => true
+ case IntegerType | LongType | StringType | BooleanType => true
+ case TimestampType | _: DecimalType => true
+ case _ => false
+ }
+
+ /**
+ * Get PredicateLeafType which is corresponding to the given DataType.
+ */
+ private def getPredicateLeafType(dataType: DataType) = dataType match {
+ case BooleanType => PredicateLeaf.Type.BOOLEAN
+ case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
+ case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
+ case StringType => PredicateLeaf.Type.STRING
+ case DateType => PredicateLeaf.Type.DATE
+ case TimestampType => PredicateLeaf.Type.TIMESTAMP
+ case _: DecimalType => PredicateLeaf.Type.DECIMAL
+ case _ => throw new UnsupportedOperationException(s"DataType: $dataType")
+ }
+
+ /**
+ * Cast literal values for filters.
+ *
+ * We need to cast to long because ORC raises exceptions
+ * at 'checkLiteralType' of SearchArgumentImpl.java.
+ */
+ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match {
+ case ByteType | ShortType | IntegerType | LongType =>
+ value.asInstanceOf[Number].longValue
+ case FloatType | DoubleType =>
+ value.asInstanceOf[Number].doubleValue()
+ case _: DecimalType =>
+ val decimal = value.asInstanceOf[java.math.BigDecimal]
+ val decimalWritable = new HiveDecimalWritable(decimal.longValue)
+ decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale)
+ decimalWritable
+ case _ => value
+ }
+
+ /**
+ * Build a SearchArgument and return the builder so far.
+ */
+ private def buildSearchArgument(
+ dataTypeMap: Map[String, DataType],
+ expression: Filter,
+ builder: Builder): Option[Builder] = {
+ def newBuilder = SearchArgumentFactory.newBuilder()
+
+ def getType(attribute: String): PredicateLeaf.Type =
+ getPredicateLeafType(dataTypeMap(attribute))
+
+ import org.apache.spark.sql.sources._
+
+ expression match {
+ case And(left, right) =>
+ // At here, it is not safe to just convert one side if we do not understand the
+ // other side. Here is an example used to explain the reason.
+ // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
+ // convert b in ('1'). If we only convert a = 2, we will end up with a filter
+ // NOT(a = 2), which will generate wrong results.
+ // Pushing one side of AND down is only safe to do at the top level.
+ // You can see ParquetRelation's initializeLocalJobFunc method as an example.
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Or(left, right) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Not(child) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, child, newBuilder)
+ negate <- buildSearchArgument(dataTypeMap, child, builder.startNot())
+ } yield negate.end()
+
+ // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
+ // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
+ // wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
+
+ case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end())
+
+ case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end())
+
+ case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startAnd().isNull(attribute, getType(attribute)).end())
+
+ case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startNot().isNull(attribute, getType(attribute)).end())
+
+ case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute)))
+ Some(builder.startAnd().in(attribute, getType(attribute),
+ castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
+
+ case _ => None
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala
new file mode 100644
index 000000000000..5d13b5368cb6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.util.Locale
+
+import org.apache.orc.OrcConf
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Options for ORC data source.
+ */
+private[orc] class OrcOptions(
+ @transient private val parameters: CaseInsensitiveMap[String],
+ @transient private val sqlConf: SQLConf)
+ extends Serializable {
+
+ import OrcOptions._
+
+ def this(parameters: Map[String, String], sqlConf: SQLConf) =
+ this(CaseInsensitiveMap(parameters), sqlConf)
+
+ /**
+ * Compression codec to use. By default use the value specified in SQLConf.
+ * Acceptable values are defined in [[shortOrcCompressionCodecNames]].
+ */
+ val compressionCodecClassName: String = {
+ val codecName = parameters
+ .get("compression")
+ .orElse(parameters.get(OrcConf.COMPRESS.getAttribute))
+ .getOrElse("snappy").toLowerCase(Locale.ROOT)
+ OrcOptions.shortOrcCompressionCodecNames(codecName)
+ }
+}
+
+object OrcOptions {
+ // The orc compression short names
+ private[orc] val shortOrcCompressionCodecNames = Map(
+ "none" -> "NONE",
+ "uncompressed" -> "NONE",
+ "snappy" -> "SNAPPY",
+ "zlib" -> "ZLIB",
+ "lzo" -> "LZO")
+
+ private[orc] val extensionsForCompressionCodecNames = Map(
+ "NONE" -> "",
+ "SNAPPY" -> ".snappy",
+ "ZLIB" -> ".zlib",
+ "LZO" -> ".lzo")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
new file mode 100644
index 000000000000..796da712bdea
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.mapreduce._
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcOutputFormat
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.OutputWriter
+import org.apache.spark.sql.types.StructType
+
+private[orc] class OrcOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter {
+
+ private val recordWriter = {
+ new OrcOutputFormat[OrcStruct]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ new Path(path)
+ }
+ }.getRecordWriter(context)
+ }
+
+ private lazy val orcStruct: OrcStruct =
+ OrcFileFormat.createOrcValue(dataSchema).asInstanceOf[OrcStruct]
+
+ override def write(row: InternalRow): Unit = {
+ recordWriter.write(
+ NullWritable.get,
+ OrcFileFormat.convertInternalRowToOrcStruct(row, dataSchema, Some(orcStruct)))
+ }
+
+ override def close(): Unit = recordWriter.close(context)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcRecordIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcRecordIterator.scala
new file mode 100644
index 000000000000..9ff088a424e6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcRecordIterator.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.mapreduce.{InputSplit, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.orc._
+import org.apache.orc.mapred.OrcInputFormat
+import org.apache.orc.storage.ql.exec.vector._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A RecordIterator returns InternalRow from ORC data source.
+ */
+private[orc] class OrcRecordIterator extends Iterator[InternalRow] with Logging {
+
+ /**
+ * ORC File Reader.
+ */
+ private var reader: Reader = _
+
+ /**
+ * ORC Data Schema.
+ */
+ private var schema: TypeDescription = _
+
+ /**
+ * Use index to find corresponding fields.
+ */
+ private var useIndex: Boolean = _
+
+ /**
+ * Spark Schema.
+ */
+ private var sparkSchema: StructType = _
+
+ /**
+ * Required Schema.
+ */
+ private var requiredSchema: StructType = _
+
+ /**
+ * ORC Batch Record Reader.
+ */
+ private var rows: org.apache.orc.RecordReader = _
+
+ /**
+ * The number of total rows.
+ */
+ private var totalRowCount: Long = 0L
+
+ /**
+ * The number of rows that have been returned.
+ */
+ private var rowsReturned: Long = 0L
+
+ /**
+ * Vectorized Row Batch.
+ */
+ private var batch: VectorizedRowBatch = _
+
+ /**
+ * Current index in the batch.
+ */
+ private var batchIdx = -1
+
+ /**
+ * The number of rows in the current batch.
+ */
+ private var numBatched = 0
+
+ /**
+ * The current row.
+ */
+ private var mutableRow: InternalRow = _
+
+ def initialize(
+ inputSplit: InputSplit,
+ taskAttemptContext: TaskAttemptContext,
+ orcSchema: TypeDescription,
+ requiredSchema: StructType,
+ partitionColumns: StructType,
+ partitionValues: InternalRow,
+ useIndex: Boolean): Unit = {
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+ val conf = taskAttemptContext.getConfiguration
+
+ reader = OrcFile.createReader(
+ fileSplit.getPath,
+ OrcFile.readerOptions(conf).maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)))
+ schema = orcSchema
+ sparkSchema = CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]
+ totalRowCount = reader.getNumberOfRows
+
+ // Create batch and load the first batch.
+ val options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart, fileSplit.getLength)
+ batch = schema.createRowBatch
+ rows = reader.rows(options)
+ rows.nextBatch(batch)
+ batchIdx = 0
+ numBatched = batch.size
+
+ // Create a mutableRow for the full schema which is
+ // requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ this.requiredSchema = requiredSchema
+ val fullSchema = new StructType(this.requiredSchema.fields ++ partitionColumns)
+ mutableRow = new SpecificInternalRow(fullSchema.map(_.dataType))
+
+ this.useIndex = useIndex
+
+ // Initialize the partition column values once.
+ for (i <- requiredSchema.length until fullSchema.length) {
+ mutableRow.update(i, partitionValues.get(i - requiredSchema.length, fullSchema(i).dataType))
+ }
+ }
+
+ private def updateRow(): Unit = {
+ // Fill the required fields into mutableRow.
+ for (index <- 0 until requiredSchema.length) {
+ val field = requiredSchema(index)
+ val fieldType = field.dataType
+ val vector = if (useIndex) {
+ batch.cols(index)
+ } else {
+ batch.cols(sparkSchema.fieldIndex(field.name))
+ }
+ updateField(fieldType, vector, mutableRow, index)
+ }
+ }
+
+ private def updateField(
+ fieldType: DataType,
+ vector: ColumnVector,
+ mutableRow: InternalRow,
+ index: Int) = {
+ if (vector.noNulls || !vector.isNull(batchIdx)) {
+ fieldType match {
+ case BooleanType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx) == 1
+ mutableRow.setBoolean(index, fieldValue)
+ case ByteType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx)
+ mutableRow.setByte(index, fieldValue.asInstanceOf[Byte])
+ case ShortType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx)
+ mutableRow.setShort(index, fieldValue.asInstanceOf[Short])
+ case IntegerType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx)
+ mutableRow.setInt(index, fieldValue.asInstanceOf[Int])
+ case LongType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx)
+ mutableRow.setLong(index, fieldValue)
+
+ case FloatType =>
+ val fieldValue = vector.asInstanceOf[DoubleColumnVector].vector(batchIdx)
+ mutableRow.setFloat(index, fieldValue.asInstanceOf[Float])
+ case DoubleType =>
+ val fieldValue = vector.asInstanceOf[DoubleColumnVector].vector(batchIdx)
+ mutableRow.setDouble(index, fieldValue.asInstanceOf[Double])
+ case _: DecimalType =>
+ val fieldValue = vector.asInstanceOf[DecimalColumnVector].vector(batchIdx)
+ mutableRow.update(index, OrcFileFormat.getCatalystValue(fieldValue, fieldType))
+
+ case _: DateType =>
+ val fieldValue = vector.asInstanceOf[LongColumnVector].vector(batchIdx)
+ mutableRow.update(index, fieldValue.asInstanceOf[SQLDate])
+
+ case _: TimestampType =>
+ val fieldValue =
+ vector.asInstanceOf[TimestampColumnVector].asScratchTimestamp(batchIdx)
+ mutableRow.update(index, DateTimeUtils.fromJavaTimestamp(fieldValue))
+
+ case StringType =>
+ val v = vector.asInstanceOf[BytesColumnVector]
+ val fieldValue =
+ UTF8String.fromBytes(v.vector(batchIdx), v.start(batchIdx), v.length(batchIdx))
+ mutableRow.update(index, fieldValue)
+
+ case BinaryType =>
+ val fieldVector = vector.asInstanceOf[BytesColumnVector]
+ val fieldValue = java.util.Arrays.copyOfRange(
+ fieldVector.vector(batchIdx),
+ fieldVector.start(batchIdx),
+ fieldVector.start(batchIdx) + fieldVector.length(batchIdx))
+ mutableRow.update(index, fieldValue)
+
+ case dt => throw new UnsupportedOperationException(s"Unknown Data Type: $dt")
+
+ }
+ } else {
+ fieldType match {
+ case dt: DecimalType => mutableRow.setDecimal(index, null, dt.precision)
+ case _ => mutableRow.setNullAt(index)
+ }
+ }
+ }
+
+ def hasNext: Boolean = {
+ 0 <= batchIdx && batchIdx < numBatched && rowsReturned < totalRowCount
+ }
+
+ def next: InternalRow = {
+ updateRow()
+
+ if (rowsReturned == totalRowCount) {
+ close()
+ } else {
+ batchIdx += 1
+ rowsReturned += 1
+ if (batchIdx == numBatched && rowsReturned < totalRowCount) {
+ rows.nextBatch(batch)
+ batchIdx = 0
+ numBatched = batch.size
+ }
+ }
+
+ mutableRow
+ }
+
+ def close(): Unit = {
+ rows.close()
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
new file mode 100644
index 000000000000..10e4936bb4f5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
+
+import org.apache.spark.sql.{Column, DataFrame, QueryTest}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation}
+
+/**
+ * A test suite that tests ORC filter API based filter pushdown optimization.
+ * This is a port of org.apache.spark.sql.hive.orc.OrcFilterSuite.
+ */
+class OrcFilterSuite extends QueryTest with OrcTest {
+ private def checkFilterPredicate(
+ df: DataFrame,
+ predicate: Predicate,
+ checker: (SearchArgument) => Unit): Unit = {
+ val output = predicate.collect { case a: Attribute => a }.distinct
+ val query = df
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
+
+ var maybeRelation: Option[HadoopFsRelation] = None
+ val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) =>
+ maybeRelation = Some(orcRelation)
+ filters
+ }.flatten.reduceLeftOption(_ && _)
+ assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query")
+
+ val (_, selectedFilters, _) =
+ DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
+ assert(selectedFilters.nonEmpty, "No filter is pushed down")
+
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters)
+ assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters")
+ checker(maybeFilter.get)
+ }
+
+ private def checkFilterPredicate
+ (predicate: Predicate, filterOperator: PredicateLeaf.Operator)
+ (implicit df: DataFrame): Unit = {
+ def checkComparisonOperator(filter: SearchArgument) = {
+ val operator = filter.getLeaves.asScala
+ assert(operator.map(_.getOperator).contains(filterOperator))
+ }
+ checkFilterPredicate(df, predicate, checkComparisonOperator)
+ }
+
+ private def checkFilterPredicate
+ (predicate: Predicate, stringExpr: String)
+ (implicit df: DataFrame): Unit = {
+ def checkLogicalOperator(filter: SearchArgument) = {
+ assert(filter.toString == stringExpr)
+ }
+ checkFilterPredicate(df, predicate, checkLogicalOperator)
+ }
+
+ private def checkNoFilterPredicate
+ (predicate: Predicate)
+ (implicit df: DataFrame): Unit = {
+ val output = predicate.collect { case a: Attribute => a }.distinct
+ val query = df
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
+
+ var maybeRelation: Option[HadoopFsRelation] = None
+ val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) =>
+ maybeRelation = Some(orcRelation)
+ filters
+ }.flatten.reduceLeftOption(_ && _)
+ assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query")
+
+ val (_, selectedFilters, _) =
+ DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq)
+ assert(selectedFilters.nonEmpty, "No filter is pushed down")
+
+ val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters)
+ assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters")
+ }
+
+ test("filter pushdown - integer") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - long") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - float") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - double") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - string") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - boolean") {
+ withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === true, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < true, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > false, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= false, PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(false) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(false) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(false) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(true) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(true) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(true) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - decimal") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(
+ Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - timestamp") {
+ val timeString = "2015-08-20 14:57:00"
+ val timestamps = (1 to 4).map { i =>
+ val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600
+ new Timestamp(milliseconds)
+ }
+ withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df =>
+ checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL)
+
+ checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+
+ checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN)
+
+ checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS)
+ checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS)
+ checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN)
+ checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS)
+ checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN)
+ }
+ }
+
+ test("filter pushdown - combinations with logical operators") {
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
+ checkFilterPredicate(
+ '_1.isNotNull,
+ "leaf-0 = (IS_NULL _1), expr = (not leaf-0)"
+ )
+ checkFilterPredicate(
+ '_1 =!= 1,
+ "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))"
+ )
+ checkFilterPredicate(
+ !('_1 < 4),
+ "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))"
+ )
+ checkFilterPredicate(
+ '_1 < 2 || '_1 > 3,
+ "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " +
+ "expr = (or leaf-0 (not leaf-1))"
+ )
+ checkFilterPredicate(
+ '_1 < 2 && '_1 > 3,
+ "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), " +
+ "leaf-2 = (LESS_THAN_EQUALS _1 3), expr = (and (not leaf-0) leaf-1 (not leaf-2))"
+ )
+ }
+ }
+
+ test("no filter pushdown - non-supported types") {
+ implicit class IntToBinary(int: Int) {
+ def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8)
+ }
+ // ArrayType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df =>
+ checkNoFilterPredicate('_1.isNull)
+ }
+ // BinaryType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df =>
+ checkNoFilterPredicate('_1 <=> 1.b)
+ }
+ // DateType
+ val stringDate = "2015-01-01"
+ withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df =>
+ checkNoFilterPredicate('_1 === Date.valueOf(stringDate))
+ }
+ // MapType
+ withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df =>
+ checkNoFilterPredicate('_1.isNotNull)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcHadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcHadoopFsRelationSuite.scala
new file mode 100644
index 000000000000..53a7c89afd31
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcHadoopFsRelationSuite.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io.File
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.orc.OrcFile
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+
+/**
+ * This test suite is a port of org.apache.spark.sql.hive.orc.OrcHadoopFsRelationSuite.
+ */
+class OrcHadoopFsRelationSuite extends QueryTest with SharedSQLContext with OrcTest {
+ import testImplicits._
+
+ val dataSourceName: String = classOf[OrcFileFormat].getCanonicalName
+
+ val dataSchema =
+ StructType(
+ Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField("b", StringType, nullable = false)))
+
+ def checkQueries(df: DataFrame): Unit = {
+ // Selects everything
+ checkAnswer(
+ df,
+ for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2))
+
+ // Simple filtering and partition pruning
+ checkAnswer(
+ df.filter('a > 1 && 'p1 === 2),
+ for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2))
+
+ // Simple projection and filtering
+ checkAnswer(
+ df.filter('a > 1).select('b, 'a + 1),
+ for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1))
+
+ // Simple projection and partition pruning
+ checkAnswer(
+ df.filter('a > 1 && 'p1 < 2).select('b, 'p1),
+ for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1))
+
+ // Project many copies of columns with different types (reproduction for SPARK-7858)
+ checkAnswer(
+ df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1),
+ for (i <- 2 to 3; _ <- Seq("foo", "bar"))
+ yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1))
+
+ // Self-join
+ df.createOrReplaceTempView("t")
+ withTempView("t") {
+ checkAnswer(
+ sql(
+ """SELECT l.a, r.b, l.p1, r.p2
+ |FROM t l JOIN t r
+ |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2
+ """.stripMargin),
+ for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2))
+ }
+ }
+
+ test("save()/load() - partitioned table - simple queries - partition columns in data") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempDir { file =>
+ for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
+ val partitionDir = new Path(
+ CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
+ sparkContext
+ .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
+ .toDF("a", "b", "p1")
+ .write
+ .format(ORC_FILE_FORMAT)
+ .save(partitionDir.toString)
+ }
+
+ val dataSchemaWithPartition =
+ StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))
+
+ checkQueries(
+ spark.read.options(Map(
+ "path" -> file.getCanonicalPath,
+ "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load())
+ }
+ }
+ }
+ }
+
+ test("SPARK-12218: 'Not' is included in ORC filter pushdown") {
+ import testImplicits._
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/table1"
+ (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b")
+ .write.format(ORC_FILE_FORMAT).save(path)
+
+ checkAnswer(
+ spark.read.format(ORC_FILE_FORMAT).load(path).where("not (a = 2) or not(b in ('1'))"),
+ (1 to 5).map(i => Row(i, (i % 2).toString)))
+
+ checkAnswer(
+ spark.read.format(ORC_FILE_FORMAT).load(path).where("not (a = 2 and b in ('1'))"),
+ (1 to 5).map(i => Row(i, (i % 2).toString)))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-13543: Support for specifying compression codec for ORC via option()") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempPath { dir =>
+ val path = s"${dir.getCanonicalPath}/table1"
+ val df = (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b")
+ df.write
+ .option("compression", "ZlIb")
+ .format(ORC_FILE_FORMAT)
+ .save(path)
+
+ val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+
+ val copyDf = spark
+ .read
+ .format(ORC_FILE_FORMAT)
+ .load(path)
+ checkAnswer(df, copyDf)
+ }
+ }
+ }
+ }
+
+ test("Default compression codec is snappy for ORC compression") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempPath { path =>
+ spark.range(0, 10).write.format(ORC_FILE_FORMAT).save(path.getAbsolutePath)
+
+ assert(path.listFiles().exists(f => f.getAbsolutePath.endsWith(".snappy.orc")))
+
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert(path.listFiles().forall { f =>
+ val filePath = new Path(f.getAbsolutePath)
+ !f.getAbsolutePath.endsWith(".snappy.orc") ||
+ "SNAPPY" === OrcFile.createReader(filePath, conf).getCompressionKind.name
+ })
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
new file mode 100644
index 000000000000..38312ce4e0ca
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+// The data where the partitioning key exists only in the directory structure.
+case class OrcParData(intField: Int, stringField: String)
+
+// The data that also includes the partitioning key
+case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
+
+/**
+ * This test suite is a port from org.apache.spark.sql.hive.orc.OrcPartitionDiscoverySuite.
+ */
+class OrcPartitionDiscoverySuite
+ extends QueryTest with SharedSQLContext with OrcTest with BeforeAndAfterAll {
+
+ val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME
+
+ protected def withTempTable(tableName: String)(f: => Unit): Unit = {
+ try f finally spark.catalog.dropTempView(tableName)
+ }
+
+ protected def makePartitionDir(
+ basePath: File,
+ defaultPartitionName: String,
+ partitionCols: (String, Any)*): File = {
+ val partNames = partitionCols.map { case (k, v) =>
+ val valueString = if (v == null || v == "") defaultPartitionName else v.toString
+ s"$k=$valueString"
+ }
+
+ val partDir = partNames.foldLeft(basePath) { (parent, child) =>
+ new File(parent, child)
+ }
+
+ assert(partDir.mkdirs(), s"Couldn't create directory $partDir")
+ partDir
+ }
+
+ test("read partitioned table - normal case") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ spark.read.format(ORC_FILE_FORMAT).load(base.getCanonicalPath)
+ .createOrReplaceTempView("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, 1, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, i.toString, pi, "foo"))
+ }
+ }
+ }
+ }
+ }
+
+ test("read partitioned table - partition key included in orc file") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ spark.read.format(ORC_FILE_FORMAT).load(base.getCanonicalPath)
+ .createOrReplaceTempView("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, 1, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, "foo"))
+ }
+ }
+ }
+ }
+ }
+
+
+ test("read partitioned table - with nulls") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempDir { base =>
+ for {
+ // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero...
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ spark.read
+ .option("hive.exec.default.partition.name", defaultPartitionName)
+ .format(ORC_FILE_FORMAT)
+ .load(base.getCanonicalPath)
+ .createOrReplaceTempView("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi IS NULL"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, null, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ } yield Row(i, i.toString, pi, null))
+ }
+ }
+ }
+ }
+ }
+
+ test("read partitioned table - with nulls and partition keys are included in Orc file") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeOrcFile(
+ (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ spark.read
+ .option("hive.exec.default.partition.name", defaultPartitionName)
+ .format(ORC_FILE_FORMAT)
+ .load(base.getCanonicalPath)
+ .createOrReplaceTempView("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, null))
+ }
+ }
+ }
+ }
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala
new file mode 100644
index 000000000000..2beee3b45636
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala
@@ -0,0 +1,680 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.sql.Timestamp
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc.{OrcConf, OrcFile}
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcInputFormat
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.execution.datasources.RecordReaderIterator
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.Utils
+
+case class AllDataTypesWithNonPrimitiveType(
+ stringField: String,
+ intField: Int,
+ longField: Long,
+ floatField: Float,
+ doubleField: Double,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean,
+ array: Seq[Int],
+ arrayContainsNull: Seq[Option[Int]],
+ map: Map[Int, Long],
+ mapValueContainsNull: Map[Int, Option[Long]],
+ data: (Seq[Int], (Int, String)))
+
+case class BinaryData(binaryData: Array[Byte])
+
+case class Contact(name: String, phone: String)
+
+case class Person(name: String, age: Int, contacts: Seq[Contact])
+
+/**
+ * This test suite is a port of org.apache.spark.sql.hive.orc.OrcQuerySuite.
+ * Please note the following difference.
+ *
+ * - LZO test case is enabled
+ * - "Empty schema does not read data from ORC file" is ignored due to ORC 1.3 bug.
+ *
+ * Since RelationConversions is inside HiveStrategies.scala,
+ * the following test cases using CONVERT_METASTORE_ORC are omitted.
+ *
+ * - test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC")
+ * - test("converted ORC table supports resolving mixed case field")
+ */
+class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
+ import testImplicits._
+
+ test("Read/write All Types") {
+ val data = (0 to 255).map { i =>
+ (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0)
+ }
+
+ withOrcFile(data) { file =>
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ checkAnswer(spark.read.format(ORC_FILE_FORMAT).load(file), data.toDF().collect())
+ }
+ }
+ }
+ }
+
+ test("Read/write binary data") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcFile(BinaryData("test".getBytes(StandardCharsets.UTF_8)) :: Nil) { file =>
+ val bytes = spark.read.format(ORC_FILE_FORMAT).load(file).head().getAs[Array[Byte]](0)
+ assert(new String(bytes, StandardCharsets.UTF_8) === "test")
+ }
+ }
+ }
+ }
+
+ test("Read/write all types with non-primitive type") {
+ val data: Seq[AllDataTypesWithNonPrimitiveType] = (0 to 255).map { i =>
+ AllDataTypesWithNonPrimitiveType(
+ s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0,
+ 0 until i,
+ (0 until i).map(Option(_).filter(_ % 3 == 0)),
+ (0 until i).map(i => i -> i.toLong).toMap,
+ (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None),
+ (0 until i, (i, s"$i")))
+ }
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcFile(data) { file =>
+ checkAnswer(
+ spark.read.format(ORC_FILE_FORMAT).load(file),
+ data.toDF().collect())
+ }
+ }
+ }
+ }
+
+ test("Read/write UserDefinedType") {
+ val data = Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))))
+ val udtDF = data.toDF("id", "vectors")
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempPath { path =>
+ udtDF.write.format(ORC_FILE_FORMAT).save(path.getAbsolutePath)
+ val readBack =
+ spark.read.schema(udtDF.schema).format(ORC_FILE_FORMAT).load(path.getAbsolutePath)
+ checkAnswer(udtDF, readBack)
+ }
+ }
+ }
+ }
+
+ test("Creating case class RDD table") {
+ val data = (1 to 100).map(i => (i, s"val_$i"))
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ sparkContext.parallelize(data).toDF().createOrReplaceTempView("t")
+ withTempView("t") {
+ checkAnswer(sql("SELECT * FROM t"), data.toDF().collect())
+ }
+ }
+ }
+ }
+
+ test("Simple selection form ORC table") {
+ val data = (1 to 10).map { i =>
+ Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") })
+ }
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcTable(data, "t") {
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // expr = leaf-0
+ assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5)
+
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // expr = (not leaf-0)
+ assertResult(10) {
+ sql("SELECT name, contacts FROM t where age > 5")
+ .rdd
+ .flatMap(_.getAs[Seq[_]]("contacts"))
+ .count()
+ }
+
+ // ppd:
+ // leaf-0 = (LESS_THAN_EQUALS age 5)
+ // leaf-1 = (LESS_THAN age 8)
+ // expr = (and (not leaf-0) leaf-1)
+ {
+ val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8")
+ assert(df.count() === 2)
+ assertResult(4) {
+ df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ }
+ }
+
+ // ppd:
+ // leaf-0 = (LESS_THAN age 2)
+ // leaf-1 = (LESS_THAN_EQUALS age 8)
+ // expr = (or leaf-0 (not leaf-1))
+ {
+ val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8")
+ assert(df.count() === 3)
+ assertResult(6) {
+ df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("save and load case class RDD with `None`s as orc") {
+ val data = (
+ Option.empty[Int],
+ Option.empty[Long],
+ Option.empty[Float],
+ Option.empty[Double],
+ Option.empty[Boolean]
+ ) :: Nil
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcFile(data) { file =>
+ checkAnswer(
+ spark.read.format(ORC_FILE_FORMAT).load(file),
+ Row(Seq.fill(5)(null): _*))
+ }
+ }
+ }
+ }
+
+ test("SPARK-16610: Respect orc.compress option when compression is unset") {
+ // Respect `orc.compress`.
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("orc.compress", "ZLIB")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+
+ // `compression` overrides `orc.compress`.
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "ZLIB")
+ .option("orc.compress", "SNAPPY")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+ }
+
+ // Hive supports zlib, snappy and none for Hive 1.2.1.
+ test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") {
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "ZLIB")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".zlib.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("ZLIB" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "SNAPPY")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".snappy.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("SNAPPY" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "NONE")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("NONE" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+ }
+
+ // Previously, this test case was ignored because it's not supported in Hive 1.2.1.
+ // Now, this test case is enabled since ORC 1.3.X supports this.
+ test("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") {
+ withTempPath { file =>
+ spark.range(0, 10).write
+ .option("compression", "LZO")
+ .format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+
+ val maybeOrcFile = file.listFiles().find(_.getName.endsWith(".lzo.orc"))
+ assert(maybeOrcFile.isDefined)
+
+ val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
+ val conf = OrcFile.readerOptions(new Configuration())
+ assert("LZO" === OrcFile.createReader(orcFilePath, conf).getCompressionKind.name)
+ }
+ }
+
+ test("simple select queries") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcTable((0 until 10).map(i => (i, i.toString)), "t") {
+ checkAnswer(
+ sql("SELECT `_1` FROM t where t.`_1` > 5"),
+ (6 until 10).map(Row.apply(_)))
+
+ checkAnswer(
+ sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"),
+ (0 until 5).map(Row.apply(_)))
+ }
+ }
+ }
+ }
+
+ test("appending") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
+ withOrcTable(data, "t") {
+ sql("INSERT INTO TABLE t SELECT * FROM tmp")
+ checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple))
+ }
+ spark.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"),
+ ignoreIfNotExists = true,
+ purge = false)
+ }
+
+ test("overwriting") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp")
+ withOrcTable(data, "t") {
+ sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
+ checkAnswer(spark.table("t"), data.map(Row.fromTuple))
+ }
+ spark.sessionState.catalog.dropTable(
+ TableIdentifier("tmp"),
+ ignoreIfNotExists = true,
+ purge = false)
+ }
+
+ test("self-join") {
+ // 4 rows, cells of column 1 of row 2 and row 4 are null
+ val data = (1 to 4).map { i =>
+ val maybeInt = if (i % 2 == 0) None else Some(i)
+ (maybeInt, i.toString)
+ }
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+
+ withOrcTable(data, "t") {
+ val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`")
+ val queryOutput = selfJoin.queryExecution.analyzed.output
+
+ assertResult(4, "Field count mismatches")(queryOutput.size)
+ assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") {
+ queryOutput.filter(_.name == "_1").map(_.exprId).size
+ }
+
+ checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
+ }
+ }
+ }
+ }
+
+ test("nested data - struct with array field") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcTable(data, "t") {
+ checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map {
+ case Tuple1((_, Seq(string))) => Row(string)
+ })
+ }
+ }
+ }
+ }
+
+ test("nested data - array of struct") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i")))
+ withOrcTable(data, "t") {
+ checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map {
+ case Tuple1(Seq((_, string))) => Row(string)
+ })
+ }
+ }
+ }
+ }
+
+ test("columns only referenced by pushed down filters should remain") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withOrcTable((1 to 10).map(Tuple1.apply), "t") {
+ checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_)))
+ }
+ }
+ }
+ }
+
+ test("SPARK-5309 strings stored using dictionary compression in orc") {
+ withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") {
+ checkAnswer(
+ sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"),
+ (0 until 10).map(i => Row("same", "run_" + i, 100)))
+
+ checkAnswer(
+ sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"),
+ List(Row("same", "run_5", 100)))
+ }
+ }
+
+ test("SPARK-9170: Don't implicitly lowercase of user-provided columns") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ spark.range(0, 10).select('id as "Acol").write.format(ORC_FILE_FORMAT).save(path)
+ spark.read.format(ORC_FILE_FORMAT).load(path).schema("Acol")
+ intercept[IllegalArgumentException] {
+ spark.read.format(ORC_FILE_FORMAT).load(path).schema("acol")
+ }
+ checkAnswer(spark.read.format(ORC_FILE_FORMAT).load(path).select("acol").sort("acol"),
+ (0 until 10).map(Row(_)))
+ }
+ }
+ }
+ }
+
+ test("Schema discovery on empty ORC files") {
+ // SPARK-8501 is fixed.
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ withTable("empty_orc") {
+ withTempView("empty", "single") {
+ spark.sql(
+ s"""CREATE TABLE empty_orc(key INT, value STRING)
+ |USING $ORC_FILE_FORMAT
+ |LOCATION '${dir.toURI}'
+ """.stripMargin)
+
+ val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1)
+ emptyDF.createOrReplaceTempView("empty")
+
+ // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because
+ // Spark SQL ORC data source always avoids write empty ORC files.
+ spark.sql(
+ s"""INSERT INTO TABLE empty_orc
+ |SELECT key, value FROM empty
+ """.stripMargin)
+
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ val df = spark.read.format(ORC_FILE_FORMAT).load(path)
+ assert(df.schema === emptyDF.schema.asNullable)
+ checkAnswer(df, emptyDF)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-10623 Enable ORC PPD") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withTempPath { dir =>
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ import testImplicits._
+ val path = dir.getCanonicalPath
+
+ // For field "a", the first column has odds integers. This is to check the filtered
+ // count when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters
+ // rows only when all the values are null (maybe this works differently when the data
+ // or query is complicated). So, simply here a column only having `null` is added.
+ val data = (0 until 10).map { i =>
+ val maybeInt = if (i % 2 == 0) None else Some(i)
+ val nullValue: Option[String] = None
+ (maybeInt, nullValue)
+ }
+ // It needs to repartition data so that we can have several ORC files
+ // in order to skip stripes in ORC.
+ spark.createDataFrame(data).toDF("a", "b").repartition(10)
+ .write.format(ORC_FILE_FORMAT).save(path)
+ val df = spark.read.format(ORC_FILE_FORMAT).load(path)
+
+ def checkPredicate(pred: Column, answer: Seq[Row]): Unit = {
+ val sourceDf = stripSparkFilter(df.where(pred))
+ val data = sourceDf.collect().toSet
+ val expectedData = answer.toSet
+
+ // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check
+ // the number of rows returned from the ORC to make sure our filter pushdown work.
+ // A tricky part is, ORC does not process filter rows fully but return some possible
+ // results. So, this checks if the number of result is less than the original count
+ // of data, and then checks if it contains the expected data.
+ assert(
+ sourceDf.count < 10 && expectedData.subsetOf(data),
+ s"No data was filtered for predicate: $pred")
+ }
+
+ checkPredicate('a === 5, List(5).map(Row(_, null)))
+ checkPredicate('a <=> 5, List(5).map(Row(_, null)))
+ checkPredicate('a < 5, List(1, 3).map(Row(_, null)))
+ checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null)))
+ checkPredicate('a > 5, List(7, 9).map(Row(_, null)))
+ checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null)))
+ checkPredicate('a.isNull, List(null).map(Row(_, null)))
+ checkPredicate('b.isNotNull, List())
+ checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null)))
+ checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null)))
+ checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null)))
+ checkPredicate(!('a > 3), List(1, 3).map(Row(_, null)))
+ checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null)))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-14962 Produce correct results on array type with isnotnull") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ val data = (0 until 10).map(i => Tuple1(Array(i)))
+ withOrcFile(data) { file =>
+ val actual = spark
+ .read
+ .format(ORC_FILE_FORMAT)
+ .load(file)
+ .where("_1 is not null")
+ val expected = data.toDF()
+ checkAnswer(actual, expected)
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-15198 Support for pushing down filters for boolean types") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ val data = (0 until 10).map(_ => (true, false))
+ withOrcFile(data) { file =>
+ val df = spark.read.format(ORC_FILE_FORMAT).load(file).where("_2 == true")
+ val actual = stripSparkFilter(df).count()
+
+ // ORC filter should be applied and the total count should be 0.
+ assert(actual === 0)
+ }
+ }
+ }
+ }
+ }
+
+ test("Support for pushing down filters for decimal types") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ val data = (0 until 10).map(i => Tuple1(BigDecimal.valueOf(i)))
+ withTempPath { file =>
+ // It needs to repartition data so that we can have several ORC files
+ // in order to skip stripes in ORC.
+ spark.createDataFrame(data).toDF("a").repartition(10)
+ .write.format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+ val df = spark.read.format(ORC_FILE_FORMAT).load(file.getCanonicalPath).where("a == 2")
+ val actual = stripSparkFilter(df).count()
+
+ assert(actual < 10)
+ }
+ }
+ }
+ }
+ }
+
+ test("Support for pushing down filters for timestamp types") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
+ val timeString = "2015-08-20 14:57:00"
+ val data = (0 until 10).map { i =>
+ val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600
+ Tuple1(new Timestamp(milliseconds))
+ }
+ withTempPath { file =>
+ // It needs to repartition data so that we can have several ORC files
+ // in order to skip stripes in ORC.
+ spark.createDataFrame(data).toDF("a").repartition(10)
+ .write.format(ORC_FILE_FORMAT).save(file.getCanonicalPath)
+ val df = spark.read.format(ORC_FILE_FORMAT)
+ .load(file.getCanonicalPath).where(s"a == '$timeString'")
+ val actual = stripSparkFilter(df).count()
+
+ assert(actual < 10)
+ }
+ }
+ }
+ }
+ }
+
+ test("column nullability and comment - write and then read") {
+ val schema = (new StructType)
+ .add("cl1", IntegerType, nullable = false, comment = "test")
+ .add("cl2", IntegerType, nullable = true)
+ .add("cl3", IntegerType, nullable = true)
+ val row = Row(3, null, 4)
+ val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+
+ val tableName = "tab"
+ withTable(tableName) {
+ df.write.format(ORC_FILE_FORMAT).mode("overwrite").saveAsTable(tableName)
+ // Verify the DDL command result: DESCRIBE TABLE
+ checkAnswer(
+ sql(s"desc $tableName").select("col_name", "comment").where($"comment" === "test"),
+ Row("cl1", "test") :: Nil)
+ // Verify the schema
+ val expectedFields = schema.fields.map(f => f.copy(nullable = true))
+ assert(spark.table(tableName).schema == schema.copy(fields = expectedFields))
+ }
+ }
+
+ ignore("Empty schema does not read data from ORC file") {
+ val data = Seq((1, 1), (2, 2))
+ withOrcFile(data) { path =>
+ val conf = new Configuration()
+ conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, "")
+ conf.setBoolean("hive.io.file.read.all.columns", false)
+
+ val orcRecordReader = {
+ val file = new File(path).listFiles().find(_.getName.endsWith(".snappy.orc")).head
+ val split = new FileSplit(new Path(file.toURI), 0, file.length, Array.empty[String])
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val oif = new OrcInputFormat[OrcStruct]
+ oif.createRecordReader(split, hadoopAttemptContext)
+ }
+
+ val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader)
+ try {
+ assert(recordsIterator.next().toString == "{null, null}")
+ } finally {
+ recordsIterator.close()
+ }
+ }
+ }
+
+ test("read from multiple orc input paths") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ val path1 = Utils.createTempDir()
+ val path2 = Utils.createTempDir()
+ makeOrcFile((1 to 10).map(Tuple1.apply), path1)
+ makeOrcFile((1 to 10).map(Tuple1.apply), path2)
+ val df = spark
+ .read
+ .format(ORC_FILE_FORMAT)
+ .load(path1.getCanonicalPath, path2.getCanonicalPath)
+ assertResult(20)(df.count())
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
new file mode 100644
index 000000000000..0b5e33caa94b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+case class OrcData(intField: Int, stringField: String)
+
+/**
+ * This test suite is a port of org.apache.spark.sql.hive.orc.OrcSuite.
+ */
+abstract class OrcSuite extends QueryTest with SharedSQLContext with OrcTest {
+ import testImplicits._
+
+ var orcTableDir: File = null
+ var orcTableAsDir: File = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ orcTableAsDir = Utils.createTempDir("orctests", "sparksql")
+
+ // Hack: to prepare orc data files using hive external tables
+ orcTableDir = Utils.createTempDir("orctests", "sparksql")
+
+ spark.sparkContext
+ .makeRDD(1 to 10)
+ .map(i => OrcData(i, s"part-$i"))
+ .toDF()
+ .createOrReplaceTempView(s"orc_temp_table")
+
+ sql(
+ s"""CREATE TABLE normal_orc(
+ | intField INT,
+ | stringField STRING
+ |)
+ |USING $ORC_FILE_FORMAT
+ |LOCATION '${orcTableAsDir.toURI}'
+ """.stripMargin)
+
+ sql(
+ s"""INSERT INTO TABLE normal_orc
+ |SELECT intField, stringField FROM orc_temp_table
+ """.stripMargin)
+ }
+
+ test("create temporary orc table") {
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source where intField > 5"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
+ (1 to 10).map(i => Row(1, s"part-$i")))
+ }
+
+ test("create temporary orc table as") {
+ checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source WHERE intField > 5"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+
+ checkAnswer(
+ sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"),
+ (1 to 10).map(i => Row(1, s"part-$i")))
+ }
+
+ test("appending insert") {
+ sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5")
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_source"),
+ (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i =>
+ Seq.fill(2)(Row(i, s"part-$i"))
+ })
+ }
+
+ test("overwrite insert") {
+ sql(
+ """INSERT OVERWRITE TABLE normal_orc_as_source
+ |SELECT * FROM orc_temp_table WHERE intField > 5
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT * FROM normal_orc_as_source"),
+ (6 to 10).map(i => Row(i, s"part-$i")))
+ }
+
+ test("write null values") {
+
+ sql("DROP TABLE IF EXISTS orcNullValues")
+
+ val df = sql(
+ """
+ |SELECT
+ | CAST(null as TINYINT) as c0,
+ | CAST(null as SMALLINT) as c1,
+ | CAST(null as INT) as c2,
+ | CAST(null as BIGINT) as c3,
+ | CAST(null as FLOAT) as c4,
+ | CAST(null as DOUBLE) as c5,
+ | CAST(null as DECIMAL(7,2)) as c6,
+ | CAST(null as TIMESTAMP) as c7,
+ | CAST(null as DATE) as c8,
+ | CAST(null as STRING) as c9,
+ | CAST(null as VARCHAR(10)) as c10
+ |FROM orc_temp_table limit 1
+ """.stripMargin)
+
+ df.write.format(ORC_FILE_FORMAT).saveAsTable("orcNullValues")
+
+ checkAnswer(
+ sql("SELECT * FROM orcNullValues"),
+ Row.fromSeq(Seq.fill(11)(null)))
+
+ sql("DROP TABLE IF EXISTS orcNullValues")
+ }
+
+ test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
+ val options = new OrcOptions(Map("Orc.Compress" -> "NONE"), spark.sessionState.conf)
+ assert(options.compressionCodecClassName == "NONE")
+ }
+}
+
+class OrcSourceSuite extends OrcSuite {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ spark.sql(
+ s"""CREATE TEMPORARY VIEW normal_orc_source
+ |USING $ORC_FILE_FORMAT
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
+ |)
+ """.stripMargin)
+
+ spark.sql(
+ s"""CREATE TEMPORARY VIEW normal_orc_as_source
+ |USING $ORC_FILE_FORMAT
+ |OPTIONS (
+ | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}'
+ |)
+ """.stripMargin)
+ }
+
+ test("SPARK-12218 Converting conjunctions into ORC SearchArguments") {
+ Seq("false", "true").foreach { value =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> value) {
+ // The `LessThan` should be converted while the `StringContains` shouldn't
+ val schema = new StructType(
+ Array(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", StringType, nullable = true)))
+ assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ StringContains("b", "prefix")
+ )).get.toString
+ }
+
+ // The `LessThan` should be converted while the whole inner `And` shouldn't
+ assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") {
+ OrcFilters.createFilter(schema, Array(
+ LessThan("a", 10),
+ Not(And(
+ GreaterThan("a", 1),
+ StringContains("b", "prefix")
+ ))
+ )).get.toString
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
new file mode 100644
index 000000000000..b17bd0f50d0e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import java.io.File
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+
+/**
+ * This test suite is a port of org.apache.spark.sql.hive.orc.OrcTest.
+ */
+private[sql] trait OrcTest extends SQLTestUtils with SharedSQLContext {
+ import testImplicits._
+
+ val ORC_FILE_FORMAT = classOf[OrcFileFormat].getCanonicalName
+ /**
+ * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f`
+ * returns.
+ */
+ protected def withOrcFile[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T])
+ (f: String => Unit): Unit = {
+ withTempPath { file =>
+ sparkContext.parallelize(data).toDF().write.format(ORC_FILE_FORMAT)
+ .save(file.getCanonicalPath)
+ f(file.getCanonicalPath)
+ }
+ }
+
+ /**
+ * Writes `data` to a Orc file and reads it back as a `DataFrame`,
+ * which is then passed to `f`. The Orc file will be deleted after `f` returns.
+ */
+ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T])
+ (f: DataFrame => Unit): Unit = {
+ withOrcFile(data)(path => f(spark.read.format(ORC_FILE_FORMAT).load(path)))
+ }
+
+ /**
+ * Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a
+ * temporary table named `tableName`, then call `f`. The temporary table together with the
+ * Orc file will be dropped/deleted after `f` returns.
+ */
+ protected def withOrcTable[T <: Product: ClassTag: TypeTag]
+ (data: Seq[T], tableName: String)
+ (f: => Unit): Unit = {
+ withOrcDataFrame(data) { df =>
+ df.createOrReplaceTempView(tableName)
+ withTempView(tableName)(f)
+ }
+ }
+
+ protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ data: Seq[T], path: File): Unit = {
+ data.toDF().write.mode(SaveMode.Overwrite).format(ORC_FILE_FORMAT).save(path.getCanonicalPath)
+ }
+
+ protected def makeOrcFile[T <: Product: ClassTag: TypeTag](
+ df: DataFrame, path: File): Unit = {
+ df.write.mode(SaveMode.Overwrite).format(ORC_FILE_FORMAT).save(path.getCanonicalPath)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala
new file mode 100644
index 000000000000..450898b16c27
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.orc
+
+import java.io.File
+
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc.OrcFile
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcInputFormat
+import org.apache.orc.storage.ql.exec.vector.{BytesColumnVector, LongColumnVector}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.{Benchmark, Utils}
+
+
+/**
+ * Benchmark to measure orc read performance.
+ *
+ * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources.
+ * After removing `sql/hive` ORC data sources, we need to move this into `sql/core` module
+ * like the other ORC test suites.
+ */
+object OrcReadBenchmark {
+ val conf = new SparkConf()
+ conf.set("orc.compression", "snappy")
+
+ private val spark = SparkSession.builder()
+ .master("local[1]")
+ .appName("OrcReadBenchmark")
+ .config(conf)
+ .getOrCreate()
+
+ // Set default configs. Individual cases will change them if necessary.
+ spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true")
+ spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true")
+ spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
+
+ def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path) finally Utils.deleteRecursively(path)
+ }
+
+ def withTempTable(tableNames: String*)(f: => Unit): Unit = {
+ try f finally tableNames.foreach(spark.catalog.dropTempView)
+ }
+
+ def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
+ val (keys, values) = pairs.unzip
+ val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
+ (keys, values).zipped.foreach(spark.conf.set)
+ try f finally {
+ keys.zip(currentValues).foreach {
+ case (key, Some(value)) => spark.conf.set(key, value)
+ case (key, None) => spark.conf.unset(key)
+ }
+ }
+ }
+
+ private val SQL_ORC_FILE_FORMAT = "org.apache.spark.sql.execution.datasources.orc.OrcFileFormat"
+ private val HIVE_ORC_FILE_FORMAT = "org.apache.spark.sql.hive.orc.OrcFileFormat"
+
+ // scalastyle:off line.size.limit
+ def intScanBenchmark(values: Int): Unit = {
+ // Benchmarks running through spark sql.
+ val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values)
+ // Benchmarks driving reader component directly.
+ val orcReaderBenchmark = new Benchmark("ORC Reader Single Int Column Scan", values)
+
+ withTempPath { dir =>
+ withTempTable("t1", "coreOrcTable", "hiveOrcTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+ spark.sql("select cast(id as INT) as id from t1")
+ .write.orc(dir.getCanonicalPath)
+ spark.read.format(SQL_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("coreOrcTable")
+ spark.read.format(HIVE_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("hiveOrcTable")
+
+ sqlBenchmark.addCase("SQL ORC Vectorized") { _ =>
+ spark.sql("select sum(id) from coreOrcTable").collect
+ }
+
+ sqlBenchmark.addCase("SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("select sum(id) from coreOrcTable").collect
+ }
+ }
+
+ sqlBenchmark.addCase("HIVE ORC MR") { _ =>
+ spark.sql("select sum(id) from hiveOrcTable").collect
+ }
+
+ val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
+ // Driving the orc reader in batch mode directly.
+ val conf = new Configuration
+ orcReaderBenchmark.addCase("OrcReader Vectorized") { _ =>
+ var sum = 0L
+ files.map(_.asInstanceOf[String]).foreach { p =>
+ val reader = OrcFile.createReader(new Path(p), OrcFile.readerOptions(conf))
+ val rows = reader.rows()
+ try {
+ val batch = reader.getSchema.createRowBatch
+ val longColumnVector = batch.cols(0).asInstanceOf[LongColumnVector]
+
+ while (rows.nextBatch(batch)) {
+ for (r <- 0 until batch.size) {
+ if (longColumnVector.noNulls || !longColumnVector.isNull(r)) {
+ val record = longColumnVector.vector(r)
+ sum += record
+ }
+ }
+ }
+ } finally {
+ rows.close()
+ }
+ }
+ }
+
+ orcReaderBenchmark.addCase("OrcReader") { _ =>
+ var sum = 0L
+ files.map(_.asInstanceOf[String]).foreach { p =>
+ val oif = new OrcInputFormat[OrcStruct]
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val fileSplit = new FileSplit(new Path(p), 0L, Long.MaxValue, new Array[String](0))
+ val reader = oif.createRecordReader(fileSplit, hadoopAttemptContext)
+ try {
+ while (reader.nextKeyValue()) {
+ sum += reader.getCurrentValue.getFieldValue(0).asInstanceOf[IntWritable].get
+ }
+ } finally {
+ reader.close()
+ }
+ }
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized 170 / 194 92.5 10.8 1.0X
+ SQL ORC MR 388 / 396 40.5 24.7 0.4X
+ HIVE ORC MR 488 / 496 32.3 31.0 0.3X
+ */
+ sqlBenchmark.run()
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ ORC Reader Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ OrcReader Vectorized 119 / 124 132.7 7.5 1.0X
+ OrcReader 369 / 377 42.6 23.5 0.3X
+ */
+ orcReaderBenchmark.run()
+ }
+ }
+ }
+
+ def intStringScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("Int and String Scan", values)
+
+ withTempPath { dir =>
+ withTempTable("t1", "coreOrcTable", "hiveOrcTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+ spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1")
+ .write.orc(dir.getCanonicalPath)
+ spark.read.format(SQL_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("coreOrcTable")
+ spark.read.format(HIVE_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("hiveOrcTable")
+
+ benchmark.addCase("SQL ORC Vectorized") { _ =>
+ spark.sql("select sum(c1), sum(length(c2)) from coreOrcTable").collect
+ }
+
+ benchmark.addCase("SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("select sum(c1), sum(length(c2)) from coreOrcTable").collect
+ }
+ }
+
+ benchmark.addCase("HIVE ORC MR") { _ =>
+ spark.sql("select sum(c1), sum(length(c2)) from hiveOrcTable").collect
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized 310 / 373 33.8 29.5 1.0X
+ SQL ORC MR 580 / 617 18.1 55.4 0.5X
+ HIVE ORC MR 881 / 938 11.9 84.0 0.4X
+ */
+ benchmark.run()
+ }
+ }
+ }
+
+ def stringDictionaryScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("String Dictionary", values)
+
+ withTempPath { dir =>
+ withTempTable("t1", "coreOrcTable", "hiveOrcTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+ spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1")
+ .write.orc(dir.getCanonicalPath)
+ spark.read.format(SQL_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("coreOrcTable")
+ spark.read.format(HIVE_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("hiveOrcTable")
+
+ benchmark.addCase("SQL ORC Vectorized") { _ =>
+ spark.sql("select sum(length(c1)) from coreOrcTable").collect
+ }
+
+ benchmark.addCase("SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("select sum(length(c1)) from coreOrcTable").collect
+ }
+ }
+
+ benchmark.addCase("HIVE ORC MR") { _ =>
+ spark.sql("select sum(length(c1)) from hiveOrcTable").collect
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized 165 / 173 63.7 15.7 1.0X
+ SQL ORC MR 401 / 406 26.2 38.2 0.4X
+ HIVE ORC MR 620 / 629 16.9 59.1 0.3X
+ */
+ benchmark.run()
+ }
+ }
+ }
+
+ def partitionTableScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("Partitioned Table", values)
+
+ withTempPath { dir =>
+ withTempTable("t1", "coreOrcTable", "hiveOrcTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+ spark.sql("select id % 2 as p, cast(id as INT) as id from t1")
+ .write.partitionBy("p").orc(dir.getCanonicalPath)
+ spark.read.format(SQL_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("coreOrcTable")
+ spark.read.format(HIVE_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("hiveOrcTable")
+
+ benchmark.addCase("SQL Read data column") { _ =>
+ spark.sql("select sum(id) from coreOrcTable").collect
+ }
+
+ benchmark.addCase("SQL Read partition column") { _ =>
+ spark.sql("select sum(p) from coreOrcTable").collect
+ }
+
+ benchmark.addCase("SQL Read both columns") { _ =>
+ spark.sql("select sum(p), sum(id) from coreOrcTable").collect
+ }
+
+ benchmark.addCase("HIVE Read data column") { _ =>
+ spark.sql("select sum(id) from hiveOrcTable").collect
+ }
+
+ benchmark.addCase("HIVE Read partition column") { _ =>
+ spark.sql("select sum(p) from hiveOrcTable").collect
+ }
+
+ benchmark.addCase("HIVE Read both columns") { _ =>
+ spark.sql("select sum(p), sum(id) from hiveOrcTable").collect
+ }
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL Read data column 188 / 227 83.6 12.0 1.0X
+ SQL Read partition column 98 / 109 161.2 6.2 1.9X
+ SQL Read both columns 193 / 227 81.5 12.3 1.0X
+ HIVE Read data column 530 / 530 29.7 33.7 0.4X
+ HIVE Read partition column 420 / 423 37.4 26.7 0.4X
+ HIVE Read both columns 558 / 562 28.2 35.5 0.3X
+ */
+ benchmark.run()
+ }
+ }
+ }
+
+ def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = {
+ withTempPath { dir =>
+ withTempTable("t1", "coreOrcTable", "hiveOrcTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+ spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " +
+ s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1")
+ .write.orc(dir.getCanonicalPath)
+ spark.read.format(SQL_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("coreOrcTable")
+ spark.read.format(HIVE_ORC_FILE_FORMAT)
+ .load(dir.getCanonicalPath).createOrReplaceTempView("hiveOrcTable")
+
+ val benchmark = new Benchmark("String with Nulls Scan", values)
+
+ benchmark.addCase(s"SQL ORC Vectorized ($fractionOfNulls%)") { iter =>
+ spark.sql("select sum(length(c2)) from coreOrcTable where c1 is " +
+ "not NULL and c2 is not NULL").collect()
+ }
+
+ benchmark.addCase(s"HIVE ORC ($fractionOfNulls%)") { iter =>
+ spark.sql("select sum(length(c2)) from hiveOrcTable where c1 is " +
+ "not NULL and c2 is not NULL").collect()
+ }
+
+ val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
+ // Driving the orc reader in batch mode directly.
+ val conf = new Configuration
+ benchmark.addCase("OrcReader Vectorized") { _ =>
+ var sum = 0L
+ files.map(_.asInstanceOf[String]).foreach { p =>
+ val reader = OrcFile.createReader(new Path(p), OrcFile.readerOptions(conf))
+ val rows = reader.rows()
+ try {
+ val batch = reader.getSchema.createRowBatch
+ val col = batch.cols(0).asInstanceOf[BytesColumnVector]
+
+ while (rows.nextBatch(batch)) {
+ for (r <- 0 until batch.size) {
+ val value = UTF8String.fromBytes(col.vector(r), col.start(r), col.length(r))
+ if (!col.isNull(r)) sum += value.numBytes()
+ }
+ }
+ } finally {
+ rows.close()
+ }
+ }
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.12.4
+ Intel(R) Core(TM) i7-3615QM CPU @ 2.30GHz
+
+ String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized (0.0%) 501 / 596 20.9 47.7 1.0X
+ HIVE ORC (0.0%) 1225 / 1322 8.6 116.8 0.4X
+ OrcReader Vectorized 757 / 761 13.9 72.2 0.7X
+
+ String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized (0.5%) 415 / 453 25.3 39.6 1.0X
+ HIVE ORC (0.5%) 884 / 940 11.9 84.3 0.5X
+ OrcReader Vectorized 905 / 940 11.6 86.3 0.5X
+
+ String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ SQL ORC Vectorized (0.95%) 221 / 239 47.4 21.1 1.0X
+ HIVE ORC (0.95%) 517 / 527 20.3 49.3 0.4X
+ OrcReader Vectorized 358 / 365 29.3 34.1 0.6X
+ */
+
+ benchmark.run()
+ }
+ }
+ }
+ // scalastyle:on line.size.limit
+
+ def main(args: Array[String]): Unit = {
+ intScanBenchmark(1024 * 1024 * 15)
+ intStringScanBenchmark(1024 * 1024 * 10)
+ stringDictionaryScanBenchmark(1024 * 1024 * 10)
+ partitionTableScanBenchmark(1024 * 1024 * 15)
+ for (fractionOfNulls <- List(0.0, 0.50, 0.95)) {
+ stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls)
+ }
+ }
+}