diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 53e9f810d7c85..6dd21f114c902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -488,7 +488,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => + val filledColumns = df.schema.fields.filter { f => val typeMatches = (targetType, f.dataType) match { case (NumericType, dt) => dt.isInstanceOf[NumericType] case (StringType, dt) => dt == StringType @@ -497,12 +497,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { throw new IllegalArgumentException(s"$targetType is not matched at fillValue") } // Only fill if the column is part of the cols list. - if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { - fillCol[T](f, value) - } else { - df.col(f.name) - } + typeMatches && cols.exists(col => columnEquals(f.name, col)) } - df.select(projections : _*) + df.withColumns(filledColumns.map(_.name), filledColumns.map(fillCol[T](_, value))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index aeee4577d3483..75642a0bd9325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -231,6 +231,41 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } } + def createDFsWithSameFieldsName(): (DataFrame, DataFrame) = { + val df1 = Seq( + ("f1-1", "f2", null), + ("f1-2", null, null), + ("f1-3", "f2", "f3-1"), + ("f1-4", "f2", "f3-1") + ).toDF("f1", "f2", "f3") + val df2 = Seq( + ("f1-1", null, null), + ("f1-2", "f2", null), + ("f1-3", "f2", "f4-1") + ).toDF("f1", "f2", "f4") + (df1, df2) + } + + test("fill unambiguous field for join operation") { + val (df1, df2) = createDFsWithSameFieldsName() + val joined_df = df1.join(df2, Seq("f1"), joinType = "left_outer") + checkAnswer(joined_df.na.fill("", cols = Seq("f4")), + Row("f1-1", "f2", null, null, "") :: + Row("f1-2", null, null, "f2", "") :: + Row("f1-3", "f2", "f3-1", "f2", "f4-1") :: + Row("f1-4", "f2", "f3-1", null, "") :: Nil) + } + + test("fill ambiguous field for join operation") { + val (df1, df2) = createDFsWithSameFieldsName() + val joined_df = df1.join(df2, Seq("f1"), joinType = "left_outer") + + val message = intercept[AnalysisException] { + joined_df.na.fill("", cols = Seq("f2")) + }.getMessage + assert(message.contains("Reference 'f2' is ambiguous")) + } + test("replace") { val input = createDF()