Skip to content

Commit 91becf1

Browse files
stefankandicMaxGekk
authored andcommitted
[SPARK-49962][SQL] Simplify AbstractStringTypes class hierarchy
### What changes were proposed in this pull request? Simplifying the AbstractStringType hierarchy. ### Why are the changes needed? The addition of trim-sensitive collation (#48336) highlighted the complexity of extending the existing AbstractStringType structure. Besides adding a new parameter to all types inheriting from AbstractStringType, it caused changing the logic of every subclass as well as changing the name of a derived class StringTypeAnyCollation into StringTypeWithCaseAccentSensitivity which could again be subject to change if we keep adding new specifiers. Looking ahead, the introduction of support for indeterminate collation would further complicate these types. To address this, the proposed changes simplify the design by consolidating common logic into a single base class. This base class will handle core functionality such as trim or indeterminate collation, while a derived class, StringTypeWithCollation (previously awkwardly called StringTypeWithCaseAccentSensitivity), will manage collation specifiers. This approach allows for easier future extensions: fundamental checks can be handled in the base class, while any new specifiers can be added as optional fields in StringTypeWithCollation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48459 from stefankandic/refactorStringTypes. Authored-by: Stefan Kandic <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 6362e0c commit 91becf1

File tree

26 files changed

+226
-203
lines changed

26 files changed

+226
-203
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,15 +1157,13 @@ public static int collationNameToId(String collationName) throws SparkException
11571157
return Collation.CollationSpec.collationNameToId(collationName);
11581158
}
11591159

1160-
/**
1161-
* Returns whether the ICU collation is not Case Sensitive Accent Insensitive
1162-
* for the given collation id.
1163-
* This method is used in expressions which do not support CS_AI collations.
1164-
*/
1165-
public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
1160+
public static boolean isCaseInsensitive(int collationId) {
11661161
return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
1167-
Collation.CollationSpecICU.CaseSensitivity.CS &&
1168-
Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
1162+
Collation.CollationSpecICU.CaseSensitivity.CI;
1163+
}
1164+
1165+
public static boolean isAccentInsensitive(int collationId) {
1166+
return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
11691167
Collation.CollationSpecICU.AccentSensitivity.AI;
11701168
}
11711169

sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,34 @@ import org.apache.spark.sql.internal.SqlApiConf
2121
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}
2222

2323
/**
24-
* AbstractStringType is an abstract class for StringType with collation support. As every type of
25-
* collation can support trim specifier this class is parametrized with it.
24+
* AbstractStringType is an abstract class for StringType with collation support.
2625
*/
27-
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
26+
abstract class AbstractStringType(supportsTrimCollation: Boolean = false)
2827
extends AbstractDataType {
2928
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
3029
override private[sql] def simpleString: String = "string"
31-
private[sql] def canUseTrimCollation(other: DataType): Boolean =
32-
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation
30+
31+
override private[sql] def acceptsType(other: DataType): Boolean = other match {
32+
case st: StringType =>
33+
canUseTrimCollation(st) && acceptsStringType(st)
34+
case _ =>
35+
false
36+
}
37+
38+
private[sql] def canUseTrimCollation(other: StringType): Boolean =
39+
supportsTrimCollation || !other.usesTrimCollation
40+
41+
def acceptsStringType(other: StringType): Boolean
3342
}
3443

3544
/**
36-
* Use StringTypeBinary for expressions supporting only binary collation.
45+
* Used for expressions supporting only binary collation.
3746
*/
38-
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
47+
case class StringTypeBinary(supportsTrimCollation: Boolean)
3948
extends AbstractStringType(supportsTrimCollation) {
40-
override private[sql] def acceptsType(other: DataType): Boolean =
41-
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
42-
canUseTrimCollation(other)
49+
50+
override def acceptsStringType(other: StringType): Boolean =
51+
other.supportsBinaryEquality
4352
}
4453

4554
object StringTypeBinary extends StringTypeBinary(false) {
@@ -49,13 +58,13 @@ object StringTypeBinary extends StringTypeBinary(false) {
4958
}
5059

5160
/**
52-
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
61+
* Used for expressions supporting only binary and lowercase collation.
5362
*/
54-
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
63+
case class StringTypeBinaryLcase(supportsTrimCollation: Boolean)
5564
extends AbstractStringType(supportsTrimCollation) {
56-
override private[sql] def acceptsType(other: DataType): Boolean =
57-
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
58-
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)
65+
66+
override def acceptsStringType(other: StringType): Boolean =
67+
other.supportsBinaryEquality || other.isUTF8LcaseCollation
5968
}
6069

6170
object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
@@ -65,31 +74,44 @@ object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
6574
}
6675

