diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4d371976364d3..d7ee055f8198c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2490,6 +2490,28 @@ def arrays_zip(*cols): return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf6..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 80a0af672bf74..e7517e8c676e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -422,6 +422,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cf90e6e555fc8..5c449e23d2344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -563,6 +563,14 @@ object TypeCoercion { case None => s } + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..6b6a4b294d4f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -503,6 +503,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends Expression { + + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } + + override def dataType: MapType = { + val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption + .getOrElse(MapType(StringType, StringType)) + val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + if (dt.valueContainsNull != valueContainsNull) { + dt.copy(valueContainsNull = valueContainsNull) + } else { + dt + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } + + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin + } else { + setterCode1 + } + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def prettyName: String = "map_concat" +} + /** * Returns a map created from the given array of entries. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..12f6626255faf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -98,6 +98,132 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + } + test("MapFromEntries") { def arrayType(keyType: DataType, valueType: DataType) : DataType = { ArrayType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f2627e69939cd..89dbba10a6bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3627,6 +3627,14 @@ object functions { @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..d60ed7a5ef0d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -657,6 +657,84 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + test("map_from_entries function") { def dummyFilter(c: Column): Column = c.isNull || c.isNotNull val oneRowDF = Seq(3215).toDF("i")