Skip to content

Commit 0461745

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-48281][SQL] Alter string search logic for UTF8_BINARY_LCASE collation (StringInStr, SubstringIndex)
### What changes were proposed in this pull request? String searching in UTF8_BINARY_LCASE now works on character-level, rather than on byte-level. For example: `instr("İ", "i")`; now returns 0, because there exists no `start, len` such that `lowercase(substring("İ", start, len)) == "i"`. ### Why are the changes needed? Fix functions that give unusable results due to one-to-many case mapping when performing string search under UTF8_BINARY_LCASE (see example above). ### Does this PR introduce _any_ user-facing change? Yes, behaviour of `instr` and `substring_index` expressions is changed for edge cases with one-to-many case mapping. ### How was this patch tested? New unit tests in `CollationSupportSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46589 from uros-db/alter-lcase-vol2. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 49204b1 commit 0461745

File tree

4 files changed

+75
-38
lines changed

4 files changed

+75
-38
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -345,14 +345,14 @@ public static int findInSet(final UTF8String match, final UTF8String set, int co
345345
*/
346346
public static int lowercaseIndexOf(final UTF8String target, final UTF8String pattern,
347347
final int start) {
348-
if (pattern.numChars() == 0) return 0;
348+
if (pattern.numChars() == 0) return target.indexOfEmpty(start);
349349
return lowercaseFind(target, pattern.toLowerCase(), start);
350350
}
351351

352352
public static int indexOf(final UTF8String target, final UTF8String pattern,
353353
final int start, final int collationId) {
354354
if (pattern.numBytes() == 0) {
355-
return 0;
355+
return target.indexOfEmpty(start);
356356
}
357357

358358
StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
@@ -444,47 +444,27 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string,
444444
return UTF8String.EMPTY_UTF8;
445445
}
446446

447-
UTF8String lowercaseString = string.toLowerCase();
448447
UTF8String lowercaseDelimiter = delimiter.toLowerCase();
449448

450449
if (count > 0) {
451-
int idx = -1;
450+
// Search left to right (note: the start code point is inclusive).
451+
int matchLength = -1;
452452
while (count > 0) {
453-
idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
454-
if (idx >= 0) {
455-
count--;
456-
} else {
457-
// can not find enough delim
458-
return string;
459-
}
460-
}
461-
if (idx == 0) {
462-
return UTF8String.EMPTY_UTF8;
453+
matchLength = lowercaseFind(string, lowercaseDelimiter, matchLength + 1);
454+
if (matchLength > MATCH_NOT_FOUND) --count; // Found a delimiter.
455+
else return string; // Cannot find enough delimiters in the string.
463456
}
464-
byte[] bytes = new byte[idx];
465-
copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx);
466-
return UTF8String.fromBytes(bytes);
467-
457+
return string.substring(0, matchLength);
468458
} else {
469-
int idx = string.numBytes() - delimiter.numBytes() + 1;
459+
// Search right to left (note: the end code point is exclusive).
460+
int matchLength = string.numChars() + 1;
470461
count = -count;
471462
while (count > 0) {
472-
idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
473-
if (idx >= 0) {
474-
count--;
475-
} else {
476-
// can not find enough delim
477-
return string;
478-
}
463+
matchLength = lowercaseRFind(string, lowercaseDelimiter, matchLength - 1);
464+
if (matchLength > MATCH_NOT_FOUND) --count; // Found a delimiter.
465+
else return string; // Cannot find enough delimiters in the string.
479466
}
480-
if (idx + delimiter.numBytes() == string.numBytes()) {
481-
return UTF8String.EMPTY_UTF8;
482-
}
483-
int size = string.numBytes() - delimiter.numBytes() - idx;
484-
byte[] bytes = new byte[size];
485-
copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(),
486-
bytes, BYTE_ARRAY_OFFSET, size);
487-
return UTF8String.fromBytes(bytes);
467+
return string.substring(matchLength, string.numChars());
488468
}
489469
}
490470

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ public static int execBinary(final UTF8String string, final UTF8String substring
354354
return string.indexOf(substring, 0);
355355
}
356356
public static int execLowercase(final UTF8String string, final UTF8String substring) {
357-
return string.toLowerCase().indexOf(substring.toLowerCase(), 0);
357+
return CollationAwareUTF8String.lowercaseIndexOf(string, substring, 0);
358358
}
359359
public static int execICU(final UTF8String string, final UTF8String substring,
360360
final int collationId) {

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,17 @@ public UTF8String repeat(int times) {
773773
return UTF8String.fromBytes(newBytes);
774774
}
775775

776+
/**
777+
* Returns the (default) position of the first occurrence of an empty substr in the current
778+
* string from the specified position (0-based index).
779+
*
780+
* @param start the start position of the current string for searching
781+
* @return the position of the first occurrence of the empty substr (now, always 0)
782+
*/
783+
public int indexOfEmpty(int start) {
784+
return 0; // TODO: Fix this behaviour (SPARK-48284)
785+
}
786+
776787
/**
777788
* Returns the position of the first occurrence of substr in
778789
* current string from the specified position (0-based index).
@@ -783,7 +794,7 @@ public UTF8String repeat(int times) {
783794
*/
784795
public int indexOf(UTF8String v, int start) {
785796
if (v.numBytes() == 0) {
786-
return 0;
797+
return indexOfEmpty(start);
787798
}
788799

789800
// locate to the start position.

common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,28 @@ public void testStringInstr() throws SparkException {
635635
assertStringInstr("aaads", "dS", "UNICODE_CI", 4);
636636
assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0);
637637
assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8);
638-
assertStringInstr("abİo12", "i̇o", "UNICODE_CI", 3);
639-
assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3);
638+
assertStringInstr("i̇", "i", "UNICODE_CI", 0);
639+
assertStringInstr("i̇", "\u0307", "UNICODE_CI", 0);
640+
assertStringInstr("i̇", "İ", "UNICODE_CI", 1);
641+
assertStringInstr("İ", "i", "UNICODE_CI", 0);
642+
assertStringInstr("İoi̇o12", "i̇o", "UNICODE_CI", 1);
643+
assertStringInstr("i̇oİo12", "İo", "UNICODE_CI", 1);
644+
assertStringInstr("abİoi̇o", "i̇o", "UNICODE_CI", 3);
645+
assertStringInstr("abi̇oİo", "İo", "UNICODE_CI", 3);
646+
assertStringInstr("ai̇oxXİo", "Xx", "UNICODE_CI", 5);
647+
assertStringInstr("aİoi̇oxx", "XX", "UNICODE_CI", 7);
648+
assertStringInstr("i̇", "i", "UTF8_BINARY_LCASE", 1); // != UNICODE_CI
649+
assertStringInstr("i̇", "\u0307", "UTF8_BINARY_LCASE", 2); // != UNICODE_CI
650+
assertStringInstr("i̇", "İ", "UTF8_BINARY_LCASE", 1);
651+
assertStringInstr("İ", "i", "UTF8_BINARY_LCASE", 0);
652+
assertStringInstr("İoi̇o12", "i̇o", "UTF8_BINARY_LCASE", 1);
653+
assertStringInstr("i̇oİo12", "İo", "UTF8_BINARY_LCASE", 1);
654+
assertStringInstr("abİoi̇o", "i̇o", "UTF8_BINARY_LCASE", 3);
655+
assertStringInstr("abi̇oİo", "İo", "UTF8_BINARY_LCASE", 3);
656+
assertStringInstr("abI\u0307oi̇o", "İo", "UTF8_BINARY_LCASE", 3);
657+
assertStringInstr("ai̇oxXİo", "Xx", "UTF8_BINARY_LCASE", 5);
658+
assertStringInstr("abİoi̇o", "\u0307o", "UTF8_BINARY_LCASE", 6);
659+
assertStringInstr("aİoi̇oxx", "XX", "UTF8_BINARY_LCASE", 7);
640660
}
641661

