Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException

import org.apache.spark.TaskContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
* Catches compile error during code generation.
*/
object CodegenError {
def unapply(throwable: Throwable): Option[Exception] = throwable match {
case e: InternalCompilerException => Some(e)
case e: CompileException => Some(e)
case _ => None
}
}

/**
* Defines values for `SQLConf` config of fallback mode. Use for test only.
*/
object CodegenObjectFactoryMode extends Enumeration {
val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value
}

/**
* A codegen object generator which creates objects with codegen path first. Once any compile
* error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config
* `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior.
*/
abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] {

def createObject(in: IN): OUT = {
// We are allowed to choose codegen-only or no-codegen modes if under tests.
val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
val fallbackMode = CodegenObjectFactoryMode.withName(config)

fallbackMode match {
case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
createCodeGeneratedObject(in)
case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
createInterpretedObject(in)
case _ =>
try {
createCodeGeneratedObject(in)
} catch {
case CodegenError(_) => createInterpretedObject(in)
}
}
}

protected def createCodeGeneratedObject(in: IN): OUT
protected def createInterpretedObject(in: IN): OUT
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
/**
* Helper functions for creating an [[InterpretedUnsafeProjection]].
*/
object InterpretedUnsafeProjection extends UnsafeProjectionCreator {

object InterpretedUnsafeProjection {
/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
// We need to make sure that we do not reuse stateful expressions.
val cleanedExpressions = exprs.map(_.transform {
case s: Stateful => s.freshCopy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow
}

trait UnsafeProjectionCreator {
/**
* The factory object for `UnsafeProjection`.
*/
object UnsafeProjection
extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {

override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(in)
}

override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
InterpretedUnsafeProjection.createProjection(in)
}

protected def toBoundExprs(
exprs: Seq[Expression],
inputSchema: Seq[Attribute]): Seq[Expression] = {
exprs.map(BindReferences.bindReference(_, inputSchema))
}

protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
}

/**
* Returns an UnsafeProjection for given StructType.
*
Expand All @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator {
* Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
val unsafeExprs = exprs.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
createProjection(unsafeExprs)
createObject(toUnsafeExprs(exprs))
}

def create(expr: Expression): UnsafeProjection = create(Seq(expr))
Expand All @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator {
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
create(exprs.map(BindReferences.bindReference(_, inputSchema)))
}

/**
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
protected def createProjection(exprs: Seq[Expression]): UnsafeProjection
}

object UnsafeProjection extends UnsafeProjectionCreator {

override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
GenerateUnsafeProjection.generate(exprs)
create(toBoundExprs(exprs, inputSchema))
}

/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
* TODO: refactor the plumbing and clean this up.
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
* when fallbacking to interpreted execution, it is not supported.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
val e = exprs.map(BindReferences.bindReference(_, inputSchema))
.map(_ transform {
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
try {
GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
} catch {
case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -686,6 +687,17 @@ object SQLConf {
.intConf
.createWithDefault(100)

val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode")
.doc("This config determines the fallback behavior of several codegen generators " +
"during tests. `FALLBACK` means trying codegen first and then fallbacking to " +
"interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " +
"`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " +
"this config works only for tests.")
.internal()
.stringConf
.checkValues(CodegenObjectFactoryMode.values.map(_.toString))
.createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString)

val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
.doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, LongType}

class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {

test("UnsafeProjection with codegen factory mode") {
val input = Seq(LongType, IntegerType)
.zipWithIndex.map(x => BoundReference(x._2, x._1, true))

val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
val obj = UnsafeProjection.createObject(input)
assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
}

val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
val obj = UnsafeProjection.createObject(input)
assert(obj.isInstanceOf[InterpretedUnsafeProjection])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils
/**
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
*/
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase {
self: SparkFunSuite =>

protected def create_row(values: Any*): InternalRow = {
Expand Down Expand Up @@ -202,39 +203,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
}

protected def checkEvaluationWithUnsafeProjection(
expression: Expression,
expected: Any,
inputRow: InternalRow,
factory: UnsafeProjectionCreator): Unit = {
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
} else {
val lit = InternalRow(expected, expected)
val expectedRow =
factory.create(Array(expression.dataType, expression.dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
for (fallbackMode <- modes) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

if (expected == null) {
if (!unsafeRow.isNullAt(0)) {
val expectedRow = InternalRow(expected, expected)
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
} else {
val lit = InternalRow(expected, expected)
val expectedRow =
UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
if (unsafeRow != expectedRow) {
fail("Incorrect evaluation in unsafe mode: " +
s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
}
}
}
}
}

protected def evaluateWithUnsafeProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow,
factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
inputRow: InternalRow = EmptyRow): InternalRow = {
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val structExpected = new GenericArrayData(
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
checkEvaluationWithUnsafeProjection(
structEncoder.serializer.head,
structExpected,
structInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
structEncoder.serializer.head, structExpected, structInputRow)

// test UnsafeArray-backed data
val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
val arrayExpected = new GenericArrayData(
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
checkEvaluationWithUnsafeProjection(
arrayEncoder.serializer.head,
arrayExpected,
arrayInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

// test UnsafeMap-backed data
val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
Expand All @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new GenericArrayData(Array(3, 4)),
new GenericArrayData(Array(300, 400)))))
checkEvaluationWithUnsafeProjection(
mapEncoder.serializer.head,
mapExpected,
mapInputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
mapEncoder.serializer.head, mapExpected, mapInputRow)
}

test("SPARK-23582: StaticInvoke should support interpreted execution") {
Expand Down Expand Up @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluationWithUnsafeProjection(
expr,
expected,
inputRow,
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
inputRow)
}
checkEvaluationWithOptimization(expr, expected, inputRow)
}
Expand Down
Loading