Skip to content

Commit bc76a0f

Browse files
Davies Liurxin
authored andcommitted
[SPARK-7184] [SQL] enable codegen by default
In order to have better performance out of box, this PR turn on codegen by default, then codegen can be tested by sql/test and hive/test. This PR also fix some corner cases for codegen. Before 1.5 release, we should re-visit this, turn it off if it's not stable or causing regressions. cc rxin JoshRosen Author: Davies Liu <[email protected]> Closes apache#6726 from davies/enable_codegen and squashes the following commits: f3b25a5 [Davies Liu] fix warning 73750ea [Davies Liu] fix long overflow when compare 3017a47 [Davies Liu] Merge branch 'master' of github.com:apache/spark into enable_codegen a7d75da [Davies Liu] Merge branch 'master' of github.com:apache/spark into enable_codegen ff5b75a [Davies Liu] Merge branch 'master' of github.com:apache/spark into enable_codegen f4cf2c2 [Davies Liu] fix style 99fc139 [Davies Liu] Merge branch 'enable_codegen' of github.com:davies/spark into enable_codegen 91fc7a2 [Davies Liu] disable codegen for ScalaUDF 207e339 [Davies Liu] Update CodeGenerator.scala 44573a3 [Davies Liu] check thread safety of expression f3886fa [Davies Liu] don't inline primitiveTerm for null literal c8e7cd2 [Davies Liu] address comment a8618c9 [Davies Liu] enable codegen by default
1 parent 1a62d61 commit bc76a0f

File tree

21 files changed

+95
-81
lines changed

21 files changed

+95
-81
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.catalyst.errors.attachTree
22-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
22+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
23+
import org.apache.spark.sql.catalyst.trees
2324
import org.apache.spark.sql.types._
24-
import org.apache.spark.sql.catalyst.{InternalRow, trees}
2525

2626
/**
2727
* A bound reference points to a specific slot in the input tuple, allowing the actual value

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ abstract class Expression extends TreeNode[Expression] {
6060
/** Returns the result of evaluating this expression on a given input Row */
6161
def eval(input: InternalRow = null): Any
6262

63+
/**
64+
* Return true if this expression is thread-safe, which means it could be used by multiple
65+
* threads in the same time.
66+
*
67+
* An expression that is not thread-safe can not be cached and re-used, especially for codegen.
68+
*/
69+
def isThreadSafe: Boolean = true
70+
6371
/**
6472
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
6573
* can be used to generate the result of evaluating the expression on an input row.
@@ -68,6 +76,9 @@ abstract class Expression extends TreeNode[Expression] {
6876
* @return [[GeneratedExpressionCode]]
6977
*/
7078
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
79+
if (!isThreadSafe) {
80+
throw new Exception(s"$this is not thread-safe, can not be used in codegen")
81+
}
7182
val isNull = ctx.freshName("isNull")
7283
val primitive = ctx.freshName("primitive")
7384
val ve = GeneratedExpressionCode("", isNull, primitive)
@@ -169,6 +180,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
169180

170181
override def toString: String = s"($left $symbol $right)"
171182

