From 204e33884c1b67f3e5d518dd6ed955d6e96ab466 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 27 May 2024 22:32:56 +0200 Subject: [PATCH 01/13] Initial commit --- .../spark/unsafe/types/CollationSupportSuite.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 7fc3c4e349c3b..7ffed6f65c779 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.util.CollationSupport; import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; @@ -571,6 +573,19 @@ public void testStringInstr() throws SparkException { assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3); } + private void assertStringTranslate(final String source, final Map dict, + final String collationName, final String result) throws SparkException { + int collationId = CollationFactory.collationNameToId(collationName); + UTF8String str = UTF8String.fromString(source); + UTF8String res = UTF8String.fromString(result); + assertEquals(res, CollationSupport.StringTranslate.exec(str, dict, collationId)); + } + + @Test + public void testStringTranslate() throws SparkException { + assertStringTranslate("abc", Map.of("a", "A", "b", "B", "c", "C"), "UTF8_BINARY", "ABC"); + } + private void assertFindInSet(String word, String set, String collationName, Integer expected) throws SparkException { UTF8String w = UTF8String.fromString(word); From b366630369b98e39c4c858a7bad72547fb4aa0ea Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 28 May 2024 15:41:45 +0200 Subject: [PATCH 02/13] Fix StringTranslate --- .../util/CollationAwareUTF8String.java | 77 ++++++++++++------- .../sql/catalyst/util/CollationFactory.java | 14 +++- .../sql/catalyst/util/CollationSupport.java | 30 +------- .../unsafe/types/CollationSupportSuite.java | 13 ---- .../expressions/stringExpressions.scala | 14 ++-- .../sql/CollationStringExpressionsSuite.scala | 31 ++++++-- 6 files changed, 96 insertions(+), 83 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index ee0d611d7e652..6e7fe3d6ed305 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -141,6 +141,28 @@ public static String toUpperCase(final String target, final int collationId) { return UCharacter.toUpperCase(locale, target); } + /** + * Convert the input string to lowercase using the ICU root locale rules. + * + * @param target the input string + * @return the lowercase string + */ + public static UTF8String toLowerCase(final UTF8String target) { + return UTF8String.fromString(toLowerCase(target.toString())); + } + public static String toLowerCase(final String target) { + return UCharacter.toLowerCase(target); + } + + /** + * Convert the input string to lowercase using the specified ICU collation rules. + * + * @param target the input string + * @return the lowercase string + */ + public static UTF8String toLowerCase(final UTF8String target, final int collationId) { + return UTF8String.fromString(toLowerCase(target.toString(), collationId)); + } public static String toLowerCase(final String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); @@ -322,37 +344,40 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, } } - public static Map getCollationAwareDict(UTF8String string, - Map dict, int collationId) { - String srcStr = string.toString(); - + private static Map getCollationAwareDict(final Map dict, + int collationId) { + // replace all the keys in the dict with collation keys Map collationAwareDict = new HashMap<>(); - for (String key : dict.keySet()) { - StringSearch stringSearch = - CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); - - int pos = 0; - while ((pos = stringSearch.next()) != StringSearch.DONE) { - int codePoint = srcStr.codePointAt(pos); - int charCount = Character.charCount(codePoint); - String newKey = srcStr.substring(pos, pos + charCount); - - boolean exists = false; - for (String existingKey : collationAwareDict.keySet()) { - if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { - collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); - exists = true; - break; - } - } + for (Map.Entry entry : dict.entrySet()) { + String collationKey = CollationFactory.getCollationKey(entry.getKey(), collationId); + collationAwareDict.putIfAbsent(collationKey, entry.getValue()); + } + return collationAwareDict; + } - if (!exists) { - collationAwareDict.put(newKey, dict.get(key)); - } + private static String translate(final String input, final Map dict, + final int collationId) { + StringBuilder sb = new StringBuilder(); + int charCount = 0; + for (int k = 0; k < input.length(); k += charCount) { + int codePoint = input.codePointAt(k); + charCount = Character.charCount(codePoint); + String subStr = input.substring(k, k + charCount); + String collationKey = CollationFactory.getCollationKey(subStr, collationId); + String translated = dict.get(collationKey); + if (null == translated) { + sb.append(subStr); + } else if (!"\0".equals(translated)) { + sb.append(translated); } } + return sb.toString(); + } - return collationAwareDict; + public static UTF8String translate(final UTF8String input, final Map dict, + final int collationId) { + Map collationAwareDict = getCollationAwareDict(dict,collationId); + return UTF8String.fromString(translate(input.toString(), collationAwareDict, collationId)); } public static UTF8String lowercaseTrim( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 0133c3feb611a..ba4a973b49bfd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -372,11 +372,23 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsBinaryEquality) { return input; } else if (collation.supportsLowercaseEquality) { - return input.toLowerCase(); + return CollationAwareUTF8String.toLowerCase(input); } else { CollationKey collationKey = collation.collator.getCollationKey(input.toString()); return UTF8String.fromBytes(collationKey.toByteArray()); } } + public static String getCollationKey(String input, int collationId) { + Collation collation = fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return input; + } else if (collation.supportsLowercaseEquality) { + return CollationAwareUTF8String.toLowerCase(input); + } else { + CollationKey collationKey = collation.collator.getCollationKey(input); + return Arrays.toString(collationKey.toByteArray()); + } + } + } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index bea3dc08b4489..0788967bdee12 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -483,10 +483,8 @@ public static UTF8String exec(final UTF8String source, Map dict, CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { - return execLowercase(source, dict); } else { - return execICU(source, dict, collationId); + return execNonBinary(source, dict, collationId); } } public static String genCode(final String source, final String dict, final int collationId) { @@ -494,36 +492,16 @@ public static String genCode(final String source, final String dict, final int c String expr = "CollationSupport.EndsWith.exec"; if (collation.supportsBinaryEquality) { return String.format(expr + "Binary(%s, %s)", source, dict); - } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", source, dict); } else { - return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId); + return String.format(expr + "NonBinary(%s, %s, %d)", source, dict, collationId); } } public static UTF8String execBinary(final UTF8String source, Map dict) { return source.translate(dict); } - public static UTF8String execLowercase(final UTF8String source, Map dict) { - String srcStr = source.toString(); - StringBuilder sb = new StringBuilder(); - int charCount = 0; - for (int k = 0; k < srcStr.length(); k += charCount) { - int codePoint = srcStr.codePointAt(k); - charCount = Character.charCount(codePoint); - String subStr = srcStr.substring(k, k + charCount); - String translated = dict.get(subStr.toLowerCase()); - if (null == translated) { - sb.append(subStr); - } else if (!"\0".equals(translated)) { - sb.append(translated); - } - } - return UTF8String.fromString(sb.toString()); - } - public static UTF8String execICU(final UTF8String source, Map dict, + public static UTF8String execNonBinary(final UTF8String source, Map dict, final int collationId) { - return source.translate(CollationAwareUTF8String.getCollationAwareDict( - source, dict, collationId)); + return CollationAwareUTF8String.translate(source, dict, collationId); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 7ffed6f65c779..5e231b5ed94e5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -573,19 +573,6 @@ public void testStringInstr() throws SparkException { assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3); } - private void assertStringTranslate(final String source, final Map dict, - final String collationName, final String result) throws SparkException { - int collationId = CollationFactory.collationNameToId(collationName); - UTF8String str = UTF8String.fromString(source); - UTF8String res = UTF8String.fromString(result); - assertEquals(res, CollationSupport.StringTranslate.exec(str, dict, collationId)); - } - - @Test - public void testStringTranslate() throws SparkException { - assertStringTranslate("abc", Map.of("a", "A", "b", "B", "c", "C"), "UTF8_BINARY", "ABC"); - } - private void assertFindInSet(String word, String set, String collationName, Integer expected) throws SparkException { UTF8String w = UTF8String.fromString(word); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 09ec501311ade..86f684b12b594 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} @@ -859,13 +859,9 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: object StringTranslate { - def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Int) + def buildDict(matchingString: UTF8String, replaceString: UTF8String) : JMap[String, String] = { - val matching = if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { - matchingString.toString().toLowerCase() - } else { - matchingString.toString() - } + val matching = matchingString.toString() val replace = replaceString.toString() val dict = new HashMap[String, String]() @@ -923,7 +919,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac if (matchingEval != lastMatching || replaceEval != lastReplace) { lastMatching = matchingEval.asInstanceOf[UTF8String].clone() lastReplace = replaceEval.asInstanceOf[UTF8String].clone() - dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId) + dict = StringTranslate.buildDict(lastMatching, lastReplace) } CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], dict, collationId) @@ -947,7 +943,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac $termLastMatching = $matching.clone(); $termLastReplace = $replace.clone(); $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict($termLastMatching, $termLastReplace, $collationId); + .buildDict($termLastMatching, $termLastReplace); } ${ev.value} = CollationSupport.StringTranslate. exec($src, $termDict, $collationId); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 9cc123b708aff..3e068eaae4a60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -250,7 +250,8 @@ class CollationStringExpressionsSuite } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("TRANSLATE check result on explicitly collated string") { + + test("Support StringTranslate string expression with collation") { // Supported collations case class TranslateTestCase[R](input: String, matchExpression: String, replaceExpression: String, collation: String, result: R) @@ -260,33 +261,27 @@ class CollationStringExpressionsSuite TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", "xXaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", "xxaxsXaxex"), TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", "xXaxsXaxeX"), - // scalastyle:off TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", "test大千世AB大千世A"), TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", "大千世界abca大千世界"), TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", "oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", "大千世界大千世界OesO"), TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", "世世世界世世世界tesT"), - // scalastyle:on TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"), - // scalastyle:off TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"), TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"), - // scalastyle:on TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"), TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"), - // scalastyle:off TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"), TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"), TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"), TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"), - // scalastyle:on TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY_LCASE", "14234e"), TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"), TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"), @@ -298,7 +293,20 @@ class CollationStringExpressionsSuite TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f") + TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"), + // Case mapping edge cases + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), + TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY_LCASE", "123"), + TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY_LCASE", "1i\u0307"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY_LCASE", "İ23"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "123"), + TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İ23"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "123"), + TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "1i\u0307"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İ23") ) testCases.foreach(t => { @@ -325,6 +333,13 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Handling invalid UTF-8 sequences in StringTranslate") { + Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach { collation => + val query = s"SELECT translate(cast(unhex('C22C41') as string collate $collation), ',', 'X')" + checkAnswer(sql(query), Row("�XA")) + } + } + test("Support Replace string expression with collation") { case class ReplaceTestCase[R](source: String, search: String, replace: String, c: String, result: R) From cee88c571e60a0603417749eea89eff41e819a2f Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 28 May 2024 18:03:12 +0200 Subject: [PATCH 03/13] Fix Java lint --- .../org/apache/spark/unsafe/types/CollationSupportSuite.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 5e231b5ed94e5..7fc3c4e349c3b 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.util.CollationSupport; import org.junit.jupiter.api.Test; -import java.util.Map; - import static org.junit.jupiter.api.Assertions.*; From f436ade415a61c368daad1bef73011f849b7562c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 31 May 2024 15:52:40 +0200 Subject: [PATCH 04/13] Fix LCASE implementation --- .../util/CollationAwareUTF8String.java | 47 ++++++++++++++++++- .../sql/catalyst/util/CollationFactory.java | 2 +- .../sql/catalyst/util/CollationSupport.java | 15 ++++-- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index da83175264e1f..6ba3c97ccb618 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -318,6 +318,22 @@ public static String toLowerCase(final String target, final int collationId) { return UCharacter.toLowerCase(locale, target); } + private final static int COMBINED_LOWERCASE_I_DOT = 0x69 << 16 | 0x307; + private static int getLowercaseCodePoint(final int codePoint) { + if (codePoint == 0x0130) { + // Latin capital letter I with dot above is mapped to 2 lowercase characters. + return COMBINED_LOWERCASE_I_DOT; + } + else if (codePoint == 0x03C2) { + // Greek final and non-final capital letter sigma should be mapped the same. + return 0x03C3; + } + else { + // All other characters should follow context-unaware ICU single-code point case mapping. + return UCharacter.toLowerCase(codePoint); + } + } + public static String toTitleCase(final String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); @@ -490,9 +506,18 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, } } + private static Map getLowercaseDict(final Map dict) { + // Replace all the keys in the dict with lowercased code points. + Map lowercaseDict = new HashMap<>(); + for (Map.Entry entry : dict.entrySet()) { + int codePoint = entry.getKey().codePointAt(0); + lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue()); + } + return lowercaseDict; + } private static Map getCollationAwareDict(final Map dict, int collationId) { - // replace all the keys in the dict with collation keys + // Replace all the keys in the dict with collation keys. Map collationAwareDict = new HashMap<>(); for (Map.Entry entry : dict.entrySet()) { String collationKey = CollationFactory.getCollationKey(entry.getKey(), collationId); @@ -501,6 +526,21 @@ private static Map getCollationAwareDict(final Map dict) { + StringBuilder sb = new StringBuilder(); + int charCount = 0; + for (int k = 0; k < input.length(); k += charCount) { + int codePoint = input.codePointAt(k); + charCount = Character.charCount(codePoint); + String translated = dict.get(getLowercaseCodePoint(codePoint)); + if (null == translated) { + sb.appendCodePoint(codePoint); + } else if (!"\0".equals(translated)) { + sb.append(translated); + } + } + return sb.toString(); + } private static String translate(final String input, final Map dict, final int collationId) { StringBuilder sb = new StringBuilder(); @@ -520,6 +560,11 @@ private static String translate(final String input, final Map di return sb.toString(); } + public static UTF8String lowercaseTranslate(final UTF8String input, + final Map dict) { + Map lowercaseDict = getLowercaseDict(dict); + return UTF8String.fromString(lowercaseTranslate(input.toString(), lowercaseDict)); + } public static UTF8String translate(final UTF8String input, final Map dict, final int collationId) { Map collationAwareDict = getCollationAwareDict(dict,collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 76ca73029c4b6..6c84e877c2b90 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -810,7 +810,7 @@ public static String getCollationKey(String input, int collationId) { if (collation.supportsBinaryEquality) { return input; } else if (collation.supportsLowercaseEquality) { - return CollationAwareUTF8String.toLowerCase(input); + return input.toLowerCase(); } else { CollationKey collationKey = collation.collator.getCollationKey(input); return Arrays.toString(collationKey.toByteArray()); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index ed2f157092fd2..d040123dd125b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -483,8 +483,10 @@ public static UTF8String exec(final UTF8String source, Map dict, CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { return execBinary(source, dict); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(source, dict); } else { - return execNonBinary(source, dict, collationId); + return execICU(source, dict, collationId); } } public static String genCode(final String source, final String dict, final int collationId) { @@ -492,14 +494,19 @@ public static String genCode(final String source, final String dict, final int c String expr = "CollationSupport.EndsWith.exec"; if (collation.supportsBinaryEquality) { return String.format(expr + "Binary(%s, %s)", source, dict); - } else { - return String.format(expr + "NonBinary(%s, %s, %d)", source, dict, collationId); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", source, dict); + } else { + return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId); } } public static UTF8String execBinary(final UTF8String source, Map dict) { return source.translate(dict); } - public static UTF8String execNonBinary(final UTF8String source, Map dict, + public static UTF8String execLowercase(final UTF8String source, Map dict) { + return CollationAwareUTF8String.lowercaseTranslate(source, dict); + } + public static UTF8String execICU(final UTF8String source, Map dict, final int collationId) { return CollationAwareUTF8String.translate(source, dict, collationId); } From 659b1ebf755400040c0d6553548d1f59720e4511 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 31 May 2024 17:48:22 +0200 Subject: [PATCH 05/13] Fix lint --- .../spark/sql/catalyst/util/CollationAwareUTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 6ba3c97ccb618..1c5581da09339 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -318,7 +318,7 @@ public static String toLowerCase(final String target, final int collationId) { return UCharacter.toLowerCase(locale, target); } - private final static int COMBINED_LOWERCASE_I_DOT = 0x69 << 16 | 0x307; + private static final int COMBINED_LOWERCASE_I_DOT = 0x69 << 16 | 0x307; private static int getLowercaseCodePoint(final int codePoint) { if (codePoint == 0x0130) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. From a70b651603d17c1f6a41d8a9a51667fc3ce90aef Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 10 Jun 2024 16:21:53 +0200 Subject: [PATCH 06/13] Add tests --- .../spark/sql/CollationStringExpressionsSuite.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 3e068eaae4a60..bb73cfebc1921 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -295,9 +295,11 @@ class CollationStringExpressionsSuite TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"), // Case mapping edge cases + TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), + TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY_LCASE", "İ"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY_LCASE", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY_LCASE", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY_LCASE", "İ23"), @@ -333,13 +335,6 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("Handling invalid UTF-8 sequences in StringTranslate") { - Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach { collation => - val query = s"SELECT translate(cast(unhex('C22C41') as string collate $collation), ',', 'X')" - checkAnswer(sql(query), Row("�XA")) - } - } - test("Support Replace string expression with collation") { case class ReplaceTestCase[R](source: String, search: String, replace: String, c: String, result: R) From 542fee3f889403a1ae72ad352dc567bf5b46c913 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Thu, 4 Jul 2024 15:24:49 +0200 Subject: [PATCH 07/13] Update CollationStringExpressionsSuite.scala --- .../sql/CollationStringExpressionsSuite.scala | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index ed0558c74b8c6..8d3586c7bae81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -258,16 +258,16 @@ class CollationStringExpressionsSuite case class TranslateTestCase[R](input: String, matchExpression: String, replaceExpression: String, collation: String, result: R) val testCases = Seq( - TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY_LCASE", "41a2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY_LCASE", "xXaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY_LCASE", "xxaxsXaxex"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY_LCASE", "xXaxsXaxeX"), - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY_LCASE", "test大千世AB大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY_LCASE", "大千世界abca大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY_LCASE", "oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY_LCASE", "大千世界大千世界OesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY_LCASE", "世世世界世世世界tesT"), + TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), + TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"), + TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"), + TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"), + TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"), + TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"), + TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"), + TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"), + TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"), TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), @@ -284,16 +284,16 @@ class CollationStringExpressionsSuite TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"), TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY_LCASE", "14234e"), + TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"), TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"), TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"), TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY_LCASE", "41a2s3a4e"), + TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"), TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"), TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"), TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"), TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY_LCASE", "123f"), + TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"), // Case mapping edge cases @@ -301,10 +301,10 @@ class CollationStringExpressionsSuite TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), - TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY_LCASE", "İ"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY_LCASE", "123"), - TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY_LCASE", "1i\u0307"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY_LCASE", "İ23"), + TranslateTestCase("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "123"), + TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_LCASE", "1i\u0307"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İ23"), @@ -331,7 +331,7 @@ class CollationStringExpressionsSuite }) // Collation mismatch val collationMismatch = intercept[AnalysisException] { - sql(s"SELECT translate(collate('Translate', 'UTF8_BINARY_LCASE')," + + sql(s"SELECT translate(collate('Translate', 'UTF8_LCASE')," + s"collate('Rnlt', 'UNICODE'), '1234')") } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") From be50416bff2ca90f15f65d11e6940b299c23fdc6 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 5 Jul 2024 17:45:45 +0200 Subject: [PATCH 08/13] Refactor translate --- .../util/CollationAwareUTF8String.java | 167 ++++++++++++------ .../expressions/stringExpressions.scala | 37 +++- .../sql/CollationStringExpressionsSuite.scala | 28 ++- 3 files changed, 160 insertions(+), 72 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 073b661be85e7..7a5018d67c5a1 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -18,6 +18,8 @@ import com.ibm.icu.lang.UCharacter; import com.ibm.icu.text.BreakIterator; +import com.ibm.icu.text.Collator; +import com.ibm.icu.text.RuleBasedCollator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; @@ -27,6 +29,8 @@ import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; import static org.apache.spark.unsafe.Platform.copyMemory; +import java.text.CharacterIterator; +import java.text.StringCharacterIterator; import java.util.HashMap; import java.util.Map; @@ -424,28 +428,40 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col * @param codePoint The code point to convert to lowercase. * @param sb The StringBuilder to append the lowercase character to. */ - private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) { - if (codePoint == 0x0130) { + private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) { + int lowercaseCodePoint = getLowercaseCodePoint(codePoint); + if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. sb.appendCodePoint(0x0069); sb.appendCodePoint(0x0307); - } - else if (codePoint == 0x03C2) { - // Greek final and non-final capital letter sigma should be mapped the same. - sb.appendCodePoint(0x03C3); - } - else { + } else { // All other characters should follow context-unaware ICU single-code point case mapping. - sb.appendCodePoint(UCharacter.toLowerCase(codePoint)); + sb.appendCodePoint(lowercaseCodePoint); } } - private static final int COMBINED_LOWERCASE_I_DOT = 0x69 << 16 | 0x307; + /** + * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase + * code point for ASCII lowercase letter i with an additional combining dot character (U+0307). + * This integer value is not a valid code point itself, but rather an artificial code point + * marker used to represent the two lowercase characters that are the result of converting the + * uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. + */ + private static final int CODE_POINT_LOWERCASE_I = 0x69; + private static final int CODE_POINT_COMBINING_DOT = 0x307; + private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT = + CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT; + /** + * Returns the lowercase version of the provided code point, with special handling for + * one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and + * context-insensitive case mappings (i.e. characters that map to different characters based on + * the position in the string relative to other characters in lowercase). + */ private static int getLowercaseCodePoint(final int codePoint) { if (codePoint == 0x0130) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. - return COMBINED_LOWERCASE_I_DOT; + return CODE_POINT_COMBINED_LOWERCASE_I_DOT; } else if (codePoint == 0x03C2) { // Greek final and non-final capital letter sigma should be mapped the same. @@ -461,7 +477,7 @@ else if (codePoint == 0x03C2) { * Converts an entire string to lowercase using ICU rules, code point by code point, with * special handling for one-to-many case mappings (i.e. characters that map to multiple * characters in lowercase). Also, this method omits information about context-sensitive case - * mappings using special handling in the `lowercaseCodePoint` method. + * mappings using special handling in the `appendLowercaseCodePoint` method. * * @param target The target string to convert to lowercase. * @return The string converted to lowercase in a context-unaware manner. @@ -475,7 +491,7 @@ private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { String targetString = target.toValidString(); StringBuilder sb = new StringBuilder(); for (int i = 0; i < targetString.length(); ++i) { - lowercaseCodePoint(targetString.codePointAt(i), sb); + appendLowercaseCodePoint(targetString.codePointAt(i), sb); } return UTF8String.fromString(sb.toString()); } @@ -672,6 +688,17 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string, } } + /** + * Converts the original translation dictionary (`dict`) to a dictionary with lowercased keys. + * This method is used to create a dictionary that can be used for the UTF8_LCASE collation. + * Note that `StringTranslate.buildDict` will ensure that all strings are validated properly. + * + * The method returns a map with lowercased code points as keys, while the values remain + * unchanged. Note that `dict` is constructed on a character by character basis, and the + * original keys are stored as strings. Keys in the resulting lowercase dictionary are stored + * as integers, which correspond only to single characters from the original `dict`. Also, + * there is special handling for the Turkish dotted uppercase letter I (U+0130). + */ private static Map getLowercaseDict(final Map dict) { // Replace all the keys in the dict with lowercased code points. Map lowercaseDict = new HashMap<>(); @@ -680,63 +707,87 @@ private static Map getLowercaseDict(final Map d lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue()); } return lowercaseDict; - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - } - private static Map getCollationAwareDict(final Map dict, - int collationId) { - // Replace all the keys in the dict with collation keys. - Map collationAwareDict = new HashMap<>(); - for (Map.Entry entry : dict.entrySet()) { - String collationKey = CollationFactory.getCollationKey(entry.getKey(), collationId); - collationAwareDict.putIfAbsent(collationKey, entry.getValue()); - } - return collationAwareDict; - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` } - private static String lowercaseTranslate(final String input, final Map dict) { + /** + * Translates the `input` string using the translation map `dict`, for UTF8_LCASE collation. + * String translation is performed by iterating over the input string, from left to right, and + * repeatedly translating the longest possible substring that matches a key in the dictionary. + * For UTF8_LCASE, the method uses the lowercased substring to perform the lookup in the + * lowercase version of the translation map. + * + * @param input the string to be translated + * @param dict the lowercase translation dictionary + * @return the translated string + */ + public static UTF8String lowercaseTranslate(final UTF8String input, + final Map dict) { + Map lowercaseDict = getLowercaseDict(dict); StringBuilder sb = new StringBuilder(); - int charCount = 0; - for (int k = 0; k < input.length(); k += charCount) { - int codePoint = input.codePointAt(k); - charCount = Character.charCount(codePoint); - String translated = dict.get(getLowercaseCodePoint(codePoint)); - if (null == translated) { + for (int charIndex = 0; charIndex < input.numChars(); ++charIndex) { + int codePoint = input.getChar(charIndex); + if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && + codePoint == CODE_POINT_LOWERCASE_I && charIndex + 1 < input.numChars() && + input.getChar(charIndex + 1) == CODE_POINT_COMBINING_DOT) { + // Special handling for letter i (U+0069) followed by a combining dot (U+0307) + codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; + ++charIndex; + } + String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint)); + if (translated == null) { sb.appendCodePoint(codePoint); } else if (!"\0".equals(translated)) { sb.append(translated); } } - return sb.toString(); + return UTF8String.fromString(sb.toString()); } - private static String translate(final String input, final Map dict, - final int collationId) { + + /** + * Translates the `input` string using the translation map `dict`, for all ICU collations. + * String translation is performed by iterating over the input string, from left to right, and + * repeatedly translating the longest possible substring that matches a key in the dictionary. + * For ICU collations, the method uses the collation key of the substring to perform the lookup + * in the collation aware version of the translation map. + * + * @param input the string to be translated + * @param dict the collation aware translation dictionary + * @param collationId the collation ID to use for string translation + * @return the translated string + */ + public static UTF8String translate(final UTF8String input, + final Map dict, final int collationId) { + String inputString = input.toValidString(); + CharacterIterator target = new StringCharacterIterator(inputString); + Collator collator = CollationFactory.fetchCollation(collationId).collator; StringBuilder sb = new StringBuilder(); - int charCount = 0; - for (int k = 0; k < input.length(); k += charCount) { - int codePoint = input.codePointAt(k); - charCount = Character.charCount(codePoint); - String subStr = input.substring(k, k + charCount); - String collationKey = CollationFactory.getCollationKey(subStr, collationId); - String translated = dict.get(collationKey); - if (null == translated) { - sb.append(subStr); - } else if (!"\0".equals(translated)) { - sb.append(translated); + int charIndex = 0; + while (charIndex < inputString.length()) { + int longestMatchLen = 0; + String longestMatch = ""; + for (String key : dict.keySet()) { + StringSearch stringSearch = new StringSearch(key, target, (RuleBasedCollator) collator); + stringSearch.setIndex(charIndex); + int matchIndex = stringSearch.next(); + if (matchIndex == charIndex) { + int matchLen = stringSearch.getMatchLength(); + if (matchLen > longestMatchLen) { + longestMatchLen = matchLen; + longestMatch = key; + } + } + } + if (longestMatchLen == 0) { + sb.append(inputString.charAt(charIndex)); + charIndex++; + } else { + if (!"\0".equals(dict.get(longestMatch))) { + sb.append(dict.get(longestMatch)); + } + charIndex += longestMatchLen; } } - return sb.toString(); - } - - public static UTF8String lowercaseTranslate(final UTF8String input, - final Map dict) { - Map lowercaseDict = getLowercaseDict(dict); - return UTF8String.fromString(lowercaseTranslate(input.toString(), lowercaseDict)); - } - public static UTF8String translate(final UTF8String input, final Map dict, - final int collationId) { - Map collationAwareDict = getCollationAwareDict(dict,collationId); - return UTF8String.fromString(translate(input.toString(), collationAwareDict, collationId)); + return UTF8String.fromString(sb.toString()); } public static UTF8String lowercaseTrim( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2871c540c5e18..a31949f7290a6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} @@ -1051,11 +1051,35 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: object StringTranslate { - def buildDict(matchingString: UTF8String, replaceString: UTF8String) + /** + * Build a translation dictionary from UTF8Strings. First, this method converts the input strings + * to valid Java Strings. However, we avoid any behavior changes for the UTF8_BINARY collation, + * but ensure that all other collations use `UTF8String.toValidString` to achieve this step. + */ + def buildDict(matchingString: UTF8String, replaceString: UTF8String, collationId: Integer) : JMap[String, String] = { - val matching = matchingString.toString() + val isCollationAware = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID + val matching: String = if (isCollationAware) { + matchingString.toString + } else { + matchingString.toValidString + } + val replace: String = if (isCollationAware) { + replaceString.toString + } else { + replaceString.toValidString + } + buildDict(matching, replace) + } - val replace = replaceString.toString() + /** + * Build a translation dictionary from Strings. This method assumes that the input strings are + * already valid. The result dictionary maps each character in `matching` to the corresponding + * character in `replace`. If `replace` is shorter than `matching`, the extra characters in + * `matching` will be mapped to null terminator, which causes characters to get deleted during + * translation. If `replace` is longer than `matching`, the extra characters will be ignored. + */ + private def buildDict(matching: String, replace: String): JMap[String, String] = { val dict = new HashMap[String, String]() var i = 0 var j = 0 @@ -1079,6 +1103,7 @@ object StringTranslate { } dict } + } /** @@ -1111,7 +1136,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac if (matchingEval != lastMatching || replaceEval != lastReplace) { lastMatching = matchingEval.asInstanceOf[UTF8String].clone() lastReplace = replaceEval.asInstanceOf[UTF8String].clone() - dict = StringTranslate.buildDict(lastMatching, lastReplace) + dict = StringTranslate.buildDict(lastMatching, lastReplace, collationId) } CollationSupport.StringTranslate.exec(srcEval.asInstanceOf[UTF8String], dict, collationId) @@ -1135,7 +1160,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac $termLastMatching = $matching.clone(); $termLastReplace = $replace.clone(); $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict($termLastMatching, $termLastReplace); + .buildDict($termLastMatching, $termLastReplace, $collationId); } ${ev.value} = CollationSupport.StringTranslate. exec($src, $termDict, $collationId); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 8d3586c7bae81..30dce116aa29c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -256,7 +256,7 @@ class CollationStringExpressionsSuite test("Support StringTranslate string expression with collation") { // Supported collations case class TranslateTestCase[R](input: String, matchExpression: String, - replaceExpression: String, collation: String, result: R) + replaceExpression: String, collation: String, result: R) val testCases = Seq( TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), @@ -301,16 +301,28 @@ class CollationStringExpressionsSuite TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), + TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"), + TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"), + TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"), TranslateTestCase("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "123"), - TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_LCASE", "1i\u0307"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"), + TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "123"), + TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"), + TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"), + TranslateTestCase("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"), TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İ23"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "123"), - TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "1i\u0307"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İ23") + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"), + TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"), + TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"), + TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"), + TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"), + TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "11"), + TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"), + TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"), + TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"), + TranslateTestCase("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2") ) testCases.foreach(t => { From aec007797e03f1609b24fee41aa3220f11f54a4c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:02:45 +0200 Subject: [PATCH 09/13] Fixes --- .../util/CollationAwareUTF8String.java | 49 ++++++++++++++----- .../sql/catalyst/util/CollationFactory.java | 12 ----- .../sql/catalyst/util/CollationSupport.java | 8 +-- .../sql/CollationStringExpressionsSuite.scala | 35 ++++++++++++- 4 files changed, 76 insertions(+), 28 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 7a5018d67c5a1..184f75f4132f9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -28,10 +28,12 @@ import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; import static org.apache.spark.unsafe.Platform.copyMemory; +import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType; import java.text.CharacterIterator; import java.text.StringCharacterIterator; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; /** @@ -488,10 +490,11 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) { } private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { - String targetString = target.toValidString(); + Iterator targetIter = target.codePointIterator( + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); StringBuilder sb = new StringBuilder(); - for (int i = 0; i < targetString.length(); ++i) { - appendLowercaseCodePoint(targetString.codePointAt(i), sb); + while (targetIter.hasNext()) { + appendLowercaseCodePoint(targetIter.next(), sb); } return UTF8String.fromString(sb.toString()); } @@ -714,7 +717,7 @@ private static Map getLowercaseDict(final Map d * String translation is performed by iterating over the input string, from left to right, and * repeatedly translating the longest possible substring that matches a key in the dictionary. * For UTF8_LCASE, the method uses the lowercased substring to perform the lookup in the - * lowercase version of the translation map. + * lowercased version of the translation map. * * @param input the string to be translated * @param dict the lowercase translation dictionary @@ -722,24 +725,48 @@ private static Map getLowercaseDict(final Map d */ public static UTF8String lowercaseTranslate(final UTF8String input, final Map dict) { + // Iterator for the input string. + Iterator inputIter = input.codePointIterator( + CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID); + // Lowercased translation dictionary. Map lowercaseDict = getLowercaseDict(dict); + // StringBuilder to store the translated string. StringBuilder sb = new StringBuilder(); - for (int charIndex = 0; charIndex < input.numChars(); ++charIndex) { - int codePoint = input.getChar(charIndex); + + // Buffered code point iteration to handle one-to-many case mappings. + int codePointBuffer = -1, codePoint; + while (inputIter.hasNext()) { + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } else { + codePoint = inputIter.next(); + } + // Special handling for letter i (U+0069) followed by a combining dot (U+0307). if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && - codePoint == CODE_POINT_LOWERCASE_I && charIndex + 1 < input.numChars() && - input.getChar(charIndex + 1) == CODE_POINT_COMBINING_DOT) { - // Special handling for letter i (U+0069) followed by a combining dot (U+0307) - codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; - ++charIndex; + codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) { + int nextCodePoint = inputIter.next(); + if (nextCodePoint == CODE_POINT_COMBINING_DOT) { + codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; + } else { + codePointBuffer = nextCodePoint; + } } + // Translate the code point using the lowercased dictionary. String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint)); if (translated == null) { + // Append the original code point if no translation is found. sb.appendCodePoint(codePoint); } else if (!"\0".equals(translated)) { + // Append the translated code point if the translation is not the null character. sb.append(translated); } + // Skip the code point if it maps to the null character. } + // Append the last code point if it was buffered. + if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer); + + // Return the translated string. return UTF8String.fromString(sb.toString()); } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 19a03baef07f5..51627e1dd6c98 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -813,18 +813,6 @@ public static String[] getICULocaleNames() { return Collation.CollationSpecICU.ICULocaleNames; } - public static String getCollationKey(String input, int collationId) { - Collation collation = fetchCollation(collationId); - if (collation.supportsBinaryEquality) { - return input; - } else if (collation.supportsLowercaseEquality) { - return input.toLowerCase(); - } else { - CollationKey collationKey = collation.collator.getCollationKey(input); - return Arrays.toString(collationKey.toByteArray()); - } - } - public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); if (collation.supportsBinaryEquality) { diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index d084e1e247c22..f9ccd22f3f5c6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -212,7 +212,7 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean return useICU ? execBinaryICU(v) : execBinary(v); } else if (collation.supportsLowercaseEquality) { return execLowercase(v); - } else { + } else { return execICU(v, collationId); } } @@ -224,7 +224,7 @@ public static String genCode(final String v, final int collationId, boolean useI return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); - } else { + } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } } @@ -261,7 +261,7 @@ public static String genCode(final String v, final int collationId, boolean useI return String.format(expr + "%s(%s)", funcName, v); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s)", v); - } else { + } else { return String.format(expr + "ICU(%s, %d)", v, collationId); } } @@ -514,7 +514,7 @@ public static String genCode(final String source, final String dict, final int c return String.format(expr + "Binary(%s, %s)", source, dict); } else if (collation.supportsLowercaseEquality) { return String.format(expr + "Lowercase(%s, %s)", source, dict); - } else { + } else { return String.format(expr + "ICU(%s, %s, %d)", source, dict, collationId); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 30dce116aa29c..efb1910d1c29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -258,6 +258,19 @@ class CollationStringExpressionsSuite case class TranslateTestCase[R](input: String, matchExpression: String, replaceExpression: String, collation: String, result: R) val testCases = Seq( + // Basic tests - UTF8_BINARY. + TranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"), + TranslateTestCase("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"), + TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"), + TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"), + TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"), + TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"), + TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"), + TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"), + TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"), + TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"), + TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"), + // Basic tests - UTF8_LCASE. TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"), @@ -268,6 +281,7 @@ class CollationStringExpressionsSuite TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"), TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"), + // Basic tests - UNICODE. TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), @@ -275,6 +289,7 @@ class CollationStringExpressionsSuite TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"), TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"), TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"), + // Basic tests - UNICODE_CI. TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"), TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"), TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"), @@ -296,27 +311,45 @@ class CollationStringExpressionsSuite TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"), - // Case mapping edge cases + + // One-to-many case mapping - UTF8_BINARY. TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"), + TranslateTestCase("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"), + TranslateTestCase("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"), + TranslateTestCase("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"), TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"), TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"), + // One-to-many case mapping - UTF8_LCASE. TranslateTestCase("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"), + TranslateTestCase("i\u0307", "İ", "xy", "UTF8_LCASE", "x"), + TranslateTestCase("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"), + TranslateTestCase("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"), TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"), TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"), TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"), TranslateTestCase("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"), + // One-to-many case mapping - UNICODE. + TranslateTestCase("İ", "i\u0307", "xy", "UNICODE", "İ"), + TranslateTestCase("i\u0307", "İ", "xy", "UNICODE", "i\u0307"), + TranslateTestCase("i\u030A", "İ", "x", "UNICODE", "i\u030A"), + TranslateTestCase("i\u030A", "İi", "xy", "UNICODE", "i\u030A"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"), TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"), TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"), TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"), TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"), + // One-to-many case mapping - UNICODE_CI. + TranslateTestCase("İ", "i\u0307", "xy", "UNICODE_CI", "İ"), + TranslateTestCase("i\u0307", "İ", "xy", "UNICODE_CI", "x"), + TranslateTestCase("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A"), + TranslateTestCase("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A"), TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"), TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "11"), TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"), From 4202538cc1aed738a79f65d527bc01da714e7824 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 8 Jul 2024 21:55:07 +0200 Subject: [PATCH 10/13] Update tests --- .../unsafe/types/CollationSupportSuite.java | 144 ++++++++++++++++-- .../sql/CollationStringExpressionsSuite.scala | 98 +----------- 2 files changed, 135 insertions(+), 107 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 9438484344d62..0ebdfe45887ab 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -22,6 +22,9 @@ import org.apache.spark.sql.catalyst.util.CollationSupport; import org.junit.jupiter.api.Test; +import java.util.HashMap; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; // checkstyle.off: AvoidEscapedUnicodeCharacters @@ -1378,19 +1381,138 @@ public void testStringTrim() throws SparkException { assertStringTrimRight("UTF8_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); } - // TODO: Test more collation-aware string expressions. - - /** - * Collation-aware regexp expressions. - */ - - // TODO: Test more collation-aware regexp expressions. + private void assertStringTranslate( + String inputString, + String matchingString, + String replaceString, + String collationName, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collationName); + Map dict = buildDict(matchingString, replaceString); + UTF8String source = UTF8String.fromString(inputString); + UTF8String result = CollationSupport.StringTranslate.exec(source, dict, collationId); + assertEquals(expectedResultString, result.toString()); + } - /** - * Other collation-aware expressions. - */ + @Test + public void testStringTranslate() throws SparkException { + // Basic tests - UTF8_BINARY. + assertStringTranslate("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"); + // Basic tests - UTF8_LCASE. + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"); + // Basic tests - UNICODE. + assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"); + // Basic tests - UNICODE_CI. + assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"); + assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"); + assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"); + assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"); + assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"); + assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f"); + + // One-to-many case mapping - UTF8_BINARY. + assertStringTranslate("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"); + // One-to-many case mapping - UTF8_LCASE. + assertStringTranslate("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UTF8_LCASE", "x"); + assertStringTranslate("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"); + assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"); + // One-to-many case mapping - UNICODE. + assertStringTranslate("İ", "i\u0307", "xy", "UNICODE", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UNICODE", "i\u0307"); + assertStringTranslate("i\u030A", "İ", "x", "UNICODE", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UNICODE", "i\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"); + assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"); + assertStringTranslate("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"); + // One-to-many case mapping - UNICODE_CI. + assertStringTranslate("İ", "i\u0307", "xy", "UNICODE_CI", "İ"); + assertStringTranslate("i\u0307", "İ", "xy", "UNICODE_CI", "x"); + assertStringTranslate("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A"); + assertStringTranslate("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A"); + assertStringTranslate("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "İyz", "123", "UNICODE_CI", "11"); + assertStringTranslate("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"); + assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"); + assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2"); + } - // TODO: Test other collation-aware expressions. + private Map buildDict(String matching, String replace) { + Map dict = new HashMap<>(); + int i = 0, j = 0; + while (i < matching.length()) { + String rep = "\u0000"; + if (j < replace.length()) { + int repCharCount = Character.charCount(replace.codePointAt(j)); + rep = replace.substring(j, j + repCharCount); + j += repCharCount; + } + int matchCharCount = Character.charCount(matching.codePointAt(i)); + String matchStr = matching.substring(i, i + matchCharCount); + dict.putIfAbsent(matchStr, rep); + i += matchCharCount; + } + return dict; + } } // checkstyle.on: AvoidEscapedUnicodeCharacters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index efb1910d1c29e..5f722b2f01fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -258,104 +258,10 @@ class CollationStringExpressionsSuite case class TranslateTestCase[R](input: String, matchExpression: String, replaceExpression: String, collation: String, result: R) val testCases = Seq( - // Basic tests - UTF8_BINARY. TranslateTestCase("Translate", "Rnlt", "12", "UTF8_BINARY", "Tra2sae"), - TranslateTestCase("Translate", "Rn", "1234", "UTF8_BINARY", "Tra2slate"), - TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_BINARY", "Tra2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_BINARY", "TRaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_BINARY", "TxaxsXaxeX"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_BINARY", "TXaxsXaxex"), - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_BINARY", "test大千世AX大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_BINARY", "大千世界test大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"), - // Basic tests - UTF8_LCASE. TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UTF8_LCASE", "xXaxsXaxeX"), - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UTF8_LCASE", "test大千世AB大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UTF8_LCASE", "大千世界abca大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"), - // Basic tests - UNICODE. - TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"), - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"), - // Basic tests - UNICODE_CI. - TranslateTestCase("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"), - TranslateTestCase("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"), - TranslateTestCase("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"), - TranslateTestCase("TRanslater", "Rrnlt", "XxxXx", "UNICODE_CI", "xXaxsXaxeX"), - TranslateTestCase("test大千世界X大千世界", "界x", "AB", "UNICODE_CI", "test大千世AB大千世A"), - TranslateTestCase("大千世界test大千世界", "TEST", "abcd", "UNICODE_CI", "大千世界abca大千世界"), - TranslateTestCase("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"), - TranslateTestCase("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"), - TranslateTestCase("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"), - TranslateTestCase("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"), - TranslateTestCase("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"), - TranslateTestCase("abcdef", "abcde", "123", "UTF8_BINARY", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UTF8_LCASE", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UNICODE", "123f"), - TranslateTestCase("abcdef", "abcde", "123", "UNICODE_CI", "123f"), - - // One-to-many case mapping - UTF8_BINARY. - TranslateTestCase("İ", "i\u0307", "xy", "UTF8_BINARY", "İ"), - TranslateTestCase("i\u0307", "İ", "xy", "UTF8_BINARY", "i\u0307"), - TranslateTestCase("i\u030A", "İ", "x", "UTF8_BINARY", "i\u030A"), - TranslateTestCase("i\u030A", "İi", "xy", "UTF8_BINARY", "y\u030A"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_BINARY", "123"), - TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_BINARY", "1i\u0307"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_BINARY", "İ23"), - TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_BINARY", "12bc3"), - TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_BINARY", "a2bcå"), - TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UTF8_BINARY", "3\u030Aβφδ1\u0307"), - // One-to-many case mapping - UTF8_LCASE. - TranslateTestCase("İ", "i\u0307", "xy", "UTF8_LCASE", "İ"), - TranslateTestCase("i\u0307", "İ", "xy", "UTF8_LCASE", "x"), - TranslateTestCase("i\u030A", "İ", "x", "UTF8_LCASE", "i\u030A"), - TranslateTestCase("i\u030A", "İi", "xy", "UTF8_LCASE", "y\u030A"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UTF8_LCASE", "11"), - TranslateTestCase("İi\u0307", "İyz", "123", "UTF8_LCASE", "11"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UTF8_LCASE", "İ23"), - TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UTF8_LCASE", "12bc3"), - TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UTF8_LCASE", "12bc3"), - TranslateTestCase("A\u030Aβφδi\u0307", "Iİaå", "1234", "UTF8_LCASE", "3\u030Aβφδ2"), - // One-to-many case mapping - UNICODE. - TranslateTestCase("İ", "i\u0307", "xy", "UNICODE", "İ"), - TranslateTestCase("i\u0307", "İ", "xy", "UNICODE", "i\u0307"), - TranslateTestCase("i\u030A", "İ", "x", "UNICODE", "i\u030A"), - TranslateTestCase("i\u030A", "İi", "xy", "UNICODE", "i\u030A"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE", "1i\u0307"), - TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE", "1i\u0307"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE", "İi\u0307"), - TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UNICODE", "3bc3"), - TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UNICODE", "a\u030Abcå"), - TranslateTestCase("a\u030AβφδI\u0307", "Iİaå", "1234", "UNICODE", "4βφδ2"), - // One-to-many case mapping - UNICODE_CI. - TranslateTestCase("İ", "i\u0307", "xy", "UNICODE_CI", "İ"), - TranslateTestCase("i\u0307", "İ", "xy", "UNICODE_CI", "x"), - TranslateTestCase("i\u030A", "İ", "x", "UNICODE_CI", "i\u030A"), - TranslateTestCase("i\u030A", "İi", "xy", "UNICODE_CI", "i\u030A"), - TranslateTestCase("İi\u0307", "İi\u0307", "123", "UNICODE_CI", "11"), - TranslateTestCase("İi\u0307", "İyz", "123", "UNICODE_CI", "11"), - TranslateTestCase("İi\u0307", "xi\u0307", "123", "UNICODE_CI", "İi\u0307"), - TranslateTestCase("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"), - TranslateTestCase("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"), - TranslateTestCase("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2") + TranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), + TranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) testCases.foreach(t => { From b50994159c6a9359e522c3b5723ffa91093b64ea Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 9 Jul 2024 20:23:44 +0200 Subject: [PATCH 11/13] Fixes --- .../util/CollationAwareUTF8String.java | 32 +++++++-- .../sql/catalyst/util/CollationFactory.java | 2 +- .../unsafe/types/CollationSupportSuite.java | 68 ++++++++++++++++--- .../expressions/stringExpressions.scala | 1 - 4 files changed, 87 insertions(+), 16 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 184f75f4132f9..3492f64b419dd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -466,7 +466,9 @@ private static int getLowercaseCodePoint(final int codePoint) { return CODE_POINT_COMBINED_LOWERCASE_I_DOT; } else if (codePoint == 0x03C2) { - // Greek final and non-final capital letter sigma should be mapped the same. + // Greek final and non-final letter sigma should be mapped the same. This is achieved by + // mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital + // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch. return 0x03C3; } else { @@ -733,7 +735,10 @@ public static UTF8String lowercaseTranslate(final UTF8String input, // StringBuilder to store the translated string. StringBuilder sb = new StringBuilder(); - // Buffered code point iteration to handle one-to-many case mappings. + // We use buffered code point iteration to handle one-to-many case mappings. We need to handle + // at most two code points at a time (for `CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of + // size 1 enables us to match two codepoints in the input string with a single codepoint in + // the lowercase translation dictionary. int codePointBuffer = -1, codePoint; while (inputIter.hasNext()) { if (codePointBuffer != -1) { @@ -742,7 +747,8 @@ public static UTF8String lowercaseTranslate(final UTF8String input, } else { codePoint = inputIter.next(); } - // Special handling for letter i (U+0069) followed by a combining dot (U+0307). + // Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring + // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match. if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) { int nextCodePoint = inputIter.next(); @@ -784,19 +790,32 @@ public static UTF8String lowercaseTranslate(final UTF8String input, */ public static UTF8String translate(final UTF8String input, final Map dict, final int collationId) { + // Replace invalid UTF-8 sequences with the Unicode replacement character U+FFFD. String inputString = input.toValidString(); + // Create a character iterator for the validated input string. This will be used for searching + // inside the string using ICU `StringSearch` class. We only need to do it once before the + // main loop of the translate algorithm. CharacterIterator target = new StringCharacterIterator(inputString); Collator collator = CollationFactory.fetchCollation(collationId).collator; StringBuilder sb = new StringBuilder(); + // Index for the current character in the (validated) input string. This is the character we + // want to determine if we need to replace or not. int charIndex = 0; while (charIndex < inputString.length()) { + // We search the replacement dictionary to find a match. If there are more than one matches + // (which is possible for collated strings), we want to choose the match of largest length. int longestMatchLen = 0; String longestMatch = ""; for (String key : dict.keySet()) { StringSearch stringSearch = new StringSearch(key, target, (RuleBasedCollator) collator); + // Point `stringSearch` to start at the current character. stringSearch.setIndex(charIndex); int matchIndex = stringSearch.next(); if (matchIndex == charIndex) { + // We have found a match (that is the current position matches with one of the characters + // in the dictionary). However, there might be other matches of larger length, so we need + // to continue searching against the characters in the dictionary and keep track of the + // match of largest length. int matchLen = stringSearch.getMatchLength(); if (matchLen > longestMatchLen) { longestMatchLen = matchLen; @@ -805,15 +824,20 @@ public static UTF8String translate(final UTF8String input, } } if (longestMatchLen == 0) { + // No match was found, so output the current character. sb.append(inputString.charAt(charIndex)); - charIndex++; + // Move on to the next character in the input string. + ++charIndex; } else { + // We have found at least one match. Append the match of longest match length to the output. if (!"\0".equals(dict.get(longestMatch))) { sb.append(dict.get(longestMatch)); } + // Skip as many characters as the longest match. charIndex += longestMatchLen; } } + // Return the translated string. return UTF8String.fromString(sb.toString()); } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 51627e1dd6c98..f13f66e384e0f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -818,7 +818,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsBinaryEquality) { return input; } else if (collation.supportsLowercaseEquality) { - return CollationAwareUTF8String.toLowerCase(input); + return input.toLowerCase(); } else { CollationKey collationKey = collation.collator.getCollationKey(input.toString()); return UTF8String.fromBytes(collationKey.toByteArray()); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 0ebdfe45887ab..ced0d4d9fb79c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -1408,8 +1408,12 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_BINARY", "Oeso大千世界大千世界"); assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_BINARY", "大千世界大千世界oesO"); assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_BINARY", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); // Basic tests - UTF8_LCASE. - assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("Translate", "Rnlt", "12", "UTF8_LCASE", "1a2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UTF8_LCASE", "T1a2slate"); assertStringTranslate("Translate", "Rnlt", "1234", "UTF8_LCASE", "41a2s3a4e"); assertStringTranslate("TRanslate", "rnlt", "XxXx", "UTF8_LCASE", "xXaxsXaxe"); assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UTF8_LCASE", "xxaxsXaxex"); @@ -1419,15 +1423,27 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UTF8_LCASE", "oeso大千世界大千世界"); assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UTF8_LCASE", "大千世界大千世界OesO"); assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UTF8_LCASE", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); // Basic tests - UNICODE. + assertStringTranslate("Translate", "Rnlt", "12", "UNICODE", "Tra2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UNICODE", "Tra2slate"); assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE", "Tra2s3a4e"); assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE", "TRaxsXaxe"); assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE", "TxaxsXaxeX"); assertStringTranslate("TRanslater", "Rrnlt", "XxxXx", "UNICODE", "TXaxsXaxex"); assertStringTranslate("test大千世界X大千世界", "界x", "AB", "UNICODE", "test大千世AX大千世A"); + assertStringTranslate("大千世界test大千世界", "TEST", "abcd", "UNICODE", "大千世界test大千世界"); assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE", "Oeso大千世界大千世界"); assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE", "大千世界大千世界oesO"); + assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE", "世世世界世世世界tesT"); + assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); + assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); + assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); // Basic tests - UNICODE_CI. + assertStringTranslate("Translate", "Rnlt", "12", "UNICODE_CI", "1a2sae"); + assertStringTranslate("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate"); assertStringTranslate("Translate", "Rnlt", "1234", "UNICODE_CI", "41a2s3a4e"); assertStringTranslate("TRanslate", "rnlt", "XxXx", "UNICODE_CI", "xXaxsXaxe"); assertStringTranslate("TRanslater", "Rrnlt", "xXxXx", "UNICODE_CI", "xxaxsXaxex"); @@ -1437,17 +1453,8 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("Test大千世界大千世界", "tT", "oO", "UNICODE_CI", "oeso大千世界大千世界"); assertStringTranslate("大千世界大千世界tesT", "Tt", "Oo", "UNICODE_CI", "大千世界大千世界OesO"); assertStringTranslate("大千世界大千世界tesT", "大千", "世世", "UNICODE_CI", "世世世界世世世界tesT"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_LCASE", "14234e"); assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE_CI", "14234e"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UNICODE", "Tr4234e"); - assertStringTranslate("Translate", "Rnlasdfjhgadt", "1234", "UTF8_BINARY", "Tr4234e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_LCASE", "41a2s3a4e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE", "Tra2s3a4e"); assertStringTranslate("Translate", "Rnlt", "123495834634", "UNICODE_CI", "41a2s3a4e"); - assertStringTranslate("Translate", "Rnlt", "123495834634", "UTF8_BINARY", "Tra2s3a4e"); - assertStringTranslate("abcdef", "abcde", "123", "UTF8_BINARY", "123f"); - assertStringTranslate("abcdef", "abcde", "123", "UTF8_LCASE", "123f"); - assertStringTranslate("abcdef", "abcde", "123", "UNICODE", "123f"); assertStringTranslate("abcdef", "abcde", "123", "UNICODE_CI", "123f"); // One-to-many case mapping - UTF8_BINARY. @@ -1494,6 +1501,47 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("a\u030Abcå", "a\u030Aå", "123", "UNICODE_CI", "3bc3"); assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"); assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2"); + + // Conditional case mapping - UTF8_BINARY. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY", "σΥσΤΗΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_BINARY", "ςΥςΤΗΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", "σιστιματικος"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", "σιστιματικοσ"); + // Conditional case mapping - UTF8_LCASE. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UTF8_LCASE", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); + // Conditional case mapping - UNICODE. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE", "σΥσΤΗΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE", "ςΥςΤΗΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", "σιστιματικος"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", "σιστιματικος"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", "σιστιματικοσ"); + // Conditional case mapping - UNICODE_CI. + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "ςιι", "UNICODE_CI", "ςιςΤιΜΑΤΙΚΟς"); + assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE_CI", "σιστιματικοσ"); + assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE_CI", "σιστιματικοσ"); } private Map buildDict(String matching, String replace) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a31949f7290a6..bbda2dd8c6aec 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1103,7 +1103,6 @@ object StringTranslate { } dict } - } /** From 2cdda662a012387f59a4e826fdb7d541c45033c0 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 10 Jul 2024 04:57:02 +0200 Subject: [PATCH 12/13] Update comment --- .../spark/sql/catalyst/util/CollationAwareUTF8String.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 3492f64b419dd..af152c87f88ce 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -780,8 +780,8 @@ public static UTF8String lowercaseTranslate(final UTF8String input, * Translates the `input` string using the translation map `dict`, for all ICU collations. * String translation is performed by iterating over the input string, from left to right, and * repeatedly translating the longest possible substring that matches a key in the dictionary. - * For ICU collations, the method uses the collation key of the substring to perform the lookup - * in the collation aware version of the translation map. + * For ICU collations, the method uses the ICU `StringSearch` class to perform the lookup in + * the translation map, while respecting the rules of the specified ICU collation. * * @param input the string to be translated * @param dict the collation aware translation dictionary From ae718aadb173cc92dedb3df6c2468751c84e7dce Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 12 Jul 2024 07:48:46 +0200 Subject: [PATCH 13/13] Fix comments --- .../apache/spark/unsafe/types/CollationSupportSuite.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index ced0d4d9fb79c..ce0cef3fef307 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -1502,7 +1502,7 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("a\u030Abcå", "A\u030AÅ", "123", "UNICODE_CI", "3bc3"); assertStringTranslate("A\u030Aβφδi\u0307", "Iİaå", "1234", "UNICODE_CI", "4βφδ2"); - // Conditional case mapping - UTF8_BINARY. + // Greek sigmas - UTF8_BINARY. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_BINARY", "σΥσΤΗΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_BINARY", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); @@ -1512,7 +1512,7 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_BINARY", "σιστιματικος"); assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_BINARY", "σιστιματικος"); assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_BINARY", "σιστιματικοσ"); - // Conditional case mapping - UTF8_LCASE. + // Greek sigmas - UTF8_LCASE. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UTF8_LCASE", "σισΤιΜΑΤΙΚΟσ"); @@ -1522,7 +1522,7 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("συστηματικος", "Συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); assertStringTranslate("συστηματικος", "συη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); assertStringTranslate("συστηματικος", "ςυη", "σιι", "UTF8_LCASE", "σιστιματικοσ"); - // Conditional case mapping - UNICODE. + // Greek sigmas - UNICODE. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE", "σΥσΤΗΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE", "ΣΥΣΤΗΜΑΤΙΚΟΣ"); @@ -1532,7 +1532,7 @@ public void testStringTranslate() throws SparkException { assertStringTranslate("συστηματικος", "Συη", "σιι", "UNICODE", "σιστιματικος"); assertStringTranslate("συστηματικος", "συη", "σιι", "UNICODE", "σιστιματικος"); assertStringTranslate("συστηματικος", "ςυη", "σιι", "UNICODE", "σιστιματικοσ"); - // Conditional case mapping - UNICODE_CI. + // Greek sigmas - UNICODE_CI. assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "Συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "συη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ"); assertStringTranslate("ΣΥΣΤΗΜΑΤΙΚΟΣ", "ςυη", "σιι", "UNICODE_CI", "σισΤιΜΑΤΙΚΟσ");