Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
*
* @since 2.2.0
*/
def fill(value: Long): DataFrame = fill(value, df.columns)

/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
def fill(value: Double): DataFrame = fill(value, df.columns)
Expand All @@ -139,6 +145,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
def fill(value: String): DataFrame = fill(value, df.columns)

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored.
*
* @since 2.2.0
*/
def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
* If a specified column is not a numeric column, it is ignored.
Expand All @@ -147,24 +161,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*/
def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 2.2.0
*/
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric columns. If a specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
*/
def fill(value: Double, cols: Seq[String]): DataFrame = {
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
fillCol[Double](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)


/**
* Returns a new `DataFrame` that replaces null values in specified string columns.
Expand All @@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def fill(value: String, cols: Seq[String]): DataFrame = {
val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
fillCol[String](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)

/**
* Returns a new `DataFrame` that replaces null values.
Expand All @@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq)
def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq)

/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values.
Expand All @@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)

/**
* Replaces values matching keys in `replacement` map with the corresponding values.
Expand Down Expand Up @@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
df.select(projections : _*)
}

private def fill0(values: Seq[(String, Any)]): DataFrame = {
private def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
values.foreach { case (colName, replaceValue) =>
// Check column name exists
Expand Down Expand Up @@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case v => throw new IllegalArgumentException(
s"Unsupported value type ${v.getClass.getName} ($v).")
}

/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
* numeric, string columns. If a specified column is not a numeric, string column,
* it is ignored.
*/
private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
// the fill[T] which T is Long/Double,
// should apply on all the NumericType Column, for example:
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
// input.na.fill(3.1)
// the result is (3,164.3), not (null, 164.3)
val targetType = value match {
case _: Double | _: Long => NumericType
case _: String => StringType
case _ => throw new IllegalArgumentException(
s"Unsupported value type ${value.getClass.getName} ($value).")
}

val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val typeMatches = (targetType, f.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
fillCol[T](f, value)
} else {
df.col(f.name)
}
}
df.select(projections : _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
Row("test", null))

checkAnswer(
Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L))
.toDF("a", "b").na.fill(0),
Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil
)

checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(2.34),
Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil
)

checkAnswer(
Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45))
.toDF("a", "b").na.fill(5),
Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil
)
}

test("fill with map") {
Expand Down