@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
1919
2020import java .util .Locale
2121
22+ import scala .collection .mutable
23+
2224import org .apache .spark .broadcast
2325import org .apache .spark .rdd .RDD
2426import org .apache .spark .sql .catalyst .InternalRow
@@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan {
106108 */
107109 protected def doProduce (ctx : CodegenContext ): String
108110
111+ private def prepareRowVar (ctx : CodegenContext , row : String , colVars : Seq [ExprCode ]): ExprCode = {
112+ if (row != null ) {
113+ ExprCode (" " , " false" , row)
114+ } else {
115+ if (colVars.nonEmpty) {
116+ val colExprs = output.zipWithIndex.map { case (attr, i) =>
117+ BoundReference (i, attr.dataType, attr.nullable)
118+ }
119+ val evaluateInputs = evaluateVariables(colVars)
120+ // generate the code to create a UnsafeRow
121+ ctx.INPUT_ROW = row
122+ ctx.currentVars = colVars
123+ val ev = GenerateUnsafeProjection .createCode(ctx, colExprs, false )
124+ val code = s """
125+ | $evaluateInputs
126+ | ${ev.code.trim}
127+ """ .stripMargin.trim
128+ ExprCode (code, " false" , ev.value)
129+ } else {
130+ // There is no columns
131+ ExprCode (" " , " false" , " unsafeRow" )
132+ }
133+ }
134+ }
135+
109136 /**
110137 * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
111138 *
@@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan {
126153 }
127154 }
128155
129- val rowVar = if (row != null ) {
130- ExprCode (" " , " false" , row)
131- } else {
132- if (outputVars.nonEmpty) {
133- val colExprs = output.zipWithIndex.map { case (attr, i) =>
134- BoundReference (i, attr.dataType, attr.nullable)
135- }
136- val evaluateInputs = evaluateVariables(outputVars)
137- // generate the code to create a UnsafeRow
138- ctx.INPUT_ROW = row
139- ctx.currentVars = outputVars
140- val ev = GenerateUnsafeProjection .createCode(ctx, colExprs, false )
141- val code = s """
142- | $evaluateInputs
143- | ${ev.code.trim}
144- """ .stripMargin.trim
145- ExprCode (code, " false" , ev.value)
146- } else {
147- // There is no columns
148- ExprCode (" " , " false" , " unsafeRow" )
149- }
150- }
156+ val rowVar = prepareRowVar(ctx, row, outputVars)
151157
152158 // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
153159 // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
@@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan {
156162 ctx.INPUT_ROW = null
157163 ctx.freshNamePrefix = parent.variablePrefix
158164 val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
165+
166+ // Under certain conditions, we can put the logic to consume the rows of this operator into
167+ // another function. So we can prevent a generated function too long to be optimized by JIT.
168+ // The conditions:
169+ // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
170+ // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
171+ // all variables in output (see `requireAllOutput`).
172+ // 3. The number of output variables must less than maximum number of parameters in Java method
173+ // declaration.
174+ val confEnabled = SQLConf .get.wholeStageSplitConsumeFuncByOperator
175+ val requireAllOutput = output.forall(parent.usedInputs.contains(_))
176+ val paramLength = ctx.calculateParamLength(output) + (if (row != null ) 1 else 0 )
177+ val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) {
178+ constructDoConsumeFunction(ctx, inputVars, row)
179+ } else {
180+ parent.doConsume(ctx, inputVars, rowVar)
181+ }
159182 s """
160183 | ${ctx.registerComment(s " CONSUME: ${parent.simpleString}" )}
161184 | $evaluated
162- | ${parent.doConsume(ctx, inputVars, rowVar)}
185+ | $consumeFunc
186+ """ .stripMargin
187+ }
188+
189+ /**
190+ * To prevent concatenated function growing too long to be optimized by JIT. We can separate the
191+ * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call.
192+ */
193+ private def constructDoConsumeFunction (
194+ ctx : CodegenContext ,
195+ inputVars : Seq [ExprCode ],
196+ row : String ): String = {
197+ val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
198+ val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
199+
200+ val doConsume = ctx.freshName(" doConsume" )
201+ ctx.currentVars = inputVarsInFunc
202+ ctx.INPUT_ROW = null
203+
204+ val doConsumeFuncName = ctx.addNewFunction(doConsume,
205+ s """
206+ | private void $doConsume( ${params.mkString(" , " )}) throws java.io.IOException {
207+ | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
208+ | }
209+ """ .stripMargin)
210+
211+ s """
212+ | $doConsumeFuncName( ${args.mkString(" , " )});
163213 """ .stripMargin
164214 }
165215
216+ /**
217+ * Returns arguments for calling method and method definition parameters of the consume function.
218+ * And also returns the list of `ExprCode` for the parameters.
219+ */
220+ private def constructConsumeParameters (
221+ ctx : CodegenContext ,
222+ attributes : Seq [Attribute ],
223+ variables : Seq [ExprCode ],
224+ row : String ): (Seq [String ], Seq [String ], Seq [ExprCode ]) = {
225+ val arguments = mutable.ArrayBuffer [String ]()
226+ val parameters = mutable.ArrayBuffer [String ]()
227+ val paramVars = mutable.ArrayBuffer [ExprCode ]()
228+
229+ if (row != null ) {
230+ arguments += row
231+ parameters += s " InternalRow $row"
232+ }
233+
234+ variables.zipWithIndex.foreach { case (ev, i) =>
235+ val paramName = ctx.freshName(s " expr_ $i" )
236+ val paramType = ctx.javaType(attributes(i).dataType)
237+
238+ arguments += ev.value
239+ parameters += s " $paramType $paramName"
240+ val paramIsNull = if (! attributes(i).nullable) {
241+ // Use constant `false` without passing `isNull` for non-nullable variable.
242+ " false"
243+ } else {
244+ val isNull = ctx.freshName(s " exprIsNull_ $i" )
245+ arguments += ev.isNull
246+ parameters += s " boolean $isNull"
247+ isNull
248+ }
249+
250+ paramVars += ExprCode (" " , paramIsNull, paramName)
251+ }
252+ (arguments, parameters, paramVars)
253+ }
254+
166255 /**
167256 * Returns source code to evaluate all the variables, and clear the code of them, to prevent
168257 * them to be evaluated twice.
0 commit comments