Skip to content

Commit 10e0b61

Browse files
jovanpavl-dbMaxGekk
authored andcommitted
[SPARK-49670][SQL] Enable trim collation for all passthrough expressions
### What changes were proposed in this pull request? Enabling usage of passthrough expressions for trim collation. As with this change there will be more expressions that will support trim collation then those who don't in follow up default for support trim collation will be set on true. **NOTE: it looks like a tons of changes but only changes are: for each expression set supportsTrimCollation=true and add tests.** ### Why are the changes needed? So that all expressions could be used with trim collation ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add tests to CollationSqlExpressionsSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#48739 from jovanpavl-db/implement_passthrough_functions. Authored-by: Jovan Pavlovic <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 74c3757 commit 10e0b61

File tree

22 files changed

+634
-110
lines changed

22 files changed

+634
-110
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ abstract class TypeCoercionHelper {
318318
}
319319

320320
case aj @ ArrayJoin(arr, d, nr)
321-
if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) &&
321+
if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)).
322+
acceptsType(arr.dataType) &&
322323
ArrayType.acceptsType(arr.dataType) =>
323324
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
324325
implicitCast(arr, ArrayType(StringType, containsNull)) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ case class CallMethodViaReflection(
115115
"requiredType" -> toSQLType(
116116
TypeCollection(BooleanType, ByteType, ShortType,
117117
IntegerType, LongType, FloatType, DoubleType,
118-
StringTypeWithCollation)),
118+
StringTypeWithCollation(supportsTrimCollation = true))),
119119
"inputSql" -> toSQLExpr(e),
120120
"inputType" -> toSQLType(e.dataType))
121121
)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ object ExprUtils extends EvalHelper with QueryErrorsBase {
6161

6262
def convertToMapData(exp: Expression): Map[String, String] = exp match {
6363
case m: CreateMap
64-
if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)
64+
if AbstractMapType(
65+
StringTypeWithCollation(supportsTrimCollation = true),
66+
StringTypeWithCollation(supportsTrimCollation = true))
6567
.acceptsType(m.dataType) =>
6668
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
6769
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ case class Reverse(child: Expression)
13541354
override def nullIntolerant: Boolean = true
13551355
// Input types are utilized by type coercion in ImplicitTypeCasts.
13561356
override def inputTypes: Seq[AbstractDataType] =
1357-
Seq(TypeCollection(StringTypeWithCollation, ArrayType))
1357+
Seq(TypeCollection(StringTypeWithCollation(supportsTrimCollation = true), ArrayType))
13581358

13591359
override def dataType: DataType = child.dataType
13601360

@@ -2127,12 +2127,12 @@ case class ArrayJoin(
21272127
this(array, delimiter, Some(nullReplacement))
21282128

21292129
override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
2130-
Seq(AbstractArrayType(StringTypeWithCollation),
2131-
StringTypeWithCollation,
2132-
StringTypeWithCollation)
2130+
Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)),
2131+
StringTypeWithCollation(supportsTrimCollation = true),
2132+
StringTypeWithCollation(supportsTrimCollation = true))
21332133
} else {
2134-
Seq(AbstractArrayType(StringTypeWithCollation),
2135-
StringTypeWithCollation)
2134+
Seq(AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)),
2135+
StringTypeWithCollation(supportsTrimCollation = true))
21362136
}
21372137

21382138
override def children: Seq[Expression] = if (nullReplacement.isDefined) {
@@ -2855,7 +2855,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
28552855
with QueryErrorsBase {
28562856

28572857
private def allowedTypes: Seq[AbstractDataType] =
2858-
Seq(StringTypeWithCollation, BinaryType, ArrayType)
2858+
Seq(StringTypeWithCollation(supportsTrimCollation = true), BinaryType, ArrayType)
28592859

28602860
final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
28612861

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ case class CsvToStructs(
8787
copy(timeZoneId = Option(timeZoneId))
8888
}
8989

90-
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil
90+
override def inputTypes: Seq[AbstractDataType] =
91+
StringTypeWithCollation(supportsTrimCollation = true) :: Nil
9192

9293
override def prettyName: String = "from_csv"
9394

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
971971
override def dataType: DataType = SQLConf.get.defaultStringType
972972

973973
override def inputTypes: Seq[AbstractDataType] =
974-
Seq(TimestampType, StringTypeWithCollation)
974+
Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true))
975975

976976
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
977977
copy(timeZoneId = Option(timeZoneId))
@@ -1279,10 +1279,13 @@ abstract class ToTimestamp
12791279
override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType
12801280

12811281
override def inputTypes: Seq[AbstractDataType] =
1282-
Seq(TypeCollection(
1283-
StringTypeWithCollation, DateType, TimestampType, TimestampNTZType
1284-
),
1285-
StringTypeWithCollation)
1282+
Seq(
1283+
TypeCollection(
1284+
StringTypeWithCollation(supportsTrimCollation = true),
1285+
DateType,
1286+
TimestampType,
1287+
TimestampNTZType),
1288+
StringTypeWithCollation(supportsTrimCollation = true))
12861289

