Skip to content

Commit b19b013

Browse files
committed
add codegen and clean code
1 parent ac863e9 commit b19b013

File tree

5 files changed

+192
-46
lines changed

5 files changed

+192
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ case class StringInstr(str: Expression, substr: Expression)
361361
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
362362
*/
363363
case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
364-
extends Expression with ImplicitCastInputTypes with CodegenFallback {
364+
extends Expression with ImplicitCastInputTypes {
365365

366366
override def dataType: DataType = StringType
367367
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -385,6 +385,30 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
385385
}
386386
null
387387
}
388+
389+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
390+
val str = strExpr.gen(ctx)
391+
val delim = delimExpr.gen(ctx)
392+
val count = countExpr.gen(ctx)
393+
val resultCode =
394+
s"""org.apache.spark.unsafe.types.UTF8String.subStringIndex(
395+
|${str.primitive}, ${delim.primitive}, ${count.primitive})""".stripMargin
396+
s"""
397+
${str.code}
398+
boolean ${ev.isNull} = true;
399+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
400+
if (!${str.isNull}) {
401+
${delim.code}
402+
if (!${delim.isNull}) {
403+
${count.code}
404+
if (!${count.isNull}) {
405+
${ev.isNull} = false;
406+
${ev.primitive} = $resultCode;
407+
}
408+
}
409+
}
410+
"""
411+
}
388412
}
389413

390414
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.types._
23+
import org.apache.spark.unsafe.types.UTF8String
2324

2425

2526
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
187188
checkEvaluation(s.substring(0), "example", row)
188189
}
189190

