From 821b42d1d430194bdc9e1856572cd41c5ba4d506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Mon, 16 Sep 2024 16:32:07 +0200 Subject: [PATCH 1/5] Disallowed cs_ai collators with expressions that use StringSearch --- .../internal/types/AbstractStringType.scala | 11 +++ .../expressions/complexTypeCreator.scala | 4 +- .../expressions/stringExpressions.scala | 33 +++++++-- .../org/apache/spark/sql/CollationSuite.scala | 74 +++++++++++++++++++ 4 files changed, 113 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 05d1701eff74..aed643824d54 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -51,3 +51,14 @@ case object StringTypeBinaryLcase extends AbstractStringType { case object StringTypeAnyCollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } + +/** + * Use StringTypeNonCSAICollation for expressions supporting all possible collation types + * except CS_AI collation types. + */ +case object StringTypeNonCSAICollation extends AbstractStringType { + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && + (!other.asInstanceOf[StringType].typeName.contains("_AI") || + other.asInstanceOf[StringType].typeName.contains("_CI")) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ba1beab28d9a..b8b47f2763f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeNonCSAICollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -579,7 +579,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def third: Expression = keyValueDelim override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def dataType: DataType = MapType(first.dataType, first.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e75df87994f0..da6d786efb4e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -609,6 +609,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.Contains.genCode(c1, c2, collationId)) } + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } @@ -650,6 +652,10 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.StartsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) } @@ -691,6 +697,10 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.EndsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) } @@ -919,7 +929,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr @@ -1167,7 +1177,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = matchingExpr override def third: Expression = replaceExpr @@ -1394,6 +1404,9 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrim.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( srcStr = newChildren.head, @@ -1501,6 +1514,9 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = copy( @@ -1561,6 +1577,9 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = copy( @@ -1595,7 +1614,7 @@ case class StringInstr(str: Expression, substr: Expression) override def right: Expression = substr override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.StringInstr. @@ -1643,7 +1662,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def dataType: DataType = strExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def first: Expression = strExpr override def second: Expression = delimExpr override def third: Expression = countExpr @@ -1701,7 +1720,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def eval(input: InternalRow): Any = { val s = start.eval(input) @@ -3463,7 +3482,7 @@ case class SplitPart ( false) override def nodeName: String = "split_part" override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) def children: Seq[Expression] = Seq(str, delimiter, partNum) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 489a990d3e1c..bed937aafa72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters.MapHasAsJava +import scala.util.Try import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException @@ -1625,6 +1626,79 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("Expressions not supporting CS_AI collators") { + val unsupportedExpressions: Seq[Any] = Seq( + "ltrim", + "rtrim", + "trim", + "startswith", + "endswith", + "locate", + "instr", + "str_to_map", + "contains", + "replace", + ("translate", "efg"), + ("split_part", "2"), + ("substring_index", "2")) + + val unsupportedCollator = "unicode_ai" + val supportedCollators: Seq[String] = Seq( + "unicode", + "unicode_ci", + "unicode_ci_ai" + ) + + unsupportedExpressions.foreach { + case expression: String => + val analysisException = intercept[AnalysisException] { + sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc')").collect() + } + assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + + // Tests that expression works properly with non-cs_ai collation + supportedCollators.foreach { + collator => + val result = Try { + sql(s"select $expression('bcd' collate $collator, 'abc')").collect() + } + assert(result.isSuccess) + } + + case (expression: String, parameter: String) => + val analysisException = intercept[AnalysisException] { + sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc', '$parameter')") + .collect() + } + assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + + // Tests that expression works properly with non-cs_ai collation + supportedCollators.foreach { + collator => + val result = Try { + sql(s"select $expression('bcd' collate $collator, 'abc', '$parameter')").collect() + } + assert(result.isSuccess) + } + + case (expression: String, parameter: Integer) => + val analysisException = intercept[AnalysisException] { + sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc', $parameter)") + .collect() + } + assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + + // Tests that expression works properly with non-cs_ai collation + supportedCollators.foreach { + collator => + val result = Try { + sql(s"select $expression('bcd' collate $collator, 'abc', $parameter)").collect() + } + assert(result.isSuccess) + } + } + } + test("TVF collations()") { assert(sql("SELECT * FROM collations()").collect().length >= 562) From 8854def3d1da9759f352403de10aebbf31114f4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Tue, 17 Sep 2024 15:08:51 +0200 Subject: [PATCH 2/5] Refactored test and removed duplicate code --- .../org/apache/spark/sql/CollationSuite.scala | 66 ++++++++----------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index bed937aafa72..fc491ac98600 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1650,52 +1650,40 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) unsupportedExpressions.foreach { - case expression: String => - val analysisException = intercept[AnalysisException] { - sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc')").collect() - } - assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") - - // Tests that expression works properly with non-cs_ai collation - supportedCollators.foreach { - collator => - val result = Try { - sql(s"select $expression('bcd' collate $collator, 'abc')").collect() - } - assert(result.isSuccess) + expression: Any => + val unsupportedQuery: String = { + expression match { + case expression: String => + s"select $expression('bcd' collate $unsupportedCollator, 'abc')" + case (expression: String, parameter: String) => + s"select $expression('bcd' collate $unsupportedCollator, 'abc', '$parameter')" + case (expression: String, parameter: Integer) => + s"select $expression('bcd' collate $unsupportedCollator, 'abc', $parameter)" + } } - case (expression: String, parameter: String) => val analysisException = intercept[AnalysisException] { - sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc', '$parameter')") - .collect() + sql(unsupportedQuery).collect() } assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") - // Tests that expression works properly with non-cs_ai collation - supportedCollators.foreach { - collator => - val result = Try { - sql(s"select $expression('bcd' collate $collator, 'abc', '$parameter')").collect() - } - assert(result.isSuccess) - } - - case (expression: String, parameter: Integer) => - val analysisException = intercept[AnalysisException] { - sql(s"select $expression('bcd' collate $unsupportedCollator, 'abc', $parameter)") - .collect() - } - assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + supportedCollators.foreach { + collator => + val supportedQuery = expression match { + case expression: String => + s"select $expression('bcd' collate $collator, 'abc')" + case (expression: String, parameter: String) => + s"select $expression('bcd' collate $collator, 'abc', '$parameter')" + case (expression: String, parameter: Integer) => + s"select $expression('bcd' collate $collator, 'abc', $parameter)" + } - // Tests that expression works properly with non-cs_ai collation - supportedCollators.foreach { - collator => - val result = Try { - sql(s"select $expression('bcd' collate $collator, 'abc', $parameter)").collect() - } - assert(result.isSuccess) - } + // Tests that expression works properly with non-cs_ai collation + val result = Try { + sql(supportedQuery).collect() + } + assert(result.isSuccess) + } } } From 686275717d9a361e1ae033195ac94b4f8dce03fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Wed, 18 Sep 2024 09:23:39 +0200 Subject: [PATCH 3/5] Refactored method for checking whether collation is CS_AI --- .../sql/catalyst/util/CollationFactory.java | 16 ++++++++++++++++ .../sql/internal/types/AbstractStringType.scala | 4 +--- .../org/apache/spark/sql/types/StringType.scala | 3 +++ .../org/apache/spark/sql/CollationSuite.scala | 6 +----- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 4b88e15e8ed7..ea5f9665dc88 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -921,6 +921,22 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + /** + * Returns whether the ICU collation is Case Insensitive for the given collation id. + */ + public static Boolean isCI(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity + == Collation.CollationSpecICU.CaseSensitivity.CI; + } + + /** + * Returns whether the ICU collation is Accent Insensitive for the given collation id. + */ + public static Boolean isAI(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity + == Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index aed643824d54..21d82d83d5c2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -58,7 +58,5 @@ case object StringTypeAnyCollation extends AbstractStringType { */ case object StringTypeNonCSAICollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && - (!other.asInstanceOf[StringType].typeName.contains("_AI") || - other.asInstanceOf[StringType].typeName.contains("_CI")) + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAICollation } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index eba12c4ff487..b5ecf779f6cb 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -44,6 +44,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsLowercaseEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + private[sql] def isNonCSAICollation: Boolean = + !CollationFactory.isAI(collationId) || CollationFactory.isCI(collationId) + private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index fc491ac98600..0df32e4c0bb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1643,11 +1643,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("substring_index", "2")) val unsupportedCollator = "unicode_ai" - val supportedCollators: Seq[String] = Seq( - "unicode", - "unicode_ci", - "unicode_ci_ai" - ) + val supportedCollators: Seq[String] = Seq("unicode", "unicode_ci", "unicode_ci_ai") unsupportedExpressions.foreach { expression: Any => From 3d70783b4e0fd3c204915d1af41c2064b4c5282a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Wed, 18 Sep 2024 22:52:09 +0200 Subject: [PATCH 4/5] Refactored tests and added additional e2e tests in collations.sql golden file --- .../sql/catalyst/util/CollationFactory.java | 19 +- .../apache/spark/sql/types/StringType.scala | 2 +- .../analyzer-results/collations.sql.out | 336 ++++++++++++++++ .../resources/sql-tests/inputs/collations.sql | 14 + .../sql-tests/results/collations.sql.out | 364 ++++++++++++++++++ .../sql/CollationSQLExpressionsSuite.scala | 24 ++ .../sql/CollationStringExpressionsSuite.scala | 251 ++++++++++++ .../org/apache/spark/sql/CollationSuite.scala | 58 --- 8 files changed, 997 insertions(+), 71 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index ea5f9665dc88..448df1b741cc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -922,19 +922,14 @@ public static int collationNameToId(String collationName) throws SparkException } /** - * Returns whether the ICU collation is Case Insensitive for the given collation id. + * Returns whether the ICU collation is not Case Sensitive Accent Insensitive + * for the given collation id. */ - public static Boolean isCI(int collationId) { - return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity - == Collation.CollationSpecICU.CaseSensitivity.CI; - } - - /** - * Returns whether the ICU collation is Accent Insensitive for the given collation id. - */ - public static Boolean isAI(int collationId) { - return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity - == Collation.CollationSpecICU.AccentSensitivity.AI; + public static Boolean isNonCSAI(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CI || + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity != + Collation.CollationSpecICU.AccentSensitivity.AI; } public static void assertValidProvider(String provider) throws SparkException { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index b5ecf779f6cb..c3dd64d9189e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -45,7 +45,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ CollationFactory.fetchCollation(collationId).supportsLowercaseEquality private[sql] def isNonCSAICollation: Boolean = - !CollationFactory.isAI(collationId) || CollationFactory.isCI(collationId) + CollationFactory.isNonCSAI(collationId) private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 83c9ebfef4b2..eed7fa73ab69 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -436,6 +436,30 @@ Project [str_to_map(collate(text#x, utf8_binary), collate(pairDelim#x, utf8_bina +- Relation spark_catalog.default.t4[text#x,pairDelim#x,keyValueDelim#x] parquet +-- !query +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query analysis @@ -820,6 +844,30 @@ Project [split_part(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query analysis @@ -883,6 +931,30 @@ Project [Contains(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query analysis @@ -946,6 +1018,30 @@ Project [substring_index(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase# +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query analysis @@ -1009,6 +1105,30 @@ Project [instr(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query analysis @@ -1135,6 +1255,30 @@ Project [StartsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -1190,6 +1334,30 @@ Project [translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query analysis @@ -1253,6 +1421,30 @@ Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query analysis @@ -1316,6 +1508,30 @@ Project [EndsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -2039,6 +2255,30 @@ Project [locate(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_l +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query analysis @@ -2102,6 +2342,30 @@ Project [trim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2165,6 +2429,30 @@ Project [btrim(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2228,6 +2516,30 @@ Project [ltrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2291,6 +2603,30 @@ Project [rtrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 183577b83971..f3a42fd3e1f1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -99,6 +99,7 @@ insert into t4 values('a:1,b:2,c:3', ',', ':'); select str_to_map(text, pairDelim, keyValueDelim) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_lcase, keyValueDelim collate utf8_binary) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, keyValueDelim collate utf8_binary) from t4; +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4; drop table t4; @@ -159,6 +160,7 @@ select split_part(s, utf8_binary, 1) from t5; select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5; select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5; select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; @@ -168,6 +170,7 @@ select contains(s, utf8_binary) from t5; select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5; select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5; select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -177,6 +180,7 @@ select substring_index(s, utf8_binary,1) from t5; select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5; select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5; select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; @@ -186,6 +190,7 @@ select instr(s, utf8_binary) from t5; select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5; select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5; select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -204,6 +209,7 @@ select startswith(s, utf8_binary) from t5; select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5; select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -212,6 +218,7 @@ select translate(utf8_lcase, utf8_lcase, '12345') from t5; select translate(utf8_binary, utf8_lcase, '12345') from t5; select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5; select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5; +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5; select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5; select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; @@ -221,6 +228,7 @@ select replace(s, utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5; select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5; +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; @@ -230,6 +238,7 @@ select endswith(s, utf8_binary) from t5; select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5; select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -364,6 +373,7 @@ select locate(s, utf8_binary) from t5; select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5; select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5; +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5; select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5; select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; @@ -373,6 +383,7 @@ select TRIM(s, utf8_binary) from t5; select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5; select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimBoth @@ -381,6 +392,7 @@ select BTRIM(s, utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5; select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimLeft @@ -389,6 +401,7 @@ select LTRIM(s, utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5; select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimRight @@ -397,6 +410,7 @@ select RTRIM(s, utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5; select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index ea5564aafe96..5999bf20f688 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -480,6 +480,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query schema @@ -1021,6 +1047,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query schema @@ -1148,6 +1200,32 @@ true true +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query schema @@ -1275,6 +1353,32 @@ kitten İo +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query schema @@ -1402,6 +1506,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query schema @@ -1656,6 +1786,32 @@ true true +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -1763,6 +1919,32 @@ kitten İo +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query schema @@ -1890,6 +2072,32 @@ bbabcbabcabcbabc kitten +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query schema @@ -2017,6 +2225,32 @@ true true +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -3570,6 +3804,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query schema @@ -3685,6 +3945,32 @@ QL sitTing +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3812,6 +4098,32 @@ park İ +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3927,6 +4239,32 @@ QL sitTing +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4042,6 +4380,32 @@ SQL sitTing +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f8cd840ecdbb..941d5cd31db4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -982,6 +982,7 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) + val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -996,6 +997,29 @@ class CollationSQLExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(dataType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " + + s"'${unsupportedTestCase.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " + + "'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } } test("Support RaiseError misc expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 6804411d470b..fe9872ddaf57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -98,6 +98,7 @@ class CollationStringExpressionsSuite SplitPartTestCase("1a2", "A", 2, "UTF8_LCASE", "2"), SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") ) + val unsupportedTestCase = SplitPartTestCase("1a2", "a", 2, "UNICODE_AI", "2") testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -111,6 +112,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select split_part('${unsupportedTestCase.str}', '${unsupportedTestCase.delimiter}', " + + s"${unsupportedTestCase.partNum})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"split_part('1a2' collate UNICODE_AI, 'a' collate UNICODE_AI, 2)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'1a2' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "split_part('1a2', 'a', 2)", start = 7, stop = 31) + ) + } } test("Support `StringSplitSQL` string expression with collation") { @@ -166,6 +187,7 @@ class CollationStringExpressionsSuite ContainsTestCase("abcde", "FGH", "UTF8_LCASE", false), ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) ) + val unsupportedTestCase = ContainsTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -178,6 +200,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select contains('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"contains('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "contains('abcde', 'A')", start = 7, stop = 28) + ) + } } test("Support `SubstringIndex` expression with collation") { @@ -194,6 +235,7 @@ class CollationStringExpressionsSuite SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"), SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg") ) + val unsupportedTestCase = SubstringIndexTestCase("abacde", "a", 2, "UNICODE_AI", "cde") testCases.foreach(t => { // Unit test. val strExpr = Literal.create(t.strExpr, StringType(t.collation)) @@ -207,6 +249,29 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select substring_index('${unsupportedTestCase.strExpr}', " + + s"'${unsupportedTestCase.delimExpr}', ${unsupportedTestCase.countExpr})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"substring_index('abacde' collate UNICODE_AI, " + + "'a' collate UNICODE_AI, 2)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abacde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "substring_index('abacde', 'a', 2)", + start = 7, + stop = 39)) + } } test("Support `StringInStr` string expression with collation") { @@ -219,6 +284,7 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) ) + val unsupportedTestCase = StringInStrTestCase("a", "abcde", "UNICODE_AI", 0) testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -231,6 +297,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select instr('${unsupportedTestCase.str}', '${unsupportedTestCase.substr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"instr('a' collate UNICODE_AI, 'abcde' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'a' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "instr('a', 'abcde')", start = 7, stop = 25) + ) + } } test("Support `FindInSet` string expression with collation") { @@ -264,6 +349,7 @@ class CollationStringExpressionsSuite StartsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) ) + val unsupportedTestCase = StartsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -276,6 +362,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select startswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"startswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "startswith('abcde', 'A')", start = 7, stop = 30) + ) + } } test("Support `StringTranslate` string expression with collation") { @@ -291,6 +396,7 @@ class CollationStringExpressionsSuite StringTranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) + val unsupportedTestCase = StringTranslateTestCase("ABC", "AB", "12", "UNICODE_AI", "12C") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -304,6 +410,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select translate('${unsupportedTestCase.srcExpr}', " + + s"'${unsupportedTestCase.matchingExpr}', '${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"translate('ABC' collate UNICODE_AI, 'AB' collate UNICODE_AI, " + + "'12' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'ABC' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "translate('ABC', 'AB', '12')", start = 7, stop = 34) + ) + } } test("Support `StringReplace` string expression with collation") { @@ -321,6 +448,7 @@ class CollationStringExpressionsSuite StringReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") ) + val unsupportedTestCase = StringReplaceTestCase("abcde", "A", "B", "UNICODE_AI", "abcde") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -334,6 +462,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select replace('${unsupportedTestCase.srcExpr}', '${unsupportedTestCase.searchExpr}', " + + s"'${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"replace('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI, " + + "'B' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "replace('abcde', 'A', 'B')", start = 7, stop = 32) + ) + } } test("Support `EndsWith` string expression with collation") { @@ -344,6 +493,7 @@ class CollationStringExpressionsSuite EndsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) ) + val unsupportedTestCase = EndsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -355,6 +505,25 @@ class CollationStringExpressionsSuite checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select endswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"endswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "endswith('abcde', 'A')", start = 7, stop = 28) + ) + } }) } @@ -1097,6 +1266,7 @@ class CollationStringExpressionsSuite StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) ) + val unsupportedTestCase = StringLocateTestCase("aa", "Aaads", 0, "UNICODE_AI", 1) testCases.foreach(t => { // Unit test. val substr = Literal.create(t.substr, StringType(t.collation)) @@ -1110,6 +1280,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select locate('${unsupportedTestCase.substr}', '${unsupportedTestCase.str}', " + + s"${unsupportedTestCase.start})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"locate('aa' collate UNICODE_AI, 'Aaads' collate UNICODE_AI, 0)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'aa' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "locate('aa', 'Aaads', 0)", start = 7, stop = 30) + ) + } } test("Support `StringTrimLeft` string expression with collation") { @@ -1124,6 +1314,7 @@ class CollationStringExpressionsSuite StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd ") ) + val unsupportedTestCase = StringTrimLeftTestCase("xxasdxx", Some("x"), "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1137,6 +1328,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select ltrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(LEADING 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "ltrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrimRight` string expression with collation") { @@ -1151,6 +1361,7 @@ class CollationStringExpressionsSuite StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd") ) + val unsupportedTestCase = StringTrimRightTestCase("xxasdxx", Some("x"), "UNICODE_AI", "xxasd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1164,6 +1375,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select rtrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"TRIM(TRAILING 'x' collate UNICODE_AI FROM 'xxasdxx'" + + " collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "rtrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrim` string expression with collation") { @@ -1178,6 +1409,7 @@ class CollationStringExpressionsSuite StringTrimTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd") ) + val unsupportedTestCase = StringTrimTestCase("xxasdxx", Some("x"), "UNICODE_AI", "asd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1191,6 +1423,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select trim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(BOTH 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "trim('x', 'xxasdxx')", start = 7, stop = 26) + ) + } } test("Support `StringTrimBoth` string expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 0df32e4c0bb0..489a990d3e1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters.MapHasAsJava -import scala.util.Try import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException @@ -1626,63 +1625,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("Expressions not supporting CS_AI collators") { - val unsupportedExpressions: Seq[Any] = Seq( - "ltrim", - "rtrim", - "trim", - "startswith", - "endswith", - "locate", - "instr", - "str_to_map", - "contains", - "replace", - ("translate", "efg"), - ("split_part", "2"), - ("substring_index", "2")) - - val unsupportedCollator = "unicode_ai" - val supportedCollators: Seq[String] = Seq("unicode", "unicode_ci", "unicode_ci_ai") - - unsupportedExpressions.foreach { - expression: Any => - val unsupportedQuery: String = { - expression match { - case expression: String => - s"select $expression('bcd' collate $unsupportedCollator, 'abc')" - case (expression: String, parameter: String) => - s"select $expression('bcd' collate $unsupportedCollator, 'abc', '$parameter')" - case (expression: String, parameter: Integer) => - s"select $expression('bcd' collate $unsupportedCollator, 'abc', $parameter)" - } - } - - val analysisException = intercept[AnalysisException] { - sql(unsupportedQuery).collect() - } - assert(analysisException.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") - - supportedCollators.foreach { - collator => - val supportedQuery = expression match { - case expression: String => - s"select $expression('bcd' collate $collator, 'abc')" - case (expression: String, parameter: String) => - s"select $expression('bcd' collate $collator, 'abc', '$parameter')" - case (expression: String, parameter: Integer) => - s"select $expression('bcd' collate $collator, 'abc', $parameter)" - } - - // Tests that expression works properly with non-cs_ai collation - val result = Try { - sql(supportedQuery).collect() - } - assert(result.isSuccess) - } - } - } - test("TVF collations()") { assert(sql("SELECT * FROM collations()").collect().length >= 562) From d8bf03f2497f28bdb40703463727b2479393e78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Thu, 19 Sep 2024 08:49:36 +0200 Subject: [PATCH 5/5] Refactored methods --- .../apache/spark/sql/catalyst/util/CollationFactory.java | 7 ++++--- .../spark/sql/internal/types/AbstractStringType.scala | 6 +++--- .../main/scala/org/apache/spark/sql/types/StringType.scala | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 448df1b741cc..f7769d917d6c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -924,11 +924,12 @@ public static int collationNameToId(String collationName) throws SparkException /** * Returns whether the ICU collation is not Case Sensitive Accent Insensitive * for the given collation id. + * This method is used in expressions which do not support CS_AI collations. */ - public static Boolean isNonCSAI(int collationId) { + public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == - Collation.CollationSpecICU.CaseSensitivity.CI || - Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity != + Collation.CollationSpecICU.CaseSensitivity.CS && + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == Collation.CollationSpecICU.AccentSensitivity.AI; } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 21d82d83d5c2..dc4ee013fd18 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -53,10 +53,10 @@ case object StringTypeAnyCollation extends AbstractStringType { } /** - * Use StringTypeNonCSAICollation for expressions supporting all possible collation types - * except CS_AI collation types. + * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except + * CS_AI collation types. */ case object StringTypeNonCSAICollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAICollation + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index c3dd64d9189e..c2dd6cec7ba7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -44,8 +44,8 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsLowercaseEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsLowercaseEquality - private[sql] def isNonCSAICollation: Boolean = - CollationFactory.isNonCSAI(collationId) + private[sql] def isNonCSAI: Boolean = + !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID