Skip to content

Commit 148a84b

Browse files
kiszkdavies
authored andcommitted
[SPARK-17912] [SQL] Refactor code generation to get data for ColumnVector/ColumnarBatch
## What changes were proposed in this pull request? This PR refactors the code generation part to get data from `ColumnarVector` and `ColumnarBatch` by using a trait `ColumnarBatchScan` for ease of reuse. This is because this part will be reused by several components (e.g. parquet reader, Dataset.cache, and others) since `ColumnarBatch` will be first citizen. This PR is a part of #15219. In advance, this PR makes the code generation for `ColumnarVector` and `ColumnarBatch` reuseable as a trait. In general, this is very useful for other components from the reuseability view, too. ## How was this patch tested? tested existing test suites Author: Kazuaki Ishizaki <[email protected]> Closes #15467 from kiszk/columnarrefactor.
1 parent 63d8390 commit 148a84b

File tree

2 files changed

+135
-84
lines changed

2 files changed

+135
-84
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
22+
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
23+
import org.apache.spark.sql.execution.metric.SQLMetrics
24+
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
25+
import org.apache.spark.sql.types.DataType
26+
27+
28+
/**
29+
* Helper trait for abstracting scan functionality using
30+
* [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es.
31+
*/
32+
private[sql] trait ColumnarBatchScan extends CodegenSupport {
33+
34+
val inMemoryTableScan: InMemoryTableScanExec = null
35+
36+
override lazy val metrics = Map(
37+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
38+
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
39+
40+
/**
41+
* Generate [[ColumnVector]] expressions for our parent to consume as rows.
42+
* This is called once per [[ColumnarBatch]].
43+
*/
44+
private def genCodeColumnVector(
45+
ctx: CodegenContext,
46+
columnVar: String,
47+
ordinal: String,
48+
dataType: DataType,
49+
nullable: Boolean): ExprCode = {
50+
val javaType = ctx.javaType(dataType)
51+
val value = ctx.getValue(columnVar, dataType, ordinal)
52+
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
53+
val valueVar = ctx.freshName("value")
54+
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
55+
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
56+
s"""
57+
boolean $isNullVar = $columnVar.isNullAt($ordinal);
58+
$javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
59+
"""
60+
} else {
61+
s"$javaType $valueVar = $value;"
62+
}).trim
63+
ExprCode(code, isNullVar, valueVar)
64+
}
65+
66+
/**
67+
* Produce code to process the input iterator as [[ColumnarBatch]]es.
68+
* This produces an [[UnsafeRow]] for each row in each batch.
69+
*/
70+
// TODO: return ColumnarBatch.Rows instead
71+
override protected def doProduce(ctx: CodegenContext): String = {
72+
val input = ctx.freshName("input")
73+
// PhysicalRDD always just has one input
74+
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
75+
76+
// metrics
77+
val numOutputRows = metricTerm(ctx, "numOutputRows")
78+
val scanTimeMetric = metricTerm(ctx, "scanTime")
79+
val scanTimeTotalNs = ctx.freshName("scanTime")
80+
ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
81+
82+
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
83+
val batch = ctx.freshName("batch")
84+
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
85+
86+
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
87+
val idx = ctx.freshName("batchIdx")
88+
ctx.addMutableState("int", idx, s"$idx = 0;")
89+
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
90+
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
91+
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
92+
s"$name = $batch.column($i);"
93+
}
94+
95+
val nextBatch = ctx.freshName("nextBatch")
96+
ctx.addNewFunction(nextBatch,
97+
s"""
98+
|private void $nextBatch() throws java.io.IOException {
99+
| long getBatchStart = System.nanoTime();
100+
| if ($input.hasNext()) {
101+
| $batch = ($columnarBatchClz)$input.next();
102+
| $numOutputRows.add($batch.numRows());
103+
| $idx = 0;
104+
| ${columnAssigns.mkString("", "\n", "\n")}
105+
| }
106+
| $scanTimeTotalNs += System.nanoTime() - getBatchStart;
107+
|}""".stripMargin)
108+
109+
ctx.currentVars = null
110+
val rowidx = ctx.freshName("rowIdx")
111+
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
112+
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
113+
}
114+
s"""
115+
|if ($batch == null) {
116+
| $nextBatch();
117+
|}
118+
|while ($batch != null) {
119+
| int numRows = $batch.numRows();
120+
| while ($idx < numRows) {
121+
| int $rowidx = $idx++;
122+
| ${consume(ctx, columnsBatchInput).trim}
123+
| if (shouldStop()) return;
124+
| }
125+
| $batch = null;
126+
| $nextBatch();
127+
|}
128+
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
129+
|$scanTimeTotalNs = 0;
130+
""".stripMargin
131+
}
132+
133+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 2 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ case class FileSourceScanExec(
145145
partitionFilters: Seq[Expression],
146146
dataFilters: Seq[Filter],
147147
override val metastoreTableIdentifier: Option[TableIdentifier])
148-
extends DataSourceScanExec {
148+
extends DataSourceScanExec with ColumnarBatchScan {
149149

150150
val supportsBatch: Boolean = relation.fileFormat.supportBatch(
151151
relation.sparkSession, StructType.fromAttributes(output))
@@ -312,7 +312,7 @@ case class FileSourceScanExec(
312312

313313
override protected def doProduce(ctx: CodegenContext): String = {
314314
if (supportsBatch) {
315-
return doProduceVectorized(ctx)
315+
return super.doProduce(ctx)
316316
}
317317
val numOutputRows = metricTerm(ctx, "numOutputRows")
318318
// PhysicalRDD always just has one input
@@ -336,88 +336,6 @@ case class FileSourceScanExec(
336336
""".stripMargin
337337
}
338338

339-
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
340-
// never requires UnsafeRow as input.
341-
private def doProduceVectorized(ctx: CodegenContext): String = {
342-
val input = ctx.freshName("input")
343-
// PhysicalRDD always just has one input
344-
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
345-
346-
// metrics
347-
val numOutputRows = metricTerm(ctx, "numOutputRows")
348-
val scanTimeMetric = metricTerm(ctx, "scanTime")
349-
val scanTimeTotalNs = ctx.freshName("scanTime")
350-
ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
351-
352-
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
353-
val batch = ctx.freshName("batch")
354-
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
355-
356-
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
357-
val idx = ctx.freshName("batchIdx")
358-
ctx.addMutableState("int", idx, s"$idx = 0;")
359-
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
360-
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
361-
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
362-
s"$name = $batch.column($i);"
363-
}
364-
365-
val nextBatch = ctx.freshName("nextBatch")
366-
ctx.addNewFunction(nextBatch,
367-
s"""
368-
|private void $nextBatch() throws java.io.IOException {
369-
| long getBatchStart = System.nanoTime();
370-
| if ($input.hasNext()) {
371-
| $batch = ($columnarBatchClz)$input.next();
372-
| $numOutputRows.add($batch.numRows());
373-
| $idx = 0;
374-
| ${columnAssigns.mkString("", "\n", "\n")}
375-
| }
376-
| $scanTimeTotalNs += System.nanoTime() - getBatchStart;
377-
|}""".stripMargin)
378-
379-
ctx.currentVars = null
380-
val rowidx = ctx.freshName("rowIdx")
381-
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
382-
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
383-
}
384-
s"""
385-
|if ($batch == null) {
386-
| $nextBatch();
387-
|}
388-
|while ($batch != null) {
389-
| int numRows = $batch.numRows();
390-
| while ($idx < numRows) {
391-
| int $rowidx = $idx++;
392-
| ${consume(ctx, columnsBatchInput).trim}
393-
| if (shouldStop()) return;
394-
| }
395-
| $batch = null;
396-
| $nextBatch();
397-
|}
398-
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
399-
|$scanTimeTotalNs = 0;
400-
""".stripMargin
401-
}
402-
403-
private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
404-
dataType: DataType, nullable: Boolean): ExprCode = {
405-
val javaType = ctx.javaType(dataType)
406-
val value = ctx.getValue(columnVar, dataType, ordinal)
407-
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
408-
val valueVar = ctx.freshName("value")
409-
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
410-
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
411-
s"""
412-
boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
413-
$javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
414-
"""
415-
} else {
416-
s"$javaType ${valueVar} = $value;"
417-
}).trim
418-
ExprCode(code, isNullVar, valueVar)
419-
}
420-
421339
/**
422340
* Create an RDD for bucketed reads.
423341
* The non-bucketed variant of this function is [[createNonBucketedReadRDD]].

0 commit comments

Comments
 (0)