Skip to content

Commit a40ffc6

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-23711][SQL] Add fallback generator for UnsafeProjection
## What changes were proposed in this pull request? Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <[email protected]> Closes #21106 from viirya/SPARK-23711.
1 parent 00c13cf commit a40ffc6

File tree

9 files changed

+230
-88
lines changed

9 files changed

+230
-88
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.codehaus.commons.compiler.CompileException
21+
import org.codehaus.janino.InternalCompilerException
22+
23+
import org.apache.spark.TaskContext
24+
import org.apache.spark.sql.internal.SQLConf
25+
import org.apache.spark.util.Utils
26+
27+
/**
28+
* Catches compile error during code generation.
29+
*/
30+
object CodegenError {
31+
def unapply(throwable: Throwable): Option[Exception] = throwable match {
32+
case e: InternalCompilerException => Some(e)
33+
case e: CompileException => Some(e)
34+
case _ => None
35+
}
36+
}
37+
38+
/**
39+
* Defines values for `SQLConf` config of fallback mode. Use for test only.
40+
*/
41+
object CodegenObjectFactoryMode extends Enumeration {
42+
val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value
43+
}
44+
45+
/**
46+
* A codegen object generator which creates objects with codegen path first. Once any compile
47+
* error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config
48+
* `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior.
49+
*/
50+
abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] {
51+
52+
def createObject(in: IN): OUT = {
53+
// We are allowed to choose codegen-only or no-codegen modes if under tests.
54+
val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
55+
val fallbackMode = CodegenObjectFactoryMode.withName(config)
56+
57+
fallbackMode match {
58+
case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
59+
createCodeGeneratedObject(in)
60+
case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
61+
createInterpretedObject(in)
62+
case _ =>
63+
try {
64+
createCodeGeneratedObject(in)
65+
} catch {
66+
case CodegenError(_) => createInterpretedObject(in)
67+
}
68+
}
69+
}
70+
71+
protected def createCodeGeneratedObject(in: IN): OUT
72+
protected def createInterpretedObject(in: IN): OUT
73+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
8787
/**
8888
* Helper functions for creating an [[InterpretedUnsafeProjection]].
8989
*/
90-
object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
91-
90+
object InterpretedUnsafeProjection {
9291
/**
9392
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
9493
*/
95-
override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
94+
def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
9695
// We need to make sure that we do not reuse stateful expressions.
9796
val cleanedExpressions = exprs.map(_.transform {
9897
case s: Stateful => s.freshCopy()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection {
108108
override def apply(row: InternalRow): UnsafeRow
109109
}
110110

111-
trait UnsafeProjectionCreator {
111+
/**
112+
* The factory object for `UnsafeProjection`.
113+
*/
114+
object UnsafeProjection
115+
extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {
116+
117+
override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
118+
GenerateUnsafeProjection.generate(in)
119+
}
120+
121+
override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
122+
InterpretedUnsafeProjection.createProjection(in)
123+
}
124+
125+
protected def toBoundExprs(
126+
exprs: Seq[Expression],
127+
inputSchema: Seq[Attribute]): Seq[Expression] = {
128+
exprs.map(BindReferences.bindReference(_, inputSchema))
129+
}
130+
131+
protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
132+
exprs.map(_ transform {
133+
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
134+
})
135+
}
136+
112137
/**
113138
* Returns an UnsafeProjection for given StructType.
114139
*
@@ -129,10 +154,7 @@ trait UnsafeProjectionCreator {
129154
* Returns an UnsafeProjection for given sequence of bound Expressions.
130155
*/
131156
def create(exprs: Seq[Expression]): UnsafeProjection = {
132-
val unsafeExprs = exprs.map(_ transform {
133-
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
134-
})
135-
createProjection(unsafeExprs)
157+
createObject(toUnsafeExprs(exprs))
136158
}
137159

138160
def create(expr: Expression): UnsafeProjection = create(Seq(expr))
@@ -142,34 +164,24 @@ trait UnsafeProjectionCreator {
142164
* `inputSchema`.
143165
*/
144166
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
145-
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
146-
}
147-
148-
/**
149-
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
150-
*/
151-
protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
152-
}
153-
154-
object UnsafeProjection extends UnsafeProjectionCreator {
155-
156-
override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
157-
GenerateUnsafeProjection.generate(exprs)
167+
create(toBoundExprs(exprs, inputSchema))
158168
}
159169

160170
/**
161171
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
162-
* TODO: refactor the plumbing and clean this up.
172+
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
173+
* when fallbacking to interpreted execution, it is not supported.
163174
*/
164175
def create(
165176
exprs: Seq[Expression],
166177
inputSchema: Seq[Attribute],
167178
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
168-
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
169-
.map(_ transform {
170-
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
171-
})
172-
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
179+
val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
180+
try {
181+
GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
182+
} catch {
183+
case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs)
184+
}
173185
}
174186
}
175187

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.config._
3333
import org.apache.spark.network.util.ByteUnit
3434
import org.apache.spark.sql.catalyst.analysis.Resolver
35+
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
3536
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
3637
import org.apache.spark.util.Utils
3738