191+
test("string substring_index function") {
192+
checkEvaluation(
193+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
194+
checkEvaluation(
195+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
196+
checkEvaluation(
197+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
198+
checkEvaluation(
199+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(0)), "")
200+
checkEvaluation(
201+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
202+
checkEvaluation(
203+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
204+
checkEvaluation(
205+
Substring_index(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
206+
checkEvaluation(
207+
Substring_index(Literal(""), Literal("."), Literal(-2)), "")
208+
checkEvaluation(
209+
Substring_index(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
210+
checkEvaluation(
211+
Substring_index(Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
212+
// non ascii chars
213+
// scalastyle:off
214+
checkEvaluation(
215+
Substring_index(Literal("大千世界大千世界"), Literal( ""), Literal(2)), "大千世界大")
216+
// scalastyle:on
217+
checkEvaluation(
218+
Substring_index(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
219+
}
220+
190221
test("LIKE literal Regular Expression") {
191222
checkEvaluation(Literal.create(null, StringType).like("a"), null)
192223
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,6 @@ object functions {
17831783
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
17841784
*
17851785
* @group string_funcs
1786-
* @since 1.5.0
17871786
*/
17881787
def substring_index(str: String, delim: String, count: Int): Column =
17891788
substring_index(Column(str), delim, count)
@@ -1795,7 +1794,6 @@ object functions {
17951794
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
17961795
*
17971796
* @group string_funcs
1798-
* @since 1.5.0
17991797
*/
18001798
def substring_index(str: Column, delim: String, count: Int): Column =
18011799
Substring_index(str.expr, lit(delim).expr, lit(count).expr)

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,8 @@ public UTF8String substring(final int start, final int until) {
146146
if (until <= start || start >= numBytes) {
147147
return fromBytes(new byte[0]);
148148
}
149-
150-
int i = 0;
151-
int c = 0;
152-
while (i < numBytes && c < start) {
153-
i += numBytesForFirstByte(getByte(i));
154-
c += 1;
155-
}
156-
157-
int j = i;
158-
while (i < numBytes && c < until) {
159-
i += numBytesForFirstByte(getByte(i));
160-
c += 1;
161-
}
162-
149+
int j = firstByteIndex(start);
150+
int i = firstByteIndex(until);
163151
byte[] bytes = new byte[i - j];
164152
copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
165153
return fromBytes(bytes);
@@ -174,12 +162,7 @@ public UTF8String substring(final int start) {
174162
return fromBytes(new byte[0]);
175163
}
176164

177-
int i = 0;
178-
int c = 0;
179-
while (i < numBytes && c < start) {
180-
i += numBytesForFirstByte(getByte(i));
181-
c += 1;
182-
}
165+
int i = firstByteIndex(start);
183166

184167
byte[] bytes = new byte[numBytes - i];
185168
copyMemory(base, offset + i, bytes, BYTE_ARRAY_OFFSET, numBytes - i);
@@ -351,13 +334,8 @@ public int indexOf(UTF8String v, int start) {
351334
return 0;
352335
}
353336

354-
// locate to the start position.
355-
int i = 0; // position in byte
356-
int c = 0; // position in character
357-
while (i < numBytes && c < start) {
358-
i += numBytesForFirstByte(getByte(i));
359-
c += 1;
360-
}
337+
int i = firstByteIndex(start); // position in byte
338+
int c = start; // position in character
361339

362340
do {
363341
if (i + v.numBytes > numBytes) {
@@ -399,19 +377,29 @@ private int firstOfCurrentCodePoint(int bytePos) {
399377
}
400378
bytePos--;
401379
}
402-
throw new RuntimeException("Invalid utf8 string");
380+
throw new RuntimeException("Invalid UTF8 string");
403381
}
404382

405-
private int indexEnd(int startCodePoint) {
406-
int i = numBytes -1; // position in byte
407-
int c = numChars() - 1; // position in character
408-
while (i >=0 && c > startCodePoint) {
409-
i = firstOfCurrentCodePoint(i) - 1;
410-
c -= 1;
383+
// Locate to the start position in byte for a given code point
384+
private int firstByteIndex(int codePoint) {
385+
int i = 0; // position in byte
386+
int c = 0; // position in character
387+
while (i < numBytes && c < codePoint) {
388+
i += numBytesForFirstByte(getByte(i));
389+
c += 1;
390+
}
391+
if (i > numBytes) {
392+
throw new StringIndexOutOfBoundsException(codePoint);
411393
}
412394
return i;
413395
}
414396

397+
// Locate to the last position in byte for a given code point
398+
private int lastByteIndex(int codePoint) {
399+
int i = firstByteIndex(codePoint);
400+
return i + numBytesForFirstByte(getByte(i)) - 1;
401+
}
402+
415403
/**
416404
* Returns the index within this string of the last occurrence of the
417405
* specified substring, searching backward starting at the specified index.
@@ -431,7 +419,7 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
431419
if (numBytes == 0) {
432420
return -1;
433421
}
434-
int fromIndexEnd = indexEnd(startCodePoint);
422+
int fromIndexEnd = lastByteIndex(startCodePoint);
435423
do {
436424
if (fromIndexEnd - v.numBytes + 1 < 0 ) {
437425
return -1;
@@ -456,35 +444,38 @@ public int lastIndexOf(UTF8String v, int vNumChars, int startCodePoint) {
456444
*
457445
* @param str the String to check, may be null
458446
* @param searchStr the String to find, may be null
459-
* @param searchStrNumChars num of code ponts of the searchStr
460447
* @param ordinal the n-th last <code>searchStr</code> to find
461448
* @return the n-th last index of the search String,
462449
* <code>-1</code> if no match or <code>null</code> string input
463450
*/
464451
public static int lastOrdinalIndexOf(
465452
UTF8String str,
466453
UTF8String searchStr,
467-
int searchStrNumChars,
468454
int ordinal) {
469-
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, true);
455+
if (str == null || searchStr == null) {
456+
return -1;
457+
}
458+
return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, true);
470459
}
460+
471461
/**
472462
* Finds the n-th index within a String, handling <code>null</code>.
473463
* A <code>null</code> String will return <code>-1</code>
474464
*
475465
* @param str the String to check, may be null
476466
* @param searchStr the String to find, may be null
477-
* @param searchStrNumChars num of code points of searchStr
478467
* @param ordinal the n-th <code>searchStr</code> to find
479468
* @return the n-th index of the search String,
480469
* <code>-1</code> if no match or <code>null</code> string input
481470
*/
482471
public static int ordinalIndexOf(
483472
UTF8String str,
484473
UTF8String searchStr,
485-
int searchStrNumChars,
486474
int ordinal) {
487-
return doOrdinalIndexOf(str, searchStr, searchStrNumChars, ordinal, false);
475+
if (str == null || searchStr == null) {
476+
return -1;
477+
}
478+
return doOrdinalIndexOf(str, searchStr, searchStr.numChars(), ordinal, false);
488479
}
489480

490481
private static int doOrdinalIndexOf(
@@ -526,19 +517,22 @@ private static int doOrdinalIndexOf(
526517
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
527518
*/
528519
public static UTF8String subStringIndex(UTF8String str, UTF8String delim, int count) {
520+
if (str == null || delim == null) {
521+
return null;
522+
}
529523
if (str.numBytes == 0 || delim.numBytes == 0 || count == 0) {
530524
return UTF8String.EMPTY_UTF8;
531525
}
532526
int delimNumChars = delim.numChars();
533527
if (count > 0) {
534-
int idx = ordinalIndexOf(str, delim, delimNumChars, count);
528+
int idx = doOrdinalIndexOf(str, delim, delimNumChars, count, false);
535529
if (idx != -1) {
536530
return str.substring(0, idx);
537531
} else {
538532
return str;
539533
}
540534
} else {
541-
int idx = lastOrdinalIndexOf(str, delim, delimNumChars, -count);
535+
int idx = doOrdinalIndexOf(str, delim, delimNumChars, -count, true);
542536
if (idx != -1) {
543537
return str.substring(idx + delimNumChars);
544538
} else {

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
import java.util.Arrays;
2222

2323
import org.junit.Test;
24+
import org.junit.rules.ExpectedException;
2425

2526
import static junit.framework.Assert.*;
2627

2728
import static org.apache.spark.unsafe.types.UTF8String.*;
29+
import static org.apache.spark.unsafe.types.UTF8String.fromString;
2830

2931
public class UTF8StringSuite {
3032

@@ -184,6 +186,63 @@ public void substring() {
184186
assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3));
185187
assertEquals(fromString("头"), fromString("数据砖头").substring(3, 5));
186188
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0, 2));
189+
190+
assertEquals(fromString("hello"), fromString("hello").substring(0));
191+
assertEquals(fromString("ello"), fromString("hello").substring(1));
192+
assertEquals(fromString("砖头"), fromString("数据砖头").substring(2));
193+
assertEquals(fromString("头"), fromString("数据砖头").substring(3));
194+
ExpectedException exception = ExpectedException.none();
195+
fromString("数据砖头").substring(4);
196+
exception.expect(java.lang.StringIndexOutOfBoundsException.class);
197+
assertEquals(fromString("ߵ梷"), fromString("ߵ梷").substring(0));
198+
}
199+
200+
@Test
201+
public void ordinalIndexOf() {
202+
assertEquals(-1,
203+
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
204+
assertEquals(3,
205+
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
206+
assertEquals(10,
207+
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
208+
assertEquals(-1,
209+
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
210+
assertEquals(-1,
211+
UTF8String.ordinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
212+
assertEquals(12,
213+
UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
214+
assertEquals(-1,
215+
UTF8String.ordinalIndexOf(null, fromString("|||"), 1));
216+
assertEquals(-1,
217+
UTF8String.ordinalIndexOf(fromString("www|||apache|||org"), null, 1));
218+
assertEquals(2,
219+
UTF8String.ordinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
220+
assertEquals(-1,
221+
UTF8String.ordinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
222+
}
223+
224+
@Test
225+
public void lastOrdinalIndexOf() {
226+
assertEquals(-1,
227+
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 0));
228+
assertEquals(10,
229+
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 1));
230+
assertEquals(3,
231+
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 2));
232+
assertEquals(-1,
233+
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("."), 3));
234+
assertEquals(-1,
235+
UTF8String.lastOrdinalIndexOf(fromString("www.apache.org"), fromString("#"), 0));
236+
assertEquals(3,
237+
UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), fromString("|||"), 2));
238+
assertEquals(-1,
239+
UTF8String.lastOrdinalIndexOf(null, fromString("|||"), 1));
240+
assertEquals(-1,
241+
UTF8String.lastOrdinalIndexOf(fromString("www|||apache|||org"), null, 1));
242+
assertEquals(3,
243+
UTF8String.lastOrdinalIndexOf(fromString("数据砖砖头"), fromString("砖"), 1));
244+
assertEquals(-1,
245+
UTF8String.lastOrdinalIndexOf(fromString("砖头数据砖头"), fromString("砖"), -2));
187246
}
188247

189248
@Test
@@ -238,6 +297,46 @@ public void lastIndexOf() {
238297
assertEquals(3, fromString("数据砖头").lastIndexOf(fromString("头"), 3));
239298
}
240299

300+
@Test
301+
public void substring_index() {
302+
assertEquals(fromString("www.apache.org"),
303+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 3));
304+
assertEquals(fromString("www.apache"),
305+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 2));
306+
assertEquals(fromString("www"),
307+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 1));
308+
assertEquals(fromString(""),
309+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), 0));
310+
assertEquals(fromString("org"),
311+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -1));
312+
assertEquals(fromString("apache.org"),
313+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -2));
314+
assertEquals(fromString("www.apache.org"),
315+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("."), -3));
316+
// str is empty string
317+
assertEquals(fromString(""),
318+
UTF8String.subStringIndex(fromString(""), fromString("."), 1));
319+
// empty string delim
320+
assertEquals(fromString(""),
321+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString(""), 1));
322+
// delim does not exist in str
323+
assertEquals(fromString("www.apache.org"),
324+
UTF8String.subStringIndex(fromString("www.apache.org"), fromString("#"), 2));
325+
// delim is 2 chars
326+
assertEquals(fromString("www||apache"),
327+
UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), 2));
328+
assertEquals(fromString("apache||org"),
329+
UTF8String.subStringIndex(fromString("www||apache||org"), fromString("||"), -2));
330+
// null
331+
assertEquals(null,
332+
UTF8String.subStringIndex(null, fromString("."), -2));
333+
assertEquals(null,
334+
UTF8String.subStringIndex(fromString("www.apache.org"), null, -2));
335+
// non ascii chars
336+
assertEquals(fromString("大千世界大"),
337+
UTF8String.subStringIndex(fromString("大千世界大千世界"), fromString("千"), 2));
338+
}
339+
241340
@Test
242341
public void reverse() {
243342
assertEquals(fromString("olleh"), fromString("hello").reverse());

0 commit comments

Comments
 (0)