Skip to content

Commit 02f2031

Browse files
mengxrrxin
authored andcommitted
[SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union
## What changes were proposed in this pull request? When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following: - The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations. However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column. See the unit tests below or JIRA for examples. This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization. ## How was this patch tested? Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...) cc: rxin davies Author: Xiangrui Meng <[email protected]> Closes #15567 from mengxr/SPARK-14393.
1 parent 742e0fe commit 02f2031

32 files changed

+231
-78
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
788788
}
789789

790790
/**
791-
* [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
792-
* performance API to be used carefully only if we are sure that the RDD elements are
791+
* [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
792+
* It is a performance API to be used carefully only if we are sure that the RDD elements are
793793
* serializable and don't require closure cleaning.
794794
*
795795
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
796796
* which should be `false` unless this is a pair RDD and the input function doesn't modify
797797
* the keys.
798798
*/
799+
private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
800+
f: (Int, Iterator[T]) => Iterator[U],
801+
preservesPartitioning: Boolean = false): RDD[U] = withScope {
802+
new MapPartitionsRDD(
803+
this,
804+
(context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
805+
preservesPartitioning)
806+
}
807+
808+
/**
809+
* [performance] Spark's internal mapPartitions method that skips closure cleaning.
810+
*/
799811
private[spark] def mapPartitionsInternal[U: ClassTag](
800812
f: Iterator[T] => Iterator[U],
801813
preservesPartitioning: Boolean = false): RDD[U] = withScope {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,17 +272,28 @@ trait Nondeterministic extends Expression {
272272
final override def deterministic: Boolean = false
273273
final override def foldable: Boolean = false
274274

275+
@transient
275276
private[this] var initialized = false
276277

277-
final def setInitialValues(): Unit = {
278-
initInternal()
278+
/**
279+
* Initializes internal states given the current partition index and mark this as initialized.
280+
* Subclasses should override [[initializeInternal()]].
281+
*/
282+
final def initialize(partitionIndex: Int): Unit = {
283+
initializeInternal(partitionIndex)
279284
initialized = true
280285
}
281286

282-
protected def initInternal(): Unit
287+
protected def initializeInternal(partitionIndex: Int): Unit
283288

289+
/**
290+
* @inheritdoc
291+
* Throws an exception if [[initialize()]] is not called yet.
292+
* Subclasses should override [[evalInternal()]].
293+
*/
284294
final override def eval(input: InternalRow = null): Any = {
285-
require(initialized, "nondeterministic expression should be initialized before evaluate")
295+
require(initialized,
296+
s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
286297
evalInternal(input)
287298
}
288299

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
3737

3838
override def prettyName: String = "input_file_name"
3939

40-
override protected def initInternal(): Unit = {}
40+
override protected def initializeInternal(partitionIndex: Int): Unit = {}
4141

4242
override protected def evalInternal(input: InternalRow): UTF8String = {
4343
InputFileNameHolder.getInputFileName()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
5050

5151
@transient private[this] var partitionMask: Long = _
5252

53-
override protected def initInternal(): Unit = {
53+
override protected def initializeInternal(partitionIndex: Int): Unit = {
5454
count = 0L
55-
partitionMask = TaskContext.getPartitionId().toLong << 33
55+
partitionMask = partitionIndex.toLong << 33
5656
}
5757

5858
override def nullable: Boolean = false
@@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
6868
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
6969
val countTerm = ctx.freshName("count")
7070
val partitionMaskTerm = ctx.freshName("partitionMask")
71-
ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
72-
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
73-
s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
71+
ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
72+
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
73+
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
74+
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
7475

7576
ev.copy(code = s"""
7677
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ import org.apache.spark.sql.types.{DataType, StructType}
2323

2424
/**
2525
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
26+
*
2627
* @param expressions a sequence of expressions that determine the value of each column of the
2728
* output row.
2829
*/
2930
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
3031
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
3132
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
3233

33-
expressions.foreach(_.foreach {
34-
case n: Nondeterministic => n.setInitialValues()
35-
case _ =>
36-
})
34+
override def initialize(partitionIndex: Int): Unit = {
35+
expressions.foreach(_.foreach {
36+
case n: Nondeterministic => n.initialize(partitionIndex)
37+
case _ =>
38+
})
39+
}
3740

3841
// null check is required for when Kryo invokes the no-arg constructor.
3942
protected val exprArray = if (expressions != null) expressions.toArray else null
@@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
5457
/**
5558
* A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
5659
* expressions.
60+
*
5761
* @param expressions a sequence of expressions that determine the value of each column of the
5862
* output row.
5963
*/
@@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
6367

6468
private[this] val buffer = new Array[Any](expressions.size)
6569

66-
expressions.foreach(_.foreach {
67-
case n: Nondeterministic => n.setInitialValues()
68-
case _ =>
69-
})
70+
override def initialize(partitionIndex: Int): Unit = {
71+
expressions.foreach(_.foreach {
72+
case n: Nondeterministic => n.initialize(partitionIndex)
73+
case _ =>
74+
})
75+
}
7076

7177
private[this] val exprArray = expressions.toArray
7278
private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.TaskContext
2120
import org.apache.spark.sql.catalyst.InternalRow
2221
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2322
import org.apache.spark.sql.types.{DataType, IntegerType}
2423

2524
/**
26-
* Expression that returns the current partition id of the Spark task.
25+
* Expression that returns the current partition id.
2726
*/
2827
@ExpressionDescription(
29-
usage = "_FUNC_() - Returns the current partition id of the Spark task",
28+
usage = "_FUNC_() - Returns the current partition id",
3029
extended = "> SELECT _FUNC_();\n 0")
3130
case class SparkPartitionID() extends LeafExpression with Nondeterministic {
3231

@@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
3837

3938
override val prettyName = "SPARK_PARTITION_ID"
4039

41-
override protected def initInternal(): Unit = {
42-
partitionId = TaskContext.getPartitionId()
40+
override protected def initializeInternal(partitionIndex: Int): Unit = {
41+
partitionId = partitionIndex
4342
}
4443

4544
override protected def evalInternal(input: InternalRow): Int = partitionId
4645

4746
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
4847
val idTerm = ctx.freshName("partitionId")
49-
ctx.addMutableState(ctx.JAVA_INT, idTerm,
50-
s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
48+
ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
49+
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
5150
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
5251
}
5352
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,20 @@ class CodegenContext {
184184
splitExpressions(initCodes, "init", Nil)
185185
}
186186

187+
/**
188+
* Code statements to initialize states that depend on the partition index.
189+
* An integer `partitionIndex` will be made available within the scope.
190+
*/
191+
val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty
192+
193+
def addPartitionInitializationStatement(statement: String): Unit = {
194+
partitionInitializationStatements += statement
195+
}
196+
197+
def initPartition(): String = {
198+
partitionInitializationStatements.mkString("\n")
199+
}
200+
187201
/**
188202
* Holding all the functions those will be added into generated class.
189203
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No
2525
trait CodegenFallback extends Expression {
2626

2727
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
28-
foreach {
29-
case n: Nondeterministic => n.setInitialValues()
30-
case _ =>
31-
}
32-
3328
// LeafNode does not need `input`
3429
val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
3530
val idx = ctx.references.length
3631
ctx.references += this
32+
var childIndex = idx
33+
this.foreach {
34+
case n: Nondeterministic =>
35+
// This might add the current expression twice, but it won't hurt.
36+
ctx.references += n
37+
childIndex += 1
38+
ctx.addPartitionInitializationStatement(
39+
s"""
40+
|((Nondeterministic) references[$childIndex])
41+
| .initialize(partitionIndex);
42+
""".stripMargin)
43+
case _ =>
44+
}
3745
val objectTerm = ctx.freshName("obj")
3846
val placeHolder = ctx.registerComment(this.toString)
3947
if (nullable) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
111111
${ctx.initMutableStates()}
112112
}
113113

114+
public void initialize(int partitionIndex) {
115+
${ctx.initPartition()}
116+
}
117+
114118
${ctx.declareAddedFunctions()}
115119

116120
public ${classOf[BaseMutableProjection].getName} target(InternalRow row) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._
2525
*/
2626
abstract class Predicate {
2727
def eval(r: InternalRow): Boolean
28+
29+
/**
30+
* Initializes internal states given the current partition index.
31+
* This is used by nondeterministic expressions to set initial states.
32+
* The default implementation does nothing.
33+
*/
34+
def initialize(partitionIndex: Int): Unit = {}
2835
}
2936

3037
/**
3138
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
3239
*/
33-
object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] {
40+
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
3441

3542
protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
3643

3744
protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
3845
BindReferences.bindReference(in, inputSchema)
3946

40-
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
47+
protected def create(predicate: Expression): Predicate = {
4148
val ctx = newCodeGenContext()
4249
val eval = predicate.genCode(ctx)
4350

@@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
5562
${ctx.initMutableStates()}
5663
}
5764

65+
public void initialize(int partitionIndex) {
66+
${ctx.initPartition()}
67+
}
68+
5869
${ctx.declareAddedFunctions()}
5970

6071
public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
@@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
6778
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
6879
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
6980

70-
val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
71-
(r: InternalRow) => p.eval(r)
81+
CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
7282
}
7383
}

0 commit comments

Comments
 (0)