Skip to content

Commit 6cbc61d

Browse files
rberenguelueshin
authored andcommitted
[SPARK-19732][SQL][PYSPARK] Add fill functions for nulls in bool fields of datasets
## What changes were proposed in this pull request? Allow fill/replace of NAs with booleans, both in Python and Scala ## How was this patch tested? Unit tests, doctests This PR is original work from me and I license this work to the Spark project Author: Ruben Berenguel Montoro <[email protected]> Author: Ruben Berenguel <[email protected]> Closes #18164 from rberenguel/SPARK-19732-fillna-bools.
1 parent 864d94f commit 6cbc61d

File tree

4 files changed

+94
-14
lines changed

4 files changed

+94
-14
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,7 @@ def fillna(self, value, subset=None):
12891289
"""Replace null values, alias for ``na.fill()``.
12901290
:func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
12911291
1292-
:param value: int, long, float, string, or dict.
1292+
:param value: int, long, float, string, bool or dict.
12931293
Value to replace null values with.
12941294
If the value is a dict, then `subset` is ignored and `value` must be a mapping
12951295
from column name (string) to replacement value. The replacement value must be
@@ -1309,6 +1309,15 @@ def fillna(self, value, subset=None):
13091309
| 50| 50| null|
13101310
+---+------+-----+
13111311
1312+
>>> df5.na.fill(False).show()
1313+
+----+-------+-----+
1314+
| age| name| spy|
1315+
+----+-------+-----+
1316+
| 10| Alice|false|
1317+
| 5| Bob|false|
1318+
|null|Mallory| true|
1319+
+----+-------+-----+
1320+
13121321
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
13131322
+---+------+-------+
13141323
|age|height| name|
@@ -1319,10 +1328,13 @@ def fillna(self, value, subset=None):
13191328
| 50| null|unknown|
13201329
+---+------+-------+
13211330
"""
1322-
if not isinstance(value, (float, int, long, basestring, dict)):
1323-
raise ValueError("value should be a float, int, long, string, or dict")
1331+
if not isinstance(value, (float, int, long, basestring, bool, dict)):
1332+
raise ValueError("value should be a float, int, long, string, bool or dict")
1333+
1334+
# Note that bool validates isinstance(int), but we don't want to
1335+
# convert bools to floats
13241336

1325-
if isinstance(value, (int, long)):
1337+
if not isinstance(value, bool) and isinstance(value, (int, long)):
13261338
value = float(value)
13271339

13281340
if isinstance(value, dict):
@@ -1819,6 +1831,9 @@ def _test():
18191831
Row(name='Bob', age=5, height=None),
18201832
Row(name='Tom', age=None, height=None),
18211833
Row(name=None, age=None, height=None)]).toDF()
1834+
globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10),
1835+
Row(name='Bob', spy=None, age=5),
1836+
Row(name='Mallory', spy=True, age=None)]).toDF()
18221837
globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846),
18231838
Row(name='Bob', time=1479442946)]).toDF()
18241839

python/pyspark/sql/tests.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,40 +1697,58 @@ def test_fillna(self):
16971697
schema = StructType([
16981698
StructField("name", StringType(), True),
16991699
StructField("age", IntegerType(), True),
1700-
StructField("height", DoubleType(), True)])
1700+
StructField("height", DoubleType(), True),
1701+
StructField("spy", BooleanType(), True)])
17011702

17021703
# fillna shouldn't change non-null values
1703-
row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
1704+
row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
17041705
self.assertEqual(row.age, 10)
17051706

17061707
# fillna with int
1707-
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
1708+
row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
17081709
self.assertEqual(row.age, 50)
17091710
self.assertEqual(row.height, 50.0)
17101711

17111712
# fillna with double
1712-
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
1713+
row = self.spark.createDataFrame(
1714+
[(u'Alice', None, None, None)], schema).fillna(50.1).first()
17131715
self.assertEqual(row.age, 50)
17141716
self.assertEqual(row.height, 50.1)
17151717

1718+
# fillna with bool
1719+
row = self.spark.createDataFrame(
1720+
[(u'Alice', None, None, None)], schema).fillna(True).first()
1721+
self.assertEqual(row.age, None)
1722+
self.assertEqual(row.spy, True)
1723+
17161724
# fillna with string
1717-
row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
1725+
row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
17181726
self.assertEqual(row.name, u"hello")
17191727
self.assertEqual(row.age, None)
17201728

