From 36028079d6bdd413a306c69abdb7af515e78cc1e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 12 Sep 2019 11:18:07 +0800 Subject: [PATCH 1/3] Modify fillValue approach to support joined dataframe --- .../spark/sql/DataFrameNaFunctions.scala | 12 +++++------ .../spark/sql/DataFrameNaFunctionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 53e9f810d7c85..f9e7abd0ae47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -488,7 +488,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => + val fillColumnsInfo = df.schema.fields.filter { f => val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType @@ -497,12 +497,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. - if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { - fillCol[T](f, value) - } else { - df.col(f.name) - } + typeMatches && cols.exists(col => columnEquals(f.name, col)) + }.map { col => + (col.name, fillCol[T](col, value)) } - df.select(projections : _*) + df.withColumns(fillColumnsInfo.map(_._1), fillColumnsInfo.map(_._2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aeee4577d3483..534c290643e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -231,6 +231,26 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } } + test("fill for join operation") { + val df1 = Seq( + ("f1-1", "f2", null), + ("f1-2", null, null), + ("f1-3", "f2", "f3-1"), + ("f1-4", "f2", "f3-1") + ).toDF("f1", "f2", "f3") + val df2 = Seq( + ("f1-1", null, null), + ("f1-2", "f2", null), + ("f1-3", "f2", "f4-1") + ).toDF("f1", "f2", "f4") + val joined_df = df1.join(df2, Seq("f1"), joinType = "left_outer") + checkAnswer(joined_df.na.fill("", cols = Seq("f4")), + Row("f1-1", "f2", null, null, "") :: + Row("f1-2", null, null, "f2", "") :: + Row("f1-3", "f2", "f3-1", "f2", "f4-1") :: + Row("f1-4", "f2", "f3-1", null, "") :: Nil) + } + test("replace") { val input = createDF() From 03305bea4e8663f7bcd12963b489b1e8755c382f Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 17 Sep 2019 20:44:11 +0800 Subject: [PATCH 2/3] Add test for checking ambiguous field --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 534c290643e5c..75642a0bd9325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -231,7 +231,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } } - test("fill for join operation") { + def createDFsWithSameFieldsName(): (DataFrame, DataFrame) = { val df1 = Seq( ("f1-1", "f2", null), ("f1-2", null, null), @@ -243,6 +243,11 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { ("f1-2", "f2", null), ("f1-3", "f2", "f4-1") ).toDF("f1", "f2", "f4") + (df1, df2) + } + + test("fill unambiguous field for join operation") { + val (df1, df2) = createDFsWithSameFieldsName() val joined_df = df1.join(df2, Seq("f1"), joinType = "left_outer") checkAnswer(joined_df.na.fill("", cols = Seq("f4")), Row("f1-1", "f2", null, null, "") :: @@ -251,6 +256,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { Row("f1-4", "f2", "f3-1", null, "") :: Nil) } + test("fill ambiguous field for join operation") { + val (df1, df2) = createDFsWithSameFieldsName() + val joined_df = df1.join(df2, Seq("f1"), joinType = "left_outer") + + val message = intercept[AnalysisException] { + joined_df.na.fill("", cols = Seq("f2")) + }.getMessage + assert(message.contains("Reference 'f2' is ambiguous")) + } + test("replace") { val input = createDF() From 59106dce05b97125c064950c67763ece4e3e19b6 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 20 Sep 2019 10:10:26 -0700 Subject: [PATCH 3/3] simplify --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f9e7abd0ae47b..6dd21f114c902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -488,7 +488,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val fillColumnsInfo = df.schema.fields.filter { f => + val filledColumns = df.schema.fields.filter { f => val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType @@ -498,9 +498,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // Only fill if the column is part of the cols list. typeMatches && cols.exists(col => columnEquals(f.name, col)) - }.map { col => - (col.name, fillCol[T](col, value)) } - df.withColumns(fillColumnsInfo.map(_._1), fillColumnsInfo.map(_._2)) + df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) } }