12871290
override def dataType: DataType = LongType
12881291
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true
@@ -1454,7 +1457,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
14541457
override def nullable: Boolean = true
14551458

14561459
override def inputTypes: Seq[AbstractDataType] =
1457-
Seq(LongType, StringTypeWithCollation)
1460+
Seq(LongType, StringTypeWithCollation(supportsTrimCollation = true))
14581461

14591462
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
14601463
copy(timeZoneId = Option(timeZoneId))
@@ -1566,7 +1569,7 @@ case class NextDay(
15661569
def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)
15671570

15681571
override def inputTypes: Seq[AbstractDataType] =
1569-
Seq(DateType, StringTypeWithCollation)
1572+
Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true))
15701573

15711574
override def dataType: DataType = DateType
15721575
override def nullable: Boolean = true
@@ -1781,7 +1784,7 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes {
17811784
val funcName: String
17821785

17831786
override def inputTypes: Seq[AbstractDataType] =
1784-
Seq(TimestampType, StringTypeWithCollation)
1787+
Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true))
17851788
override def dataType: DataType = TimestampType
17861789

17871790
override def nullSafeEval(time: Any, timezone: Any): Any = {
@@ -2123,8 +2126,11 @@ case class ParseToDate(
21232126
// Note: ideally this function should only take string input, but we allow more types here to
21242127
// be backward compatible.
21252128
TypeCollection(
2126-
StringTypeWithCollation, DateType, TimestampType, TimestampNTZType) +:
2127-
format.map(_ => StringTypeWithCollation).toSeq
2129+
StringTypeWithCollation(supportsTrimCollation = true),
2130+
DateType,
2131+
TimestampType,
2132+
TimestampNTZType) +:
2133+
format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq
21282134
}
21292135

