Skip to content

Commit 160e924

Browse files
committed
Add test for VectorizedSparkOrcNewRecordReader.
1 parent 55bb19f commit 160e924

File tree

6 files changed

+223
-28
lines changed

6 files changed

+223
-28
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,9 +980,9 @@ public ColumnVector getDictionaryIds() {
980980
return dictionaryIds;
981981
}
982982

983-
public ColumnVector() {
983+
public ColumnVector(DataType type) {
984984
this.capacity = 0;
985-
this.type = null;
985+
this.type = type;
986986
this.childColumns = null;
987987
this.resultArray = null;
988988
this.resultStruct = null;

sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/OrcColumnVector.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector;
2424
import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
2525

26+
import org.apache.spark.sql.types.DataType;
2627
import org.apache.spark.sql.types.Decimal;
2728
import org.apache.spark.unsafe.types.UTF8String;
2829

@@ -35,7 +36,8 @@
3536
public class OrcColumnVector extends org.apache.spark.sql.execution.vectorized.ColumnVector {
3637
private ColumnVector col;
3738

38-
public OrcColumnVector(ColumnVector col) {
39+
public OrcColumnVector(ColumnVector col, DataType type) {
40+
super(type);
3941
this.col = col;
4042
}
4143

sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/VectorizedSparkOrcNewRecordReader.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public VectorizedSparkOrcNewRecordReader(
7777
Configuration conf,
7878
FileSplit fileSplit,
7979
List<Integer> columnIDs,
80+
StructType requiredSchema,
8081
StructType partitionColumns,
8182
InternalRow partitionValues) throws IOException {
8283
List<OrcProto.Type> types = file.getTypes();
@@ -93,7 +94,7 @@ public VectorizedSparkOrcNewRecordReader(
9394
for (int i = 0; i < columnIDs.size(); i++) {
9495
org.apache.hadoop.hive.ql.exec.vector.ColumnVector col =
9596
this.hiveBatch.cols[columnIDs.get(i)];
96-
this.orcColumns[i] = new OrcColumnVector(col);
97+
this.orcColumns[i] = new OrcColumnVector(col, requiredSchema.fields()[i].dataType());
9798
}
9899

99100
// Allocate Spark ColumnVectors for partition columns.

sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,14 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
181181
if (enableVectorizedReader) {
182182
val columnIDs =
183183
requiredSchema.map(a => physicalSchema.fieldIndex(a.name): Integer).sorted.asJava
184-
val orcRecordReader =
185-
new VectorizedSparkOrcNewRecordReader(
186-
orcReader, conf, fileSplit, columnIDs, partitionSchema, file.partitionValues)
184+
val orcRecordReader = new VectorizedSparkOrcNewRecordReader(
185+
orcReader,
186+
conf,
187+
fileSplit,
188+
columnIDs,
189+
requiredSchema,
190+
partitionSchema,
191+
file.partitionValues)
187192

188193
if (returningBatch) {
189194
orcRecordReader.enableReturningBatches()
@@ -226,11 +231,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
226231
* Returns whether the reader will return the rows as batch or not.
227232
*/
228233
override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
229-
val conf = sparkSession.sessionState.conf
230-
conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled &&
231-
schema.length <= conf.wholeStageMaxNumFields &&
232-
schema.forall(f => f.dataType.isInstanceOf[AtomicType] &&
233-
!f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType])
234+
OrcRelation.supportBatch(sparkSession, schema)
234235
}
235236
}
236237

@@ -374,4 +375,15 @@ private[orc] object OrcRelation extends HiveInspectors {
374375
val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip
375376
HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)
376377
}
378+
379+
/**
380+
* Returns whether the reader will return the rows as batch or not.
381+
*/
382+
def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
383+
val conf = sparkSession.sessionState.conf
384+
conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled &&
385+
schema.length <= conf.wholeStageMaxNumFields &&
386+
schema.forall(f => f.dataType.isInstanceOf[AtomicType] &&
387+
!f.dataType.isInstanceOf[DateType] && !f.dataType.isInstanceOf[TimestampType])
388+
}
377389
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/vectorized/OrcColumnVectorSuite.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
8080
}
8181
}
8282

83-
private def testLongColumnVector[T](num: Int)
83+
private def testLongColumnVector[T](num: Int, dt: DataType)
8484
(genExpected: (Seq[Long] => Seq[T]))
8585
(genActual: (OrcColumnVector, Int) => Seq[T]): Unit = {
8686
val seed = System.currentTimeMillis()
@@ -96,12 +96,12 @@ class OrcColumnVectorSuite extends SparkFunSuite {
9696

9797
val expected = genExpected(data)
9898

99-
val orcCol = new OrcColumnVector(lv)
99+
val orcCol = new OrcColumnVector(lv, dt)
100100
val actual = genActual(orcCol, num)
101101
assert(actual === expected)
102102
}
103103

104-
private def testDoubleColumnVector[T](num: Int)
104+
private def testDoubleColumnVector[T](num: Int, dt: DataType)
105105
(genExpected: (Seq[Double] => Seq[T]))
106106
(genActual: (OrcColumnVector, Int) => Seq[T]): Unit = {
107107
val seed = System.currentTimeMillis()
@@ -117,12 +117,12 @@ class OrcColumnVectorSuite extends SparkFunSuite {
117117

118118
val expected = genExpected(data)
119119

120-
val orcCol = new OrcColumnVector(lv)
120+
val orcCol = new OrcColumnVector(lv, dt)
121121
val actual = genActual(orcCol, num)
122122
assert(actual === expected)
123123
}
124124

125-
private def testBytesColumnVector[T](num: Int)
125+
private def testBytesColumnVector[T](num: Int, dt: DataType)
126126
(genExpected: (Seq[Seq[Byte]] => Seq[T]))
127127
(genActual: (OrcColumnVector, Int) => Seq[T]): Unit = {
128128
val seed = System.currentTimeMillis()
@@ -139,7 +139,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
139139

140140
val expected = genExpected(data)
141141

142-
val orcCol = new OrcColumnVector(lv)
142+
val orcCol = new OrcColumnVector(lv, dt)
143143
val actual = genActual(orcCol, num)
144144
actual.zip(expected).foreach { case (a, e) =>
145145
assert(a === e)
@@ -168,7 +168,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
168168

169169
val expected = genExpected(data)
170170

171-
val orcCol = new OrcColumnVector(lv)
171+
val orcCol = new OrcColumnVector(lv, decimalType)
172172
val actual = genActual(orcCol, num, decimalType.precision, decimalType.scale)
173173
actual.zip(expected).foreach { case (a, e) =>
174174
assert(a.compareTo(e) == 0)
@@ -183,7 +183,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
183183
col.getBoolean(rowId)
184184
}
185185
}
186-
testLongColumnVector(100)(genExpected)(genActual)
186+
testLongColumnVector(100, BooleanType)(genExpected)(genActual)
187187
}
188188

189189
test("Hive LongColumnVector: Int") {
@@ -193,7 +193,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
193193
col.getInt(rowId)
194194
}
195195
}
196-
testLongColumnVector(100)(genExpected)(genActual)
196+
testLongColumnVector(100, IntegerType)(genExpected)(genActual)
197197
}
198198

199199
test("Hive LongColumnVector: Byte") {
@@ -203,7 +203,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
203203
col.getByte(rowId)
204204
}
205205
}
206-
testLongColumnVector(100)(genExpected)(genActual)
206+
testLongColumnVector(100, ByteType)(genExpected)(genActual)
207207
}
208208