6776
/**
68-
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
69-
* and ICU) but limited to using case and accent sensitivity specifiers.
77+
* Used for expressions supporting collation types with optional case, accent, and trim
78+
* sensitivity specifiers.
79+
*
80+
* Case and accent sensitivity specifiers are supported by default.
7081
*/
71-
case class StringTypeWithCaseAccentSensitivity(
72-
override val supportsTrimCollation: Boolean = false)
82+
case class StringTypeWithCollation(
83+
supportsTrimCollation: Boolean,
84+
supportsCaseSpecifier: Boolean,
85+
supportsAccentSpecifier: Boolean)
7386
extends AbstractStringType(supportsTrimCollation) {
74-
override private[sql] def acceptsType(other: DataType): Boolean =
75-
other.isInstanceOf[StringType] && canUseTrimCollation(other)
87+
88+
override def acceptsStringType(other: StringType): Boolean = {
89+
(supportsCaseSpecifier || !other.isCaseInsensitive) &&
90+
(supportsAccentSpecifier || !other.isAccentInsensitive)
91+
}
7692
}
7793

78-
object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
79-
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
80-
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
94+
object StringTypeWithCollation extends StringTypeWithCollation(false, true, true) {
95+
def apply(
96+
supportsTrimCollation: Boolean = false,
97+
supportsCaseSpecifier: Boolean = true,
98+
supportsAccentSpecifier: Boolean = true): StringTypeWithCollation = {
99+
new StringTypeWithCollation(
100+
supportsTrimCollation,
101+
supportsCaseSpecifier,
102+
supportsAccentSpecifier)
81103
}
82104
}
83105

84106
/**
85-
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
86-
* CS_AI collation types.
107+
* Used for expressions supporting all possible collation types except those that are
108+
* case-sensitive but accent insensitive (CS_AI).
87109
*/
88-
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
110+
case class StringTypeNonCSAICollation(supportsTrimCollation: Boolean)
89111
extends AbstractStringType(supportsTrimCollation) {
90-
override private[sql] def acceptsType(other: DataType): Boolean =
91-
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
92-
canUseTrimCollation(other)
112+
113+
override def acceptsStringType(other: StringType): Boolean =
114+
other.isCaseInsensitive || !other.isAccentInsensitive
93115
}
94116

95117
object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {

sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
4444
private[sql] def supportsLowercaseEquality: Boolean =
4545
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality
4646

47-
private[sql] def isNonCSAI: Boolean =
48-
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)
47+
private[sql] def isCaseInsensitive: Boolean =
48+
CollationFactory.isCaseInsensitive(collationId)
49+
50+
private[sql] def isAccentInsensitive: Boolean =
51+
CollationFactory.isAccentInsensitive(collationId)
4952

5053
private[sql] def usesTrimCollation: Boolean =
5154
CollationFactory.fetchCollation(collationId).supportsSpaceTrimming

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
3333
import org.apache.spark.sql.errors.QueryCompilationErrors
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType,
36-
StringTypeWithCaseAccentSensitivity}
36+
StringTypeWithCollation}
3737
import org.apache.spark.sql.types._
3838
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
3939

@@ -439,7 +439,7 @@ abstract class TypeCoercionBase {
439439
}
440440

441441
case aj @ ArrayJoin(arr, d, nr)
442-
if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) &&
442+
if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) &&
443443
ArrayType.acceptsType(arr.dataType) =>
444444
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
445445
implicitCast(arr, ArrayType(StringType, containsNull)) match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
2727
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2828
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
2929
import org.apache.spark.sql.internal.SQLConf
30-
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
30+
import org.apache.spark.sql.internal.types.StringTypeWithCollation
3131
import org.apache.spark.sql.types._
3232
import org.apache.spark.unsafe.types.UTF8String
3333
import org.apache.spark.util.ArrayImplicits._
@@ -84,7 +84,7 @@ case class CallMethodViaReflection(
8484
errorSubClass = "NON_FOLDABLE_INPUT",
8585
messageParameters = Map(
8686
"inputName" -> toSQLId("class"),
87-
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
87+
"inputType" -> toSQLType(StringTypeWithCollation),
8888
"inputExpr" -> toSQLExpr(children.head)
8989
)
9090
)
@@ -97,7 +97,7 @@ case class CallMethodViaReflection(
9797
errorSubClass = "NON_FOLDABLE_INPUT",
9898
messageParameters = Map(
9999
"inputName" -> toSQLId("method"),
100-
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
100+
"inputType" -> toSQLType(StringTypeWithCollation),
101101
"inputExpr" -> toSQLExpr(children(1))
102102
)
103103
)
@@ -115,7 +115,7 @@ case class CallMethodViaReflection(
115115
"requiredType" -> toSQLType(
116116
TypeCollection(BooleanType, ByteType, ShortType,
117117
IntegerType, LongType, FloatType, DoubleType,
118-
StringTypeWithCaseAccentSensitivity)),
118+
StringTypeWithCollation)),
119119
"inputSql" -> toSQLExpr(e),
120120
"inputType" -> toSQLType(e.dataType))
121121
)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2121
import org.apache.spark.sql.catalyst.util.CollationFactory
22-
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
22+
import org.apache.spark.sql.internal.types.StringTypeWithCollation
2323
import org.apache.spark.sql.types._
2424
import org.apache.spark.unsafe.types.UTF8String
2525

