Skip to content

Commit e8af7e8

Browse files
maropugatorsmile
authored andcommitted
[SPARK-22937][SQL] SQL elt output binary for binary inputs
## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #20135 from maropu/SPARK-22937.
1 parent ea95683 commit e8af7e8

File tree

7 files changed

+281
-17
lines changed

7 files changed

+281
-17
lines changed

docs/sql-programming-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,8 @@ options.
17831783

17841784
- Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.
17851785

1786+
- Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`.
1787+
17861788
## Upgrading From Spark SQL 2.1 to 2.2
17871789

17881790
- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ object TypeCoercion {
5454
BooleanEquality ::
5555
FunctionArgumentConversion ::
5656
ConcatCoercion(conf) ::
57+
EltCoercion(conf) ::
5758
CaseWhenCoercion ::
5859
IfCoercion ::
5960
StackCoercion ::
@@ -684,6 +685,34 @@ object TypeCoercion {
684685
}
685686
}
686687

688+
/**
689+
* Coerces the types of [[Elt]] children to expected ones.
690+
*
691+
* If `spark.sql.function.eltOutputAsString` is false and all children types are binary,
692+
* the expected types are binary. Otherwise, the expected ones are strings.
693+
*/
694+
case class EltCoercion(conf: SQLConf) extends TypeCoercionRule {
695+
696+
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
697+
p transformExpressionsUp {
698+
// Skip nodes if unresolved or not enough children
699+
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
700+
case c @ Elt(children) =>
701+
val index = children.head
702+
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
703+
val newInputs = if (conf.eltOutputAsString ||
704+
!children.tail.map(_.dataType).forall(_ == BinaryType)) {
705+
children.tail.map { e =>
706+
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
707+
}
708+
} else {
709+
children.tail
710+
}
711+
c.copy(children = newIndex +: newInputs)
712+
}
713+
}
714+
}
715+
687716
/**
688717
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
689718
* to TimeAdd/TimeSub

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression])
271271
}
272272
}
273273

274+
/**
275+
* An expression that returns the `n`-th input in given inputs.
276+
* If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string.
277+
* If any input is null, `elt` returns null.
278+
*/
274279
// scalastyle:off line.size.limit
275280
@ExpressionDescription(
276-
usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.",
281+
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
277282
examples = """
278283
Examples:
279284
> SELECT _FUNC_(1, 'scala', 'java');
280285
scala
281286
""")
282287
// scalastyle:on line.size.limit
283-
case class Elt(children: Seq[Expression])
284-
extends Expression with ImplicitCastInputTypes {
288+
case class Elt(children: Seq[Expression]) extends Expression {
285289

286290
private lazy val indexExpr = children.head
287-
private lazy val stringExprs = children.tail.toArray
291+
private lazy val inputExprs = children.tail.toArray
288292

289293
/** This expression is always nullable because it returns null if index is out of range. */
290294
override def nullable: Boolean = true
291295

292-
override def dataType: DataType = StringType
293-
294-
override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType)
296+
override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType)
295297

296298
override def checkInputDataTypes(): TypeCheckResult = {
297299
if (children.size < 2) {
298300
TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
299301
} else {
300-
super[ImplicitCastInputTypes].checkInputDataTypes()
302+
val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType))
303+
if (indexType != IntegerType) {
304+
return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " +
305+
s"have IntegerType, but it's $indexType")
306+
}
307+
if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
308+
return TypeCheckResult.TypeCheckFailure(
309+
s"input to function $prettyName should have StringType or BinaryType, but it's " +
310+
inputTypes.map(_.simpleString).mkString("[", ", ", "]"))
311+
}
312+
TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName")
301313
}
302314
}
303315

@@ -307,35 +319,35 @@ case class Elt(children: Seq[Expression])
307319
null
308320
} else {
309321
val index = indexObj.asInstanceOf[Int]
310-
if (index <= 0 || index > stringExprs.length) {
322+
if (index <= 0 || index > inputExprs.length) {
311323
null
312324
} else {
313-
stringExprs(index - 1).eval(input)
325+
inputExprs(index - 1).eval(input)
314326
}
315327
}
316328
}
317329

318330
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
319331
val index = indexExpr.genCode(ctx)
320-
val strings = stringExprs.map(_.genCode(ctx))
332+
val inputs = inputExprs.map(_.genCode(ctx))
321333
val indexVal = ctx.freshName("index")
322334
val indexMatched = ctx.freshName("eltIndexMatched")
323335

324-
val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal")
336+
val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
325337

326-
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
338+
val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
327339
s"""
328340
|if ($indexVal == ${index + 1}) {
329341
| ${eval.code}
330-
| $stringVal = ${eval.isNull} ? null : ${eval.value};
342+
| $inputVal = ${eval.isNull} ? null : ${eval.value};
331343
| $indexMatched = true;
332344
| continue;
333345
|}
334346
""".stripMargin
335347
}
336348

337349
val codes = ctx.splitExpressionsWithCurrentInputs(
338-
expressions = assignStringValue,
350+
expressions = assignInputValue,
339351
funcName = "eltFunc",
340352
extraArguments = ("int", indexVal) :: Nil,
341353
returnType = ctx.JAVA_BOOLEAN,
@@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression])
361373
|${index.code}
362374
|final int $indexVal = ${index.value};
363375
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
364-
|$stringVal = null;
376+
|$inputVal = null;
365377
|do {
366378
| $codes
367379
|} while (false);
368-
|final UTF8String ${ev.value} = $stringVal;
380+
|final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
369381
|final boolean ${ev.isNull} = ${ev.value} == null;
370382
""".stripMargin)
371383
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,12 @@ object SQLConf {
10521052
.booleanConf
10531053
.createWithDefault(false)
10541054

1055+
val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString")
1056+
.doc("When this option is set to false and all inputs are binary, `elt` returns " +
1057+
"an output as binary. Otherwise, it returns as a string. ")
1058+
.booleanConf
1059+
.createWithDefault(false)
1060+
10551061
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
10561062
buildConf("spark.sql.streaming.continuous.executorQueueSize")
10571063
.internal()
@@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging {
14121418

14131419
def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
14141420

1421+
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
1422+
14151423
def partitionOverwriteMode: PartitionOverwriteMode.Value =
14161424
PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
14171425

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest {
923923
}
924924
}
925925

926+
test("type coercion for Elt") {
927+
val rule = TypeCoercion.EltCoercion(conf)
928+
929+
ruleTest(rule,
930+
Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))),
931+
Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))))
932+
ruleTest(rule,
933+
Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))),
934+
Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde"))))
935+
ruleTest(rule,
936+
Elt(Seq(Literal(2), Literal(null), Literal("abc"))),
937+
Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc"))))
938+
ruleTest(rule,
939+
Elt(Seq(Literal(2), Literal(1), Literal("234"))),
940+
Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234"))))
941+
ruleTest(rule,
942+
Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))),
943+
Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
944+
Cast(Literal(0.1), StringType))))
945+
ruleTest(rule,
946+
Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))),
947+
Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
948+
Cast(Literal(3.toShort), StringType))))
949+
ruleTest(rule,
950+
Elt(Seq(Literal(1), Literal(1L), Literal(0.1))),
951+
Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
952+
ruleTest(rule,
953+
Elt(Seq(Literal(1), Literal(Decimal(10)))),
954+
Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType))))
955+
ruleTest(rule,
956+
Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))),
957+
Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType))))
958+
ruleTest(rule,
959+
Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))),
960+
Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
961+
ruleTest(rule,
962+
Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
963+
Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType),
964+
Cast(Literal(new Timestamp(0)), StringType))))
965+
966+
withSQLConf("spark.sql.function.eltOutputAsString" -> "true") {
967+
ruleTest(rule,
968+
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
969+
Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType),
970+
Cast(Literal("456".getBytes), StringType))))
971+
}
972+
973+
withSQLConf("spark.sql.function.eltOutputAsString" -> "false") {
974+
ruleTest(rule,
975+
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
976+
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))))
977+
}
978+
}
979+
926980
test("BooleanEquality type cast") {
927981
val be = TypeCoercion.BooleanEquality
928982
// Use something more than a literal to avoid triggering the simplification rules.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
-- Mixed inputs (output type is string)
2+
SELECT elt(2, col1, col2, col3, col4, col5) col
3+
FROM (
4+
SELECT
5+
'prefix_' col1,
6+
id col2,
7+
string(id + 1) col3,
8+
encode(string(id + 2), 'utf-8') col4,
9+
CAST(id AS DOUBLE) col5
10+
FROM range(10)
11+
);
12+
13+
SELECT elt(3, col1, col2, col3, col4) col
14+
FROM (
15+
SELECT
16+
string(id) col1,
17+
string(id + 1) col2,
18+
encode(string(id + 2), 'utf-8') col3,
19+
encode(string(id + 3), 'utf-8') col4
20+
FROM range(10)
21+
);
22+
23+
-- turn on eltOutputAsString
24+
set spark.sql.function.eltOutputAsString=true;
25+
26+
SELECT elt(1, col1, col2) col
27+
FROM (
28+
SELECT
29+
encode(string(id), 'utf-8') col1,
30+
encode(string(id + 1), 'utf-8') col2
31+
FROM range(10)
32+
);
33+
34+
-- turn off eltOutputAsString
35+
set spark.sql.function.eltOutputAsString=false;
36+
37+
-- Elt binary inputs (output type is binary)
38+
SELECT elt(2, col1, col2) col
39+
FROM (
40+
SELECT
41+
encode(string(id), 'utf-8') col1,
42+
encode(string(id + 1), 'utf-8') col2
43+
FROM range(10)
44+
);

0 commit comments

Comments
 (0)