Skip to content

Commit ac863e9

Browse files
committed
reduce the calling of numChars
1 parent 12e108f commit ac863e9

File tree

3 files changed

+148
-71
lines changed

3 files changed

+148
-71
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 11 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ import java.text.DecimalFormat
2121
import java.util.Locale
2222
import java.util.regex.{MatchResult, Pattern}
2323

24-
import org.apache.commons.lang.StringUtils
25-
2624
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2825
import org.apache.spark.sql.catalyst.expressions.codegen._
2926
import org.apache.spark.sql.types._
3027
import org.apache.spark.unsafe.types.UTF8String
@@ -371,75 +368,22 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
371368
override def nullable: Boolean = strExpr.nullable || delimExpr.nullable || countExpr.nullable
372369
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
373370
override def prettyName: String = "substring_index"
374-
override def toString: String = s"substring_index($strExpr, $delimExpr, $countExpr)"
375371

376372
override def eval(input: InternalRow): Any = {
377373
val str = strExpr.eval(input)
378-
val delim = delimExpr.eval(input)
379-
val count = countExpr.eval(input)
380-
if (str == null || delim == null || count == null) {
381-
null
382-
} else {
383-
subStrIndex(
384-
str.asInstanceOf[UTF8String],
385-
delim.asInstanceOf[UTF8String],
386-
count.asInstanceOf[Int])
387-
}
388-
}
389-
390-
private def lastOrdinalIndexOf(
391-
str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
392-
ordinalIndexOf(str, searchStr, ordinal, true)
393-
}
394-
395-
private def ordinalIndexOf(
396-
str: UTF8String, searchStr: UTF8String, ordinal: Int, lastIndex: Boolean = false): Int = {
397-
if (str == null || searchStr == null || ordinal <= 0) {
398-
return -1
399-
}
400-
val strNumChars = str.numChars()
401-
if (searchStr.numBytes() == 0) {
402-
return if (lastIndex) {strNumChars} else {0}
403-
}
404-
var found = 0
405-
var index = if (lastIndex) {strNumChars} else {0}
406-
do {
407-
if (lastIndex) {
408-
index = str.lastIndexOf(searchStr, index - 1)
409-
} else {
410-
index = str.indexOf(searchStr, index + 1)
411-
}
412-
if (index < 0) {
413-
return index
414-
}
415-
found += 1
416-
} while (found < ordinal)
417-
index
418-
}
419-
420-
private def subStrIndex(strUtf8: UTF8String, delimUtf8: UTF8String, count: Int): UTF8String = {
421-
if (strUtf8 == null || delimUtf8 == null || count == null) {
422-
return null
423-
}
424-
if (strUtf8.numBytes() == 0 || delimUtf8.numBytes() == 0 || count == 0) {
425-
return UTF8String.fromString("")
426-
}
427-
val res = if (count > 0) {
428-
val idx = ordinalIndexOf(strUtf8, delimUtf8, count)
429-
if (idx != -1) {
430-
strUtf8.substring(0, idx)
431-
} else {
432-
strUtf8
433-
}
434-
} else {
435-
val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
436-
if (idx != -1) {
437-
strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars())
438-
} else {
439-
strUtf8
374+
if (str != null) {
375+
val delim = delimExpr.eval(input)
376+
if (delim != null) {
377+
val count = countExpr.eval(input)
378+
if (count != null) {
379+
return UTF8String.subStringIndex(
380+
str.asInstanceOf[UTF8String],
381+
delim.asInstanceOf[UTF8String],
382+
count.asInstanceOf[Int])
383+
}
440384
}
441385
}
442-
res
386+
null
443387
}
444388
}
445389

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,27 @@ public UTF8String substring(final int start, final int until) {
165165
return fromBytes(bytes);
166166
}
167167

168+
/**
169+
* Returns a substring of this from start to end.
170+
* @param start the position of first code point
171+
*/
172+
public UTF8String substring(final int start) {
173+
if (start >= numBytes) {
174+
return fromBytes(new byte[0]);
175+
}
176+
177+
int i = 0;
178+
int c = 0;
179+
while (i < numBytes && c < start) {
180+
i += numBytesForFirstByte(getByte(i));
181+
c += 1;
182+
}
183+
184+
byte[] bytes = new byte[numBytes - i];
185+
copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
186+
return fromBytes(bytes);
187+
}
188+
168189
public UTF8String substringSQL(int pos, int length) {
169190
// Information regarding the pos calculation:
170191
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
@@ -391,30 +412,141 @@ private int indexEnd(int startCodePoint) {
391412
return i;
392413
}
393414

415+
/**
416+
* Returns the index within this string of the last occurrence of the
417+
* specified substring, searching backward starting at the specified index.
418+
* @param v the substring to search for.
419+
* @param startCodePoint the index to start search from
420+
* @return the index of the last occurrence of the specified substring,
421+
* searching backward from the specified index,
422+
* or {@code -1} if there is no such occurrence.
423+
*/
394424
public int lastIndexOf(UTF8String v, int startCodePoint) {
425+
return lastIndexOf(v, v.numChars(), startCodePoint);
426+
}
427+
public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
395428
if (v.numBytes == 0) {
396429
return 0;
397430
}
398431
if (numBytes == 0) {
399432
return -1;
400433
}
401434
int fromIndexEnd = indexEnd(startCodePoint);
402-
int count = startCodePoint;
403-
int vNumChars = v.numChars();
404435
do {
405436
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
406437
return -1;
407438
}
408439
if (ByteArrayMethods.arrayEquals(
409440
base, offset + fromIndexEnd - v.numBytes + 1, v.base, v.offset, v.numBytes)) {
410-
return count - vNumChars + 1;
441+
int count = 0; // count from right most to the match end in byte.
442+
while (fromIndexEnd >= 0) {
443+
count++;
444+
fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
445+
}
446+
return count - vNumChars;
411447
}
412448
fromIndexEnd = firstOfCurrentCodePoint(fromIndexEnd) - 1;
413-
count--;
414449
} while (fromIndexEnd >= 0);
415450
return -1;
416451
}
417452

