Skip to content

Commit 12e108f

Browse files
committed
refine unittest
1 parent d92951b commit 12e108f

File tree

3 files changed

+47
-20
lines changed

3 files changed

+47
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ case class Substring_index(strExpr: Expression, delimExpr: Expression, countExpr
434434
} else {
435435
val idx = lastOrdinalIndexOf(strUtf8, delimUtf8, -count)
436436
if (idx != -1) {
437-
strUtf8.substring(idx + 1, strUtf8.numChars())
437+
strUtf8.substring(idx + delimUtf8.numChars(), strUtf8.numChars())
438438
} else {
439439
strUtf8
440440
}

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,33 +157,60 @@ class StringFunctionsSuite extends QueryTest {
157157
}
158158

159159
test("string substring_index function") {
160-
val df = Seq(("ac,ab,ad,ab,cc", "aa", "zz")).toDF("a", "b", "c")
160+
val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
161161
checkAnswer(
162-
df.select(substring_index($"a", ",", 2)),
163-
Row("ac,ab"))
162+
df.select(substring_index($"a", ".", 3)),
163+
Row("www.apache.org"))
164164
checkAnswer(
165-
df.select(substring_index($"a", "ab", 2)),
166-
Row("ac,ab,ad,"))
165+
df.select(substring_index($"a", ".", 2)),
166+
Row("www.apache"))
167167
checkAnswer(
168-
df.select(substring_index(lit(""), "ab", 2)),
168+
df.select(substring_index($"a", ".", 1)),
169+
Row("www"))
170+
checkAnswer(
171+
df.select(substring_index($"a", ".", 0)),
172+
Row(""))
173+
checkAnswer(
174+
df.select(substring_index(lit("www.apache.org"), ".", -1)),
175+
Row("org"))
176+
checkAnswer(
177+
df.select(substring_index(lit("www.apache.org"), ".", -2)),
178+
Row("apache.org"))
179+
checkAnswer(
180+
df.select(substring_index(lit("www.apache.org"), ".", -3)),
181+
Row("www.apache.org"))
182+
// str is empty string
183+
checkAnswer(
184+
df.select(substring_index(lit(""), ".", 1)),
185+
Row(""))
186+
// empty string delim
187+
checkAnswer(
188+
df.select(substring_index(lit("www.apache.org"), "", 1)),
169189
Row(""))
190+
// delim does not exist in str
170191
checkAnswer(
171-
df.select(substring_index(lit(null), "ab", 2)),
192+
df.select(substring_index(lit("www.apache.org"), "#", 1)),
193+
Row("www.apache.org"))
194+
// delim is 2 chars
195+
checkAnswer(
196+
df.select(substring_index(lit("www||apache||org"), "||", 2)),
197+
Row("www||apache"))
198+
checkAnswer(
199+
df.select(substring_index(lit("www||apache||org"), "||", -2)),
200+
Row("apache||org"))
201+
// null
202+
checkAnswer(
203+
df.select(substring_index(lit(null), "||", 2)),
204+
Row(null))
205+
checkAnswer(
206+
df.select(substring_index(lit("www.apache.org"), null, 2)),
172207
Row(null))
208+
// non ascii chars
173209
// scalastyle:off
174210
checkAnswer(
175-
df.select(substring_index(lit("大千世界大千世界"), "", 2)),
211+
df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
176212
Row("大千世界大"))
177213
// scalastyle:on
178-
checkAnswer(
179-
df.selectExpr("""substring_index(a, ",", 2)"""),
180-
Row("ac,ab"))
181-
checkAnswer(
182-
df.selectExpr("""substring_index(a, ",", -2)"""),
183-
Row("ab,cc"))
184-
checkAnswer(
185-
df.selectExpr("""substring_index(a, ",", 10)"""),
186-
Row("ac,ab,ad,ab,cc"))
187214
}
188215

189216
test("string locate function") {

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ private int firstOfCurrentCodePoint(int bytePos) {
381381
throw new RuntimeException("Invalid utf8 string");
382382
}
383383

384-
private int endByte(int startCodePoint) {
384+
private int indexEnd(int startCodePoint) {
385385
int i = numBytes -1; // position in byte
386386
int c = numChars() - 1; // position in character
387387
while (i >=0 && c > startCodePoint) {
@@ -398,7 +398,7 @@ public int lastIndexOf(UTF8String v, int startCodePoint) {
398398
if (numBytes == 0) {
399399
return -1;
400400
}
401-
int fromIndexEnd = endByte(startCodePoint);
401+
int fromIndexEnd = indexEnd(startCodePoint);
402402
int count = startCodePoint;
403403
int vNumChars = v.numChars();
404404
do {

0 commit comments

Comments
 (0)