@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._
3939import org .apache .spark .sql .catalyst .catalog .HiveTableRelation
4040import org .apache .spark .sql .catalyst .encoders ._
4141import org .apache .spark .sql .catalyst .expressions ._
42+ import org .apache .spark .sql .catalyst .expressions .codegen .GenerateSafeProjection
4243import org .apache .spark .sql .catalyst .json .{JacksonGenerator , JSONOptions }
4344import org .apache .spark .sql .catalyst .optimizer .CombineUnions
4445import org .apache .spark .sql .catalyst .parser .{ParseException , ParserUtils }
@@ -198,15 +199,10 @@ class Dataset[T] private[sql](
198199 */
199200 private [sql] implicit val exprEnc : ExpressionEncoder [T ] = encoderFor(encoder)
200201
201- /**
202- * Encoder is used mostly as a container of serde expressions in Dataset. We build logical
203- * plans by these serde expressions and execute it within the query framework. However, for
204- * performance reasons we may want to use encoder as a function to deserialize internal rows to
205- * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
206- * `fromRow` method later.
207- */
208- private val boundEnc =
209- exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
202+ // The deserializer expression which can be used to build a projection and turn rows to objects
203+ // of type T, after collecting rows to the driver side.
204+ private val deserializer =
205+ exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer
210206
211207 private implicit def classTag = exprEnc.clsTag
212208
@@ -2661,7 +2657,12 @@ class Dataset[T] private[sql](
26612657 */
26622658 def toLocalIterator (): java.util.Iterator [T ] = {
26632659 withAction(" toLocalIterator" , queryExecution) { plan =>
2664- plan.executeToIterator().map(boundEnc.fromRow).asJava
2660+ val objProj = GenerateSafeProjection .generate(deserializer :: Nil )
2661+ plan.executeToIterator().map { row =>
2662+ // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
2663+ // parameter of its `get` method, so it's safe to use null here.
2664+ objProj(row).get(0 , null ).asInstanceOf [T ]
2665+ }.asJava
26652666 }
26662667 }
26672668
@@ -3102,7 +3103,12 @@ class Dataset[T] private[sql](
31023103 * Collect all elements from a spark plan.
31033104 */
31043105 private def collectFromPlan (plan : SparkPlan ): Array [T ] = {
3105- plan.executeCollect().map(boundEnc.fromRow)
3106+ val objProj = GenerateSafeProjection .generate(deserializer :: Nil )
3107+ plan.executeCollect().map { row =>
3108+ // The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
3109+ // parameter of its `get` method, so it's safe to use null here.
3110+ objProj(row).get(0 , null ).asInstanceOf [T ]
3111+ }
31063112 }
31073113
31083114 private def sortInternal (global : Boolean , sortExprs : Seq [Column ]): Dataset [T ] = {
0 commit comments