Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,13 @@ class EquivalentExpressions(
// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
// 3. HigherOrderFunction: lambda functions operate in the context of local lambdas and can't
// be called outside of that scope, only the arguments can be evaluated ahead of
// time.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case h: HigherOrderFunction => h.arguments
case other => skipForShortcut(other).children
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,41 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null

/**
* Holding a map of current lambda variables.
*/
var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty

def withLambdaVars(
namedLambdas: Seq[NamedLambdaVariable],
f: Seq[ExprCode] => ExprCode): ExprCode = {
val lambdaVars = namedLambdas.map { lambda =>
val id = lambda.exprId.id
if (currentLambdaVars.get(id).nonEmpty) {
throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id)
}
val isNull = if (lambda.nullable) {
JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull"))
} else {
FalseLiteral
}
val value = addMutableState(javaType(lambda.dataType), "lambdaValue")
val lambdaVar = ExprCode(isNull, JavaCode.global(value, lambda.dataType))
currentLambdaVars.put(id, lambdaVar)
lambdaVar
}

val result = f(lambdaVars)
namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove)
result
}

def getLambdaVar(id: Long): ExprCode = {
currentLambdaVars.getOrElse(
id,
throw QueryExecutionErrors.lambdaVariableNotDefinedError(id))
}

/**
* Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
* 2-tuple: java type, variable name.
Expand Down
Loading