From f22d3df35ec97fba791d2968c369ce2dbaf41731 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 14 Apr 2018 16:52:37 -0700 Subject: [PATCH 01/31] Initial commit --- python/pyspark/sql/functions.py | 22 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 56 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 21 +++++++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 29 ++++++++++ 6 files changed, 137 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4d371976364d3..64e5bc4c2c688 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2489,6 +2489,28 @@ def arrays_zip(*cols): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. If a key is found in multiple given maps, + that key's value in the resulting map comes from the last one of those maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +------------------------+ + |map3 | + +------------------------+ + |[1 -> d, 2 -> b, 3 -> c]| + +------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + # ---------------------------- User Defined Function ---------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 80a0af672bf74..e7517e8c676e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -422,6 +422,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..5917b2adee16e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.{Comparator, TimeZone} +import java.util import scala.collection.mutable import scala.reflect.ClassTag @@ -501,6 +502,61 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp } override def prettyName: String = "map_entries" + +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( +usage = "_FUNC_(map, ...) - Returns the union of all the given maps", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "c"], [3 -> "d"] + """) +case class MapConcat(children: Seq[Expression]) extends Expression + with CodegenFallback { + + override def checkInputDataTypes(): TypeCheckResult = { + // this check currently does not allow valueContainsNull to vary, + // and unfortunately none of the MapType toString methods include + // valueContainsNull for the error message + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"The given input of function $prettyName should all be of type map, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName should all be the same type, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + override def dataType: MapType = { + children.headOption.map(_.dataType.asInstanceOf[MapType]) + .getOrElse(MapType(keyType = StringType, valueType = StringType)) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val union = new util.LinkedHashMap[Any, Any]() + children.map(_.eval(input)).foreach { raw => + if (raw != null) { + val map = raw.asInstanceOf[MapData] + map.foreach(dataType.keyType, dataType.valueType, (k, v) => + union.put(k, v) + ) + } + } + val (keyArray, valueArray) = union.entrySet().toArray().map { e => + val e2 = e.asInstanceOf[java.util.Map.Entry[Any, Any]] + (e2.getKey, e2.getValue) + }.unzip + new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + } + + override def prettyName: String = "map_concat" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..78dae0a9aa0bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone +import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -96,6 +97,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) checkEvaluation(MapEntries(ms1), Seq.empty) checkEvaluation(MapEntries(ms2), null) + + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + val i1 = Literal.create(1, IntegerType) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), Map("a" -> "4", "b" -> "2", "c" -> "3")) + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3", "d" -> "4", "e" -> "5")) + // no input + checkEvaluation(MapConcat(Seq()), Map()) + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), Map("a" -> "1", "b" -> "2")) } test("MapFromEntries") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f2627e69939cd..89dbba10a6bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3627,6 +3627,14 @@ object functions { @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..420440ffe4440 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -655,6 +655,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.select(map_entries('m)), sExpected) checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + checkAnswer( + df1.selectExpr("map_concat(map1, map2)"), + Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(Map(3 -> 300, 4 -> 400)) + ) + ) + + val df2 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) + ).toDF("map1", "map2") + checkAnswer( + df2.selectExpr("map_concat()"), + Seq(Row(Map())) + ) + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains("input maps of function map_concat should all be the same type")) + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains("input of function map_concat should all be of type map")) } test("map_from_entries function") { From 84eeec57bea119a6e74bceff2efa9db6461d1f17 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 14 Apr 2018 19:04:45 -0700 Subject: [PATCH 02/31] Remove unused variable in test --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 78dae0a9aa0bb..301e0bf6b738e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -103,7 +103,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType)) val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) val mNull = Literal.create(null, MapType(StringType, StringType)) - val i1 = Literal.create(1, IntegerType) // overlapping maps checkEvaluation(MapConcat(Seq(m0, m1)), Map("a" -> "4", "b" -> "2", "c" -> "3")) From f6fbbc83773fbdc16509918abd167daf0a102cd1 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 14 Apr 2018 20:35:47 -0700 Subject: [PATCH 03/31] Cleanup --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 301e0bf6b738e..be6570796cfaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,7 +105,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps - checkEvaluation(MapConcat(Seq(m0, m1)), Map("a" -> "4", "b" -> "2", "c" -> "3")) + checkEvaluation(MapConcat(Seq(m0, m1)), + mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3")) // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) @@ -115,7 +116,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // no input checkEvaluation(MapConcat(Seq()), Map()) // null map - checkEvaluation(MapConcat(Seq(m0, mNull)), Map("a" -> "1", "b" -> "2")) + checkEvaluation(MapConcat(Seq(m0, mNull)), + mutable.LinkedHashMap("a" -> "1", "b" -> "2")) } test("MapFromEntries") { From 4ed8627783c001812cddc0780531532de511314a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 16 Apr 2018 21:29:26 -0700 Subject: [PATCH 04/31] Checkpoint non-working codegen --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/collectionOperations.scala | 84 +++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 838c045d5bcce..da8c39b87a741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1337,7 +1337,7 @@ object CodeGenerator extends Logging { )) evaluator.setExtendedClass(classOf[GeneratedClass]) - logDebug({ + logWarning({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) s"\n${CodeFormatter.format(code)}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5917b2adee16e..d112ed1fd7139 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -556,6 +556,90 @@ case class MapConcat(children: Seq[Expression]) extends Expression new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(c => c.genCode(ctx)) + val keyTypes = children.map(c => c.dataType.asInstanceOf[MapType].keyType) + val valueTypes = children.map(c => c.dataType.asInstanceOf[MapType].valueType) + val mapRefArrayName = ctx.freshName("mapRefArray") + val mapNullArrayName = ctx.freshName("mapNullArray") + val unionMapName = ctx.freshName("union") + + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[GenericArrayData].getName + val hashMapClass = classOf[util.LinkedHashMap[Any, Any]].getName + val entryClass = classOf[util.Map.Entry[Any, Any]].getName + + val init = + s""" + |boolean[] $mapNullArrayName = new boolean[${mapCodes.size}]; + |Object[] $mapRefArrayName = new Object[${mapCodes.size}]; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + val initCode = mapCodes(i).code + val isNullVarname = mapCodes(i).isNull + val valueVarname = mapCodes(i).value.code + s""" + |$initCode + |$mapNullArrayName[$i] = $isNullVarname; + | if (!$mapNullArrayName[$i]) { + | $mapRefArrayName[$i] = $valueVarname; + | } else { + | $mapRefArrayName[$i] = null; + | } + """.stripMargin + }.mkString("\n") + + val index1Name = ctx.freshName("idx1") + val index2Name = ctx.freshName("idx2") + val mapDataName = ctx.freshName("m") + val kaName = ctx.freshName("ka") + val vaName = ctx.freshName("va") + + val mapMerge = + s""" + |$hashMapClass $unionMapName = new $hashMapClass(); + |for (int $index1Name = 0; $index1Name < $mapRefArrayName.size; $index1Name++) { + | boolean isNull = $mapNullArrayName[$index1Name]; + | if (isNull) { + | continue; + | } + | MapData $mapDataName = ($mapDataClass) $mapRefArrayName[$index1Name]; + | Object[] $kaName = (Object[]) $mapDataName.keyArray().toObjectArray(); + | Object[] $vaName = (Object[]) $mapDataName.valueArray().toObjectArray(); + | for (int $index2Name = 0; $index2Name < $kaName.length; $index2Name++) { + | $unionMapName.put($kaName[$index2Name], $vaName[$index2Name]); + | } + |} + """.stripMargin + + val mergedKeyArrayName = ctx.freshName("keyArray") + val mergedValueArrayName = ctx.freshName("valueArray") + val entrySetName = ctx.freshName("entrySet") + val createMapData = + s""" + |$entryClass[] entries = $unionMapName.entrySet().toArray(); + |Object[] $mergedKeyArrayName = new Object[$unionMapName.size]; + |Object[] $mergedValueArrayName = new Object[$unionMapName.size]; + |for (int $index1Name = 0; $index1Name < $entrySetName.length(); $index1Name++) { + | $entryClass entry = $entrySetName[$index1Name]; + | $mergedKeyArrayName[$index1Name] = (Object) entry.getKey(); + | $mergedValueArrayName[$index1Name] = (Object) entry.getValue(); + |} + |${ev.value} = new $arrayBasedMapDataClass(new $arrayDataClass($mergedKeyArrayName), + | new $arrayDataClass($mergedValueArrayName)); + """.stripMargin + val code = + s""" + |$init + |$assignments + |$mapMerge + $createMapData + """.stripMargin + ev.copy(code = code) + } + override def prettyName: String = "map_concat" } From aaee5b81817a75edb0ada90524db885180476a00 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 17 Apr 2018 10:14:08 -0700 Subject: [PATCH 05/31] Checkpoint somewhat working codegen --- .../expressions/collectionOperations.scala | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d112ed1fd7139..5103567652df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -566,7 +566,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression val mapDataClass = classOf[MapData].getName val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName - val arrayDataClass = classOf[GenericArrayData].getName + val arrayDataClass = classOf[ArrayData].getName + val genericArrayDataClass = classOf[GenericArrayData].getName val hashMapClass = classOf[util.LinkedHashMap[Any, Any]].getName val entryClass = classOf[util.Map.Entry[Any, Any]].getName @@ -574,6 +575,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |boolean[] $mapNullArrayName = new boolean[${mapCodes.size}]; |Object[] $mapRefArrayName = new Object[${mapCodes.size}]; + |boolean ${ev.isNull} = false; """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => @@ -596,20 +598,24 @@ case class MapConcat(children: Seq[Expression]) extends Expression val mapDataName = ctx.freshName("m") val kaName = ctx.freshName("ka") val vaName = ctx.freshName("va") + val keyName = ctx.freshName("key") + val valueName = ctx.freshName("value") val mapMerge = s""" |$hashMapClass $unionMapName = new $hashMapClass(); - |for (int $index1Name = 0; $index1Name < $mapRefArrayName.size; $index1Name++) { + |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { | boolean isNull = $mapNullArrayName[$index1Name]; | if (isNull) { | continue; | } | MapData $mapDataName = ($mapDataClass) $mapRefArrayName[$index1Name]; - | Object[] $kaName = (Object[]) $mapDataName.keyArray().toObjectArray(); - | Object[] $vaName = (Object[]) $mapDataName.valueArray().toObjectArray(); - | for (int $index2Name = 0; $index2Name < $kaName.length; $index2Name++) { - | $unionMapName.put($kaName[$index2Name], $vaName[$index2Name]); + | $arrayDataClass $kaName = $mapDataName.keyArray(); + | $arrayDataClass $vaName = $mapDataName.valueArray(); + | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { + | Object $keyName = ${CodeGenerator.getValue(kaName, keyTypes.head, index2Name)}; + | Object $valueName = ${CodeGenerator.getValue(vaName, valueTypes.head, index2Name)}; + | $unionMapName.put($keyName, $valueName); | } |} """.stripMargin @@ -619,16 +625,18 @@ case class MapConcat(children: Seq[Expression]) extends Expression val entrySetName = ctx.freshName("entrySet") val createMapData = s""" - |$entryClass[] entries = $unionMapName.entrySet().toArray(); - |Object[] $mergedKeyArrayName = new Object[$unionMapName.size]; - |Object[] $mergedValueArrayName = new Object[$unionMapName.size]; - |for (int $index1Name = 0; $index1Name < $entrySetName.length(); $index1Name++) { - | $entryClass entry = $entrySetName[$index1Name]; + |Object[] $entrySetName = $unionMapName.entrySet().toArray(); + |Object[] $mergedKeyArrayName = new Object[$unionMapName.size()]; + |Object[] $mergedValueArrayName = new Object[$unionMapName.size()]; + |for (int $index1Name = 0; $index1Name < $entrySetName.length; $index1Name++) { + | $entryClass entry = + | ($entryClass) $entrySetName[$index1Name]; | $mergedKeyArrayName[$index1Name] = (Object) entry.getKey(); | $mergedValueArrayName[$index1Name] = (Object) entry.getValue(); |} - |${ev.value} = new $arrayBasedMapDataClass(new $arrayDataClass($mergedKeyArrayName), - | new $arrayDataClass($mergedValueArrayName)); + |$mapDataClass ${ev.value} = + | new $arrayBasedMapDataClass(new $genericArrayDataClass($mergedKeyArrayName), + | new $genericArrayDataClass($mergedValueArrayName)); """.stripMargin val code = s""" From e08362a495c5473c46e13a224de0da57e8af29d0 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 17 Apr 2018 11:12:17 -0700 Subject: [PATCH 06/31] Checkpoint better working codegen --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index da8c39b87a741..838c045d5bcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1337,7 +1337,7 @@ object CodeGenerator extends Logging { )) evaluator.setExtendedClass(classOf[GeneratedClass]) - logWarning({ + logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) s"\n${CodeFormatter.format(code)}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5103567652df1..dcf0cefed32eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -600,13 +600,14 @@ case class MapConcat(children: Seq[Expression]) extends Expression val vaName = ctx.freshName("va") val keyName = ctx.freshName("key") val valueName = ctx.freshName("value") + val isNullCheckName = ctx.freshName("isNull") val mapMerge = s""" |$hashMapClass $unionMapName = new $hashMapClass(); |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { - | boolean isNull = $mapNullArrayName[$index1Name]; - | if (isNull) { + | boolean $isNullCheckName = $mapNullArrayName[$index1Name]; + | if ($isNullCheckName) { | continue; | } | MapData $mapDataName = ($mapDataClass) $mapRefArrayName[$index1Name]; From 2032801970ded87ba5d55443c94f332889579b1a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 17 Apr 2018 11:29:22 -0700 Subject: [PATCH 07/31] Require at least two input maps --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++++- .../catalyst/expressions/CollectionExpressionsSuite.scala | 2 -- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 7 +++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dcf0cefed32eb..a035afa155206 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -520,7 +520,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression // this check currently does not allow valueContainsNull to vary, // and unfortunately none of the MapType toString methods include // valueContainsNull for the error message - if (children.exists(!_.dataType.isInstanceOf[MapType])) { + if (children.size < 2) { + TypeCheckResult.TypeCheckFailure( + s"$prettyName expects at least two input maps.") + } else if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"The given input of function $prettyName should all be of type map, " + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -532,6 +535,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression TypeCheckResult.TypeCheckSuccess } } + override def dataType: MapType = { children.headOption.map(_.dataType.asInstanceOf[MapType]) .getOrElse(MapType(keyType = StringType, valueType = StringType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index be6570796cfaa..ec30dd9d48d5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -113,8 +113,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3", "d" -> "4", "e" -> "5")) - // no input - checkEvaluation(MapConcat(Seq()), Map()) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), mutable.LinkedHashMap("a" -> "1", "b" -> "2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 420440ffe4440..b3174222f649c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -674,16 +674,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df2 = Seq( (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) ).toDF("map1", "map2") - checkAnswer( - df2.selectExpr("map_concat()"), - Seq(Row(Map())) - ) assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, map2)").collect() }.getMessage().contains("input maps of function map_concat should all be the same type")) assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, 12)").collect() }.getMessage().contains("input of function map_concat should all be of type map")) + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat()").collect() + }.getMessage().contains("expects at least two input maps")) } test("map_from_entries function") { From e149d060d3ad5fee2b13faca03a06fb2a319ca7a Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 17 Apr 2018 12:05:15 -0700 Subject: [PATCH 08/31] Small cleanup --- .../expressions/collectionOperations.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a035afa155206..70acc2f2a7490 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -562,8 +562,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapCodes = children.map(c => c.genCode(ctx)) - val keyTypes = children.map(c => c.dataType.asInstanceOf[MapType].keyType) - val valueTypes = children.map(c => c.dataType.asInstanceOf[MapType].valueType) + val keyType = children.head.dataType.asInstanceOf[MapType].keyType + val valueType = children.head.dataType.asInstanceOf[MapType].valueType val mapRefArrayName = ctx.freshName("mapRefArray") val mapNullArrayName = ctx.freshName("mapNullArray") val unionMapName = ctx.freshName("union") @@ -578,7 +578,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression val init = s""" |boolean[] $mapNullArrayName = new boolean[${mapCodes.size}]; - |Object[] $mapRefArrayName = new Object[${mapCodes.size}]; + |$mapDataClass[] $mapRefArrayName = new $mapDataClass[${mapCodes.size}]; |boolean ${ev.isNull} = false; """.stripMargin @@ -614,12 +614,12 @@ case class MapConcat(children: Seq[Expression]) extends Expression | if ($isNullCheckName) { | continue; | } - | MapData $mapDataName = ($mapDataClass) $mapRefArrayName[$index1Name]; + | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; | $arrayDataClass $kaName = $mapDataName.keyArray(); | $arrayDataClass $vaName = $mapDataName.valueArray(); | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { - | Object $keyName = ${CodeGenerator.getValue(kaName, keyTypes.head, index2Name)}; - | Object $valueName = ${CodeGenerator.getValue(vaName, valueTypes.head, index2Name)}; + | Object $keyName = ${CodeGenerator.getValue(kaName, keyType, index2Name)}; + | Object $valueName = ${CodeGenerator.getValue(vaName, valueType, index2Name)}; | $unionMapName.put($keyName, $valueName); | } |} @@ -631,8 +631,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression val createMapData = s""" |Object[] $entrySetName = $unionMapName.entrySet().toArray(); - |Object[] $mergedKeyArrayName = new Object[$unionMapName.size()]; - |Object[] $mergedValueArrayName = new Object[$unionMapName.size()]; + |Object[] $mergedKeyArrayName = new Object[$entrySetName.length]; + |Object[] $mergedValueArrayName = new Object[$entrySetName.length]; |for (int $index1Name = 0; $index1Name < $entrySetName.length; $index1Name++) { | $entryClass entry = | ($entryClass) $entrySetName[$index1Name]; From e4170cf5bf7b191eb85b60361ad0ad1ca66e121b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 17 Apr 2018 13:06:11 -0700 Subject: [PATCH 09/31] Remove redundant null check --- .../expressions/collectionOperations.scala | 17 ++++------------- .../CollectionExpressionsSuite.scala | 3 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 10 ++++++++++ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 70acc2f2a7490..0e0e74cb72fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -565,7 +565,6 @@ case class MapConcat(children: Seq[Expression]) extends Expression val keyType = children.head.dataType.asInstanceOf[MapType].keyType val valueType = children.head.dataType.asInstanceOf[MapType].valueType val mapRefArrayName = ctx.freshName("mapRefArray") - val mapNullArrayName = ctx.freshName("mapNullArray") val unionMapName = ctx.freshName("union") val mapDataClass = classOf[MapData].getName @@ -577,23 +576,16 @@ case class MapConcat(children: Seq[Expression]) extends Expression val init = s""" - |boolean[] $mapNullArrayName = new boolean[${mapCodes.size}]; |$mapDataClass[] $mapRefArrayName = new $mapDataClass[${mapCodes.size}]; |boolean ${ev.isNull} = false; """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => val initCode = mapCodes(i).code - val isNullVarname = mapCodes(i).isNull - val valueVarname = mapCodes(i).value.code + val valueVarName = mapCodes(i).value.code s""" |$initCode - |$mapNullArrayName[$i] = $isNullVarname; - | if (!$mapNullArrayName[$i]) { - | $mapRefArrayName[$i] = $valueVarname; - | } else { - | $mapRefArrayName[$i] = null; - | } + |$mapRefArrayName[$i] = $valueVarName; """.stripMargin }.mkString("\n") @@ -610,11 +602,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |$hashMapClass $unionMapName = new $hashMapClass(); |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { - | boolean $isNullCheckName = $mapNullArrayName[$index1Name]; - | if ($isNullCheckName) { + | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; + | if ($mapDataName == null) { | continue; | } - | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; | $arrayDataClass $kaName = $mapDataName.keyArray(); | $arrayDataClass $vaName = $mapDataName.valueArray(); | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index ec30dd9d48d5b..914b8f522ac4e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -116,6 +116,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // null map checkEvaluation(MapConcat(Seq(m0, mNull)), mutable.LinkedHashMap("a" -> "1", "b" -> "2")) + // Only null maps + checkEvaluation(MapConcat(Seq(mNull, mNull)), + Map()) } test("MapFromEntries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b3174222f649c..9ecdb615d9877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -683,6 +683,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df2.selectExpr("map_concat()").collect() }.getMessage().contains("expects at least two input maps")) + + val df3 = Seq( + (null.asInstanceOf[Map[Int, Int]], null.asInstanceOf[Map[Int, Int]]) + ).toDF("map1", "map2") + checkAnswer( + df3.selectExpr("map_concat(map1, map2)"), + Seq( + Row(Map()) + ) + ) } test("map_from_entries function") { From b3085f09e8f411ecf55694a13639513a4679d662 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 19 Apr 2018 07:08:57 -0700 Subject: [PATCH 10/31] Any null input means null result (ala Presto) --- .../expressions/collectionOperations.scala | 25 ++++++++++++------- .../CollectionExpressionsSuite.scala | 5 +--- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +-------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0e0e74cb72fb4..28930c69a3db6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -541,17 +541,18 @@ case class MapConcat(children: Seq[Expression]) extends Expression .getOrElse(MapType(keyType = StringType, valueType = StringType)) } - override def nullable: Boolean = false + override def nullable: Boolean = true override def eval(input: InternalRow): Any = { val union = new util.LinkedHashMap[Any, Any]() children.map(_.eval(input)).foreach { raw => - if (raw != null) { - val map = raw.asInstanceOf[MapData] - map.foreach(dataType.keyType, dataType.valueType, (k, v) => - union.put(k, v) - ) + if (raw == null) { + return null } + val map = raw.asInstanceOf[MapData] + map.foreach(dataType.keyType, dataType.valueType, (k, v) => + union.put(k, v) + ) } val (keyArray, valueArray) = union.entrySet().toArray().map { e => val e2 = e.asInstanceOf[java.util.Map.Entry[Any, Any]] @@ -578,6 +579,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |$mapDataClass[] $mapRefArrayName = new $mapDataClass[${mapCodes.size}]; |boolean ${ev.isNull} = false; + |$mapDataClass ${ev.value} = null; """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => @@ -586,6 +588,9 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |$initCode |$mapRefArrayName[$i] = $valueVarName; + |if ($valueVarName == null) { + | ${ev.isNull} = true; + |} """.stripMargin }.mkString("\n") @@ -630,7 +635,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression | $mergedKeyArrayName[$index1Name] = (Object) entry.getKey(); | $mergedValueArrayName[$index1Name] = (Object) entry.getValue(); |} - |$mapDataClass ${ev.value} = + |${ev.value} = | new $arrayBasedMapDataClass(new $genericArrayDataClass($mergedKeyArrayName), | new $genericArrayDataClass($mergedValueArrayName)); """.stripMargin @@ -638,8 +643,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |$init |$assignments - |$mapMerge - $createMapData + | if (!${ev.isNull}) { + | $mapMerge + | $createMapData + |} """.stripMargin ev.copy(code = code) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 914b8f522ac4e..4b57314acafe2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -115,10 +115,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3", "d" -> "4", "e" -> "5")) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), - mutable.LinkedHashMap("a" -> "1", "b" -> "2")) - // Only null maps - checkEvaluation(MapConcat(Seq(mNull, mNull)), - Map()) + null) } test("MapFromEntries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9ecdb615d9877..d3b5a353d0f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -667,7 +667,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq( Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), - Row(Map(3 -> 300, 4 -> 400)) + Row(null) ) ) @@ -683,16 +683,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df2.selectExpr("map_concat()").collect() }.getMessage().contains("expects at least two input maps")) - - val df3 = Seq( - (null.asInstanceOf[Map[Int, Int]], null.asInstanceOf[Map[Int, Int]]) - ).toDF("map1", "map2") - checkAnswer( - df3.selectExpr("map_concat(map1, map2)"), - Seq( - Row(Map()) - ) - ) } test("map_from_entries function") { From 71f01513ead7aa6aa19b939ce56a536b80337e20 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 19 Apr 2018 13:04:51 -0700 Subject: [PATCH 11/31] Remove redundant null check --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 28930c69a3db6..c77bd9b966056 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -608,9 +608,6 @@ case class MapConcat(children: Seq[Expression]) extends Expression |$hashMapClass $unionMapName = new $hashMapClass(); |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; - | if ($mapDataName == null) { - | continue; - | } | $arrayDataClass $kaName = $mapDataName.keyArray(); | $arrayDataClass $vaName = $mapDataName.valueArray(); | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { From 006835dfccedd02a1d7d7cd4106013c0f5a0e9e7 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 23 Apr 2018 20:57:17 -0700 Subject: [PATCH 12/31] Review feedback --- .../sql/catalyst/expressions/collectionOperations.scala | 7 +++---- .../catalyst/expressions/CollectionExpressionsSuite.scala | 2 ++ .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c77bd9b966056..e08e6f71e5825 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -601,7 +601,6 @@ case class MapConcat(children: Seq[Expression]) extends Expression val vaName = ctx.freshName("va") val keyName = ctx.freshName("key") val valueName = ctx.freshName("value") - val isNullCheckName = ctx.freshName("isNull") val mapMerge = s""" @@ -629,8 +628,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression |for (int $index1Name = 0; $index1Name < $entrySetName.length; $index1Name++) { | $entryClass entry = | ($entryClass) $entrySetName[$index1Name]; - | $mergedKeyArrayName[$index1Name] = (Object) entry.getKey(); - | $mergedValueArrayName[$index1Name] = (Object) entry.getValue(); + | $mergedKeyArrayName[$index1Name] = entry.getKey(); + | $mergedValueArrayName[$index1Name] = entry.getValue(); |} |${ev.value} = | new $arrayBasedMapDataClass(new $genericArrayDataClass($mergedKeyArrayName), @@ -640,7 +639,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression s""" |$init |$assignments - | if (!${ev.isNull}) { + |if (!${ev.isNull}) { | $mapMerge | $createMapData |} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4b57314acafe2..0e817ed279e38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -116,6 +116,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), + null) } test("MapFromEntries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d3b5a353d0f0d..61f46c1747e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -662,6 +662,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), (null, Map[Int, Int](3 -> 300, 4 -> 400)) ).toDF("map1", "map2") + checkAnswer( df1.selectExpr("map_concat(map1, map2)"), Seq( @@ -674,12 +675,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df2 = Seq( (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) ).toDF("map1", "map2") + assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, map2)").collect() }.getMessage().contains("input maps of function map_concat should all be the same type")) + assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, 12)").collect() }.getMessage().contains("input of function map_concat should all be of type map")) + assert(intercept[AnalysisException] { df2.selectExpr("map_concat()").collect() }.getMessage().contains("expects at least two input maps")) From cf64d83e3306f20eef0bf6d9889b778f64be09b4 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 27 Apr 2018 10:16:26 -0700 Subject: [PATCH 13/31] Check for null value in generated code --- .../expressions/collectionOperations.scala | 10 ++++++---- .../expressions/CollectionExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e08e6f71e5825..08d4b5216d9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -513,8 +513,7 @@ examples = """ > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); [[1 -> "a"], [2 -> "c"], [3 -> "d"] """) -case class MapConcat(children: Seq[Expression]) extends Expression - with CodegenFallback { +case class MapConcat(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { // this check currently does not allow valueContainsNull to vary, @@ -541,7 +540,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression .getOrElse(MapType(keyType = StringType, valueType = StringType)) } - override def nullable: Boolean = true + override def nullable: Boolean = children.exists(_.nullable) override def eval(input: InternalRow): Any = { val union = new util.LinkedHashMap[Any, Any]() @@ -611,7 +610,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression | $arrayDataClass $vaName = $mapDataName.valueArray(); | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { | Object $keyName = ${CodeGenerator.getValue(kaName, keyType, index2Name)}; - | Object $valueName = ${CodeGenerator.getValue(vaName, valueType, index2Name)}; + | Object $valueName = null; + | if (!${vaName}.isNullAt($index2Name)) { + | $valueName = ${CodeGenerator.getValue(vaName, valueType, index2Name)}; + | } | $unionMapName.put($keyName, $valueName); | } |} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 0e817ed279e38..d66cf52d2f230 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -102,17 +102,32 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType)) val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> null), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "b" -> "2"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> null), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "b" -> 2), MapType(StringType, IntegerType)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps checkEvaluation(MapConcat(Seq(m0, m1)), mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3")) + // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3", "d" -> "4", "e" -> "5")) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + mutable.LinkedHashMap("a" -> null, "b" -> "2")) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + mutable.LinkedHashMap("a" -> null, "b" -> 2)) + // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) From fbe00b2f7b6e40da47ee6aca439738bd9ac59b72 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 27 Apr 2018 10:26:31 -0700 Subject: [PATCH 14/31] Add since to expression description --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 08d4b5216d9ac..78c7629add0bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -512,7 +512,7 @@ examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); [[1 -> "a"], [2 -> "c"], [3 -> "d"] - """) + """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { From 83784cc92c51238eea1843ddf7fe27c7f6ea758f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 27 Apr 2018 12:56:15 -0700 Subject: [PATCH 15/31] Allow valueContainsNull to vary; Make checkInputDataTypes more in line with other expressions --- .../expressions/collectionOperations.scala | 30 ++++++++------ .../CollectionExpressionsSuite.scala | 40 +++++++++++++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++-- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 78c7629add0bf..ffa09c446854a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -516,17 +516,16 @@ examples = """ case class MapConcat(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { - // this check currently does not allow valueContainsNull to vary, - // and unfortunately none of the MapType toString methods include - // valueContainsNull for the error message - if (children.size < 2) { - TypeCheckResult.TypeCheckFailure( - s"$prettyName expects at least two input maps.") - } else if (children.exists(!_.dataType.isInstanceOf[MapType])) { + // check key types and value types separately to allow valueContainsNull to vary + if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"The given input of function $prettyName should all be of type map, " + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (children.map(_.dataType).distinct.length > 1) { + } else if (children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName should all be the same type, " + + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType.asInstanceOf[MapType].valueType).distinct.length > 1) { TypeCheckResult.TypeCheckFailure( s"The given input maps of function $prettyName should all be the same type, " + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -536,8 +535,15 @@ case class MapConcat(children: Seq[Expression]) extends Expression { } override def dataType: MapType = { - children.headOption.map(_.dataType.asInstanceOf[MapType]) - .getOrElse(MapType(keyType = StringType, valueType = StringType)) + MapType( + keyType = children.headOption + .map(_.dataType.asInstanceOf[MapType].keyType).getOrElse(StringType), + valueType = children.headOption + .map(_.dataType.asInstanceOf[MapType].valueType).getOrElse(StringType), + valueContainsNull = children.map { c => + c.dataType.asInstanceOf[MapType] + }.exists(_.valueContainsNull) + ) } override def nullable: Boolean = children.exists(_.nullable) @@ -562,8 +568,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mapCodes = children.map(c => c.genCode(ctx)) - val keyType = children.head.dataType.asInstanceOf[MapType].keyType - val valueType = children.head.dataType.asInstanceOf[MapType].valueType + val keyType = dataType.keyType + val valueType = dataType.valueType val mapRefArrayName = ctx.freshName("mapRefArray") val unionMapName = ctx.freshName("union") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d66cf52d2f230..043fec1c60a08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -99,13 +99,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) test("Map Concat") { - val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) - val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType)) + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(Map("a" -> "1", "b" -> null), MapType(StringType, StringType)) - val m4 = Literal.create(Map("a" -> null, "b" -> "2"), MapType(StringType, StringType)) - val m5 = Literal.create(Map("a" -> 1, "b" -> null), MapType(StringType, IntegerType)) - val m6 = Literal.create(Map("a" -> null, "b" -> 2), MapType(StringType, IntegerType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps @@ -122,17 +124,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // null reference values checkEvaluation(MapConcat(Seq(m3, m4)), - mutable.LinkedHashMap("a" -> null, "b" -> "2")) + mutable.LinkedHashMap("a" -> null, "b" -> "2", "c" -> "3")) // null primitive values checkEvaluation(MapConcat(Seq(m5, m6)), - mutable.LinkedHashMap("a" -> null, "b" -> 2)) + mutable.LinkedHashMap("a" -> null, "b" -> 2, "c" -> 3)) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), + null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), + mutable.LinkedHashMap("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq()), + Map()) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueContainsNull == false) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull == true) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull == true) } test("MapFromEntries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 61f46c1747e5e..531eb3417629f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -672,10 +672,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) ) + checkAnswer( + df1.selectExpr("map_concat(map1)"), + Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + ) + val df2 = Seq( (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) ).toDF("map1", "map2") + checkAnswer( + df2.selectExpr("map_concat()"), + Seq( + Row(Map()) + ) + ) + assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, map2)").collect() }.getMessage().contains("input maps of function map_concat should all be the same type")) @@ -683,10 +699,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, 12)").collect() }.getMessage().contains("input of function map_concat should all be of type map")) - - assert(intercept[AnalysisException] { - df2.selectExpr("map_concat()").collect() - }.getMessage().contains("expects at least two input maps")) } test("map_from_entries function") { From 79f93043c5d4bd5aedfe72b3586fa4def021232f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 27 Apr 2018 13:48:22 -0700 Subject: [PATCH 16/31] Add a few more tests --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 043fec1c60a08..84ad20ff6786b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -146,15 +146,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq()), Map()) - // argument checking + // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) assert(MapConcat(Seq(m0, m1)).dataType.valueContainsNull == false) assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq()).dataType.keyType == StringType) + assert(MapConcat(Seq()).dataType.valueType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull == true) assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull == true) } From cb0f57ff789c7a253fe9239c73372781d2ee9989 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 27 Apr 2018 13:58:57 -0700 Subject: [PATCH 17/31] One more test --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 84ad20ff6786b..245f1b3b32722 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -137,6 +137,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper null) checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), + null) // single map checkEvaluation(MapConcat(Seq(m0)), From 57d10cb73da224b1b8cbaa35d3b89b8d8e0ea60e Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 30 Apr 2018 06:57:48 -0700 Subject: [PATCH 18/31] Fix import statement; Add two small tests --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 245f1b3b32722..b0263126a64e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -160,8 +160,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) assert(MapConcat(Seq()).dataType.keyType == StringType) assert(MapConcat(Seq()).dataType.valueType == StringType) - assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull == true) - assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull == true) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(MapConcat(Seq(m1, m2)).nullable == false) + assert(MapConcat(Seq(m1, mNull)).nullable) } test("MapFromEntries") { From cda11581abfb73fb587400524f4b06b12ad8264d Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 2 May 2018 21:06:03 -0700 Subject: [PATCH 19/31] As per SPARK-9415, cannot compare Maps, therefore cannot support them as keys --- .../catalyst/expressions/collectionOperations.scala | 8 ++++++++ .../expressions/CollectionExpressionsSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ffa09c446854a..546be65cf697e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -521,6 +521,14 @@ case class MapConcat(children: Seq[Expression]) extends Expression { TypeCheckResult.TypeCheckFailure( s"The given input of function $prettyName should all be of type map, " + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (children.map(_.dataType.asInstanceOf[MapType].keyType) + .exists(_.isInstanceOf[MapType])) { + // map_concat needs to pick a winner when multiple maps contain the same key. map_concat + // can do that only if it can detect when two keys are the same. SPARK-9415 states "map type + // should not support equality, hash". As a result, map_concat does not support a map type + // as a key + TypeCheckResult.TypeCheckFailure( + s"The given input maps of function $prettyName cannot have a map type as a key") } else if (children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) { TypeCheckResult.TypeCheckFailure( s"The given input maps of function $prettyName should all be the same type, " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b0263126a64e7..41d723185be82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -108,6 +108,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps @@ -130,6 +138,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq(m5, m6)), mutable.LinkedHashMap("a" -> null, "b" -> 2, "c" -> 3)) + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + mutable.LinkedHashMap(List(1, 2) -> 4, List(3, 4) -> 2, List(5, 6) -> 3)) + // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) @@ -151,6 +163,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m9, m10)).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) From 83deda47a842d61467b76102c210afae06b47ba9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 6 May 2018 16:09:19 -0700 Subject: [PATCH 20/31] Updates for some review comments --- .../catalyst/expressions/collectionOperations.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 546be65cf697e..768766dcbc89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -567,11 +567,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression { union.put(k, v) ) } - val (keyArray, valueArray) = union.entrySet().toArray().map { e => - val e2 = e.asInstanceOf[java.util.Map.Entry[Any, Any]] - (e2.getKey, e2.getValue) - }.unzip - new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + ArrayBasedMapData(union, (k: Any) => k, (v: Any) => v) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -596,8 +592,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression { """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => - val initCode = mapCodes(i).code - val valueVarName = mapCodes(i).value.code + val initCode = m.code + val valueVarName = m.value.code s""" |$initCode |$mapRefArrayName[$i] = $valueVarName; From 370151e1f005d42f3199ff62b5d126de9d93dbce Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 6 May 2018 18:01:32 -0700 Subject: [PATCH 21/31] Initial commit of use of splitExpressionsWithCurrentInputs --- .../expressions/collectionOperations.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 768766dcbc89f..8cda0f32f44f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -597,11 +597,13 @@ case class MapConcat(children: Seq[Expression]) extends Expression { s""" |$initCode |$mapRefArrayName[$i] = $valueVarName; - |if ($valueVarName == null) { - | ${ev.isNull} = true; - |} """.stripMargin - }.mkString("\n") + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "mapConcat", + extraArguments = (s"${mapDataClass}[]", mapRefArrayName) :: Nil) val index1Name = ctx.freshName("idx1") val index2Name = ctx.freshName("idx2") @@ -616,6 +618,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { |$hashMapClass $unionMapName = new $hashMapClass(); |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; + | if ($mapDataName == null) { + | ${ev.isNull} = true; + | break; + | } | $arrayDataClass $kaName = $mapDataName.keyArray(); | $arrayDataClass $vaName = $mapDataName.valueArray(); | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { @@ -650,9 +656,9 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val code = s""" |$init - |$assignments - |if (!${ev.isNull}) { + |$codes | $mapMerge + | if (!${ev.isNull}) { | $createMapData |} """.stripMargin From 330582718a14db7c491fab387dae384435f68dc0 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 6 May 2018 18:14:22 -0700 Subject: [PATCH 22/31] Add test --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 41d723185be82..1ace30958e923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -160,6 +160,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq()), Map()) + // force split expressions for input in generated code + checkEvaluation(MapConcat(Seq(m0, m2, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0)), + mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) From f967483d30ebe6eb7d99c3a70d08632a370c10f2 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 May 2018 21:41:48 +0800 Subject: [PATCH 23/31] Fix indentation --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8cda0f32f44f4..51d530d2b62e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -657,8 +657,8 @@ case class MapConcat(children: Seq[Expression]) extends Expression { s""" |$init |$codes - | $mapMerge - | if (!${ev.isNull}) { + |$mapMerge + |if (!${ev.isNull}) { | $createMapData |} """.stripMargin diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1ace30958e923..0afff33a53088 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -161,9 +161,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Map()) // force split expressions for input in generated code - checkEvaluation(MapConcat(Seq(m0, m2, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + checkEvaluation(MapConcat(Seq(m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0)), + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2, m0)), mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // argument checking From d4371991702dbd445490c856bac212a566d6e5fb Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 24 May 2018 12:32:58 -0700 Subject: [PATCH 24/31] Fix after rebase --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++---- .../catalyst/expressions/CollectionExpressionsSuite.scala | 1 + .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 51d530d2b62e2..bfe63ae80a2f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -502,6 +502,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp } override def prettyName: String = "map_entries" +} /** * Returns the union of all the given maps. @@ -653,16 +654,15 @@ case class MapConcat(children: Seq[Expression]) extends Expression { | new $arrayBasedMapDataClass(new $genericArrayDataClass($mergedKeyArrayName), | new $genericArrayDataClass($mergedValueArrayName)); """.stripMargin - val code = - s""" + ev.copy( + code = code""" |$init |$codes |$mapMerge |if (!${ev.isNull}) { | $createMapData |} - """.stripMargin - ev.copy(code = code) + """.stripMargin) } override def prettyName: String = "map_concat" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 0afff33a53088..49c810fce784c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -97,6 +97,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) checkEvaluation(MapEntries(ms1), Seq.empty) checkEvaluation(MapEntries(ms2), null) + } test("Map Concat") { val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 531eb3417629f..958760814de6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -655,6 +655,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.select(map_entries('m)), sExpected) checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } test("map_concat function") { val df1 = Seq( From 206db97b6e7d26a47f4d6f4f042fafdc4878b822 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 31 May 2018 14:17:24 -0700 Subject: [PATCH 25/31] Review feedback: use pre-existing empty collections --- .../catalyst/expressions/CollectionExpressionsSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 49c810fce784c..242a204cd341c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -158,8 +158,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper mutable.LinkedHashMap("a" -> "1", "b" -> "2")) // no map - checkEvaluation(MapConcat(Seq()), - Map()) + checkEvaluation(MapConcat(Seq.empty), + Map.empty) // force split expressions for input in generated code checkEvaluation(MapConcat(Seq(m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, @@ -178,8 +178,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m0, m1)).dataType.valueContainsNull == false) assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) - assert(MapConcat(Seq()).dataType.keyType == StringType) - assert(MapConcat(Seq()).dataType.valueType == StringType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) assert(MapConcat(Seq(m1, m2)).nullable == false) From 549300f182d3901a0b5e0ecf6712b9b347616715 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 20 Jun 2018 14:02:15 -0700 Subject: [PATCH 26/31] Allow duplicate keys --- python/pyspark/sql/functions.py | 13 +- .../sql/catalyst/CatalystTypeConverters.scala | 6 + .../expressions/collectionOperations.scala | 205 +++++++++++------- .../CollectionExpressionsSuite.scala | 73 +++++-- 4 files changed, 199 insertions(+), 98 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 64e5bc4c2c688..5d74b702d9615 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2491,19 +2491,18 @@ def arrays_zip(*cols): @since(2.4) def map_concat(*cols): - """Returns the union of all the given maps. If a key is found in multiple given maps, - that key's value in the resulting map comes from the last one of those maps. + """Returns the union of all the given maps. :param cols: list of column names (string) or list of :class:`Column` expressions >>> from pyspark.sql.functions import map_concat >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) - +------------------------+ - |map3 | - +------------------------+ - |[1 -> d, 2 -> b, 3 -> c]| - +------------------------+ + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf6..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bfe63ae80a2f5..bce9022b365d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -508,28 +508,22 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp * Returns the union of all the given maps. */ @ExpressionDescription( -usage = "_FUNC_(map, ...) - Returns the union of all the given maps", -examples = """ + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ Examples: > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); - [[1 -> "a"], [2 -> "c"], [3 -> "d"] + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends Expression { + private val MAX_MAP_SIZE: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + override def checkInputDataTypes(): TypeCheckResult = { // check key types and value types separately to allow valueContainsNull to vary if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( s"The given input of function $prettyName should all be of type map, " + "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (children.map(_.dataType.asInstanceOf[MapType].keyType) - .exists(_.isInstanceOf[MapType])) { - // map_concat needs to pick a winner when multiple maps contain the same key. map_concat - // can do that only if it can detect when two keys are the same. SPARK-9415 states "map type - // should not support equality, hash". As a result, map_concat does not support a map type - // as a key - TypeCheckResult.TypeCheckFailure( - s"The given input maps of function $prettyName cannot have a map type as a key") } else if (children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) { TypeCheckResult.TypeCheckFailure( s"The given input maps of function $prettyName should all be the same type, " + @@ -549,122 +543,179 @@ case class MapConcat(children: Seq[Expression]) extends Expression { .map(_.dataType.asInstanceOf[MapType].keyType).getOrElse(StringType), valueType = children.headOption .map(_.dataType.asInstanceOf[MapType].valueType).getOrElse(StringType), - valueContainsNull = children.map { c => - c.dataType.asInstanceOf[MapType] - }.exists(_.valueContainsNull) + valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) ) } override def nullable: Boolean = children.exists(_.nullable) override def eval(input: InternalRow): Any = { - val union = new util.LinkedHashMap[Any, Any]() - children.map(_.eval(input)).foreach { raw => - if (raw == null) { - return null - } - val map = raw.asInstanceOf[MapData] - map.foreach(dataType.keyType, dataType.valueType, (k, v) => - union.put(k, v) - ) + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > MAX_MAP_SIZE) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements" + + s" elements due to exceeding the map size limit" + + s" $MAX_MAP_SIZE.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length } - ArrayBasedMapData(union, (k: Any) => k, (v: Any) => v) + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mapCodes = children.map(c => c.genCode(ctx)) + val mapCodes = children.map(_.genCode(ctx)) val keyType = dataType.keyType val valueType = dataType.valueType - val mapRefArrayName = ctx.freshName("mapRefArray") - val unionMapName = ctx.freshName("union") + val argsName = ctx.freshName("args") + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") val mapDataClass = classOf[MapData].getName val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName val arrayDataClass = classOf[ArrayData].getName - val genericArrayDataClass = classOf[GenericArrayData].getName - val hashMapClass = classOf[util.LinkedHashMap[Any, Any]].getName - val entryClass = classOf[util.Map.Entry[Any, Any]].getName val init = s""" - |$mapDataClass[] $mapRefArrayName = new $mapDataClass[${mapCodes.size}]; + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |$arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + |$arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; |boolean ${ev.isNull} = false; |$mapDataClass ${ev.value} = null; """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => - val initCode = m.code - val valueVarName = m.value.code s""" - |$initCode - |$mapRefArrayName[$i] = $valueVarName; + |${m.code} + |$argsName[$i] = ${m.value.code}; """.stripMargin } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "mapConcat", - extraArguments = (s"${mapDataClass}[]", mapRefArrayName) :: Nil) + extraArguments = (s"${mapDataClass}[]", argsName) :: Nil) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } - val index1Name = ctx.freshName("idx1") - val index2Name = ctx.freshName("idx2") - val mapDataName = ctx.freshName("m") - val kaName = ctx.freshName("ka") - val vaName = ctx.freshName("va") - val keyName = ctx.freshName("key") - val valueName = ctx.freshName("value") + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } val mapMerge = s""" - |$hashMapClass $unionMapName = new $hashMapClass(); - |for (int $index1Name = 0; $index1Name < $mapRefArrayName.length; $index1Name++) { - | $mapDataClass $mapDataName = $mapRefArrayName[$index1Name]; - | if ($mapDataName == null) { + |long $numElementsName = 0; + |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | if ($argsName[$idxName] == null) { | ${ev.isNull} = true; | break; | } - | $arrayDataClass $kaName = $mapDataName.keyArray(); - | $arrayDataClass $vaName = $mapDataName.valueArray(); - | for (int $index2Name = 0; $index2Name < $kaName.numElements(); $index2Name++) { - | Object $keyName = ${CodeGenerator.getValue(kaName, keyType, index2Name)}; - | Object $valueName = null; - | if (!${vaName}.isNullAt($index2Name)) { - | $valueName = ${CodeGenerator.getValue(vaName, valueType, index2Name)}; - | } - | $unionMapName.put($keyName, $valueName); + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + |} + | + |if (!${ev.isNull}) { + | if ($numElementsName > $MAX_MAP_SIZE) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit $MAX_MAP_SIZE."); | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); |} """.stripMargin - val mergedKeyArrayName = ctx.freshName("keyArray") - val mergedValueArrayName = ctx.freshName("valueArray") - val entrySetName = ctx.freshName("entrySet") - val createMapData = - s""" - |Object[] $entrySetName = $unionMapName.entrySet().toArray(); - |Object[] $mergedKeyArrayName = new Object[$entrySetName.length]; - |Object[] $mergedValueArrayName = new Object[$entrySetName.length]; - |for (int $index1Name = 0; $index1Name < $entrySetName.length; $index1Name++) { - | $entryClass entry = - | ($entryClass) $entrySetName[$index1Name]; - | $mergedKeyArrayName[$index1Name] = entry.getKey(); - | $mergedValueArrayName[$index1Name] = entry.getValue(); - |} - |${ev.value} = - | new $arrayBasedMapDataClass(new $genericArrayDataClass($mergedKeyArrayName), - | new $genericArrayDataClass($mergedValueArrayName)); - """.stripMargin ev.copy( code = code""" |$init |$codes |$mapMerge - |if (!${ev.isNull}) { - | $createMapData - |} """.stripMargin) } + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + override def prettyName: String = "map_concat" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 242a204cd341c..4b1cc0ff75691 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -117,31 +117,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(MapType(IntegerType, IntegerType), IntegerType)) val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps checkEvaluation(MapConcat(Seq(m0, m1)), - mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3")) + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), - mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps checkEvaluation(MapConcat(Seq(m0, m1, m2)), - mutable.LinkedHashMap("a" -> "4", "b" -> "2", "c" -> "3", "d" -> "4", "e" -> "5")) + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) // null reference values checkEvaluation(MapConcat(Seq(m3, m4)), - mutable.LinkedHashMap("a" -> null, "b" -> "2", "c" -> "3")) + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) // null primitive values checkEvaluation(MapConcat(Seq(m5, m6)), - mutable.LinkedHashMap("a" -> null, "b" -> 2, "c" -> 3)) + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) // keys that are arrays, with overlap checkEvaluation(MapConcat(Seq(m7, m8)), - mutable.LinkedHashMap(List(1, 2) -> 4, List(3, 4) -> 2, List(5, 6) -> 3)) + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) // null map checkEvaluation(MapConcat(Seq(m0, mNull)), @@ -155,34 +196,38 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // single map checkEvaluation(MapConcat(Seq(m0)), - mutable.LinkedHashMap("a" -> "1", "b" -> "2")) + Map("a" -> "1", "b" -> "2")) // no map checkEvaluation(MapConcat(Seq.empty), Map.empty) // force split expressions for input in generated code - checkEvaluation(MapConcat(Seq(m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2, m0)), - mutable.LinkedHashMap("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) - assert(MapConcat(Seq(m9, m10)).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) - assert(MapConcat(Seq(m0, m1)).dataType.valueContainsNull == false) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) assert(MapConcat(Seq.empty).dataType.keyType == StringType) assert(MapConcat(Seq.empty).dataType.valueType == StringType) assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) - assert(MapConcat(Seq(m1, m2)).nullable == false) + assert(!MapConcat(Seq(m1, m2)).nullable) assert(MapConcat(Seq(m1, mNull)).nullable) } From 1b52dd11b89f108bbb8fa5ba248e9293db543307 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 20 Jun 2018 18:35:15 -0700 Subject: [PATCH 27/31] Remove extra line added during rebase --- python/pyspark/sql/functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5d74b702d9615..d7ee055f8198c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2489,6 +2489,7 @@ def arrays_zip(*cols): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + @since(2.4) def map_concat(*cols): """Returns the union of all the given maps. From 969c66e91140a7d255aa6ead0240cf830df2a97f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 24 Jun 2018 10:38:30 -0700 Subject: [PATCH 28/31] Review comments --- .../expressions/collectionOperations.scala | 116 ++++++++++-------- .../CollectionExpressionsSuite.scala | 18 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 70 ++++++++--- 3 files changed, 119 insertions(+), 85 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bce9022b365d0..391960a22dc3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -516,36 +516,27 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp """, since = "2.4.0") case class MapConcat(children: Seq[Expression]) extends Expression { - private val MAX_MAP_SIZE: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def checkInputDataTypes(): TypeCheckResult = { - // check key types and value types separately to allow valueContainsNull to vary + var funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure( - s"The given input of function $prettyName should all be of type map, " + - "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (children.map(_.dataType.asInstanceOf[MapType].keyType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure( - s"The given input maps of function $prettyName should all be the same type, " + - "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (children.map(_.dataType.asInstanceOf[MapType].valueType).distinct.length > 1) { - TypeCheckResult.TypeCheckFailure( - s"The given input maps of function $prettyName should all be the same type, " + - "but they are " + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) } } override def dataType: MapType = { - MapType( - keyType = children.headOption - .map(_.dataType.asInstanceOf[MapType].keyType).getOrElse(StringType), - valueType = children.headOption - .map(_.dataType.asInstanceOf[MapType].valueType).getOrElse(StringType), - valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) - .exists(_.valueContainsNull) - ) + val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption + .getOrElse(MapType(StringType, StringType)) + val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + if (dt.valueContainsNull != valueContainsNull) { + dt.copy(valueContainsNull = valueContainsNull) + } else { + dt + } } override def nullable: Boolean = children.exists(_.nullable) @@ -559,10 +550,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numElements > MAX_MAP_SIZE) { - throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements" + - s" elements due to exceeding the map size limit" + - s" $MAX_MAP_SIZE.") + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } val finalKeyArray = new Array[AnyRef](numElements.toInt) val finalValueArray = new Array[AnyRef](numElements.toInt) @@ -603,14 +594,25 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val assignments = mapCodes.zipWithIndex.map { case (m, i) => s""" |${m.code} - |$argsName[$i] = ${m.value.code}; + |$argsName[$i] = ${m.value}; + |if (${m.isNull}) { + | ${ev.isNull} = true; + |} """.stripMargin } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, - funcName = "mapConcat", - extraArguments = (s"${mapDataClass}[]", argsName) :: Nil) + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", ev.isNull.code) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return ${ev.isNull}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.isNull} = $funcCall;").mkString("\n") + ) val idxName = ctx.freshName("idx") val numElementsName = ctx.freshName("numElems") @@ -618,34 +620,30 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val finValsName = ctx.freshName("finalValues") val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { - genCodeForPrimitiveArrays(ctx, keyType) + genCodeForPrimitiveArrays(ctx, keyType, false) } else { genCodeForNonPrimitiveArrays(ctx, keyType) } val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { - genCodeForPrimitiveArrays(ctx, valueType) + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) } else { genCodeForNonPrimitiveArrays(ctx, valueType) } val mapMerge = s""" - |long $numElementsName = 0; - |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { - | if ($argsName[$idxName] == null) { - | ${ev.isNull} = true; - | break; - | } - | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); - | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); - | $numElementsName += $argsName[$idxName].numElements(); - |} - | |if (!${ev.isNull}) { - | if ($numElementsName > $MAX_MAP_SIZE) { + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful attempt to concat maps with " + - | $numElementsName + " elements due to exceeding the map size limit $MAX_MAP_SIZE."); + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | } | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, | (int) $numElementsName); @@ -663,13 +661,32 @@ case class MapConcat(children: Seq[Expression]) extends Expression { """.stripMargin) } - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType, + checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val argsName = ctx.freshName("args") val numElemName = ctx.freshName("numElements") val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin.stripPrefix("\n") + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin.stripPrefix("\n") + } else { + setterCode1 + } + s""" |new Object() { | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { @@ -677,14 +694,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression { | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < $argsName[y].numElements(); z++) { - | if ($argsName[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} - | ); - | } + | $setterCode | $counter++; | } | } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4b1cc0ff75691..2129fa25734e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -185,22 +185,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) // null map - checkEvaluation(MapConcat(Seq(m0, mNull)), - null) - checkEvaluation(MapConcat(Seq(mNull, m0)), - null) - checkEvaluation(MapConcat(Seq(mNull, mNull)), - null) - checkEvaluation(MapConcat(Seq(mNull)), - null) + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) // single map - checkEvaluation(MapConcat(Seq(m0)), - Map("a" -> "1", "b" -> "2")) + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) // no map - checkEvaluation(MapConcat(Seq.empty), - Map.empty) + checkEvaluation(MapConcat(Seq.empty), Map.empty) // force split expressions for input in generated code val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 958760814de6a..50192002b1c1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -664,42 +664,72 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { (null, Map[Int, Int](3 -> 300, 4 -> 400)) ).toDF("map1", "map2") - checkAnswer( - df1.selectExpr("map_concat(map1, map2)"), - Seq( - Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), - Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), - Row(null) - ) + val expected1 = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) ) - checkAnswer( - df1.selectExpr("map_concat(map1)"), - Seq( - Row(Map(1 -> 100, 2 -> 200)), - Row(Map(1 -> 100, 2 -> 200)), - Row(null) - ) + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1) + + val expected2 = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) ) + checkAnswer(df1.selectExpr("map_concat(map1)"), expected2) + checkAnswer(df1.select(map_concat('map1)), expected2) + val df2 = Seq( (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) ).toDF("map1", "map2") - checkAnswer( - df2.selectExpr("map_concat()"), - Seq( - Row(Map()) + val expected3 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected3) + checkAnswer(df2.select(map_concat()), expected3) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected4 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) ) + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected4) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected4) + + val expectedMessage1 = "input to function map_concat should all be the same type" + assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, map2)").collect() - }.getMessage().contains("input maps of function map_concat should all be the same type")) + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" assert(intercept[AnalysisException] { df2.selectExpr("map_concat(map1, 12)").collect() - }.getMessage().contains("input of function map_concat should all be of type map")) + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) } test("map_from_entries function") { From 47f0cf5b03619bc5f4be146fd16d21baf1cbf4bd Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 26 Jun 2018 14:26:59 -0700 Subject: [PATCH 29/31] Review comments --- .../expressions/collectionOperations.scala | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 391960a22dc3e..f6c8eefa9e33c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -575,9 +575,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val keyType = dataType.keyType val valueType = dataType.valueType val argsName = ctx.freshName("args") - val keyArgsName = ctx.freshName("keyArgs") - val valArgsName = ctx.freshName("valArgs") - + val hasNullName = ctx.freshName("hasNull") val mapDataClass = classOf[MapData].getName val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName val arrayDataClass = classOf[ArrayData].getName @@ -585,18 +583,18 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val init = s""" |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; - |$arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; - |$arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; - |boolean ${ev.isNull} = false; + |boolean ${ev.isNull}, $hasNullName = false; |$mapDataClass ${ev.value} = null; """.stripMargin val assignments = mapCodes.zipWithIndex.map { case (m, i) => s""" - |${m.code} - |$argsName[$i] = ${m.value}; - |if (${m.isNull}) { - | ${ev.isNull} = true; + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } |} """.stripMargin } @@ -604,14 +602,14 @@ case class MapConcat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "getMapConcatInputs", - extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", ev.isNull.code) :: Nil, + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, returnType = "boolean", makeSplitFunction = body => s""" |$body - |return ${ev.isNull}; + |return $hasNullName; """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.isNull} = $funcCall;").mkString("\n") + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") ) val idxName = ctx.freshName("idx") @@ -631,9 +629,15 @@ case class MapConcat(children: Seq[Expression]) extends Expression { genCodeForNonPrimitiveArrays(ctx, valueType) } + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + val mapMerge = s""" + |${ev.isNull} = $hasNullName; |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; | long $numElementsName = 0; | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); @@ -661,8 +665,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { """.stripMargin) } - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType, - checkForNull: Boolean): String = { + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val argsName = ctx.freshName("args") @@ -674,7 +680,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression { |$arrayData.set$primitiveValueTypeName( | $counter, | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} - |);""".stripMargin.stripPrefix("\n") + |);""".stripMargin val setterCode = if (checkForNull) { s""" @@ -682,7 +688,7 @@ case class MapConcat(children: Seq[Expression]) extends Expression { | $arrayData.setNullAt($counter); |} else { | $setterCode1 - |}""".stripMargin.stripPrefix("\n") + |}""".stripMargin } else { setterCode1 } From 3c0da039a0f11e0bc4f18342c96cf1dc100d2060 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sat, 30 Jun 2018 11:25:24 -0700 Subject: [PATCH 30/31] Initial implementation of type coercion for map_concat --- .../sql/catalyst/analysis/TypeCoercion.scala | 30 ++++ .../CollectionExpressionsSuite.scala | 1 + .../inputs/typeCoercion/native/mapconcat.sql | 94 ++++++++++++ .../typeCoercion/native/mapconcat.sql.out | 143 ++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 29 ++-- 5 files changed, 284 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cf90e6e555fc8..9b1cf035dec5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -563,6 +563,36 @@ object TypeCoercion { case None => s } + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(children) => + val keyTypes = children.map(_.dataType.asInstanceOf[MapType].keyType) + val valueTypes = children.map(_.dataType.asInstanceOf[MapType].valueType) + + val newKeyType = if (keyTypes.distinct.size > 1) { + findWiderCommonType(keyTypes) match { + case s @ Some(_) => s + case None => None + } + } else { + keyTypes.headOption + } + + val newValueType = if (valueTypes.distinct.size > 1) { + findWiderCommonType(valueTypes) match { + case s @ Some(_) => s + case None => None + } + } else { + valueTypes.headOption + } + + if (newKeyType == None || newValueType == None) { + m + } else { + MapConcat(children + .map(Cast(_, MapType(newKeyType.head, newValueType.head)))) + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2129fa25734e1..480aa4d8c3075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone + import scala.collection.mutable import org.apache.spark.SparkFunSuite diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50192002b1c1c..d60ed7a5ef0d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -664,32 +664,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { (null, Map[Int, Int](3 -> 300, 4 -> 400)) ).toDF("map1", "map2") - val expected1 = Seq( + val expected1a = Seq( Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), Row(null) ) - checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1) - checkAnswer(df1.select(map_concat('map1, 'map2)), expected1) + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) - val expected2 = Seq( + val expected1b = Seq( Row(Map(1 -> 100, 2 -> 200)), Row(Map(1 -> 100, 2 -> 200)), Row(null) ) - checkAnswer(df1.selectExpr("map_concat(map1)"), expected2) - checkAnswer(df1.select(map_concat('map1)), expected2) + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) val df2 = Seq( - (Map[Int, Int](1 -> 100, 2 -> 200), Map[String, Int]("3" -> 300, "4" -> 400)) + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) ).toDF("map1", "map2") - val expected3 = Seq(Row(Map())) + val expected2 = Seq(Row(Map())) - checkAnswer(df2.selectExpr("map_concat()"), expected3) - checkAnswer(df2.select(map_concat()), expected3) + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) val df3 = { val schema = StructType( @@ -703,13 +706,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { spark.createDataFrame(spark.sparkContext.parallelize(data), schema) } - val expected4 = Seq( + val expected3 = Seq( Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) ) - checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected4) - checkAnswer(df3.select(map_concat('map1, 'map2)), expected4) + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) val expectedMessage1 = "input to function map_concat should all be the same type" From 03328a417ea04722c1497cf09583dff909afe979 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 5 Jul 2018 20:28:02 -0700 Subject: [PATCH 31/31] Simplify type coercion for map_concat parameters --- .../sql/catalyst/analysis/TypeCoercion.scala | 30 +++---------------- .../expressions/collectionOperations.scala | 1 - .../CollectionExpressionsSuite.scala | 2 -- 3 files changed, 4 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9b1cf035dec5b..5c449e23d2344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -565,32 +565,10 @@ object TypeCoercion { case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && !haveSameType(children) => - val keyTypes = children.map(_.dataType.asInstanceOf[MapType].keyType) - val valueTypes = children.map(_.dataType.asInstanceOf[MapType].valueType) - - val newKeyType = if (keyTypes.distinct.size > 1) { - findWiderCommonType(keyTypes) match { - case s @ Some(_) => s - case None => None - } - } else { - keyTypes.headOption - } - - val newValueType = if (valueTypes.distinct.size > 1) { - findWiderCommonType(valueTypes) match { - case s @ Some(_) => s - case None => None - } - } else { - valueTypes.headOption - } - - if (newKeyType == None || newValueType == None) { - m - } else { - MapConcat(children - .map(Cast(_, MapType(newKeyType.head, newValueType.head)))) + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case None => m } case m @ CreateMap(children) if m.keys.length == m.values.length && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f6c8eefa9e33c..6b6a4b294d4f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.{Comparator, TimeZone} -import java.util import scala.collection.mutable import scala.reflect.ClassTag diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 480aa4d8c3075..12f6626255faf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.util.TimeZone -import scala.collection.mutable - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow