Skip to content

Commit e1e0587

Browse files
viiryadavies
authored andcommitted
[SPARK-9403] [SQL] Add codegen support in In and InSet
This continues tarekauel's work in apache#7778. Author: Liang-Chi Hsieh <[email protected]> Author: Tarek Auel <[email protected]> Closes apache#7893 from viirya/codegen_in and squashes the following commits: 81ff97b [Liang-Chi Hsieh] For comments. 47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in cf4bf41 [Liang-Chi Hsieh] For comments. f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 446bbcd [Liang-Chi Hsieh] Fix bug. b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test. 224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types. 86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold. b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet
1 parent 1f8c364 commit e1e0587

File tree

6 files changed

+119
-10
lines changed

6 files changed

+119
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2023
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
2124
import org.apache.spark.sql.catalyst.InternalRow
2225
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -97,32 +100,80 @@ case class Not(child: Expression)
97100
/**
98101
* Evaluates to `true` if `list` contains `value`.
99102
*/
100-
case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
103+
case class In(value: Expression, list: Seq[Expression]) extends Predicate
104+
with ImplicitCastInputTypes {
105+
106+
override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
107+
108+
override def checkInputDataTypes(): TypeCheckResult = {
109+
if (list.exists(l => l.dataType != value.dataType)) {
110+
TypeCheckResult.TypeCheckFailure(
111+
"Arguments must be same type")
112+
} else {
113+
TypeCheckResult.TypeCheckSuccess
114+
}
115+
}
116+
101117
override def children: Seq[Expression] = value +: list
102118

103-
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
119+
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
104120
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
105121

106122
override def eval(input: InternalRow): Any = {
107123
val evaluatedValue = value.eval(input)
108124
list.exists(e => e.eval(input) == evaluatedValue)
109125
}
110-
}
111126

127+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
128+
val valueGen = value.gen(ctx)
129+
val listGen = list.map(_.gen(ctx))
130+
val listCode = listGen.map(x =>
131+
s"""
132+
if (!${ev.primitive}) {
133+
${x.code}
134+
if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
135+
${ev.primitive} = true;
136+
}
137+
}
138+
""").mkString("\n")
139+
s"""
140+
${valueGen.code}
141+
boolean ${ev.primitive} = false;
142+
boolean ${ev.isNull} = false;
143+
$listCode
144+
"""
145+
}
146+
}
112147

113148
/**
114149
* Optimized version of In clause, when all filter values of In clause are
115150
* static.
116151
*/
117-
case class InSet(child: Expression, hset: Set[Any])
118-
extends UnaryExpression with Predicate with CodegenFallback {
152+
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
119153

120-
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
154+
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
121155
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
122156

123157
override def eval(input: InternalRow): Any = {
124158
hset.contains(child.eval(input))
125159
}
160+
161+
def getHSet(): Set[Any] = hset
162+
163+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
164+
val setName = classOf[Set[Any]].getName
165+
val InSetName = classOf[InSet].getName
166+
val childGen = child.gen(ctx)
167+
ctx.references += this
168+
val hsetTerm = ctx.freshName("hset")
169+
ctx.addMutableState(setName, hsetTerm,
170+
s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
171+
s"""
172+
${childGen.code}
173+
boolean ${ev.isNull} = false;
174+
boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
175+
"""
176+
}
126177
}
127178

128179
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
393393
object OptimizeIn extends Rule[LogicalPlan] {
394394
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
395395
case q: LogicalPlan => q transformExpressionsDown {
396-
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
396+
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
397397
val hSet = list.map(e => e.eval(EmptyRow))
398398
InSet(v, HashSet() ++ hSet)
399399
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
24-
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
24+
import org.apache.spark.sql.RandomDataGenerator
25+
import org.apache.spark.sql.types._
2526

2627

2728
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
118119
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
119120
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
120121
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
122+
123+
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
124+
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
125+
primitiveTypes.map { t =>
126+
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
127+
val inputData = Seq.fill(10) {
128+
val value = dataGen.apply()
129+
value match {
130+
case d: Double if d.isNaN => 0.0d
131+
case f: Float if f.isNaN => 0.0f
132+
case _ => value
133+
}
134+
}
135+
val input = inputData.map(Literal(_))
136+
checkEvaluation(In(input(0), input.slice(1, 10)),
137+
inputData.slice(1, 10).contains(inputData(0)))
138+
}
121139
}
122140

123141
test("INSET") {
@@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
134152
checkEvaluation(InSet(three, hS), false)
135153
checkEvaluation(InSet(three, nS), false)
136154
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
155+
156+
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
157+
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
158+
primitiveTypes.map { t =>
159+
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
160+
val inputData = Seq.fill(10) {
161+
val value = dataGen.apply()
162+
value match {
163+
case d: Double if d.isNaN => 0.0d
164+
case f: Float if f.isNaN => 0.0f
165+
case _ => value
166+
}
167+
}
168+
val input = inputData.map(Literal(_))
169+
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
170+
inputData.slice(1, 10).contains(inputData(0)))
171+
}
137172
}
138173

139174
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
4343

4444
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
4545

46-
test("OptimizedIn test: In clause optimized to InSet") {
46+
test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
4747
val originalQuery =
4848
testRelation
4949
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
5050
.analyze
5151

52+
val optimized = Optimize.execute(originalQuery.analyze)
53+
comparePlans(optimized, originalQuery)
54+
}
55+
56+
test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
57+
val originalQuery =
58+
testRelation
59+
.where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
60+
.analyze
61+
5262
val optimized = Optimize.execute(originalQuery.analyze)
5363
val correctAnswer =
5464
testRelation
55-
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
65+
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
5666
.analyze
5767

5868
comparePlans(optimized, correctAnswer)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
366366
case expressions.InSet(a: Attribute, set) =>
367367
Some(sources.In(a.name, set.toArray))
368368

369+
// Because we only convert In to InSet in Optimizer when there are more than certain
370+
// items. So it is possible we still get an In expression here that needs to be pushed
371+
// down.
372+
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
373+
val hSet = list.map(e => e.eval(EmptyRow))
374+
Some(sources.In(a.name, hSet.toArray))
375+
369376
case expressions.IsNull(a: Attribute) =>
370377
Some(sources.IsNull(a.name))
371378
case expressions.IsNotNull(a: Attribute) =>

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
357357
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
358358
checkAnswer(df.filter($"b".in("z", "y")),
359359
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
360+
361+
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
362+
363+
intercept[AnalysisException] {
364+
df2.filter($"a".in($"b"))
365+
}
360366
}
361367

362368
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(

0 commit comments

Comments
 (0)