453+
/**
454+
* Finds the n-th last index within a String.
455+
* This method uses {@link String#lastIndexOf(String)}.</p>
456+
*
457+
* @param str the String to check, may be null
458+
* @param searchStr the String to find, may be null
459+
* @param searchStrNumChars num of code ponts of the searchStr
460+
* @param ordinal the n-th last <code>searchStr</code> to find
461+
* @return the n-th last index of the search String,
462+
* <code>-1</code> if no match or <code>null</code> string input
463+
*/
464+
public static int lastOrdinalIndexOf(
465+
UTF8String str,
466+
UTF8String searchStr,
467+
int searchStrNumChars,
468+
int ordinal) {
469+
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
470+
}
471+
/**
472+
* Finds the n-th index within a String, handling <code>null</code>.
473+
* A <code>null</code> String will return <code>-1</code>
474+
*
475+
* @param str the String to check, may be null
476+
* @param searchStr the String to find, may be null
477+
* @param searchStrNumChars num of code points of searchStr
478+
* @param ordinal the n-th <code>searchStr</code> to find
479+
* @return the n-th index of the search String,
480+
* <code>-1</code> if no match or <code>null</code> string input
481+
*/
482+
public static int ordinalIndexOf(
483+
UTF8String str,
484+
UTF8String searchStr,
485+
int searchStrNumChars,
486+
int ordinal) {
487+
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
488+
}
489+
490+
private static int doOrdinalIndexOf(
491+
UTF8String str,
492+
UTF8String searchStr,
493+
int searchStrNumChars,
494+
int ordinal,
495+
boolean lastIndex) {
496+
if (str == null || searchStr == null || ordinal <= 0) {
497+
return -1;
498+
}
499+
// Only calc numChars when lastIndex == true sicnc the calculation is expensive
500+
int strNumChars = 0;
501+
if (lastIndex) {
502+
strNumChars = str.numChars();
503+
}
504+
if (searchStr.numBytes == 0) {
505+
return lastIndex ? strNumChars : 0;
506+
}
507+
int found = 0;
508+
int index = lastIndex ? strNumChars : 0;
509+
do {
510+
if (lastIndex) {
511+
index = str.lastIndexOf(searchStr, searchStrNumChars, index - 1);
512+
} else {
513+
index = str.indexOf(searchStr, index + 1);
514+
}
515+
if (index < 0) {
516+
return index;
517+
}
518+
found += 1;
519+
} while (found < ordinal);
520+
return index;
521+
}
522+
/**
523+
* Returns the substring from string str before count occurrences of the delimiter delim.
524+
* If count is positive, everything the left of the final delimiter (counting from left) is
525+
* returned. If count is negative, every to the right of the final delimiter (counting from the
526+
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
527+
*/
528+
public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
529+
if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
530+
return UTF8String.EMPTY_UTF8;
531+
}
532+
int delimNumChars = delim.numChars();
533+
if (count > 0) {
534+
int idx = ordinalIndexOf(str, delim, delimNumChars, count);
535+
if (idx != -1) {
536+
return str.substring(0, idx);
537+
} else {
538+
return str;
539+
}
540+
} else {
541+
int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
542+
if (idx != -1) {
543+
return str.substring(idx + delimNumChars);
544+
} else {
545+
return str;
546+
}
547+
}
548+
}
549+
418550
/**
419551
* Returns str, right-padded with pad to a length of len
420552
* For example:

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ public void lastIndexOf() {
226226
assertEquals(0, fromString("").lastIndexOf(fromString(""), 0));
227227
assertEquals(-1, fromString("").lastIndexOf(fromString("l"), 0));
228228
assertEquals(0, fromString("hello").lastIndexOf(fromString(""), 0));
229+
assertEquals(0, fromString("hello").lastIndexOf(fromString("h"), 4));
229230
assertEquals(-1, fromString("hello").lastIndexOf(fromString("l"), 0));
230231
assertEquals(3, fromString("hello").lastIndexOf(fromString("l"), 3));
231232
assertEquals(-1, fromString("hello").lastIndexOf(fromString("a"), 4));

0 commit comments

Comments
 (0)