Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
34abc22
Support wholestage codegen for reducing expression codes to prevent 6…
viirya Nov 24, 2017
e0d111e
Merge remote-tracking branch 'upstream/master' into reduce-expr-code-…
viirya Nov 25, 2017
65d07d5
Assert the added test is under wholestage codegen.
viirya Nov 25, 2017
9f848be
Put input rows and evaluated columns referred by deferred expressions…
viirya Nov 27, 2017
57b1add
Revert unnecessary changes.
viirya Nov 27, 2017
d051f9e
Fix subexpression isNull for non nullable case. Fix columnar batch sc…
viirya Nov 28, 2017
6368702
Let rowidx as global variable instead of early evaluation of column o…
viirya Nov 28, 2017
8c7f749
Fix the problematic case.
viirya Nov 28, 2017
7f00515
Fix duplicate parameters.
viirya Nov 29, 2017
777eb7a
Address comments.
viirya Nov 30, 2017
7230997
Polish the patch.
viirya Nov 30, 2017
fd87e9b
Add test for new APIs.
viirya Nov 30, 2017
57a9fb7
Generate function parameters if needed.
viirya Nov 30, 2017
0d358d6
Address comments.
viirya Dec 1, 2017
aa3db2e
Address comments.
viirya Dec 1, 2017
429afba
Rename variable.
viirya Dec 4, 2017
48add65
Address comments.
viirya Dec 5, 2017
9443011
Address comments.
viirya Dec 8, 2017
2f4014f
Address comments again.
viirya Dec 11, 2017
655917c
Remove redundant optimization.
viirya Dec 12, 2017
c083a79
Use utility method.
viirya Dec 12, 2017
1251dfa
Address comments.
viirya Dec 12, 2017
c4f15f7
Move isLiteral and isEvaluated into ExpressionCodegen.
viirya Dec 12, 2017
f35974e
Remove useless isLiteral and isEvaluted. Add one more test.
viirya Dec 12, 2017
e413043
Merge remote-tracking branch 'upstream/master' into reduce-expr-code-…
viirya Apr 24, 2018
ae25004
Fix check of literal.
viirya Apr 24, 2018
7dc6ccc
Don't parameterize a global variable.
viirya Apr 25, 2018
85568e7
Fix some tests in TPCDS.
viirya Apr 25, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,16 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")

val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
eval.isNull = if (this.nullable) eval.isNull else FalseLiteral

// Records current input row and variables of this expression.
eval.inputRow = ctx.INPUT_ROW
eval.inputVars = findInputVars(ctx, eval)

reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
Expand All @@ -117,9 +124,29 @@ abstract class Expression extends TreeNode[Expression] {
}
}

/**
* Returns the input variables to this expression.
*/
private def findInputVars(ctx: CodegenContext, eval: ExprCode): Seq[ExprInputVar] = {
if (ctx.currentVars != null) {
this.collect {
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
ExprInputVar(exprCode = ctx.currentVars(ordinal),
dataType = b.dataType, nullable = b.nullable)
}
} else {
Seq.empty
}
}

/**
* In order to prevent 64kb compile error, reducing the size of generated codes by
* separating it into a function if the size exceeds a threshold.
*/
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
lazy val funcParams = ExpressionCodegen.getExpressionInputParams(ctx, this)

if (eval.code.trim.length > 1024 && funcParams.isDefined) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
Expand All @@ -133,17 +160,20 @@ abstract class Expression extends TreeNode[Expression] {
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(nodeName)
val callParams = funcParams.map(_._1.mkString(", ")).get
val declParams = funcParams.map(_._2.mkString(", ")).get

val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
|private $javaType $funcName($declParams) {
| ${eval.code.trim}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)

eval.value = JavaCode.variable(newValue, dataType)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.code = s"$javaType $newValue = $funcFullName($callParams);"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
* @param inputRow A term that holds the input row name when generating this code.
* @param inputVars A list of [[ExprInputVar]] that holds input variables when generating this code.
*/
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
case class ExprCode(
var code: String,
var isNull: ExprValue,
var value: ExprValue,
var inputRow: String = null,
var inputVars: Seq[ExprInputVar] = Seq.empty)

object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
Expand All @@ -72,6 +79,15 @@ object ExprCode {
}
}

/**
* Represents an input variable [[ExprCode]] to an evaluation of an [[Expression]].
*
* @param exprCode The [[ExprCode]] that represents the evaluation result for the input variable.
* @param dataType The data type of the input variable.
* @param nullable Whether the input variable can be null or not.
*/
case class ExprInputVar(exprCode: ExprCode, dataType: DataType, nullable: Boolean)

