Skip to content

Commit c09b50f

Browse files
committed
[SPARK-9448][SQL] GenerateUnsafeProjection should not share expressions across instances.
1 parent b715933 commit c09b50f

File tree

2 files changed

+96
-6
lines changed

2 files changed

+96
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,18 +256,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
256256
eval.code = createCode(ctx, eval, expressions)
257257

258258
val code = s"""
259-
private $exprType[] expressions;
260-
261-
public Object generate($exprType[] expr) {
262-
this.expressions = expr;
263-
return new SpecificProjection();
259+
public Object generate($exprType[] exprs) {
260+
return new SpecificProjection(exprs);
264261
}
265262

266263
class SpecificProjection extends ${classOf[UnsafeProjection].getName} {
267264

265+
private $exprType[] expressions;
266+
268267
${declareMutableStates(ctx)}
269268

270-
public SpecificProjection() {
269+
public SpecificProjection($exprType[] expressions) {
270+
this.expressions = expressions;
271271
${initMutableStates(ctx)}
272272
}
273273

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression}
23+
import org.apache.spark.sql.types.{BooleanType, DataType}
24+
25+
/**
26+
* A test suite that makes sure code generation handles expression internally states correctly.
27+
*/
28+
class CodegenExpressionCachingSuite extends SparkFunSuite {
29+
30+
test("GenerateUnsafeProjection") {
31+
val expr1 = MutableExpression()
32+
val instance1 = UnsafeProjection.create(Seq(expr1))
33+
assert(instance1.apply(null).getBoolean(0) === false)
34+
35+
val expr2 = MutableExpression()
36+
expr2.mutableState = true
37+
val instance2 = UnsafeProjection.create(Seq(expr2))
38+
assert(instance1.apply(null).getBoolean(0) === false)
39+
assert(instance2.apply(null).getBoolean(0) === true)
40+
}
41+
42+
test("GenerateProjection") {
43+
val expr1 = MutableExpression()
44+
val instance1 = GenerateProjection.generate(Seq(expr1))
45+
assert(instance1.apply(null).getBoolean(0) === false)
46+
47+
val expr2 = MutableExpression()
48+
expr2.mutableState = true
49+
val instance2 = GenerateProjection.generate(Seq(expr2))
50+
assert(instance1.apply(null).getBoolean(0) === false)
51+
assert(instance2.apply(null).getBoolean(0) === true)
52+
}
53+
54+
test("GenerateMutableProjection") {
55+
val expr1 = MutableExpression()
56+
val instance1 = GenerateMutableProjection.generate(Seq(expr1))()
57+
assert(instance1.apply(null).getBoolean(0) === false)
58+
59+
val expr2 = MutableExpression()
60+
expr2.mutableState = true
61+
val instance2 = GenerateMutableProjection.generate(Seq(expr2))()
62+
assert(instance1.apply(null).getBoolean(0) === false)
63+
assert(instance2.apply(null).getBoolean(0) === true)
64+
}
65+
66+
test("GeneratePredicate") {
67+
val expr1 = MutableExpression()
68+
val instance1 = GeneratePredicate.generate(expr1)
69+
assert(instance1.apply(null) === false)
70+
71+
val expr2 = MutableExpression()
72+
expr2.mutableState = true
73+
val instance2 = GeneratePredicate.generate(expr2)
74+
assert(instance1.apply(null) === false)
75+
assert(instance2.apply(null) === true)
76+
}
77+
78+
}
79+
80+
81+
/**
82+
* An expression with mutable state so we can change it freely in our test suite.
83+
*/
84+
case class MutableExpression() extends LeafExpression with CodegenFallback {
85+
var mutableState: Boolean = false
86+
override def eval(input: InternalRow): Any = mutableState
87+
88+
override def nullable: Boolean = false
89+
override def dataType: DataType = BooleanType
90+
}

0 commit comments

Comments
 (0)