17211729
# fillna with subset specified for numeric cols
17221730
row = self.spark.createDataFrame(
1723-
[(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
1731+
[(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
17241732
self.assertEqual(row.name, None)
17251733
self.assertEqual(row.age, 50)
17261734
self.assertEqual(row.height, None)
1735+
self.assertEqual(row.spy, None)
17271736

1728-
# fillna with subset specified for numeric cols
1737+
# fillna with subset specified for string cols
17291738
row = self.spark.createDataFrame(
1730-
[(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
1739+
[(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
17311740
self.assertEqual(row.name, "haha")
17321741
self.assertEqual(row.age, None)
17331742
self.assertEqual(row.height, None)
1743+
self.assertEqual(row.spy, None)
1744+
1745+
# fillna with subset specified for bool cols
1746+
row = self.spark.createDataFrame(
1747+
[(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
1748+
self.assertEqual(row.name, None)
1749+
self.assertEqual(row.age, None)
1750+
self.assertEqual(row.height, None)
1751+
self.assertEqual(row.spy, True)
17341752

17351753
# fillna with dictionary for boolean types
17361754
row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,30 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
195195
*/
196196
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
197197

198+
/**
199+
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
200+
*
201+
* @since 2.3.0
202+
*/
203+
def fill(value: Boolean): DataFrame = fill(value, df.columns)
204+
205+
/**
206+
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
207+
* boolean columns. If a specified column is not a boolean column, it is ignored.
208+
*
209+
* @since 2.3.0
210+
*/
211+
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
212+
213+
/**
214+
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
215+
* If a specified column is not a boolean column, it is ignored.
216+
*
217+
* @since 2.3.0
218+
*/
219+
def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
220+
221+
198222
/**
199223
* Returns a new `DataFrame` that replaces null values.
200224
*
@@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
440464

441465
/**
442466
* Returns a new `DataFrame` that replaces null or NaN values in specified
443-
* numeric, string columns. If a specified column is not a numeric, string column,
444-
* it is ignored.
467+
* numeric, string columns. If a specified column is not a numeric, string
468+
* or boolean column it is ignored.
445469
*/
446470
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
447471
// the fill[T] which T is Long/Double,
@@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
452476
val targetType = value match {
453477
case _: Double | _: Long => NumericType
454478
case _: String => StringType
479+
case _: Boolean => BooleanType
455480
case _ => throw new IllegalArgumentException(
456481
s"Unsupported value type ${value.getClass.getName} ($value).")
457482
}
@@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
461486
val typeMatches = (targetType, f.dataType) match {
462487
case (NumericType, dt) => dt.isInstanceOf[NumericType]
463488
case (StringType, dt) => dt == StringType
489+
case (BooleanType, dt) => dt == BooleanType
464490
}
465491
// Only fill if the column is part of the cols list.
466492
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
104104
test("fill") {
105105
val input = createDF()
106106

107+
val boolInput = Seq[(String, java.lang.Boolean)](
108+
("Bob", false),
109+
("Alice", null),
110+
("Mallory", true),
111+
(null, null)
112+
).toDF("name", "spy")
113+
107114
val fillNumeric = input.na.fill(50.6)
108115
checkAnswer(
109116
fillNumeric,
@@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
124131
Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
125132
assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
126133

134+
// boolean
135+
checkAnswer(
136+
boolInput.na.fill(true).select("spy"),
137+
Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil)
138+
assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq)
139+
127140
// fill double with subset columns
128141
checkAnswer(
129142
input.na.fill(50.6, "age" :: Nil).select("name", "age"),
@@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
134147
Row("Amy", 50) ::
135148
Row(null, 50) :: Nil)
136149

150+
// fill boolean with subset columns
151+
checkAnswer(
152+
boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"),
153+
Row("Bob", false) ::
154+
Row("Alice", true) ::
155+
Row("Mallory", true) ::
156+
Row(null, true) :: Nil)
157+
137158
// fill string with subset columns
138159
checkAnswer(
139160
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),

0 commit comments

Comments
 (0)