diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 53d732403f979..25d71aa825018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -388,6 +388,16 @@ class RelationalGroupedDataset protected[sql]( pivot(pivotColumn, values.asScala) } + /** + * Returns all grouping column names as an array. + * + * @since 2.1.0 + */ + def columns: Array[String] = { + val groupingNamedExpressions = groupingExprs.map(alias) + groupingNamedExpressions.map(_.name).toArray + } + /** * Applies the given serialized R function `func` to each group of data. For each unique group, * the function will be passed the group key and an iterator that contains all of the elements in diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69a3b5f278fd8..043822363c35a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -85,6 +85,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)), Row(null, new java.math.BigDecimal(2.0))) ) + + assert(df1.groupBy("key").columns.sameElements(Array("key"))) + assert(df1.groupBy("value1", "value2").columns.sameElements(Array("value1", "value2"))) } test("SPARK-17124 agg should be ordering preserving") {