Skip to content

Commit cecea8c

Browse files
committed
Dataset.collect is not threadsafe
1 parent 1051ebe commit cecea8c

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._
3939
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
4040
import org.apache.spark.sql.catalyst.encoders._
4141
import org.apache.spark.sql.catalyst.expressions._
42+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
4243
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
4344
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
4445
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
@@ -198,15 +199,10 @@ class Dataset[T] private[sql](
198199
*/
199200
private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
200201

201-
/**
202-
* Encoder is used mostly as a container of serde expressions in Dataset. We build logical
203-
* plans by these serde expressions and execute it within the query framework. However, for
204-
* performance reasons we may want to use encoder as a function to deserialize internal rows to
205-
* custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
206-
* `fromRow` method later.
207-
*/
208-
private val boundEnc =
209-
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
202+
// The deserializer expression which can be used to build a projection and turn rows to objects
203+
// of type T, after collecting rows to the driver side.
204+
private val deserializer =
205+
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer
210206

211207
private implicit def classTag = exprEnc.clsTag
212208

@@ -2661,7 +2657,12 @@ class Dataset[T] private[sql](
26612657
*/
26622658
def toLocalIterator(): java.util.Iterator[T] = {
26632659
withAction("toLocalIterator", queryExecution) { plan =>
2664-
plan.executeToIterator().map(boundEnc.fromRow).asJava
2660+
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
2661+
plan.executeToIterator().map { row =>
2662+
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
2663+
// parameter of its `get` method, so it's safe to use null here.
2664+
objProj(row).get(0, null).asInstanceOf[T]
2665+
}.asJava
26652666
}
26662667
}
26672668

@@ -3102,7 +3103,12 @@ class Dataset[T] private[sql](
31023103
* Collect all elements from a spark plan.
31033104
*/
31043105
private def collectFromPlan(plan: SparkPlan): Array[T] = {
3105-
plan.executeCollect().map(boundEnc.fromRow)
3106+
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
3107+
plan.executeCollect().map { row =>
3108+
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
3109+
// parameter of its `get` method, so it's safe to use null here.
3110+
objProj(row).get(0, null).asInstanceOf[T]
3111+
}
31063112
}
31073113

31083114
private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {

0 commit comments

Comments
 (0)