/**
* State used for subexpression elimination.
*
Expand Down Expand Up @@ -1006,16 +1022,25 @@ class CodegenContext {
commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("subExpr")
val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
val isNull = if (expr.nullable) {
addMutableState(JAVA_BOOLEAN, "subExprIsNull")
} else {
""
}
val value = addMutableState(javaType(expr.dataType), "subExprValue")

// Generate the code for this expression tree and wrap it in a function.
val eval = expr.genCode(this)
val assignIsNull = if (expr.nullable) {
s"$isNull = ${eval.isNull};"
} else {
""
}
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code.trim}
| $isNull = ${eval.isNull};
| $assignIsNull
| $value = ${eval.value};
|}
""".stripMargin
Expand All @@ -1035,9 +1060,15 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType))
val state = if (expr.nullable) {
SubExprEliminationState(
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType))
} else {
SubExprEliminationState(
FalseLiteral,
JavaCode.global(value, expr.dataType))
}
subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DataType

/**
* Defines util methods used in expression code generation.
*/
object ExpressionCodegen {

/**
* Given an expression, returns the all necessary parameters to evaluate it, so the generated
* code of this expression can be split in a function.
* The 1st string in returned tuple is the parameter strings used to call the function.
* The 2nd string in returned tuple is the parameter strings used to declare the function.
*
* Returns `None` if it can't produce valid parameters.
*
* Params to include:
* 1. Evaluated columns referred by this, children or deferred expressions.
* 2. Rows referred by this, children or deferred expressions.
* 3. Eliminated subexpressions referred by children expressions.
*/
def getExpressionInputParams(
ctx: CodegenContext,
expr: Expression): Option[(Seq[String], Seq[String])] = {
val subExprs = getSubExprInChildren(ctx, expr)
val subExprCodes = getSubExprCodes(ctx, subExprs)
val subVars = subExprs.zip(subExprCodes).map { case (subExpr, subExprCode) =>
ExprInputVar(subExprCode, subExpr.dataType, subExpr.nullable)
}
val paramsFromSubExprs = prepareFunctionParams(ctx, subVars)

val inputVars = getInputVarsForChildren(ctx, expr)
val paramsFromColumns = prepareFunctionParams(ctx, inputVars)

val inputRows = ctx.INPUT_ROW +: getInputRowsForChildren(ctx, expr)
val paramsFromRows = inputRows.distinct.filter(_ != null).map { row =>
(row, s"InternalRow $row")
}

val paramsLength = getParamLength(ctx, inputVars ++ subVars) + paramsFromRows.length
// Maximum allowed parameter number for Java's method descriptor.
if (paramsLength > 255) {
None
} else {
val allParams = (paramsFromRows ++ paramsFromColumns ++ paramsFromSubExprs).unzip
val callParams = allParams._1.distinct
val declParams = allParams._2.distinct
Some((callParams, declParams))
}
}

/**
* Returns the eliminated subexpressions in the children expressions.
*/
def getSubExprInChildren(ctx: CodegenContext, expr: Expression): Seq[Expression] = {
expr.children.flatMap { child =>
child.collect {
case e if ctx.subExprEliminationExprs.contains(e) => e
}
}.distinct
}

/**
* A small helper function to return `ExprCode`s that represent subexpressions.
*/
def getSubExprCodes(ctx: CodegenContext, subExprs: Seq[Expression]): Seq[ExprCode] = {
subExprs.map { subExpr =>
val state = ctx.subExprEliminationExprs(subExpr)
ExprCode(code = "", value = state.value, isNull = state.isNull)
}
}

/**
* Retrieves previous input rows referred by children and deferred expressions.
*/
def getInputRowsForChildren(ctx: CodegenContext, expr: Expression): Seq[String] = {
expr.children.flatMap(getInputRows(ctx, _)).distinct
}

/**
* Given a child expression, retrieves previous input rows referred by it or deferred expressions
* which are needed to evaluate it.
*/
def getInputRows(ctx: CodegenContext, child: Expression): Seq[String] = {
child.flatMap {
// An expression directly evaluates on current input row.
case BoundReference(ordinal, _, _) if ctx.currentVars == null ||
ctx.currentVars(ordinal) == null =>
Seq(ctx.INPUT_ROW)

// An expression which is not evaluated yet. Tracks down to find input rows.
case BoundReference(ordinal, _, _) if !isEvaluated(ctx.currentVars(ordinal)) =>
trackDownRow(ctx, ctx.currentVars(ordinal))

case _ => Seq.empty
}.distinct
}

/**
* Tracks down input rows referred by the generated code snippet.
*/
def trackDownRow(ctx: CodegenContext, exprCode: ExprCode): Seq[String] = {
val exprCodes = mutable.Queue[ExprCode](exprCode)
val inputRows = mutable.ArrayBuffer.empty[String]

while (exprCodes.nonEmpty) {
val curExprCode = exprCodes.dequeue()
if (curExprCode.inputRow != null) {
inputRows += curExprCode.inputRow
}
curExprCode.inputVars.foreach { inputVar =>
if (!isEvaluated(inputVar.exprCode)) {
exprCodes.enqueue(inputVar.exprCode)
}
}
}
inputRows
}

/**
* Retrieves previously evaluated columns referred by children and deferred expressions.
* Returned tuple contains the list of expressions and the list of generated codes.
*/
def getInputVarsForChildren(
ctx: CodegenContext,
expr: Expression): Seq[ExprInputVar] = {
expr.children.flatMap(getInputVars(ctx, _)).distinct
}

/**
* Given a child expression, retrieves previously evaluated columns referred by it or
* deferred expressions which are needed to evaluate it.
*/
def getInputVars(ctx: CodegenContext, child: Expression): Seq[ExprInputVar] = {
if (ctx.currentVars == null) {
return Seq.empty
}

child.flatMap {
// An evaluated variable.
case b @ BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null &&
isEvaluated(ctx.currentVars(ordinal)) =>
Seq(ExprInputVar(ctx.currentVars(ordinal), b.dataType, b.nullable))

// An input variable which is not evaluated yet. Tracks down to find any evaluated variables
// in the expression path.
// E.g., if this expression is "d = c + 1" and "c" is not evaluated. We need to track to
// "c = a + b" and see if "a" and "b" are evaluated. If they are, we need to return them so
// to include them into parameters, if not, we track down further.
case BoundReference(ordinal, _, _) if ctx.currentVars(ordinal) != null =>
trackDownVar(ctx, ctx.currentVars(ordinal))

case _ => Seq.empty
}.distinct
}

/**
* Tracks down previously evaluated columns referred by the generated code snippet.
*/
def trackDownVar(ctx: CodegenContext, exprCode: ExprCode): Seq[ExprInputVar] = {
val exprCodes = mutable.Queue[ExprCode](exprCode)
val inputVars = mutable.ArrayBuffer.empty[ExprInputVar]

while (exprCodes.nonEmpty) {
exprCodes.dequeue().inputVars.foreach { inputVar =>
if (isEvaluated(inputVar.exprCode)) {
inputVars += inputVar
} else {
exprCodes.enqueue(inputVar.exprCode)
}
}
}
inputVars
}

/**
* Determines the parameter length in a Java method for given parameters.
*/
def getParamLength(ctx: CodegenContext, inputs: Seq[ExprInputVar]): Int = {
// Method parameter length only depends on data type and nullability. Make fake catalyst
// expressions for calculation.
val exprs = inputs.map(inputVar => BoundReference(1, inputVar.dataType, inputVar.nullable))
CodeGenerator.calculateParamLength(exprs)
}

/**
* Given the lists of input attributes and variables to this expression, returns the strings of
* funtion parameters. The first is the variable names used to call the function, the second is
* the parameters used to declare the function in generated code.
*/
def prepareFunctionParams(
ctx: CodegenContext,
inputVars: Seq[ExprInputVar]): Seq[(String, String)] = {
inputVars.flatMap { inputVar =>
val params = mutable.ArrayBuffer.empty[(String, String)]
val ev = inputVar.exprCode

// Only include the expression value if it can't be accessed in a method.
if (!ev.value.canGlobalAccess) {
val argType = CodeGenerator.javaType(inputVar.dataType)
params += ((ev.value, s"$argType ${ev.value}"))
}

// If it is a nullable expression and `isNull` can't be accessed in a method.
if (inputVar.nullable && !ev.isNull.canGlobalAccess) {
params += ((ev.isNull, s"boolean ${ev.isNull}"))
}

params
}.distinct
}

/**
* Only applied to the `ExprCode` in `ctx.currentVars`.
* The code is emptied after evaluation.
*/
def isEvaluated(exprCode: ExprCode): Boolean = exprCode.code == ""
}
Loading