@@ -703,6 +704,17 @@ object SQLConf {
703704
.intConf
704705
.createWithDefault(100)
705706

707+
val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode")
708+
.doc("This config determines the fallback behavior of several codegen generators " +
709+
"during tests. `FALLBACK` means trying codegen first and then fallbacking to " +
710+
"interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " +
711+
"`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " +
712+
"this config works only for tests.")
713+
.internal()
714+
.stringConf
715+
.checkValues(CodegenObjectFactoryMode.values.map(_.toString))
716+
.createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString)
717+
706718
val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
707719
.internal()
708720
.doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.plans.PlanTestBase
22+
import org.apache.spark.sql.internal.SQLConf
23+
import org.apache.spark.sql.types.{IntegerType, LongType}
24+
25+
class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {
26+
27+
test("UnsafeProjection with codegen factory mode") {
28+
val input = Seq(LongType, IntegerType)
29+
.zipWithIndex.map(x => BoundReference(x._2, x._1, true))
30+
31+
val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
32+
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
33+
val obj = UnsafeProjection.createObject(input)
34+
assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
35+
}
36+
37+
val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
38+
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
39+
val obj = UnsafeProjection.createObject(input)
40+
assert(obj.isInstanceOf[InterpretedUnsafeProjection])
41+
}
42+
}
43+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3131
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
3232
import org.apache.spark.sql.catalyst.expressions.codegen._
3333
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
34+
import org.apache.spark.sql.catalyst.plans.PlanTestBase
3435
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
3536
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
3637
import org.apache.spark.sql.internal.SQLConf
@@ -40,7 +41,7 @@ import org.apache.spark.util.Utils
4041
/**
4142
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
4243
*/
43-
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
44+
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase {
4445
self: SparkFunSuite =>
4546

4647
protected def create_row(values: Any*): InternalRow = {
@@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
205206
expression: Expression,
206207
expected: Any,
207208
inputRow: InternalRow = EmptyRow): Unit = {
208-
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
209-
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
210-
}
211-
212-
protected def checkEvaluationWithUnsafeProjection(
213-
expression: Expression,
214-
expected: Any,
215-
inputRow: InternalRow,
216-
factory: UnsafeProjectionCreator): Unit = {
217-
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
218-
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
219-
220-
if (expected == null) {
221-
if (!unsafeRow.isNullAt(0)) {
222-
val expectedRow = InternalRow(expected, expected)
223-
fail("Incorrect evaluation in unsafe mode: " +
224-
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
225-
}
226-
} else {
227-
val lit = InternalRow(expected, expected)
228-
val expectedRow =
229-
factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
230-
if (unsafeRow != expectedRow) {
231-
fail("Incorrect evaluation in unsafe mode: " +
232-
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
209+
val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
210+
for (fallbackMode <- modes) {
211+
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
212+
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
213+
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
214+
215+
if (expected == null) {
216+
if (!unsafeRow.isNullAt(0)) {
217+
val expectedRow = InternalRow(expected, expected)
218+
fail("Incorrect evaluation in unsafe mode: " +
219+
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
220+
}
221+
} else {
222+
val lit = InternalRow(expected, expected)
223+
val expectedRow =
224+
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
225+
if (unsafeRow != expectedRow) {
226+
fail("Incorrect evaluation in unsafe mode: " +
227+
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
228+
}
229+
}
233230
}
234231
}
235232
}
236233

237234
protected def evaluateWithUnsafeProjection(
238235
expression: Expression,
239-
inputRow: InternalRow = EmptyRow,
240-
factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
236+
inputRow: InternalRow = EmptyRow): InternalRow = {
241237
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
242238
// some expression is reusing variable names across different instances.
243239
// This behavior is tested in ExpressionEvalHelperSuite.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
8181
val structExpected = new GenericArrayData(
8282
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
8383
checkEvaluationWithUnsafeProjection(
84-
structEncoder.serializer.head,
85-
structExpected,
86-
structInputRow,
87-
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
84+
structEncoder.serializer.head, structExpected, structInputRow)
8885

8986
// test UnsafeArray-backed data
9087
val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
9188
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
9289
val arrayExpected = new GenericArrayData(
9390
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
9491
checkEvaluationWithUnsafeProjection(
95-
arrayEncoder.serializer.head,
96-
arrayExpected,
97-
arrayInputRow,
98-
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
92+
arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
9993

10094
// test UnsafeMap-backed data
10195
val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
@@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
109103
new GenericArrayData(Array(3, 4)),
110104
new GenericArrayData(Array(300, 400)))))
111105
checkEvaluationWithUnsafeProjection(
112-
mapEncoder.serializer.head,
113-
mapExpected,
114-
mapInputRow,
115-
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
106+
mapEncoder.serializer.head, mapExpected, mapInputRow)
116107
}
117108

118109
test("SPARK-23582: StaticInvoke should support interpreted execution") {
@@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
286277
checkEvaluationWithUnsafeProjection(
287278
expr,
288279
expected,
289-
inputRow,
290-
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
280+
inputRow)
291281
}
292282
checkEvaluationWithOptimization(expr, expected, inputRow)
293283
}

0 commit comments

Comments
 (0)