2626
case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
2727
override def inputTypes: Seq[AbstractDataType] =
28-
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
28+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
2929
override def dataType: DataType = BinaryType
3030

3131
final lazy val collationId: Int = expr.dataType match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2828
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
2929
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
3030
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
31-
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity}
31+
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation}
3232
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
3333
import org.apache.spark.unsafe.types.UTF8String
3434

@@ -61,7 +61,7 @@ object ExprUtils extends EvalHelper with QueryErrorsBase {
6161

6262
def convertToMapData(exp: Expression): Map[String, String] = exp match {
6363
case m: CreateMap
64-
if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity)
64+
if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)
6565
.acceptsType(m.dataType) =>
6666
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
6767
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
2727
import org.apache.spark.sql.catalyst.trees.BinaryLike
2828
import org.apache.spark.sql.catalyst.util.CollationFactory
2929
import org.apache.spark.sql.errors.QueryExecutionErrors
30-
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
30+
import org.apache.spark.sql.internal.types.StringTypeWithCollation
3131
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection}
3232
import org.apache.spark.unsafe.types.UTF8String
3333

@@ -109,7 +109,7 @@ case class HllSketchAgg(
109109
TypeCollection(
110110
IntegerType,
111111
LongType,
112-
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
112+
StringTypeWithCollation(supportsTrimCollation = true),
113113
BinaryType),
114114
IntegerType)
115115

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.util.CollationFactory
2424
import org.apache.spark.sql.errors.QueryCompilationErrors
2525
import org.apache.spark.sql.internal.SQLConf
26-
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
26+
import org.apache.spark.sql.internal.types.StringTypeWithCollation
2727
import org.apache.spark.sql.types._
2828

2929
// scalastyle:off line.contains.tab
@@ -78,7 +78,7 @@ case class Collate(child: Expression, collationName: String)
7878
private val collationId = CollationFactory.collationNameToId(collationName)
7979
override def dataType: DataType = StringType(collationId)
8080
override def inputTypes: Seq[AbstractDataType] =
81-
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
81+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
8282

8383
override protected def withNewChildInternal(
8484
newChild: Expression): Expression = copy(newChild)
@@ -117,5 +117,5 @@ case class Collation(child: Expression)
117117
Literal.create(collationName, SQLConf.get.defaultStringType)
118118
}
119119
override def inputTypes: Seq[AbstractDataType] =
120-
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
120+
Seq(StringTypeWithCollation(supportsTrimCollation = true))
121121
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3939
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
4040
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
4141
import org.apache.spark.sql.internal.SQLConf
42-
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity}
42+
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCollation}
4343
import org.apache.spark.sql.types._
4444
import org.apache.spark.sql.util.SQLOpenHashSet
4545
import org.apache.spark.unsafe.UTF8StringBuilder
@@ -1349,7 +1349,7 @@ case class Reverse(child: Expression)
13491349

13501350
// Input types are utilized by type coercion in ImplicitTypeCasts.
13511351
override def inputTypes: Seq[AbstractDataType] =
1352-
Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType))
1352+
Seq(TypeCollection(StringTypeWithCollation, ArrayType))
13531353

13541354
override def dataType: DataType = child.dataType
13551355

@@ -2135,12 +2135,12 @@ case class ArrayJoin(
21352135
this(array, delimiter, Some(nullReplacement))
21362136

21372137
override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
2138-
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
2139-
StringTypeWithCaseAccentSensitivity,
2140-
StringTypeWithCaseAccentSensitivity)
2138+
Seq(AbstractArrayType(StringTypeWithCollation),
2139+
StringTypeWithCollation,
2140+
StringTypeWithCollation)
21412141
} else {
2142-
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
2143-
StringTypeWithCaseAccentSensitivity)
2142+
Seq(AbstractArrayType(StringTypeWithCollation),
2143+
StringTypeWithCollation)
21442144
}
21452145

21462146
override def children: Seq[Expression] = if (nullReplacement.isDefined) {
@@ -2861,7 +2861,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
28612861
with QueryErrorsBase {
28622862

28632863
private def allowedTypes: Seq[AbstractDataType] =
2864-
Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType)
2864+
Seq(StringTypeWithCollation, BinaryType, ArrayType)
28652865

28662866
final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
28672867

0 commit comments

Comments
 (0)