Skip to content

Commit 84454d7

Browse files
jiayue-zhanggatorsmile
authored andcommitted
[SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None
## What changes were proposed in this pull request? Currently `df.na.replace("*", Map[String, String]("NULL" -> null))` will produce exception. This PR enables passing null/None as value in the replacement map in DataFrame.replace(). Note that the replacement map keys and values should still be the same type, while the values can have a mix of null/None and that type. This PR enables following operations for example: `df.na.replace("*", Map[String, String]("NULL" -> null))`(scala) `df.na.replace("*", Map[Any, Any](60 -> null, 70 -> 80))`(scala) `df.na.replace('Alice', None)`(python) `df.na.replace([10, 20])`(python, replacing with None is by default) One use case could be: I want to replace all the empty strings with null/None because they were incorrectly generated and then drop all null/None data `df.na.replace("*", Map("" -> null)).na.drop()`(scala) `df.replace(u'', None).dropna()`(python) ## How was this patch tested? Scala unit test. Python doctest and unit test. Author: bravo-zhang <[email protected]> Closes #18820 from bravo-zhang/spark-14932.
1 parent c06f3f5 commit 84454d7

File tree

4 files changed

+113
-37
lines changed

4 files changed

+113
-37
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,8 +1362,8 @@ def replace(self, to_replace, value=None, subset=None):
13621362
"""Returns a new :class:`DataFrame` replacing a value with another value.
13631363
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
13641364
aliases of each other.
1365-
Values to_replace and value should contain either all numerics, all booleans,
1366-
or all strings. When replacing, the new value will be cast
1365+
Values to_replace and value must have the same type and can only be numerics, booleans,
1366+
or strings. Value can have None. When replacing, the new value will be cast
13671367
to the type of the existing column.
13681368
For numeric replacements all values to be replaced should have unique
13691369
floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`)
@@ -1373,8 +1373,8 @@ def replace(self, to_replace, value=None, subset=None):
13731373
Value to be replaced.
13741374
If the value is a dict, then `value` is ignored and `to_replace` must be a
13751375
mapping between a value and a replacement.
1376-
:param value: int, long, float, string, or list.
1377-
The replacement value must be an int, long, float, or string. If `value` is a
1376+
:param value: bool, int, long, float, string, list or None.
1377+
The replacement value must be a bool, int, long, float, string or None. If `value` is a
13781378
list, `value` should be of the same length and type as `to_replace`.
13791379
If `value` is a scalar and `to_replace` is a sequence, then `value` is
13801380
used as a replacement for each item in `to_replace`.
@@ -1393,6 +1393,16 @@ def replace(self, to_replace, value=None, subset=None):
13931393
|null| null| null|
13941394
+----+------+-----+
13951395
1396+
>>> df4.na.replace('Alice', None).show()
1397+
+----+------+----+
1398+
| age|height|name|
1399+
+----+------+----+
1400+
| 10| 80|null|
1401+
| 5| null| Bob|
1402+
|null| null| Tom|
1403+
|null| null|null|
1404+
+----+------+----+
1405+
13961406
>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
13971407
+----+------+----+
13981408
| age|height|name|
@@ -1425,12 +1435,13 @@ def all_of_(xs):
14251435
valid_types = (bool, float, int, long, basestring, list, tuple)
14261436
if not isinstance(to_replace, valid_types + (dict, )):
14271437
raise ValueError(
1428-
"to_replace should be a float, int, long, string, list, tuple, or dict. "
1438+
"to_replace should be a bool, float, int, long, string, list, tuple, or dict. "
14291439
"Got {0}".format(type(to_replace)))
14301440

1431-
if not isinstance(value, valid_types) and not isinstance(to_replace, dict):
1441+
if not isinstance(value, valid_types) and value is not None \
1442+
and not isinstance(to_replace, dict):
14321443
raise ValueError("If to_replace is not a dict, value should be "
1433-
"a float, int, long, string, list, or tuple. "
1444+
"a bool, float, int, long, string, list, tuple or None. "
14341445
"Got {0}".format(type(value)))
14351446

14361447
if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
@@ -1446,21 +1457,21 @@ def all_of_(xs):
14461457
if isinstance(to_replace, (float, int, long, basestring)):
14471458
to_replace = [to_replace]
14481459

1449-
if isinstance(value, (float, int, long, basestring)):
1450-
value = [value for _ in range(len(to_replace))]
1451-
14521460
if isinstance(to_replace, dict):
14531461
rep_dict = to_replace
14541462
if value is not None:
14551463
warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
14561464
else:
1465+
if isinstance(value, (float, int, long, basestring)) or value is None:
1466+
value = [value for _ in range(len(to_replace))]
14571467
rep_dict = dict(zip(to_replace, value))
14581468

14591469
if isinstance(subset, basestring):
14601470
subset = [subset]
14611471

1462-
# Verify we were not passed in mixed type generics."
1463-
if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values())
1472+
# Verify we were not passed in mixed type generics.
1473+
if not any(all_of_type(rep_dict.keys())
1474+
and all_of_type(x for x in rep_dict.values() if x is not None)
14641475
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
14651476
raise ValueError("Mixed type replacements are not supported")
14661477

python/pyspark/sql/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,21 @@ def test_replace(self):
19641964
.replace(False, True).first())
19651965
self.assertTupleEqual(row, (True, True))
19661966

1967+
# replace list while value is not given (default to None)
1968+
row = self.spark.createDataFrame(
1969+
[(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
1970+
self.assertTupleEqual(row, (None, 10, 80.0))
1971+
1972+
# replace string with None and then drop None rows
1973+
row = self.spark.createDataFrame(
1974+
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
1975+
self.assertEqual(row.count(), 0)
1976+
1977+
# replace with number and None
1978+
row = self.spark.createDataFrame(
1979+
[(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
1980+
self.assertTupleEqual(row, (u'Alice', 20, None))
1981+
19671982
# should fail if subset is not list, tuple or None
19681983
with self.assertRaises(ValueError):
19691984
self.spark.createDataFrame(

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
260260

261261
/**
262262
* Replaces values matching keys in `replacement` map with the corresponding values.
263-
* Key and value of `replacement` map must have the same type, and
264-
* can only be doubles, strings or booleans.
265-
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
266263
*
267264
* {{{
268265
* import com.google.common.collect.ImmutableMap;
@@ -277,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
277274
* df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
278275
* }}}
279276
*
280-
* @param col name of the column to apply the value replacement
281-
* @param replacement value replacement map, as explained above
277+
* @param col name of the column to apply the value replacement. If `col` is "*",
278+
* replacement is applied on all string, numeric or boolean columns.
279+
* @param replacement value replacement map. Key and value of `replacement` map must have
280+
* the same type, and can only be doubles, strings or booleans.
281+
* The map value can have nulls.
282282
*
283283
* @since 1.3.1
284284
*/
@@ -288,8 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
288288

289289
/**
290290
* Replaces values matching keys in `replacement` map with the corresponding values.
291-
* Key and value of `replacement` map must have the same type, and
292-
* can only be doubles, strings or booleans.
293291
*
294292
* {{{
295293
* import com.google.common.collect.ImmutableMap;
@@ -301,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
301299
* df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
302300
* }}}
303301
*
304-
* @param cols list of columns to apply the value replacement
305-
* @param replacement value replacement map, as explained above
302+
* @param cols list of columns to apply the value replacement. If `col` is "*",
303+
* replacement is applied on all string, numeric or boolean columns.
304+
* @param replacement value replacement map. Key and value of `replacement` map must have
305+
* the same type, and can only be doubles, strings or booleans.
306+
* The map value can have nulls.
306307
*
307308
* @since 1.3.1
308309
*/
@@ -312,10 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
312313

313314
/**
314315
* (Scala-specific) Replaces values matching keys in `replacement` map.
315-
* Key and value of `replacement` map must have the same type, and
316-
* can only be doubles, strings or booleans.
317-
* If `col` is "*",
318-
* then the replacement is applied on all string columns , numeric columns or boolean columns.
319316
*
320317
* {{{
321318
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
@@ -328,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
328325
* df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
329326
* }}}
330327
*
331-
* @param col name of the column to apply the value replacement
332-
* @param replacement value replacement map, as explained above
328+
* @param col name of the column to apply the value replacement. If `col` is "*",
329+
* replacement is applied on all string, numeric or boolean columns.
330+
* @param replacement value replacement map. Key and value of `replacement` map must have
331+
* the same type, and can only be doubles, strings or booleans.
332+
* The map value can have nulls.
333333
*
334334
* @since 1.3.1
335335
*/
@@ -343,8 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
343343

344344
/**
345345
* (Scala-specific) Replaces values matching keys in `replacement` map.
346-
* Key and value of `replacement` map must have the same type, and
347-
* can only be doubles , strings or booleans.
348346
*
349347
* {{{
350348
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
@@ -354,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
354352
* df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
355353
* }}}
356354
*
357-
* @param cols list of columns to apply the value replacement
358-
* @param replacement value replacement map, as explained above
355+
* @param cols list of columns to apply the value replacement. If `col` is "*",
356+
* replacement is applied on all string, numeric or boolean columns.
357+
* @param replacement value replacement map. Key and value of `replacement` map must have
358+
* the same type, and can only be doubles, strings or booleans.
359+
* The map value can have nulls.
359360
*
360361
* @since 1.3.1
361362
*/
@@ -366,14 +367,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
366367
return df
367368
}
368369

369-
// replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean]
370-
val replacementMap: Map[_, _] = replacement.head._2 match {
371-
case v: String => replacement
372-
case v: Boolean => replacement
373-
case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
370+
// Convert the NumericType in replacement map to DoubleType,
371+
// while leaving StringType, BooleanType and null untouched.
372+
val replacementMap: Map[_, _] = replacement.map {
373+
case (k, v: String) => (k, v)
374+
case (k, v: Boolean) => (k, v)
375+
case (k: String, null) => (k, null)
376+
case (k: Boolean, null) => (k, null)
377+
case (k, null) => (convertToDouble(k), null)
378+
case (k, v) => (convertToDouble(k), convertToDouble(v))
374379
}
375380

376-
// targetColumnType is either DoubleType or StringType or BooleanType
381+
// targetColumnType is either DoubleType, StringType or BooleanType,
382+
// depending on the type of first key in replacement map.
383+
// Only fields of targetColumnType will perform replacement.
377384
val targetColumnType = replacement.head._1 match {
378385
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
379386
case _: jl.Boolean => BooleanType

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,47 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
262262
assert(out1(4) === Row("Amy", null, null))
263263
assert(out1(5) === Row(null, null, null))
264264
}
265+
266+
test("replace with null") {
267+
val input = Seq[(String, java.lang.Double, java.lang.Boolean)](
268+
("Bob", 176.5, true),
269+
("Alice", 164.3, false),
270+
("David", null, true)
271+
).toDF("name", "height", "married")
272+
273+
// Replace String with String and null
274+
checkAnswer(
275+
input.na.replace("name", Map(
276+
"Bob" -> "Bravo",
277+
"Alice" -> null
278+
)),
279+
Row("Bravo", 176.5, true) ::
280+
Row(null, 164.3, false) ::
281+
Row("David", null, true) :: Nil)
282+
283+
// Replace Double with null
284+
checkAnswer(
285+
input.na.replace("height", Map[Any, Any](
286+
164.3 -> null
287+
)),
288+
Row("Bob", 176.5, true) ::
289+
Row("Alice", null, false) ::
290+
Row("David", null, true) :: Nil)
291+
292+
// Replace Boolean with null
293+
checkAnswer(
294+
input.na.replace("*", Map[Any, Any](
295+
false -> null
296+
)),
297+
Row("Bob", 176.5, true) ::
298+
Row("Alice", 164.3, null) ::
299+
Row("David", null, true) :: Nil)
300+
301+
// Replace String with null and then drop rows containing null
302+
checkAnswer(
303+
input.na.replace("name", Map(
304+
"Bob" -> null
305+
)).na.drop("name" :: Nil).select("name"),
306+
Row("Alice") :: Row("David") :: Nil)
307+
}
265308
}

0 commit comments

Comments
 (0)