642662
private void assertFindInSet(String word, String set, String collationName,
@@ -878,6 +898,32 @@ public void testSubstringIndex() throws SparkException {
878898
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", "İo12İoi̇o");
879899
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", "i̇o12i̇oİo");
880900
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo");
901+
assertSubstringIndex("abi̇12", "i", 1, "UNICODE_CI", "abi̇12");
902+
assertSubstringIndex("abi̇12", "\u0307", 1, "UNICODE_CI", "abi̇12");
903+
assertSubstringIndex("abi̇12", "İ", 1, "UNICODE_CI", "ab");
904+
assertSubstringIndex("abİ12", "i", 1, "UNICODE_CI", "abİ12");
905+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UNICODE_CI", "İo12İoi̇o");
906+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UNICODE_CI", "İo12İoi̇o");
907+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UNICODE_CI", "i̇o12i̇oİo");
908+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo");
909+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", 3, "UNICODE_CI", "ai̇bi̇oİo12");
910+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", 3, "UNICODE_CI", "ai̇bi̇oİo12");
911+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", 3, "UNICODE_CI", "ai̇bİoi̇o12");
912+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", 3, "UNICODE_CI", "ai̇bİoi̇o12");
913+
assertSubstringIndex("abi̇12", "i", 1, "UTF8_BINARY_LCASE", "ab"); // != UNICODE_CI
914+
assertSubstringIndex("abi̇12", "\u0307", 1, "UTF8_BINARY_LCASE", "abi"); // != UNICODE_CI
915+
assertSubstringIndex("abi̇12", "İ", 1, "UTF8_BINARY_LCASE", "ab");
916+
assertSubstringIndex("abİ12", "i", 1, "UTF8_BINARY_LCASE", "abİ12");
917+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", -4, "UTF8_BINARY_LCASE", "İo12İoi̇o");
918+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", -4, "UTF8_BINARY_LCASE", "İo12İoi̇o");
919+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", -4, "UTF8_BINARY_LCASE", "i̇o12i̇oİo");
920+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UTF8_BINARY_LCASE", "i̇o12i̇oİo");
921+
assertSubstringIndex("bİoi̇o12i̇o", "\u0307oi", 1, "UTF8_BINARY_LCASE", "bİoi̇o12i̇o");
922+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "İo", 3, "UTF8_BINARY_LCASE", "ai̇bi̇oİo12");
923+
assertSubstringIndex("ai̇bi̇oİo12İoi̇o", "i̇o", 3, "UTF8_BINARY_LCASE", "ai̇bi̇oİo12");
924+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "İo", 3, "UTF8_BINARY_LCASE", "ai̇bİoi̇o12");
925+
assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", 3, "UTF8_BINARY_LCASE", "ai̇bİoi̇o12");
926+
assertSubstringIndex("bİoi̇o12i̇o", "\u0307oi", 1, "UTF8_BINARY_LCASE", "bİoi̇o12i̇o");
881927
}
882928

883929
private void assertStringTrim(

0 commit comments

Comments
 (0)