diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 58612f65c1a53..abd6c88d3d985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -67,37 +67,61 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) -case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Size( + child: Expression, + legacySizeOfNull: Boolean) + extends UnaryExpression with ExpectsInputTypes { + + def this(child: Expression) = + this( + child, + legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = code""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e768416f257c9..239c8266351ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1324,6 +1324,12 @@ object SQLConf { "Other column values can be ignored during parsing even if they are malformed.") .booleanConf .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) } /** @@ -1686,6 +1692,8 @@ class SQLConf extends Serializable with Logging { def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5b8cf5128fe21..caea4fb25ff7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -24,25 +24,37 @@ import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0), 3) - checkEvaluation(Size(a1), 0) - checkEvaluation(Size(a2), 2) + checkEvaluation(Size(a0, legacySizeOfNull), 3) + checkEvaluation(Size(a1, legacySizeOfNull), 0) + checkEvaluation(Size(a2, legacySizeOfNull), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0), 2) - checkEvaluation(Size(m1), 0) - checkEvaluation(Size(m2), 1) + checkEvaluation(Size(m0, legacySizeOfNull), 2) + checkEvaluation(Size(m1, legacySizeOfNull), 0) + checkEvaluation(Size(m2, legacySizeOfNull), 1) + + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + testSize(legacySizeOfNull = true, sizeOfNull = -1) + } - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + test("Array and Map Size") { + testSize(legacySizeOfNull = false, sizeOfNull = null) } test("MapKeys/MapValues") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 40c40e7083d1c..ef99ce3ad69d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3431,7 +3431,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { Size(e.expr) } + def size(e: Column): Column = withExpr { new Size(e.expr) } /** * Sorts the input array for the given column in ascending order, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d6a6c0832c96..b109898b5bfb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -487,26 +487,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("cardinality(a)"), - Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) - ) + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) + } + + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } } test("dataframe arrays_zip function") { @@ -567,21 +570,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("map size function") { + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") {