From 628165acb02e0045d8669262b6cfebd7314ca247 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 14 Dec 2017 17:39:43 +0900 Subject: [PATCH 01/16] Concat binary as binary --- .../apache/spark/unsafe/types/ByteArray.java | 25 +++++++++ .../expressions/stringExpressions.scala | 31 ++++++++--- .../org/apache/spark/sql/functions.scala | 3 +- .../sql-tests/inputs/string-functions.sql | 22 ++++++++ .../results/string-functions.sql.out | 52 ++++++++++++++++++- 5 files changed, 123 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 7ced13d357237..c03caf0076f61 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -74,4 +74,29 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { } return Arrays.copyOfRange(bytes, start, end); } + + public static byte[] concat(byte[]... inputs) { + // Compute the total length of the result + int totalLength = 0; + for (int i = 0; i < inputs.length; i++) { + if (inputs[i] != null) { + totalLength += inputs[i].length; + } else { + return null; + } + } + + // Allocate a new byte array, and copy the inputs one by one into it + final byte[] result = new byte[totalLength]; + int offset = 0; + for (int i = 0; i < inputs.length; i++) { + int len = inputs[i].length; + Platform.copyMemory( + inputs[i], Platform.BYTE_ARRAY_OFFSET, + result, Platform.BYTE_ARRAY_OFFSET + offset, + len); + offset += len; + } + return result; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c02c41db1668e..e49729e35d856 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** - * An expression that concatenates multiple input strings into a single string. + * An expression that concatenates multiple inputs into a single output. + * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. * If any input is null, concat returns null. */ @ExpressionDescription( @@ -50,15 +51,23 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} """) case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) - override def dataType: DataType = StringType + private lazy val isBinaryMode = children.forall(_.dataType == BinaryType) + + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(if (isBinaryMode) BinaryType else StringType) + override def dataType: DataType = if (isBinaryMode) BinaryType else StringType override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) override def eval(input: InternalRow): Any = { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + if (isBinaryMode) { + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + } else { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -73,14 +82,20 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + val (javaClass, initCode) = if (isBinaryMode) { + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + } else { + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = ("UTF8String[]", args) :: Nil) + extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" - UTF8String[] $args = new UTF8String[${evals.length}]; + $initCode $codes - UTF8String ${ev.value} = UTF8String.concat($args); + ${ctx.javaType(dataType)} ${ev.value} = $javaClass.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 052a3f533da5a..530a525a01dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2171,7 +2171,8 @@ object functions { def base64(e: Column): Column = withExpr { Base64(e.expr) } /** - * Concatenates multiple input string columns together into a single string column. + * Concatenates multiple input columns together into a single column. + * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. * * @group string_funcs * @since 1.5.0 diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 40d0c064f5c44..e1e2e4e8853d0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -24,3 +24,25 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); + +-- Concatenate binary inputs +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + CAST('a' || id AS BINARY) col1, + CAST('b' || id AS BINARY) col2, + CAST('c' || id AS BINARY) col3, + CAST('d' || id AS BINARY) col4 + FROM range(10) +); + +-- Concatenate mixed inputs between strings and binary +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + 'a' || id col1, + 'b' || id col2, + CAST('c' || id AS BINARY) col3, + CAST('d' || id AS BINARY) col4 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2d9b3d7d2ca33..4ea4be2cf19bd 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 14 -- !query 0 @@ -118,3 +118,53 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') struct -- !query 11 output NULL NULL + + +-- !query 12 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + CAST('a' || id AS BINARY) col1, + CAST('b' || id AS BINARY) col2, + CAST('c' || id AS BINARY) col3, + CAST('d' || id AS BINARY) col4 + FROM range(10) +) +-- !query 12 schema +struct +-- !query 12 output +a0b0c0d0 +a1b1c1d1 +a2b2c2d2 +a3b3c3d3 +a4b4c4d4 +a5b5c5d5 +a6b6c6d6 +a7b7c7d7 +a8b8c8d8 +a9b9c9d9 + + +-- !query 13 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + 'a' || id col1, + 'b' || id col2, + CAST('c' || id AS BINARY) col3, + CAST('d' || id AS BINARY) col4 + FROM range(10) +) +-- !query 13 schema +struct +-- !query 13 output +a0b0c0d0 +a1b1c1d1 +a2b2c2d2 +a3b3c3d3 +a4b4c4d4 +a5b5c5d5 +a6b6c6d6 +a7b7c7d7 +a8b8c8d8 +a9b9c9d9 From b87e61e3d61d9c6ad30f2ec331481756dc4a4ed9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 14 Dec 2017 21:04:43 +0900 Subject: [PATCH 02/16] Fix --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e49729e35d856..ed0e475bfd677 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} """) case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { - private lazy val isBinaryMode = children.forall(_.dataType == BinaryType) + private lazy val isBinaryMode = children.nonEmpty && children.forall(_.dataType == BinaryType) override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(if (isBinaryMode) BinaryType else StringType) From 4f8a7621a0aa56c1669b9515aac1bba9fc493338 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 15 Dec 2017 14:34:28 +0900 Subject: [PATCH 03/16] Add entry in Migration Guide --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f02f46236e2b0..b5d0df4cf33e7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1780,6 +1780,8 @@ options. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as string. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. From 583bc5dd2c6ab4cd9922a15fe56192d07bb25886 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 15 Dec 2017 17:25:47 +0900 Subject: [PATCH 04/16] Add option to disable binary mode --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 19 +++++-- .../expressions/stringExpressions.scala | 8 ++- .../sql/catalyst/optimizer/expressions.scala | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../catalyst/analysis/TypeCoercionSuite.scala | 52 +++++++++-------- .../sql-tests/inputs/string-functions.sql | 16 +++--- .../results/string-functions.sql.out | 56 +++++++++---------- 8 files changed, 99 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 10b237fb22b96..fba5abe48fa13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -151,7 +151,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: - TypeCoercion.typeCoercionRules ++ + TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2f306f58b7b80..2620c1d2d75e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -45,13 +46,13 @@ import org.apache.spark.sql.types._ */ object TypeCoercion { - val typeCoercionRules = + def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion :: WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion :: + new FunctionArgumentConversion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -479,9 +480,9 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends TypeCoercionRule { + class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -564,6 +565,16 @@ object TypeCoercion { NaNvl(Cast(l, DoubleType), r) case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } + // This group needs to be transformed in a post order + .transformExpressionsUp { + // When all inputs in [[Concat]] are binary, coerces an output type to binary + case c @ Concat(children, _) + if conf.concatBinaryModeEnabled && + c.childrenResolved && + children.nonEmpty && + children.forall(_.dataType == BinaryType) => + c.copy(children, isBinaryMode = true) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ed0e475bfd677..7bb2a8b0bf07b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,7 +24,6 @@ import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -49,9 +48,10 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { +case class Concat(children: Seq[Expression], isBinaryMode: Boolean = false) + extends Expression with ImplicitCastInputTypes { - private lazy val isBinaryMode = children.nonEmpty && children.forall(_.dataType == BinaryType) + def this(children: Seq[Expression]) = this(children, false) override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(if (isBinaryMode) BinaryType else StringType) @@ -99,6 +99,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas boolean ${ev.isNull} = ${ev.value} == null; """) } + + override def toString: String = s"concat(${children.mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 85295aff19808..6174f63ebb80d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -641,15 +641,17 @@ object CombineConcats extends Rule[LogicalPlan] { private def flattenConcats(concat: Concat): Concat = { val stack = Stack[Expression](concat) val flattened = ArrayBuffer.empty[Expression] + var isBinaryMode = concat.isBinaryMode while (stack.nonEmpty) { stack.pop() match { - case Concat(children) => + case Concat(children, binary) => + isBinaryMode &= binary stack.pushAll(children.reverse) case child => flattened += child } } - Concat(flattened) + Concat(flattened, isBinaryMode) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84fe4bb711a4e..f8bace1f174e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1043,6 +1043,12 @@ object SQLConf { "dummy value. This is currently used to redact the output of SQL explain commands. " + "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + + val CONCAT_BINARY_MODE_ENABLED = buildConf("spark.sql.expression.concat.binaryMode.enabled") + .doc("When this option is set to true and all inputs are binary, `functions.concat` returns " + + "an output as binary. Otherwise, it returns as string. ") + .booleanConf + .createWithDefault(true) val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") @@ -1378,6 +1384,8 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def concatBinaryModeEnabled: Boolean = getConf(CONCAT_BINARY_MODE_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 5dcd653e9b341..67273619b9a71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -528,7 +528,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion + val rule = new TypeCoercion.FunctionArgumentConversion(conf) val intLit = Literal(1) val longLit = Literal.create(1L) @@ -575,7 +575,9 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + val rule = new TypeCoercion.FunctionArgumentConversion(conf) + + ruleTest(rule, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -585,7 +587,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -595,7 +597,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -603,7 +605,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -615,8 +617,10 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateMap casts") { + val rule = new TypeCoercion.FunctionArgumentConversion(conf) + // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -627,7 +631,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -639,7 +643,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -650,7 +654,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -662,7 +666,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -676,8 +680,10 @@ class TypeCoercionSuite extends AnalysisTest { } test("greatest/least cast") { + val rule = new TypeCoercion.FunctionArgumentConversion(conf) + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -686,7 +692,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -695,7 +701,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -704,7 +710,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -713,7 +719,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -726,19 +732,21 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + val rule = new TypeCoercion.FunctionArgumentConversion(conf) + + ruleTest(rule, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(rule, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } @@ -1117,7 +1125,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(new FunctionArgumentConversion(conf), Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1136,7 +1144,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(new FunctionArgumentConversion(conf), Division, ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index e1e2e4e8853d0..b86b7bb9f87fe 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -29,10 +29,10 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT - CAST('a' || id AS BINARY) col1, - CAST('b' || id AS BINARY) col2, - CAST('c' || id AS BINARY) col3, - CAST('d' || id AS BINARY) col4 + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 FROM range(10) ); @@ -40,9 +40,9 @@ FROM ( SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT - 'a' || id col1, - 'b' || id col2, - CAST('c' || id AS BINARY) col3, - CAST('d' || id AS BINARY) col4 + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 FROM range(10) ); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 4ea4be2cf19bd..63b1d7279b246 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -124,47 +124,47 @@ NULL NULL SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT - CAST('a' || id AS BINARY) col1, - CAST('b' || id AS BINARY) col2, - CAST('c' || id AS BINARY) col3, - CAST('d' || id AS BINARY) col4 + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 FROM range(10) ) -- !query 12 schema struct -- !query 12 output -a0b0c0d0 -a1b1c1d1 -a2b2c2d2 -a3b3c3d3 -a4b4c4d4 -a5b5c5d5 -a6b6c6d6 -a7b7c7d7 -a8b8c8d8 -a9b9c9d9 +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 -- !query 13 SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT - 'a' || id col1, - 'b' || id col2, - CAST('c' || id AS BINARY) col3, - CAST('d' || id AS BINARY) col4 + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 FROM range(10) ) -- !query 13 schema struct -- !query 13 output -a0b0c0d0 -a1b1c1d1 -a2b2c2d2 -a3b3c3d3 -a4b4c4d4 -a5b5c5d5 -a6b6c6d6 -a7b7c7d7 -a8b8c8d8 -a9b9c9d9 +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 From 1b57428ef4614b49bf910ac320a1c49660b03d9d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Dec 2017 12:39:01 +0900 Subject: [PATCH 05/16] Address reviews --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 18 ++++--- .../expressions/stringExpressions.scala | 9 +++- .../sql/catalyst/optimizer/expressions.scala | 11 ++-- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 52 ++++++++----------- .../sql-tests/inputs/string-functions.sql | 22 -------- .../inputs/typeCoercion/native/concat.sql | 21 ++++++++ .../results/string-functions.sql.out | 52 +------------------ .../typeCoercion/native/concat.sql.out | 52 +++++++++++++++++++ 10 files changed, 125 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b5d0df4cf33e7..0a4e249fe123e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1780,7 +1780,7 @@ options. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as string. + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2620c1d2d75e1..ecf51be84f446 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -52,7 +52,8 @@ object TypeCoercion { PromoteStrings :: DecimalPrecision :: BooleanEquality :: - new FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: + ConcatCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -480,9 +481,9 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -565,9 +566,14 @@ object TypeCoercion { NaNvl(Cast(l, DoubleType), r) case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } - // This group needs to be transformed in a post order - .transformExpressionsUp { - // When all inputs in [[Concat]] are binary, coerces an output type to binary + } + + /** + * When all inputs in [[Concat]] are binary, coerces an output type to binary + */ + case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformExpressionsUp { case c @ Concat(children, _) if conf.concatBinaryModeEnabled && c.childrenResolved && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7bb2a8b0bf07b..204b04833d4da 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -54,7 +54,12 @@ case class Concat(children: Seq[Expression], isBinaryMode: Boolean = false) def this(children: Seq[Expression]) = this(children, false) override def inputTypes: Seq[AbstractDataType] = - Seq.fill(children.size)(if (isBinaryMode) BinaryType else StringType) + if (isBinaryMode) { + Seq.fill(children.size)(BinaryType) + } else { + Seq.fill(children.size)(StringType) + } + override def dataType: DataType = if (isBinaryMode) BinaryType else StringType override def nullable: Boolean = children.exists(_.nullable) @@ -101,6 +106,8 @@ case class Concat(children: Seq[Expression], isBinaryMode: Boolean = false) } override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6174f63ebb80d..d09fa3c42e7d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -635,23 +635,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { /** * Combine nested [[Concat]] expressions. + * + * If `spark.sql.expression.concat.binaryMode.enabled` is true and all inputs are binary, + * the type coercion rule `ConcatCoercion` sets true at `isBinaryMode`s in all the nested concat + * expressions. So, this optimizer rule just passes a given concat `isBinaryMode` + * into a combined concat. */ object CombineConcats extends Rule[LogicalPlan] { private def flattenConcats(concat: Concat): Concat = { val stack = Stack[Expression](concat) val flattened = ArrayBuffer.empty[Expression] - var isBinaryMode = concat.isBinaryMode while (stack.nonEmpty) { stack.pop() match { - case Concat(children, binary) => - isBinaryMode &= binary + case Concat(children, _) => stack.pushAll(children.reverse) case child => flattened += child } } - Concat(flattened, isBinaryMode) + Concat(flattened, concat.isBinaryMode) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f8bace1f174e0..0be3b61cbdc79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1043,7 +1043,7 @@ object SQLConf { "dummy value. This is currently used to redact the output of SQL explain commands. " + "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) - + val CONCAT_BINARY_MODE_ENABLED = buildConf("spark.sql.expression.concat.binaryMode.enabled") .doc("When this option is set to true and all inputs are binary, `functions.concat` returns " + "an output as binary. Otherwise, it returns as string. ") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 67273619b9a71..5dcd653e9b341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -528,7 +528,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("coalesce casts") { - val rule = new TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -575,9 +575,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - val rule = new TypeCoercion.FunctionArgumentConversion(conf) - - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -587,7 +585,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -597,7 +595,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -605,7 +603,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -617,10 +615,8 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateMap casts") { - val rule = new TypeCoercion.FunctionArgumentConversion(conf) - // type coercion for map keys - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -631,7 +627,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -643,7 +639,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -654,7 +650,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -666,7 +662,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -680,10 +676,8 @@ class TypeCoercionSuite extends AnalysisTest { } test("greatest/least cast") { - val rule = new TypeCoercion.FunctionArgumentConversion(conf) - for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -692,7 +686,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -701,7 +695,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -710,7 +704,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -719,7 +713,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -732,21 +726,19 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - val rule = new TypeCoercion.FunctionArgumentConversion(conf) - - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(rule, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } @@ -1125,7 +1117,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(new FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1144,7 +1136,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(new FunctionArgumentConversion(conf), Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index b86b7bb9f87fe..40d0c064f5c44 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -24,25 +24,3 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); - --- Concatenate binary inputs -SELECT (col1 || col2 || col3 || col4) col -FROM ( - SELECT - encode(string(id), 'utf-8') col1, - encode(string(id + 1), 'utf-8') col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); - --- Concatenate mixed inputs between strings and binary -SELECT (col1 || col2 || col3 || col4) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql new file mode 100644 index 0000000000000..160f44bd8cd35 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -0,0 +1,21 @@ +-- Concatenate binary inputs +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- Concatenate mixed inputs between strings and binary +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 63b1d7279b246..2d9b3d7d2ca33 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 12 -- !query 0 @@ -118,53 +118,3 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') struct -- !query 11 output NULL NULL - - --- !query 12 -SELECT (col1 || col2 || col3 || col4) col -FROM ( - SELECT - encode(string(id), 'utf-8') col1, - encode(string(id + 1), 'utf-8') col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 12 schema -struct --- !query 12 output -0123 -1234 -2345 -3456 -4567 -5678 -6789 -78910 -891011 -9101112 - - --- !query 13 -SELECT (col1 || col2 || col3 || col4) col -FROM ( - SELECT - string(id) col1, - string(id + 1) col2, - encode(string(id + 2), 'utf-8') col3, - encode(string(id + 3), 'utf-8') col4 - FROM range(10) -) --- !query 13 schema -struct --- !query 13 output -0123 -1234 -2345 -3456 -4567 -5678 -6789 -78910 -891011 -9101112 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out new file mode 100644 index 0000000000000..a1280a09d2353 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -0,0 +1,52 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 2 + + +-- !query 0 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 1 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 From 5a771f00ff9fa7b01dcbca855844f71a43ac039f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Dec 2017 15:22:28 +0900 Subject: [PATCH 06/16] Fix --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../sql/catalyst/optimizer/expressions.scala | 7 +- .../apache/spark/sql/internal/SQLConf.scala | 10 +- .../inputs/typeCoercion/native/concat.sql | 48 ++++++- .../typeCoercion/native/concat.sql.out | 118 ++++++++++++++++-- 6 files changed, 163 insertions(+), 27 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0a4e249fe123e..4b5f56c44444d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1780,7 +1780,7 @@ options. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ecf51be84f446..8b91460ac6e51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -574,8 +574,9 @@ object TypeCoercion { case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformExpressionsUp { - case c @ Concat(children, _) - if conf.concatBinaryModeEnabled && + case c @ Concat(children, isBinaryMode) + if !conf.concatBinaryAsString && + !isBinaryMode && c.childrenResolved && children.nonEmpty && children.forall(_.dataType == BinaryType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d09fa3c42e7d3..403ee3f82b08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -636,10 +636,9 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { /** * Combine nested [[Concat]] expressions. * - * If `spark.sql.expression.concat.binaryMode.enabled` is true and all inputs are binary, - * the type coercion rule `ConcatCoercion` sets true at `isBinaryMode`s in all the nested concat - * expressions. So, this optimizer rule just passes a given concat `isBinaryMode` - * into a combined concat. + * If `spark.sql.function.concatBinaryAsString` is false and all inputs are binary, the type + * coercion rule `ConcatCoercion` sets true at `isBinaryMode`s in all the nested concat expressions. + * So, this optimizer rule just passes a given concat `isBinaryMode` into a combined concat. */ object CombineConcats extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0be3b61cbdc79..a21e01de5ec99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1044,11 +1044,11 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) - val CONCAT_BINARY_MODE_ENABLED = buildConf("spark.sql.expression.concat.binaryMode.enabled") - .doc("When this option is set to true and all inputs are binary, `functions.concat` returns " + - "an output as binary. Otherwise, it returns as string. ") + val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") + .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + + "an output as binary. Otherwise, it returns as a string. ") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") @@ -1384,7 +1384,7 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) - def concatBinaryModeEnabled: Boolean = getConf(CONCAT_BINARY_MODE_ENABLED) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 160f44bd8cd35..380d4cc2ed954 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -1,4 +1,43 @@ --- Concatenate binary inputs +-- Concatenate mixed inputs (output type is string) +SELECT (col1 || col2 || col3) col +FROM ( + SELECT + id col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4) || col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- Concatenate binary inputs (output type is binary) +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT @@ -9,12 +48,11 @@ FROM ( FROM range(10) ); --- Concatenate mixed inputs between strings and binary -SELECT (col1 || col2 || col3 || col4) col +SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT - string(id) col1, - string(id + 1) col2, + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, encode(string(id + 2), 'utf-8') col3, encode(string(id + 3), 'utf-8') col4 FROM range(10) diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index a1280a09d2353..ea9db2b42bc6a 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,8 +1,106 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 2 +-- Number of queries: 6 -- !query 0 +SELECT (col1 || col2 || col3) col +FROM ( + SELECT + id col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +012 +123 +234 +345 +456 +567 +678 +789 +8910 +91011 + + +-- !query 1 +SELECT ((col1 || col2) || (col3 || col4) || col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +prefix_0120.0 +prefix_1231.0 +prefix_2342.0 +prefix_3453.0 +prefix_4564.0 +prefix_5675.0 +prefix_6786.0 +prefix_7897.0 +prefix_89108.0 +prefix_910119.0 + + +-- !query 2 +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 2 schema +struct +-- !query 2 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 3 +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +01 +12 +23 +34 +45 +56 +67 +78 +89 +910 + + +-- !query 4 SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT @@ -12,9 +110,9 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ) --- !query 0 schema +-- !query 4 schema struct --- !query 0 output +-- !query 4 output 0123 1234 2345 @@ -27,19 +125,19 @@ struct 9101112 --- !query 1 -SELECT (col1 || col2 || col3 || col4) col +-- !query 5 +SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT - string(id) col1, - string(id + 1) col2, + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, encode(string(id + 2), 'utf-8') col3, encode(string(id + 3), 'utf-8') col4 FROM range(10) ) --- !query 1 schema -struct --- !query 1 output +-- !query 5 schema +struct +-- !query 5 output 0123 1234 2345 From 8faebdb3accafde7b80a77a854f09dff8a404599 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 21 Dec 2017 17:59:50 +0900 Subject: [PATCH 07/16] Apply more comments --- .../sql/catalyst/analysis/TypeCoercion.scala | 40 +++++++++++-------- .../expressions/stringExpressions.scala | 24 ++++++----- .../sql/catalyst/optimizer/expressions.scala | 8 +--- 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8b91460ac6e51..a0b4018162d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -568,22 +568,6 @@ object TypeCoercion { } } - /** - * When all inputs in [[Concat]] are binary, coerces an output type to binary - */ - case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformExpressionsUp { - case c @ Concat(children, isBinaryMode) - if !conf.concatBinaryAsString && - !isBinaryMode && - c.childrenResolved && - children.nonEmpty && - children.forall(_.dataType == BinaryType) => - c.copy(children, isBinaryMode = true) - } - } - /** * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. @@ -676,6 +660,30 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Concat]] children to expected ones. + * + * If `spark.sql.function.concatBinaryAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + + case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => + val newChildren = children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + + case c @ Concat(children) if conf.concatBinaryAsString => + val newChildren = children.map(Cast(_, StringType)) + c.copy(children = newChildren) + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 204b04833d4da..91c8fcf1b58c2 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -48,19 +48,25 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} > SELECT _FUNC_('Spark', 'SQL'); SparkSQL """) -case class Concat(children: Seq[Expression], isBinaryMode: Boolean = false) - extends Expression with ImplicitCastInputTypes { +case class Concat(children: Seq[Expression]) extends Expression { - def this(children: Seq[Expression]) = this(children, false) + private lazy val isBinaryMode: Boolean = dataType == BinaryType - override def inputTypes: Seq[AbstractDataType] = - if (isBinaryMode) { - Seq.fill(children.size)(BinaryType) + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess } else { - Seq.fill(children.size)(StringType) + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } + } - override def dataType: DataType = if (isBinaryMode) BinaryType else StringType + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 403ee3f82b08c..85295aff19808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -635,10 +635,6 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { /** * Combine nested [[Concat]] expressions. - * - * If `spark.sql.function.concatBinaryAsString` is false and all inputs are binary, the type - * coercion rule `ConcatCoercion` sets true at `isBinaryMode`s in all the nested concat expressions. - * So, this optimizer rule just passes a given concat `isBinaryMode` into a combined concat. */ object CombineConcats extends Rule[LogicalPlan] { @@ -647,13 +643,13 @@ object CombineConcats extends Rule[LogicalPlan] { val flattened = ArrayBuffer.empty[Expression] while (stack.nonEmpty) { stack.pop() match { - case Concat(children, _) => + case Concat(children) => stack.pushAll(children.reverse) case child => flattened += child } } - Concat(flattened, concat.isBinaryMode) + Concat(flattened) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { From de2f80881f521c2af877c538c0546a6fbb66ec9b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 22 Dec 2017 13:25:56 +0900 Subject: [PATCH 08/16] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 18 +-- .../expressions/stringExpressions.scala | 6 +- .../catalyst/analysis/TypeCoercionSuite.scala | 54 +++++++++ .../inputs/typeCoercion/native/concat.sql | 34 ++++++ .../results/string-functions.sql.out | 2 +- .../typeCoercion/native/concat.sql.out | 107 ++++++++++++++++-- 6 files changed, 201 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a0b4018162d71..3388c3fbe1223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -667,20 +667,24 @@ object TypeCoercion { * the expected types are binary. Otherwise, the expected ones are strings. */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { + + private def typeCastToString(c: Concat): Concat = { + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) + } + override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformExpressionsUp { + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + typeCastToString(c) case c @ Concat(children) if conf.concatBinaryAsString => - val newChildren = children.map(Cast(_, StringType)) - c.copy(children = newChildren) + typeCastToString(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 91c8fcf1b58c2..b0da55a4a961b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( + TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } @@ -94,7 +94,7 @@ case class Concat(children: Seq[Expression]) extends Expression { """ } - val (javaClass, initCode) = if (isBinaryMode) { + val (concatenator, initCode) = if (isBinaryMode) { (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") } else { ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") @@ -106,7 +106,7 @@ case class Concat(children: Seq[Expression]) extends Expression { ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $javaClass.concat($args); + ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 5dcd653e9b341..3661530cd622b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -869,6 +869,60 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } + test("type coercion for Concat") { + val rule = TypeCoercion.ConcatCoercion(conf) + + ruleTest(rule, + Concat(Seq(Literal("ab"), Literal("cde"))), + Concat(Seq(Literal("ab"), Literal("cde")))) + ruleTest(rule, + Concat(Seq(Literal(null), Literal("abc"))), + Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Concat(Seq(Literal(1), Literal("234"))), + Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Concat(Seq(Literal("1"), Literal("234".getBytes()))), + Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), + Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(Decimal(10)))), + Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 380d4cc2ed954..0beebec5702fd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -29,6 +29,40 @@ FROM ( FROM range(10) ); +-- turn on concatBinaryAsString +set spark.sql.function.concatBinaryAsString=true; + +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn off concatBinaryAsString +set spark.sql.function.concatBinaryAsString=false; + -- Concatenate binary inputs (output type is binary) SELECT (col1 || col2) col FROM ( diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2d9b3d7d2ca33..708e01e41a651 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -42,7 +42,7 @@ struct == Analyzed Logical Plan == col: string -Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] +Project [concat(cast(concat(cast(concat(cast(col1#xL as string), cast(col2#xL as string)) as string), cast(col3#xL as string)) as string), cast(col4#xL as string)) AS col#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index ea9db2b42bc6a..09729fdc2ec32 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 11 -- !query 0 @@ -78,6 +78,14 @@ struct -- !query 3 +set spark.sql.function.concatBinaryAsString=true +-- !query 3 schema +struct +-- !query 3 output +spark.sql.function.concatBinaryAsString true + + +-- !query 4 SELECT (col1 || col2) col FROM ( SELECT @@ -85,9 +93,90 @@ FROM ( encode(string(id + 1), 'utf-8') col2 FROM range(10) ) --- !query 3 schema +-- !query 4 schema +struct +-- !query 4 output +01 +12 +23 +34 +45 +56 +67 +78 +89 +910 + + +-- !query 5 +SELECT (col1 || col2 || col3 || col4) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 6 +SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 6 schema +struct +-- !query 6 output +0123 +1234 +2345 +3456 +4567 +5678 +6789 +78910 +891011 +9101112 + + +-- !query 7 +set spark.sql.function.concatBinaryAsString=false +-- !query 7 schema +struct +-- !query 7 output +spark.sql.function.concatBinaryAsString false + + +-- !query 8 +SELECT (col1 || col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 8 schema struct --- !query 3 output +-- !query 8 output 01 12 23 @@ -100,7 +189,7 @@ struct 910 --- !query 4 +-- !query 9 SELECT (col1 || col2 || col3 || col4) col FROM ( SELECT @@ -110,9 +199,9 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ) --- !query 4 schema +-- !query 9 schema struct --- !query 4 output +-- !query 9 output 0123 1234 2345 @@ -125,7 +214,7 @@ struct 9101112 --- !query 5 +-- !query 10 SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT @@ -135,9 +224,9 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ) --- !query 5 schema +-- !query 10 schema struct --- !query 5 output +-- !query 10 output 0123 1234 2345 From 9ddb231e4ed9723a4125dbd8dc0518b13421a197 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 22 Dec 2017 14:06:06 +0900 Subject: [PATCH 09/16] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 20 +++++++++---------- .../results/string-functions.sql.out | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3388c3fbe1223..8e82dac0eb631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -675,16 +675,16 @@ object TypeCoercion { c.copy(children = newChildren) } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - - case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => - typeCastToString(c) - - case c @ Concat(children) if conf.concatBinaryAsString => - typeCastToString(c) + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + + case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => + typeCastToString(c) + case c @ Concat(children) if conf.concatBinaryAsString => + typeCastToString(c) + } } } diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 708e01e41a651..2d9b3d7d2ca33 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -42,7 +42,7 @@ struct == Analyzed Logical Plan == col: string -Project [concat(cast(concat(cast(concat(cast(col1#xL as string), cast(col2#xL as string)) as string), cast(col3#xL as string)) as string), cast(col4#xL as string)) AS col#x] +Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x] +- SubqueryAlias __auto_generated_subquery_name +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL] +- Range (0, 10, step=1, splits=None) From 766e0e6a3851c170f993a3b58f54ff57f334afd6 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 23 Dec 2017 09:52:10 +0900 Subject: [PATCH 10/16] Update docs --- R/pkg/R/functions.R | 3 ++- python/pyspark/sql/functions.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 237ef061e8071..a41162ccdc1b4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2088,7 +2088,8 @@ setMethod("countDistinct", }) #' @details -#' \code{concat}: Concatenates multiple input string columns together into a single string column. +#' \code{concat}: Concatenates multiple input columns together into a single column. +#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. #' #' @rdname column_string_functions #' @aliases concat concat,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ddd8df3b15bf6..ae6d00e22b8ef 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1374,7 +1374,8 @@ def hash(*cols): @ignore_unicode_prefix def concat(*cols): """ - Concatenates multiple input string columns together into a single string column. + Concatenates multiple input columns together into a single column. + If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() From fbe266cfd6ac4b9fff7c6a7c82939b9e00a7ec8c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 23 Dec 2017 18:12:00 +0900 Subject: [PATCH 11/16] Fix optimizer issues --- .../sql/catalyst/optimizer/expressions.scala | 8 +++- .../sql-tests/inputs/string-functions.sql | 14 ++++++ .../results/string-functions.sql.out | 45 ++++++++++++++++++- 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 85295aff19808..293fa16b84dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -645,11 +646,16 @@ object CombineConcats extends Rule[LogicalPlan] { stack.pop() match { case Concat(children) => stack.pushAll(children.reverse) + case Cast(Concat(children), StringType, _) => + stack.pushAll(children.reverse) case child => flattened += child } } - Concat(flattened) + val newChildren = flattened.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + Concat(newChildren) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 40d0c064f5c44..0439b2a142dc7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -24,3 +24,17 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null); select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); + +-- turn on concatBinaryAsString +set spark.sql.function.concatBinaryAsString=false; + +-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false +EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2d9b3d7d2ca33..8d84b4ab4d253 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 14 -- !query 0 @@ -118,3 +118,46 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a') struct -- !query 11 output NULL NULL + + +-- !query 12 +set spark.sql.function.concatBinaryAsString=false +-- !query 12 schema +struct +-- !query 12 output +spark.sql.function.concatBinaryAsString false + + +-- !query 13 +EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 13 schema +struct +-- !query 13 output +== Parsed Logical Plan == +'Project [concat(concat('col1, 'col2), concat('col3, 'col4)) AS col#x] ++- 'SubqueryAlias __auto_generated_subquery_name + +- 'Project ['string('id) AS col1#x, 'string(('id + 1)) AS col2#x, 'encode('string(('id + 2)), utf-8) AS col3#x, 'encode('string(('id + 3)), utf-8) AS col4#x] + +- 'UnresolvedTableValuedFunction range, [10] + +== Analyzed Logical Plan == +col: string +Project [concat(concat(col1#x, col2#x), cast(concat(col3#x, col4#x) as string)) AS col#x] ++- SubqueryAlias __auto_generated_subquery_name + +- Project [cast(id#xL as string) AS col1#x, cast((id#xL + cast(1 as bigint)) as string) AS col2#x, encode(cast((id#xL + cast(2 as bigint)) as string), utf-8) AS col3#x, encode(cast((id#xL + cast(3 as bigint)) as string), utf-8) AS col4#x] + +- Range (0, 10, step=1, splits=None) + +== Optimized Logical Plan == +Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] ++- Range (0, 10, step=1, splits=None) + +== Physical Plan == +*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2) From 179c6fdf261d3392d4d3477a68f7fde60d190435 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Dec 2017 11:36:42 +0900 Subject: [PATCH 12/16] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 ++--- .../sql/catalyst/optimizer/expressions.scala | 13 +++++++------ .../optimizer/CombineConcatsSuite.scala | 14 ++++++++++++-- .../sql-tests/inputs/string-functions.sql | 2 +- .../results/string-functions.sql.out | 19 +------------------ 5 files changed, 23 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8e82dac0eb631..dab3b05a65f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -680,9 +680,8 @@ object TypeCoercion { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => - typeCastToString(c) - case c @ Concat(children) if conf.concatBinaryAsString => + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => typeCastToString(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 293fa16b84dda..64fa3cf3726d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -646,16 +646,17 @@ object CombineConcats extends Rule[LogicalPlan] { stack.pop() match { case Concat(children) => stack.pushAll(children.reverse) - case Cast(Concat(children), StringType, _) => - stack.pushAll(children.reverse) + // If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly + // have `Concat`s with binary output. Since `TypeCoercion` casts them into strings, + // we need to handle the case to combine all nested `Concat`s. + case c @ Cast(Concat(children), StringType, _) => + val newChildren = children.map { e => c.copy(child = e) } + stack.pushAll(newChildren.reverse) case child => flattened += child } } - val newChildren = flattened.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - Concat(newChildren) + Concat(flattened) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala index 412e199dfaae3..441c15340a778 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StringType class CombineConcatsSuite extends PlanTest { @@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest { comparePlans(actual, correctAnswer) } + def str(s: String): Literal = Literal(s) + def binary(s: String): Literal = Literal(s.getBytes) + test("combine nested Concat exprs") { - def str(s: String): Literal = Literal(s, StringType) assertEquivalent( Concat( Concat(str("a") :: str("b") :: Nil) :: @@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest { Nil), Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) } + + test("combine string and binary exprs") { + assertEquivalent( + Concat( + Concat(str("a") :: str("b") :: Nil) :: + Concat(binary("c") :: binary("d") :: Nil) :: + Nil), + Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil)) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 0439b2a142dc7..3ed3db8c85134 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -29,7 +29,7 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); set spark.sql.function.concatBinaryAsString=false; -- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false -EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT string(id) col1, diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 8d84b4ab4d253..3f182c5c50c39 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -129,7 +129,7 @@ spark.sql.function.concatBinaryAsString false -- !query 13 -EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT string(id) col1, @@ -141,23 +141,6 @@ FROM ( -- !query 13 schema struct -- !query 13 output -== Parsed Logical Plan == -'Project [concat(concat('col1, 'col2), concat('col3, 'col4)) AS col#x] -+- 'SubqueryAlias __auto_generated_subquery_name - +- 'Project ['string('id) AS col1#x, 'string(('id + 1)) AS col2#x, 'encode('string(('id + 2)), utf-8) AS col3#x, 'encode('string(('id + 3)), utf-8) AS col4#x] - +- 'UnresolvedTableValuedFunction range, [10] - -== Analyzed Logical Plan == -col: string -Project [concat(concat(col1#x, col2#x), cast(concat(col3#x, col4#x) as string)) AS col#x] -+- SubqueryAlias __auto_generated_subquery_name - +- Project [cast(id#xL as string) AS col1#x, cast((id#xL + cast(1 as bigint)) as string) AS col2#x, encode(cast((id#xL + cast(2 as bigint)) as string), utf-8) AS col3#x, encode(cast((id#xL + cast(3 as bigint)) as string), utf-8) AS col4#x] - +- Range (0, 10, step=1, splits=None) - -== Optimized Logical Plan == -Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- Range (0, 10, step=1, splits=None) - == Physical Plan == *Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] +- *Range (0, 10, step=1, splits=2) From 1c94418c3aa5fe6610914a88b3b2ef3919b56ac4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Dec 2017 12:19:41 +0900 Subject: [PATCH 13/16] Fix --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index dab3b05a65f7d..a25c232462f47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -668,13 +668,6 @@ object TypeCoercion { */ case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule { - private def typeCastToString(c: Concat): Concat = { - val newChildren = c.children.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) - } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => p transformExpressionsUp { // Skip nodes if unresolved or empty children @@ -682,7 +675,10 @@ object TypeCoercion { case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => - typeCastToString(c) + val newChildren = c.children.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + c.copy(children = newChildren) } } } From 1e13b702185246dc509847fc152a68cfe8f2c954 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Dec 2017 22:19:45 +0900 Subject: [PATCH 14/16] Fix --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 1 - .../src/test/resources/sql-tests/inputs/string-functions.sql | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a25c232462f47..f1eccadd34e58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -672,7 +672,6 @@ object TypeCoercion { p transformExpressionsUp { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || !children.map(_.dataType).forall(_ == BinaryType) => val newChildren = c.children.map { e => diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 3ed3db8c85134..e896ece437458 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -25,7 +25,7 @@ select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a'); select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null); select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); --- turn on concatBinaryAsString +-- turn off concatBinaryAsString set spark.sql.function.concatBinaryAsString=false; -- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false From 57a9d1e9da21d56873c97eac08797499199a0c7b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 29 Dec 2017 13:15:24 +0900 Subject: [PATCH 15/16] Fix --- .../sql/catalyst/optimizer/expressions.scala | 6 +++++- .../sql-tests/inputs/string-functions.sql | 9 +++++++++ .../results/string-functions.sql.out | 19 ++++++++++++++++++- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 64fa3cf3726d1..af646ed3cc32e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -660,7 +660,11 @@ object CombineConcats extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { - case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) => + case concat: Concat if concat.children.exists { + case c: Concat => true + case c @ Cast(Concat(children), StringType, _) => true + case _ => false + } => flattenConcats(concat) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index e896ece437458..4113734e1707e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -38,3 +38,12 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +EXPLAIN SELECT (col1 || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 3f182c5c50c39..d5f8705a35ed6 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 15 -- !query 0 @@ -144,3 +144,20 @@ struct == Physical Plan == *Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] +- *Range (0, 10, step=1, splits=2) + + +-- !query 14 +EXPLAIN SELECT (col1 || (col3 || col4)) col +FROM ( + SELECT + string(id) col1, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 14 schema +struct +-- !query 14 output +== Physical Plan == +*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] ++- *Range (0, 10, step=1, splits=2) From b9febbdd928c4fa2cba29eaaef85ffcc173c1b44 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 29 Dec 2017 23:15:27 +0900 Subject: [PATCH 16/16] Fix --- .../spark/sql/catalyst/optimizer/expressions.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index af646ed3cc32e..7d830bbb7dc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -659,12 +659,14 @@ object CombineConcats extends Rule[LogicalPlan] { Concat(flattened) } + private def hasNestedConcats(concat: Concat): Boolean = concat.children.exists { + case c: Concat => true + case c @ Cast(Concat(children), StringType, _) => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { - case concat: Concat if concat.children.exists { - case c: Concat => true - case c @ Cast(Concat(children), StringType, _) => true - case _ => false - } => + case concat: Concat if hasNestedConcats(concat) => flattenConcats(concat) } }