209209
test("Hive LongColumnVector: Short") {
@@ -213,7 +213,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
213213
col.getShort(rowId)
214214
}
215215
}
216-
testLongColumnVector(100)(genExpected)(genActual)
216+
testLongColumnVector(100, ShortType)(genExpected)(genActual)
217217
}
218218

219219
test("Hive LongColumnVector: Long") {
@@ -223,7 +223,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
223223
col.getLong(rowId)
224224
}
225225
}
226-
testLongColumnVector(100)(genExpected)(genActual)
226+
testLongColumnVector(100, LongType)(genExpected)(genActual)
227227
}
228228

229229
test("Hive DoubleColumnVector: Float") {
@@ -233,7 +233,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
233233
col.getFloat(rowId)
234234
}
235235
}
236-
testDoubleColumnVector(100)(genExpected)(genActual)
236+
testDoubleColumnVector(100, FloatType)(genExpected)(genActual)
237237
}
238238

239239
test("Hive DoubleColumnVector: Double") {
@@ -243,7 +243,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
243243
col.getDouble(rowId)
244244
}
245245
}
246-
testDoubleColumnVector(100)(genExpected)(genActual)
246+
testDoubleColumnVector(100, DoubleType)(genExpected)(genActual)
247247
}
248248

249249
test("Hive BytesColumnVector: Binary") {
@@ -253,7 +253,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
253253
col.getBinary(rowId).toSeq
254254
}
255255
}
256-
testBytesColumnVector(100)(genExpected)(genActual)
256+
testBytesColumnVector(100, BinaryType)(genExpected)(genActual)
257257
}
258258

259259
test("Hive BytesColumnVector: String") {
@@ -266,7 +266,7 @@ class OrcColumnVectorSuite extends SparkFunSuite {
266266
col.getUTF8String(rowId)
267267
}
268268
}
269-
testBytesColumnVector(100)(genExpected)(genActual)
269+
testBytesColumnVector(100, StringType)(genExpected)(genActual)
270270
}
271271

272272
test("Hive DecimalColumnVector") {

0 commit comments

Comments
 (0)