Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,8 @@ options.

- Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.

- Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`.

## Upgrading From Spark SQL 2.1 to 2.2

- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ object TypeCoercion {
BooleanEquality ::
FunctionArgumentConversion ::
ConcatCoercion(conf) ::
EltCoercion(conf) ::
CaseWhenCoercion ::
IfCoercion ::
StackCoercion ::
Expand Down Expand Up @@ -684,6 +685,34 @@ object TypeCoercion {
}
}

/**
* Coerces the types of [[Elt]] children to expected ones.
*
* If `spark.sql.function.eltOutputAsString` is false and all children types are binary,
* the expected types are binary. Otherwise, the expected ones are strings.
*/
case class EltCoercion(conf: SQLConf) extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
!children.tail.map(_.dataType).forall(_ == BinaryType)) {
children.tail.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
}
} else {
children.tail
}
c.copy(children = newIndex +: newInputs)
}
}
}

/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression])
}
}

/**
* An expression that returns the `n`-th input in given inputs.
* If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string.
* If any input is null, `elt` returns null.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.",
usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.",
examples = """
Examples:
> SELECT _FUNC_(1, 'scala', 'java');
scala
""")
// scalastyle:on line.size.limit
case class Elt(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes {
case class Elt(children: Seq[Expression]) extends Expression {

private lazy val indexExpr = children.head
private lazy val stringExprs = children.tail.toArray
private lazy val inputExprs = children.tail.toArray

/** This expression is always nullable because it returns null if index is out of range. */
override def nullable: Boolean = true

override def dataType: DataType = StringType

override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType)
override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return null of BinaryType when eltOutputAsString is false and there's only 1 parameter for Elt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, but I miss your point. What's the type of the only 1 parameter?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant for expression elt(1) with eltOutputAsString as false, since the index is out of range. Is it better to make the result null of BinaryType? Now I think your solution makes more sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We issue an exception when the input argument is 1.


override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 2) {
TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
} else {
super[ImplicitCastInputTypes].checkInputDataTypes()
val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType))
if (indexType != IntegerType) {
return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " +
s"have IntegerType, but it's $indexType")
}
if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have StringType or BinaryType, but it's " +
inputTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName")
}
}

Expand All @@ -307,35 +319,35 @@ case class Elt(children: Seq[Expression])
null
} else {
val index = indexObj.asInstanceOf[Int]
if (index <= 0 || index > stringExprs.length) {
if (index <= 0 || index > inputExprs.length) {
null
} else {
stringExprs(index - 1).eval(input)
inputExprs(index - 1).eval(input)
}
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val index = indexExpr.genCode(ctx)
val strings = stringExprs.map(_.genCode(ctx))
val inputs = inputExprs.map(_.genCode(ctx))
val indexVal = ctx.freshName("index")
val indexMatched = ctx.freshName("eltIndexMatched")

val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal")
val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")

val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
s"""
|if ($indexVal == ${index + 1}) {
| ${eval.code}
| $stringVal = ${eval.isNull} ? null : ${eval.value};
| $inputVal = ${eval.isNull} ? null : ${eval.value};
| $indexMatched = true;
| continue;
|}
""".stripMargin
}

val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = assignStringValue,
expressions = assignInputValue,
funcName = "eltFunc",
extraArguments = ("int", indexVal) :: Nil,
returnType = ctx.JAVA_BOOLEAN,
Expand All @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression])
|${index.code}
|final int $indexVal = ${index.value};
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
|$stringVal = null;
|$inputVal = null;
|do {
| $codes
|} while (false);
|final UTF8String ${ev.value} = $stringVal;
|final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,12 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString")
.doc("When this option is set to false and all inputs are binary, `elt` returns " +
"an output as binary. Otherwise, it returns as a string. ")
.booleanConf
.createWithDefault(false)

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging {

def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)

def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)

def partitionOverwriteMode: PartitionOverwriteMode.Value =
PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest {
}
}

test("type coercion for Elt") {
val rule = TypeCoercion.EltCoercion(conf)

ruleTest(rule,
Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))),
Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))))
ruleTest(rule,
Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))),
Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde"))))
ruleTest(rule,
Elt(Seq(Literal(2), Literal(null), Literal("abc"))),
Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc"))))
ruleTest(rule,
Elt(Seq(Literal(2), Literal(1), Literal("234"))),
Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234"))))
ruleTest(rule,
Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))),
Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
Cast(Literal(0.1), StringType))))
ruleTest(rule,
Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))),
Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
Cast(Literal(3.toShort), StringType))))
ruleTest(rule,
Elt(Seq(Literal(1), Literal(1L), Literal(0.1))),
Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
ruleTest(rule,
Elt(Seq(Literal(1), Literal(Decimal(10)))),
Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType))))
ruleTest(rule,
Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))),
Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType))))
ruleTest(rule,
Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))),
Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
ruleTest(rule,
Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType),
Cast(Literal(new Timestamp(0)), StringType))))

withSQLConf("spark.sql.function.eltOutputAsString" -> "true") {
ruleTest(rule,
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType),
Cast(Literal("456".getBytes), StringType))))
}

withSQLConf("spark.sql.function.eltOutputAsString" -> "false") {
ruleTest(rule,
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))))
}
}

test("BooleanEquality type cast") {
val be = TypeCoercion.BooleanEquality
// Use something more than a literal to avoid triggering the simplification rules.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
-- Mixed inputs (output type is string)
SELECT elt(2, col1, col2, col3, col4, col5) col
FROM (
SELECT
'prefix_' col1,
id col2,
string(id + 1) col3,
encode(string(id + 2), 'utf-8') col4,
CAST(id AS DOUBLE) col5
FROM range(10)
);

SELECT elt(3, col1, col2, col3, col4) col
FROM (
SELECT
string(id) col1,
string(id + 1) col2,
encode(string(id + 2), 'utf-8') col3,
encode(string(id + 3), 'utf-8') col4
FROM range(10)
);

-- turn on eltOutputAsString
set spark.sql.function.eltOutputAsString=true;

SELECT elt(1, col1, col2) col
FROM (
SELECT
encode(string(id), 'utf-8') col1,
encode(string(id + 1), 'utf-8') col2
FROM range(10)
);

-- turn off eltOutputAsString
set spark.sql.function.eltOutputAsString=false;

-- Elt binary inputs (output type is binary)
SELECT elt(2, col1, col2) col
FROM (
SELECT
encode(string(id), 'utf-8') col1,
encode(string(id + 1), 'utf-8') col2
FROM range(10)
);
Loading