From 8de494e63ac02acbc9e5ae494a2641ff1d7eaa71 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 15:35:53 +0000 Subject: [PATCH 01/12] Do not cast NaN to an Integer, Long, Short or Byte --- .../spark/sql/catalyst/expressions/Cast.scala | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a871a746d64ff..1471fba55ee91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -492,9 +492,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampType => buildCast[Long](_, t => timestampToLong(t)) case x: NumericType if ansiEnabled => - b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b) + b => if (checkIfNaN(b)) null else x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b) case x: NumericType => - b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } // IntConverter @@ -511,9 +511,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampType => buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => - b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) + b => if (checkIfNaN(b)) null else x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) case x: NumericType => - b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } // ShortConverter @@ -549,12 +549,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $b to short causes overflow") } if (intValue == intValue.toShort) { - intValue.toShort + if (checkIfNaN(b)) null else intValue.toShort } else { throw new ArithmeticException(s"Casting $b to short causes overflow") } case x: NumericType => - b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } // ByteConverter @@ -590,12 +590,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $b to byte causes overflow") } if (intValue == intValue.toByte) { - intValue.toByte + if (checkIfNaN(b)) null else intValue.toByte } else { throw new ArithmeticException(s"Casting $b to byte causes overflow") } case x: NumericType => - b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } /** @@ -780,6 +780,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } + // Check if NaN + private[this] def checkIfNaN(value: Any): Boolean = + (value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) || + (value.isInstanceOf[Float] && value.asInstanceOf[Float].isNaN) + private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) From fad33046a25a568836c79059cb4063e56c578e75 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 19:46:43 +0000 Subject: [PATCH 02/12] Revert "Do not cast NaN to an Integer, Long, Short or Byte" This reverts commit 8de494e63ac02acbc9e5ae494a2641ff1d7eaa71. --- .../spark/sql/catalyst/expressions/Cast.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1471fba55ee91..a871a746d64ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -492,9 +492,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampType => buildCast[Long](_, t => timestampToLong(t)) case x: NumericType if ansiEnabled => - b => if (checkIfNaN(b)) null else x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b) + b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b) case x: NumericType => - b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } // IntConverter @@ -511,9 +511,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampType => buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => - b => if (checkIfNaN(b)) null else x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) + b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) case x: NumericType => - b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } // ShortConverter @@ -549,12 +549,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $b to short causes overflow") } if (intValue == intValue.toShort) { - if (checkIfNaN(b)) null else intValue.toShort + intValue.toShort } else { throw new ArithmeticException(s"Casting $b to short causes overflow") } case x: NumericType => - b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } // ByteConverter @@ -590,12 +590,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $b to byte causes overflow") } if (intValue == intValue.toByte) { - if (checkIfNaN(b)) null else intValue.toByte + intValue.toByte } else { throw new ArithmeticException(s"Casting $b to byte causes overflow") } case x: NumericType => - b => if (checkIfNaN(b)) null else x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } /** @@ -780,11 +780,6 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - // Check if NaN - private[this] def checkIfNaN(value: Any): Boolean = - (value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) || - (value.isInstanceOf[Float] && value.asInstanceOf[Float].isNaN) - private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) From 4a9573c05312e6df1081ef1e6873b11c1c0f9949 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 19:47:32 +0000 Subject: [PATCH 03/12] Take it back a step Do not do a replace if source/target is NaN and the column is not of type Double/Float --- .../apache/spark/sql/DataFrameNaFunctions.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 8447ada88a704..40844066aa61c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -456,11 +456,25 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => - Seq(buildExpr(source), buildExpr(target)) + if (!(checkIfNonFractionalNumeric(col.dataType) && + (isNaN(source) || isNaN(target)))) { + Seq(buildExpr(source), buildExpr(target)) + } else { + None + } }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } + // Check if DataType is not Double or Float and is Numeric + private[this] def checkIfNonFractionalNumeric(value: DataType): Boolean = + value == LongType || value == IntegerType || value == ShortType || value == ByteType + + // Check if NaN + private[this] def isNaN(value: Any): Boolean = + (value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) || + (value.isInstanceOf[Float] && value.asInstanceOf[Float].isNaN) + private def convertToDouble(v: Any): Double = v match { case v: Float => v.toDouble case v: Double => v From 532f4478112e118597368f8c0373ae26191e0f6d Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 21:15:26 +0000 Subject: [PATCH 04/12] Add tests --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 6cb35656835af..ebd658423b719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -404,4 +404,81 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { df.na.drop("any"), Row("5", "6", "6") :: Nil) } + + test("replace nan with float") { + val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), + (0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + + checkAnswer( + input.na.replace("*", Map( + Float.NaN -> 10f + )), + Row(1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) + } + + test("replace nan with double") { + val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), + (0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + + checkAnswer( + input.na.replace("*", Map( + Double.NaN -> 10.toDouble + )), + Row(1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) + } + + test("replace float with nan") { + val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), + (0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + + checkAnswer( + input.na.replace("*", Map( + 1.0f -> Float.NaN + )), + Row(1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) + } + + test("replace double with nan") { + val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), + (0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + + checkAnswer( + input.na.replace("*", Map( + 1.toDouble -> Double.NaN + )), + Row(1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) + + } } From 1c5f1ebd5647b69ab9f9e7c0675cdb43174eba22 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 21:23:24 +0000 Subject: [PATCH 05/12] Avoid repeating df used in tests --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 55 ++++++------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index ebd658423b719..5efe6d727886e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -37,6 +37,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { ).toDF("name", "age", "height") } + def createNaNDF(): DataFrame = { + Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, + java.lang.Byte, java.lang.Float, java.lang.Double)]( + (1, new java.lang.Long(1), new java.lang.Short("1"), + new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), + (0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + ).toDF("int", "long", "short", "byte", "float", "double") + } + test("drop") { val input = createDF() val rows = input.collect() @@ -406,35 +416,19 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } test("replace nan with float") { - val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, - java.lang.Byte, java.lang.Float, java.lang.Double)]( - (1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), - (0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) - ).toDF("int", "long", "short", "byte", "float", "double") - checkAnswer( - input.na.replace("*", Map( + createNaNDF().na.replace("*", Map( Float.NaN -> 10f )), Row(1, new java.lang.Long(1), new java.lang.Short("1"), new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) :: - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) } test("replace nan with double") { - val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, - java.lang.Byte, java.lang.Float, java.lang.Double)]( - (1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), - (0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) - ).toDF("int", "long", "short", "byte", "float", "double") - checkAnswer( - input.na.replace("*", Map( + createNaNDF().na.replace("*", Map( Double.NaN -> 10.toDouble )), Row(1, new java.lang.Long(1), new java.lang.Short("1"), @@ -444,16 +438,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } test("replace float with nan") { - val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, - java.lang.Byte, java.lang.Float, java.lang.Double)]( - (1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), - (0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) - ).toDF("int", "long", "short", "byte", "float", "double") - checkAnswer( - input.na.replace("*", Map( + createNaNDF().na.replace("*", Map( 1.0f -> Float.NaN )), Row(1, new java.lang.Long(1), new java.lang.Short("1"), @@ -463,22 +449,13 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { } test("replace double with nan") { - val input = Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, - java.lang.Byte, java.lang.Float, java.lang.Double)]( - (1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), - (0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) - ).toDF("int", "long", "short", "byte", "float", "double") - checkAnswer( - input.na.replace("*", Map( + createNaNDF().na.replace("*", Map( 1.toDouble -> Double.NaN )), Row(1, new java.lang.Long(1), new java.lang.Short("1"), new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) :: Row(0, new java.lang.Long(0), new java.lang.Short("0"), new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) - } } From ed6f08dbbc0fb4f04e9ae9b59c119351bd4ca038 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Mon, 2 Dec 2019 21:43:25 +0000 Subject: [PATCH 06/12] Improve coding style --- .../apache/spark/sql/DataFrameNaFunctions.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 40844066aa61c..97d510440374c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -456,20 +456,18 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => - if (!(checkIfNonFractionalNumeric(col.dataType) && - (isNaN(source) || isNaN(target)))) { - Seq(buildExpr(source), buildExpr(target)) - } else { - None + if (isNaN(source) || isNaN(target)) { + col.dataType match { + case IntegerType | LongType | ShortType | ByteType => Seq.empty + case _ => Seq(buildExpr(source), buildExpr(target)) } + } else { + Seq(buildExpr(source), buildExpr(target)) + } }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } - // Check if DataType is not Double or Float and is Numeric - private[this] def checkIfNonFractionalNumeric(value: DataType): Boolean = - value == LongType || value == IntegerType || value == ShortType || value == ByteType - // Check if NaN private[this] def isNaN(value: Any): Boolean = (value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) || From 279a9fd4cb5953593d0b69c46a4d0ff229f0a7ee Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 11:03:17 +0000 Subject: [PATCH 07/12] Opt for a more simple fix --- .../apache/spark/sql/DataFrameNaFunctions.scala | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 97d510440374c..319c64ef0b4d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -455,24 +455,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) - val branches = replacementMap.flatMap { case (source, target) => - if (isNaN(source) || isNaN(target)) { - col.dataType match { - case IntegerType | LongType | ShortType | ByteType => Seq.empty - case _ => Seq(buildExpr(source), buildExpr(target)) - } - } else { - Seq(buildExpr(source), buildExpr(target)) - } + val branches = replacementMap.flatMap { case (src, target) => + Seq(Literal(src), buildExpr(target)) }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } - // Check if NaN - private[this] def isNaN(value: Any): Boolean = - (value.isInstanceOf[Double] && value.asInstanceOf[Double].isNaN) || - (value.isInstanceOf[Float] && value.asInstanceOf[Float].isNaN) - private def convertToDouble(v: Any): Double = v match { case v: Float => v.toDouble case v: Double => v From 6b5d26d04d4efb67d205767148c38c7dc21584fc Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 11:11:06 +0000 Subject: [PATCH 08/12] Adjust unit tests to reflect new behaviour --- .../org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 5efe6d727886e..80f1669125654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -442,8 +442,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { createNaNDF().na.replace("*", Map( 1.0f -> Float.NaN )), - Row(1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Row(0, new java.lang.Long(0), new java.lang.Short("0"), new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) } @@ -453,8 +453,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { createNaNDF().na.replace("*", Map( 1.toDouble -> Double.NaN )), - Row(1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), java.lang.Float.NaN, java.lang.Double.NaN) :: + Row(0, new java.lang.Long(0), new java.lang.Short("0"), + new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Row(0, new java.lang.Long(0), new java.lang.Short("0"), new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) } From 10c91d6fe1305c70932618583a12681d43513b33 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 11:14:28 +0000 Subject: [PATCH 09/12] Adjust whitespace --- .../main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 319c64ef0b4d3..46260a18fd6d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -456,7 +456,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (src, target) => - Seq(Literal(src), buildExpr(target)) + Seq(Literal(src), buildExpr(target)) }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } From 1744b28a04f08dfd25f10829a6080a0471025856 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 11:44:13 +0000 Subject: [PATCH 10/12] Do not change param name --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 46260a18fd6d7..2a86b65b8f79f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -455,8 +455,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) - val branches = replacementMap.flatMap { case (src, target) => - Seq(Literal(src), buildExpr(target)) + val branches = replacementMap.flatMap { case (source, target) => + Seq(Literal(source), buildExpr(target)) }.toSeq new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name) } From 1295633ce028637974ca97119b2be96f40db71e3 Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 12:59:08 +0000 Subject: [PATCH 11/12] Clean unit tests df generator --- .../org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 80f1669125654..fde2fafa7ece2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -40,10 +40,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { def createNaNDF(): DataFrame = { Seq[(java.lang.Integer, java.lang.Long, java.lang.Short, java.lang.Byte, java.lang.Float, java.lang.Double)]( - (1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0), - (0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) + (1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0), + (0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ).toDF("int", "long", "short", "byte", "float", "double") } From b3709a16e89284584922995f4877c31b5095aafe Mon Sep 17 00:00:00 2001 From: John Ayad Date: Tue, 3 Dec 2019 13:55:36 +0000 Subject: [PATCH 12/12] Clean tests themselves as well --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 30 +++++++------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fde2fafa7ece2..fb1ca69b6f73f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -416,23 +416,19 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { test("replace nan with float") { checkAnswer( createNaNDF().na.replace("*", Map( - Float.NaN -> 10f + Float.NaN -> 10.0f )), - Row(1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) :: - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) } test("replace nan with double") { checkAnswer( createNaNDF().na.replace("*", Map( - Double.NaN -> 10.toDouble + Double.NaN -> 10.0 )), - Row(1, new java.lang.Long(1), new java.lang.Short("1"), - new java.lang.Byte("1"), new java.lang.Float(1.0), 1.0) :: - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), new java.lang.Float(10), new java.lang.Double(10)) :: Nil) + Row(1, 1L, 1.toShort, 1.toByte, 1.0f, 1.0) :: + Row(0, 0L, 0.toShort, 0.toByte, 10.0f, 10.0) :: Nil) } test("replace float with nan") { @@ -440,20 +436,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession { createNaNDF().na.replace("*", Map( 1.0f -> Float.NaN )), - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) } test("replace double with nan") { checkAnswer( createNaNDF().na.replace("*", Map( - 1.toDouble -> Double.NaN + 1.0 -> Double.NaN )), - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: - Row(0, new java.lang.Long(0), new java.lang.Short("0"), - new java.lang.Byte("0"), java.lang.Float.NaN, java.lang.Double.NaN) :: Nil) + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: + Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil) } }