Skip to content

Commit f9aeb96

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-11656][SQL] support typed aggregate in project list
insert `aEncoder` like we do in `agg` Author: Wenchen Fan <[email protected]> Closes #9630 from cloud-fan/select. (cherry picked from commit 9c57bc0) Signed-off-by: Michael Armbrust <[email protected]>
1 parent a83ce04 commit f9aeb96

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._
2121

2222
import org.apache.spark.annotation.Experimental
2323
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
2524
import org.apache.spark.api.java.function._
2625

2726
import org.apache.spark.sql.catalyst.encoders._
2827
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
2929
import org.apache.spark.sql.catalyst.plans.Inner
3030
import org.apache.spark.sql.catalyst.plans.logical._
3131
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
32+
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
3233
import org.apache.spark.sql.types.StructType
3334

3435
/**
@@ -359,7 +360,7 @@ class Dataset[T] private[sql](
359360
* @since 1.6.0
360361
*/
361362
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
362-
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
363+
new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan))
363364
}
364365

365366
/**
@@ -368,11 +369,12 @@ class Dataset[T] private[sql](
368369
* that cast appropriately for the user facing interface.
369370
*/
370371
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
371-
val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
372+
val withEncoders = columns.map(withEncoder)
373+
val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
372374
val unresolvedPlan = Project(aliases, logicalPlan)
373375
val execution = new QueryExecution(sqlContext, unresolvedPlan)
374376
// Rebind the encoders to the nested schema that will be produced by the select.
375-
val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
377+
val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
376378
case (e: ExpressionEncoder[_], a) if !e.flat =>
377379
e.nested(a.toAttribute).resolve(execution.analyzed.output)
378380
case (e, a) =>
@@ -381,6 +383,16 @@ class Dataset[T] private[sql](
381383
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
382384
}
383385

386+
private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
387+
val e = c.expr transform {
388+
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
389+
ta.copy(
390+
aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
391+
children = queryExecution.analyzed.output)
392+
}
393+
new TypedColumn(e, c.encoder)
394+
}
395+
384396
/**
385397
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
386398
* @since 1.6.0

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
114114
ComplexResultAgg.toColumn),
115115
("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
116116
}
117+
118+
test("typed aggregation: in project list") {
119+
val ds = Seq(1, 3, 2, 5).toDS()
120+
121+
checkAnswer(
122+
ds.select(sum((i: Int) => i)),
123+
11)
124+
checkAnswer(
125+
ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
126+
11 -> 22)
127+
}
117128
}

0 commit comments

Comments
 (0)