21302136
override protected def withNewChildrenInternal(
@@ -2195,10 +2201,15 @@ case class ParseToTimestamp(
21952201
override def inputTypes: Seq[AbstractDataType] = {
21962202
// Note: ideally this function should only take string input, but we allow more types here to
21972203
// be backward compatible.
2198-
val types = Seq(StringTypeWithCollation, DateType, TimestampType, TimestampNTZType)
2204+
val types = Seq(
2205+
StringTypeWithCollation(
2206+
supportsTrimCollation = true),
2207+
DateType,
2208+
TimestampType,
2209+
TimestampNTZType)
21992210
TypeCollection(
22002211
(if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _*
2201-
) +: format.map(_ => StringTypeWithCollation).toSeq
2212+
) +: format.map(_ => StringTypeWithCollation(supportsTrimCollation = true)).toSeq
22022213
}
22032214

22042215
override protected def withNewChildrenInternal(
@@ -2329,7 +2340,7 @@ case class TruncDate(date: Expression, format: Expression)
23292340
override def right: Expression = format
23302341

23312342
override def inputTypes: Seq[AbstractDataType] =
2332-
Seq(DateType, StringTypeWithCollation)
2343+
Seq(DateType, StringTypeWithCollation(supportsTrimCollation = true))
23332344
override def dataType: DataType = DateType
23342345
override def prettyName: String = "trunc"
23352346
override val instant = date
@@ -2399,7 +2410,7 @@ case class TruncTimestamp(
23992410
override def right: Expression = timestamp
24002411

24012412
override def inputTypes: Seq[AbstractDataType] =
2402-
Seq(StringTypeWithCollation, TimestampType)
2413+
Seq(StringTypeWithCollation(supportsTrimCollation = true), TimestampType)
24032414
override def dataType: TimestampType = TimestampType
24042415
override def prettyName: String = "date_trunc"
24052416
override val instant = timestamp
@@ -2800,7 +2811,7 @@ case class MakeTimestamp(
28002811
// casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0).
28012812
override def inputTypes: Seq[AbstractDataType] =
28022813
Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++
2803-
timezone.map(_ => StringTypeWithCollation)
2814+
timezone.map(_ => StringTypeWithCollation(supportsTrimCollation = true))
28042815
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true
28052816

28062817
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
@@ -3333,7 +3344,10 @@ case class ConvertTimezone(
33333344
override def third: Expression = sourceTs
33343345

33353346
override def inputTypes: Seq[AbstractDataType] =
3336-
Seq(StringTypeWithCollation, StringTypeWithCollation, TimestampNTZType)
3347+
Seq(
3348+
StringTypeWithCollation(supportsTrimCollation = true),
3349+
StringTypeWithCollation(supportsTrimCollation = true),
3350+
TimestampNTZType)
33373351
override def dataType: DataType = TimestampNTZType
33383352

33393353
override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ case class GetJsonObject(json: Expression, path: Expression)
133133
override def left: Expression = json
134134
override def right: Expression = path
135135
override def inputTypes: Seq[AbstractDataType] =
136-
Seq(StringTypeWithCollation, StringTypeWithCollation)
136+
Seq(
137+
StringTypeWithCollation(supportsTrimCollation = true),
138+
StringTypeWithCollation(supportsTrimCollation = true))
137139
override def dataType: DataType = SQLConf.get.defaultStringType
138140
override def nullable: Boolean = true
139141
override def prettyName: String = "get_json_object"
@@ -490,7 +492,8 @@ case class JsonTuple(children: Seq[Expression])
490492
)
491493
} else if (
492494
children.forall(
493-
child => StringTypeWithCollation.acceptsType(child.dataType))) {
495+
child => StringTypeWithCollation(supportsTrimCollation = true)
496+
.acceptsType(child.dataType))) {
494497
TypeCheckResult.TypeCheckSuccess
495498
} else {
496499
DataTypeMismatch(
@@ -709,7 +712,8 @@ case class JsonToStructs(
709712
|""".stripMargin)
710713
}
711714

712-
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil
715+
override def inputTypes: Seq[AbstractDataType] =
716+
StringTypeWithCollation(supportsTrimCollation = true) :: Nil
713717

714718
override def sql: String = schema match {
715719
case _: MapType => "entries"
@@ -922,7 +926,8 @@ case class LengthOfJsonArray(child: Expression)
922926
with ExpectsInputTypes
923927
with RuntimeReplaceable {
924928

925-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
929+
override def inputTypes: Seq[AbstractDataType] =
930+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
926931
override def dataType: DataType = IntegerType
927932
override def nullable: Boolean = true
928933
override def prettyName: String = "json_array_length"
@@ -967,7 +972,8 @@ case class JsonObjectKeys(child: Expression)
967972
with ExpectsInputTypes
968973
with RuntimeReplaceable {
969974

970-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
975+
override def inputTypes: Seq[AbstractDataType] =
976+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
971977
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
972978
override def nullable: Boolean = true
973979
override def prettyName: String = "json_object_keys"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,11 @@ case class Mask(
193193
*/
194194
override def inputTypes: Seq[AbstractDataType] =
195195
Seq(
196-
StringTypeWithCollation,
197-
StringTypeWithCollation,
198-
StringTypeWithCollation,
199-
StringTypeWithCollation,
200-
StringTypeWithCollation)
196+
StringTypeWithCollation(supportsTrimCollation = true),
197+
StringTypeWithCollation(supportsTrimCollation = true),
198+
StringTypeWithCollation(supportsTrimCollation = true),
199+
StringTypeWithCollation(supportsTrimCollation = true),
200+
StringTypeWithCollation(supportsTrimCollation = true))
201201

202202
override def nullable: Boolean = true
203203

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ case class Conv(
455455
override def second: Expression = fromBaseExpr
456456
override def third: Expression = toBaseExpr
457457
override def inputTypes: Seq[AbstractDataType] =
458-
Seq(StringTypeWithCollation, IntegerType, IntegerType)
458+
Seq(StringTypeWithCollation(supportsTrimCollation = true), IntegerType, IntegerType)
459459
override def dataType: DataType = first.dataType
460460
override def nullable: Boolean = true
461461

@@ -1118,7 +1118,7 @@ case class Hex(child: Expression)
11181118
override def nullIntolerant: Boolean = true
11191119

11201120
override def inputTypes: Seq[AbstractDataType] =
1121-
Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation))
1121+
Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation(supportsTrimCollation = true)))
11221122

11231123
override def dataType: DataType = child.dataType match {
11241124
case st: StringType => st
@@ -1163,7 +1163,8 @@ case class Unhex(child: Expression, failOnError: Boolean = false)
11631163

11641164
def this(expr: Expression) = this(expr, false)
11651165

1166-
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation)
1166+
override def inputTypes: Seq[AbstractDataType] =
1167+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
11671168

11681169
override def nullable: Boolean = true
11691170
override def dataType: DataType = BinaryType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType:
8585
override def foldable: Boolean = false
8686
override def nullable: Boolean = true
8787
override def inputTypes: Seq[AbstractDataType] =
88-
Seq(StringTypeWithCollation, AbstractMapType(StringTypeWithCollation, StringTypeWithCollation))
88+
Seq(
89+
StringTypeWithCollation(supportsTrimCollation = true),
90+
AbstractMapType(
91+
StringTypeWithCollation(supportsTrimCollation = true),
92+
StringTypeWithCollation(supportsTrimCollation = true)
93+
))
8994

9095
override def left: Expression = errorClass
9196
override def right: Expression = errorParms
@@ -416,8 +421,8 @@ case class AesEncrypt(
416421

417422
override def inputTypes: Seq[AbstractDataType] =
418423
Seq(BinaryType, BinaryType,
419-
StringTypeWithCollation,
420-
StringTypeWithCollation,
424+
StringTypeWithCollation(supportsTrimCollation = true),
425+
StringTypeWithCollation(supportsTrimCollation = true),
421426
BinaryType, BinaryType)
422427

423428
override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad)
@@ -493,8 +498,8 @@ case class AesDecrypt(
493498
override def inputTypes: Seq[AbstractDataType] = {
494499
Seq(BinaryType,
495500
BinaryType,
496-
StringTypeWithCollation,
497-
StringTypeWithCollation, BinaryType)
501+
StringTypeWithCollation(supportsTrimCollation = true),
502+
StringTypeWithCollation(supportsTrimCollation = true), BinaryType)
498503
}
499504

500505
override def prettyName: String = "aes_decrypt"

0 commit comments

Comments
 (0)