183+
override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
172184
/**
173185
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
174186
* the same type. If either of the sub-expressions is null, the result of this computation
@@ -218,6 +230,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
218230

219231
override def foldable: Boolean = child.foldable
220232
override def nullable: Boolean = child.nullable
233+
override def isThreadSafe: Boolean = child.isThreadSafe
221234

222235
/**
223236
* Called by unary expressions to generate a code block that returns null if its parent returns

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,4 +958,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
958958
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
959959
override def eval(input: InternalRow): Any = converter(f(input))
960960

961+
// TODO(davies): make ScalaUdf work with codegen
962+
override def isThreadSafe: Boolean = false
961963
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.errors.TreeNodeException
21-
import org.apache.spark.sql.catalyst.{InternalRow, trees}
21+
import org.apache.spark.sql.catalyst.trees
2222
import org.apache.spark.sql.types.DataType
2323

2424
abstract sealed class SortDirection

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -341,31 +341,29 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
341341
}
342342

343343
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
344-
if (ctx.isNativeType(left.dataType)) {
345-
val eval1 = left.gen(ctx)
346-
val eval2 = right.gen(ctx)
347-
eval1.code + eval2.code + s"""
348-
boolean ${ev.isNull} = false;
349-
${ctx.javaType(left.dataType)} ${ev.primitive} =
350-
${ctx.defaultValue(left.dataType)};
351-
352-
if (${eval1.isNull}) {
353-
${ev.isNull} = ${eval2.isNull};
354-
${ev.primitive} = ${eval2.primitive};
355-
} else if (${eval2.isNull}) {
356-
${ev.isNull} = ${eval1.isNull};
344+
val eval1 = left.gen(ctx)
345+
val eval2 = right.gen(ctx)
346+
val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive)
347+
348+
eval1.code + eval2.code + s"""
349+
boolean ${ev.isNull} = false;
350+
${ctx.javaType(left.dataType)} ${ev.primitive} =
351+
${ctx.defaultValue(left.dataType)};
352+
353+
if (${eval1.isNull}) {
354+
${ev.isNull} = ${eval2.isNull};
355+
${ev.primitive} = ${eval2.primitive};
356+
} else if (${eval2.isNull}) {
357+
${ev.isNull} = ${eval1.isNull};
358+
${ev.primitive} = ${eval1.primitive};
359+
} else {
360+
if ($compCode > 0) {
357361
${ev.primitive} = ${eval1.primitive};
358362
} else {
359-
if (${eval1.primitive} > ${eval2.primitive}) {
360-
${ev.primitive} = ${eval1.primitive};
361-
} else {
362-
${ev.primitive} = ${eval2.primitive};
363-
}
363+
${ev.primitive} = ${eval2.primitive};
364364
}
365-
"""
366-
} else {
367-
super.genCode(ctx, ev)
368-
}
365+
}
366+
"""
369367
}
370368
override def toString: String = s"MaxOf($left, $right)"
371369
}
@@ -395,33 +393,29 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
395393
}
396394

397395
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
398-
if (ctx.isNativeType(left.dataType)) {
399-
400-
val eval1 = left.gen(ctx)
401-
val eval2 = right.gen(ctx)
402-
403-
eval1.code + eval2.code + s"""
404-
boolean ${ev.isNull} = false;
405-
${ctx.javaType(left.dataType)} ${ev.primitive} =
406-
${ctx.defaultValue(left.dataType)};
396+
val eval1 = left.gen(ctx)
397+
val eval2 = right.gen(ctx)
398+
val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive)
407399

408-
if (${eval1.isNull}) {
409-
${ev.isNull} = ${eval2.isNull};
410-
${ev.primitive} = ${eval2.primitive};
411-
} else if (${eval2.isNull}) {
412-
${ev.isNull} = ${eval1.isNull};
400+
eval1.code + eval2.code + s"""
401+
boolean ${ev.isNull} = false;
402+
${ctx.javaType(left.dataType)} ${ev.primitive} =
403+
${ctx.defaultValue(left.dataType)};
404+
405+
if (${eval1.isNull}) {
406+
${ev.isNull} = ${eval2.isNull};
407+
${ev.primitive} = ${eval2.primitive};
408+
} else if (${eval2.isNull}) {
409+
${ev.isNull} = ${eval1.isNull};
410+
${ev.primitive} = ${eval1.primitive};
411+
} else {
412+
if ($compCode < 0) {
413413
${ev.primitive} = ${eval1.primitive};
414414
} else {
415-
if (${eval1.primitive} < ${eval2.primitive}) {
416-
${ev.primitive} = ${eval1.primitive};
417-
} else {
418-
${ev.primitive} = ${eval2.primitive};
419-
}
415+
${ev.primitive} = ${eval2.primitive};
420416
}
421-
"""
422-
} else {
423-
super.genCode(ctx, ev)
424-
}
417+
}
418+
"""
425419
}
426420

427421
override def toString: String = s"MinOf($left, $right)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
2424
import org.codehaus.janino.ClassBodyEvaluator
2525

2626
import org.apache.spark.Logging
27-
import org.apache.spark.sql.catalyst
2827
import org.apache.spark.sql.catalyst.expressions._
2928
import org.apache.spark.sql.types._
3029
import org.apache.spark.unsafe.types.UTF8String
@@ -176,9 +175,8 @@ class CodeGenContext {
176175
* Generate code for compare expression in Java
177176
*/
178177
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
179-
// Use signum() to keep any small difference bwteen float/double
180-
case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
181-
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
178+
// use c1 - c2 may overflow
179+
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
182180
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
183181
case other => s"$c1.compare($c2)"
184182
}
@@ -266,7 +264,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
266264
* weak keys/values and thus does not respond to memory pressure.
267265
*/
268266
protected val cache = CacheBuilder.newBuilder()
269-
.maximumSize(1000)
267+
.maximumSize(100)
270268
.build(
271269
new CacheLoader[InType, OutType]() {
272270
override def load(in: InType): OutType = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ case class Alias(child: Expression, name: String)(
117117

118118
override def eval(input: InternalRow): Any = child.eval(input)
119119

120+
override def isThreadSafe: Boolean = child.isThreadSafe
121+
120122
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
121123

122124
override def dataType: DataType = child.dataType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
22-
import org.apache.spark.sql.catalyst.trees
2322
import org.apache.spark.sql.types.DataType
2423

2524
case class Coalesce(children: Seq[Expression]) extends Expression {
@@ -53,6 +52,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
5352
result
5453
}
5554

55+
override def isThreadSafe: Boolean = children.forall(_.isThreadSafe)
56+
5657
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
5758
s"""
5859
boolean ${ev.isNull} = true;
@@ -73,7 +74,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7374
}
7475
}
7576

76-
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
77+
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
7778
override def foldable: Boolean = child.foldable
7879
override def nullable: Boolean = false
7980

@@ -91,7 +92,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
9192
override def toString: String = s"IS NULL $child"
9293
}
9394

94-
case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
95+
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
9596
override def foldable: Boolean = child.foldable
9697
override def nullable: Boolean = false
9798
override def toString: String = s"IS NOT NULL $child"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2121
import org.apache.spark.sql.catalyst.errors.TreeNodeException
22-
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.types.{NumericType, DataType}
22+
import org.apache.spark.sql.types.{DataType, NumericType}
2423

2524
/**
2625
* The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.Attribute
21-
import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters, analysis}
21+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
2222
import org.apache.spark.sql.types.{StructField, StructType}
2323

2424
object LocalRelation {

0 commit comments

Comments
 (0)