@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._
2121
2222import org .apache .spark .annotation .Experimental
2323import org .apache .spark .rdd .RDD
24- import org .apache .spark .sql .catalyst .analysis .UnresolvedAlias
2524import org .apache .spark .api .java .function ._
2625
2726import org .apache .spark .sql .catalyst .encoders ._
2827import org .apache .spark .sql .catalyst .expressions ._
28+ import org .apache .spark .sql .catalyst .analysis .UnresolvedAlias
2929import org .apache .spark .sql .catalyst .plans .Inner
3030import org .apache .spark .sql .catalyst .plans .logical ._
3131import org .apache .spark .sql .execution .{Queryable , QueryExecution }
32+ import org .apache .spark .sql .execution .aggregate .TypedAggregateExpression
3233import org .apache .spark .sql .types .StructType
3334
3435/**
@@ -359,7 +360,7 @@ class Dataset[T] private[sql](
359360 * @since 1.6.0
360361 */
361362 def select [U1 : Encoder ](c1 : TypedColumn [T , U1 ]): Dataset [U1 ] = {
362- new Dataset [U1 ](sqlContext, Project (Alias (c1 .expr, " _1" )() :: Nil , logicalPlan))
363+ new Dataset [U1 ](sqlContext, Project (Alias (withEncoder(c1) .expr, " _1" )() :: Nil , logicalPlan))
363364 }
364365
365366 /**
@@ -368,11 +369,12 @@ class Dataset[T] private[sql](
368369 * that cast appropriately for the user facing interface.
369370 */
370371 protected def selectUntyped (columns : TypedColumn [_, _]* ): Dataset [_] = {
371- val aliases = columns.zipWithIndex.map { case (c, i) => Alias (c.expr, s " _ ${i + 1 }" )() }
372+ val withEncoders = columns.map(withEncoder)
373+ val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias (c.expr, s " _ ${i + 1 }" )() }
372374 val unresolvedPlan = Project (aliases, logicalPlan)
373375 val execution = new QueryExecution (sqlContext, unresolvedPlan)
374376 // Rebind the encoders to the nested schema that will be produced by the select.
375- val encoders = columns .map(_.encoder.asInstanceOf [ExpressionEncoder [_]]).zip(aliases).map {
377+ val encoders = withEncoders .map(_.encoder.asInstanceOf [ExpressionEncoder [_]]).zip(aliases).map {
376378 case (e : ExpressionEncoder [_], a) if ! e.flat =>
377379 e.nested(a.toAttribute).resolve(execution.analyzed.output)
378380 case (e, a) =>
@@ -381,6 +383,16 @@ class Dataset[T] private[sql](
381383 new Dataset (sqlContext, execution, ExpressionEncoder .tuple(encoders))
382384 }
383385
386+ private def withEncoder (c : TypedColumn [_, _]): TypedColumn [_, _] = {
387+ val e = c.expr transform {
388+ case ta : TypedAggregateExpression if ta.aEncoder.isEmpty =>
389+ ta.copy(
390+ aEncoder = Some (encoder.asInstanceOf [ExpressionEncoder [Any ]]),
391+ children = queryExecution.analyzed.output)
392+ }
393+ new TypedColumn (e, c.encoder)
394+ }
395+
384396 /**
385397 * Returns a new [[Dataset ]] by computing the given [[Column ]] expressions for each element.
386398 * @since 1.6.0
0 commit comments