Skip to content

Commit a9c1189

Browse files
amandeep-sharmacloud-fan
authored andcommitted
[SPARK-34649][SQL][DOCS] org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name having a dot
### What changes were proposed in this pull request? Use resolved attributes instead of data-frame fields for replacing values. ### Why are the changes needed? dataframe.na.replace() does not work for column having a dot in the name ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Added unit tests for the same Closes #31769 from amandeep-sharma/master. Authored-by: Amandeep Sharma <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent b5b1985 commit a9c1189

File tree

3 files changed

+67
-20
lines changed

3 files changed

+67
-20
lines changed

docs/sql-migration-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ license: |
6666
- In Spark 3.2, the output schema of `SHOW TBLPROPERTIES` becomes `key: string, value: string` whether you specify the table property key or not. In Spark 3.1 and earlier, the output schema of `SHOW TBLPROPERTIES` is `value: string` when you specify the table property key. To restore the old schema with the builtin catalog, you can set `spark.sql.legacy.keepCommandOutputSchema` to `true`.
6767

6868
- In Spark 3.2, we support typed literals in the partition spec of INSERT and ADD/DROP/RENAME PARTITION. For example, `ADD PARTITION(dt = date'2020-01-01')` adds a partition with date value `2020-01-01`. In Spark 3.1 and earlier, the partition value will be parsed as string value `date '2020-01-01'`, which is an illegal date value, and we add a partition with null value at the end.
69+
70+
- In Spark 3.2, `DataFrameNaFunctions.replace()` no longer uses exact string match for the input column names, to match the SQL syntax and support qualified column names. Input column name having a dot in the name (not nested) needs to be escaped with backtick \`. Now, it throws `AnalysisException` if the column is not found in the data frame schema. It also throws `IllegalArgumentException` if the input column name is a nested column. In Spark 3.1 and earlier, it used to ignore invalid input column name and nested column name.
6971

7072
## Upgrading from Spark SQL 3.0 to 3.1
7173

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
327327
*/
328328
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
329329
if (col == "*") {
330-
replace0(df.columns, replacement)
330+
replace0(df.logicalPlan.output, replacement)
331331
} else {
332-
replace0(Seq(col), replacement)
332+
replace(Seq(col), replacement)
333333
}
334334
}
335335

@@ -352,10 +352,21 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
352352
*
353353
* @since 1.3.1
354354
*/
355-
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
355+
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
356+
val attrs = cols.map { colName =>
357+
// Check column name exists
358+
val attr = df.resolve(colName) match {
359+
case a: Attribute => a
360+
case _ => throw new UnsupportedOperationException(
361+
s"Nested field ${colName} is not supported.")
362+
}
363+
attr
364+
}
365+
replace0(attrs, replacement)
366+
}
356367

357-
private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
358-
if (replacement.isEmpty || cols.isEmpty) {
368+
private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]): DataFrame = {
369+
if (replacement.isEmpty || attrs.isEmpty) {
359370
return df
360371
}
361372

@@ -379,15 +390,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
379390
case _: String => StringType
380391
}
381392

382-
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
383-
val projections = df.schema.fields.map { f =>
384-
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
385-
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
386-
replaceCol(f, replacementMap)
387-
} else if (f.dataType == targetColumnType && shouldReplace) {
388-
replaceCol(f, replacementMap)
393+
val output = df.queryExecution.analyzed.output
394+
val projections = output.map { attr =>
395+
if (attrs.contains(attr) && (attr.dataType == targetColumnType ||
396+
(attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) {
397+
replaceCol(attr, replacementMap)
389398
} else {
390-
df.col(f.name)
399+
Column(attr)
391400
}
392401
}
393402
df.select(projections : _*)
@@ -453,13 +462,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
453462
*
454463
* TODO: This can be optimized to use broadcast join when replacementMap is large.
455464
*/
456-
private def replaceCol[K, V](col: StructField, replacementMap: Map[K, V]): Column = {
457-
val keyExpr = df.col(col.name).expr
458-
def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
465+
private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]): Column = {
466+
def buildExpr(v: Any) = Cast(Literal(v), attr.dataType)
459467
val branches = replacementMap.flatMap { case (source, target) =>
460468
Seq(Literal(source), buildExpr(target))
461469
}.toSeq
462-
new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
470+
new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
463471
}
464472

465473
private def convertToDouble(v: Any): Double = v match {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,28 +461,65 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
461461
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
462462
}
463463

464-
test("SPARK-34417 - test fillMap() for column with a dot in the name") {
464+
test("SPARK-34417: test fillMap() for column with a dot in the name") {
465465
val na = "n/a"
466466
checkAnswer(
467467
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col")
468468
.na.fill(Map("`ColWith.Dot`" -> na)),
469469
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
470470
}
471471

472-
test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") {
472+
test("SPARK-34417: test fillMap() for qualified-column with a dot in the name") {
473473
val na = "n/a"
474474
checkAnswer(
475475
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF")
476476
.na.fill(Map("testDF.`ColWith.Dot`" -> na)),
477477
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
478478
}
479479

480-
test("SPARK-34417 - test fillMap() for column without a dot in the name" +
480+
test("SPARK-34417: test fillMap() for column without a dot in the name" +
481481
" and dataframe with another column having a dot in the name") {
482482
val na = "n/a"
483483
checkAnswer(
484484
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("Col", "ColWith.Dot")
485485
.na.fill(Map("Col" -> na)),
486486
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
487487
}
488+
489+
test("SPARK-34649: replace value of a column with dot in the name") {
490+
checkAnswer(
491+
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
492+
.na.replace("`Col.1`", Map( "n/a" -> "unknown")),
493+
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
494+
}
495+
496+
test("SPARK-34649: replace value of a qualified-column with dot in the name") {
497+
checkAnswer(
498+
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2").as("testDf")
499+
.na.replace("testDf.`Col.1`", Map( "n/a" -> "unknown")),
500+
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
501+
}
502+
503+
test("SPARK-34649: replace value of a dataframe having dot in the all column names") {
504+
checkAnswer(
505+
Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
506+
.na.replace("*", Map( "n/a" -> "unknown")),
507+
Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
508+
}
509+
510+
test("SPARK-34649: replace value of a column not present in the dataframe") {
511+
val df = Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
512+
val exception = intercept[AnalysisException] {
513+
df.na.replace("aa", Map( "n/a" -> "unknown"))
514+
}
515+
assert(exception.getMessage.equals("Cannot resolve column name \"aa\" among (Col.1, Col.2)"))
516+
}
517+
518+
test("SPARK-34649: replace value of a nested column") {
519+
val df = createDFWithNestedColumns
520+
val exception = intercept[UnsupportedOperationException] {
521+
df.na.replace("c1.c1-1", Map("b1" ->"a1"))
522+
}
523+
assert(exception.getMessage.equals("Nested field c1.c1-1 is not supported."))
524+
}
488525
}

0 commit comments

Comments
 (0)