@@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
3030 * Window function testing for DataFrame API.
3131 */
3232class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
33+
3334 import testImplicits ._
3435
3536 test(" reuse window partitionBy" ) {
@@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
7273 cume_dist().over(Window .partitionBy(" value" ).orderBy(" key" )),
7374 percent_rank().over(Window .partitionBy(" value" ).orderBy(" key" ))),
7475 Row (1 , 1 , 1 , 1.0d , 1 , 1 , 1 , 1 , 1 , 1 , 1.0d , 0.0d ) ::
75- Row (1 , 1 , 1 , 1.0d , 1 , 1 , 1 , 1 , 1 , 1 , 1.0d / 3.0d , 0.0d ) ::
76- Row (2 , 2 , 1 , 5.0d / 3.0d , 3 , 5 , 1 , 2 , 2 , 2 , 1.0d , 0.5d ) ::
77- Row (2 , 2 , 1 , 5.0d / 3.0d , 3 , 5 , 2 , 3 , 2 , 2 , 1.0d , 0.5d ) :: Nil )
76+ Row (1 , 1 , 1 , 1.0d , 1 , 1 , 1 , 1 , 1 , 1 , 1.0d / 3.0d , 0.0d ) ::
77+ Row (2 , 2 , 1 , 5.0d / 3.0d , 3 , 5 , 1 , 2 , 2 , 2 , 1.0d , 0.5d ) ::
78+ Row (2 , 2 , 1 , 5.0d / 3.0d , 3 , 5 , 2 , 3 , 2 , 2 , 1.0d , 0.5d ) :: Nil )
7879 }
7980
8081 test(" window function should fail if order by clause is not specified" ) {
@@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
162163 Seq (
163164 Row (" a" , - 50.0 , 50.0 , 50.0 , 7.0710678118654755 , 7.0710678118654755 ),
164165 Row (" b" , - 50.0 , 50.0 , 50.0 , 7.0710678118654755 , 7.0710678118654755 ),
165- Row (" c" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
166- Row (" d" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
167- Row (" e" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
168- Row (" f" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
169- Row (" g" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
170- Row (" h" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
166+ Row (" c" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
167+ Row (" d" , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ),
168+ Row (" e" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
169+ Row (" f" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
170+ Row (" g" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
171+ Row (" h" , 24.0 , 12.0 , 12.0 , 3.4641016151377544 , 3.4641016151377544 ),
171172 Row (" i" , Double .NaN , Double .NaN , Double .NaN , Double .NaN , Double .NaN )))
172173 }
173174
@@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
326327 var_samp($" value" ).over(window),
327328 approx_count_distinct($" value" ).over(window)),
328329 Seq .fill(4 )(Row (" a" , 1.0d / 4.0d , 1.0d / 3.0d , 2 ))
329- ++ Seq .fill(3 )(Row (" b" , 2.0d / 3.0d , 1.0d , 3 )))
330+ ++ Seq .fill(3 )(Row (" b" , 2.0d / 3.0d , 1.0d , 3 )))
330331 }
331332
332333 test(" window function with aggregates" ) {
@@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
624625
625626 test(" SPARK-24575: Window functions inside WHERE and HAVING clauses" ) {
626627 def checkAnalysisError (df : => DataFrame ): Unit = {
627- val thrownException = the [AnalysisException ] thrownBy {
628+ val thrownException = the[AnalysisException ] thrownBy {
628629 df.queryExecution.analyzed
629630 }
630631 assert(thrownException.message.contains(" window functions inside WHERE and HAVING clauses" ))
@@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
658659 |GROUP BY a
659660 |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1 """ .stripMargin))
660661 }
662+
663+ test(" window functions in multiple selects" ) {
664+ val df = Seq (
665+ (" S1" , " P1" , 100 ),
666+ (" S1" , " P1" , 700 ),
667+ (" S2" , " P1" , 200 ),
668+ (" S2" , " P2" , 300 )
669+ ).toDF(" sno" , " pno" , " qty" )
670+
671+ val w1 = Window .partitionBy(" sno" )
672+ val w2 = Window .partitionBy(" sno" , " pno" )
673+
674+ checkAnswer(
675+ df.select($" sno" , $" pno" , $" qty" , sum($" qty" ).over(w2).alias(" sum_qty_2" ))
676+ .select($" sno" , $" pno" , $" qty" , col(" sum_qty_2" ), sum(" qty" ).over(w1).alias(" sum_qty_1" )),
677+ Seq (
678+ Row (" S1" , " P1" , 100 , 800 , 800 ),
679+ Row (" S1" , " P1" , 700 , 800 , 800 ),
680+ Row (" S2" , " P1" , 200 , 200 , 500 ),
681+ Row (" S2" , " P2" , 300 , 300 , 500 )))
682+
683+ }
661684}
0 commit comments