Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
:func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.

:param value: int, long, float, string, or dict.
:param value: int, long, float, string, bool or dict.
Value to replace null values with.
If the value is a dict, then `subset` is ignored and `value` must be a mapping
from column name (string) to replacement value. The replacement value must be
Expand All @@ -1292,6 +1292,15 @@ def fillna(self, value, subset=None):
| 50| 50| null|
+---+------+-----+

>>> df5.na.fill(False).show()
+----+-------+-----+
| age| name| spy|
+----+-------+-----+
| 10| Alice|false|
| 5| Bob|false|
|null|Mallory| true|
+----+-------+-----+

>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
+---+------+-------+
|age|height| name|
Expand All @@ -1302,10 +1311,13 @@ def fillna(self, value, subset=None):
| 50| null|unknown|
+---+------+-------+
"""
if not isinstance(value, (float, int, long, basestring, dict)):
raise ValueError("value should be a float, int, long, string, or dict")
if not isinstance(value, (float, int, long, basestring, bool, dict)):
raise ValueError("value should be a float, int, long, string, bool or dict")

# Note that bool validates isinstance(int), but we don't want to
# convert bools to floats

if isinstance(value, (int, long)):
if not isinstance(value, bool) and isinstance(value, (int, long)):
value = float(value)

if isinstance(value, dict):
Expand Down Expand Up @@ -1802,6 +1814,9 @@ def _test():
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
Row(name=None, age=None, height=None)]).toDF()
globs['df5'] = sc.parallelize([Row(name='Alice', spy=False, age=10),
Row(name='Bob', spy=None, age=5),
Row(name='Mallory', spy=True, age=None)]).toDF()
globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846),
Row(name='Bob', time=1479442946)]).toDF()

Expand Down
34 changes: 26 additions & 8 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,40 +1697,58 @@ def test_fillna(self):
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True)])
StructField("height", DoubleType(), True),
StructField("spy", BooleanType(), True)])

# fillna shouldn't change non-null values
row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
self.assertEqual(row.age, 10)

# fillna with int
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.0)

# fillna with double
row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
row = self.spark.createDataFrame(
[(u'Alice', None, None, None)], schema).fillna(50.1).first()
self.assertEqual(row.age, 50)
self.assertEqual(row.height, 50.1)

# fillna with bool
row = self.spark.createDataFrame(
[(u'Alice', None, None, None)], schema).fillna(True).first()
self.assertEqual(row.age, None)
self.assertEqual(row.spy, True)

# fillna with string
row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
self.assertEqual(row.name, u"hello")
self.assertEqual(row.age, None)

# fillna with subset specified for numeric cols
row = self.spark.createDataFrame(
[(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
[(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
self.assertEqual(row.name, None)
self.assertEqual(row.age, 50)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, None)

# fillna with subset specified for numeric cols
# fillna with subset specified for string cols
row = self.spark.createDataFrame(
[(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
[(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
self.assertEqual(row.name, "haha")
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, None)

# fillna with subset specified for bool cols
row = self.spark.createDataFrame(
[(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
self.assertEqual(row.name, None)
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
self.assertEqual(row.spy, True)

# fillna with dictionary for boolean types
row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,30 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)

/**
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
*
* @since 2.3.0
*/
def fill(value: Boolean): DataFrame = fill(value, df.columns)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
* boolean columns. If a specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)

/**
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
* If a specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq)


/**
* Returns a new `DataFrame` that replaces null values.
*
Expand Down Expand Up @@ -440,8 +464,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric, string columns. If a specified column is not a numeric, string column,
* it is ignored.
* numeric, string columns. If a specified column is not a numeric, string
* or boolean column it is ignored.
*/
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
// the fill[T] which T is Long/Double,
Expand All @@ -452,6 +476,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val targetType = value match {
case _: Double | _: Long => NumericType
case _: String => StringType
case _: Boolean => BooleanType
case _ => throw new IllegalArgumentException(
s"Unsupported value type ${value.getClass.getName} ($value).")
}
Expand All @@ -461,6 +486,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val typeMatches = (targetType, f.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
case (BooleanType, dt) => dt == BooleanType
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
test("fill") {
val input = createDF()

val boolInput = Seq[(String, java.lang.Boolean)](
("Bob", false),
("Alice", null),
("Mallory", true),
(null, null)
).toDF("name", "spy")

val fillNumeric = input.na.fill(50.6)
checkAnswer(
fillNumeric,
Expand All @@ -124,6 +131,12 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil)
assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)

// boolean
checkAnswer(
boolInput.na.fill(true).select("spy"),
Row(false) :: Row(true) :: Row(true) :: Row(true) :: Nil)
assert(boolInput.na.fill(true).columns.toSeq === boolInput.columns.toSeq)

// fill double with subset columns
checkAnswer(
input.na.fill(50.6, "age" :: Nil).select("name", "age"),
Expand All @@ -134,6 +147,14 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
Row("Amy", 50) ::
Row(null, 50) :: Nil)

// fill boolean with subset columns
checkAnswer(
boolInput.na.fill(true, "spy" :: Nil).select("name", "spy"),
Row("Bob", false) ::
Row("Alice", true) ::
Row("Mallory", true) ::
Row(null, true) :: Nil)

// fill string with subset columns
checkAnswer(
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
Expand Down