Skip to content

Commit 6996bd2

Browse files
zhichao-lirxin
authored andcommitted
[SPARK-8264][SQL]add substring_index function
This PR is based on #7533 , thanks to zhichao-li Closes #7533 Author: zhichao.li <[email protected]> Author: Davies Liu <[email protected]> Closes #7843 from davies/str_index and squashes the following commits: 391347b [Davies Liu] add python api 3ce7802 [Davies Liu] fix substringIndex f2d29a1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into str_index 515519b [zhichao.li] add foldable and remove null checking 9546991 [zhichao.li] scala style 67c253a [zhichao.li] hide some apis and clean code b19b013 [zhichao.li] add codegen and clean code ac863e9 [zhichao.li] reduce the calling of numChars 12e108f [zhichao.li] refine unittest d92951b [zhichao.li] add lastIndexOf 52d7b03 [zhichao.li] add substring_index function
1 parent 03377d2 commit 6996bd2

File tree

8 files changed

+261
-2
lines changed

8 files changed

+261
-2
lines changed

python/pyspark/sql/functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,25 @@ def trunc(date, format):
920920
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
921921

922922

923+
@since(1.5)
924+
@ignore_unicode_prefix
925+
def substring_index(str, delim, count):
926+
"""
927+
Returns the substring from string str before count occurrences of the delimiter delim.
928+
If count is positive, everything the left of the final delimiter (counting from left) is
929+
returned. If count is negative, every to the right of the final delimiter (counting from the
930+
right) is returned. substring_index performs a case-sensitive match when searching for delim.
931+
932+
>>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s'])
933+
>>> df.select(substring_index(df.s, '.', 2).alias('s')).collect()
934+
[Row(s=u'a.b')]
935+
>>> df.select(substring_index(df.s, '.', -3).alias('s')).collect()
936+
[Row(s=u'b.c.d')]
937+
"""
938+
sc = SparkContext._active_spark_context
939+
return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
940+
941+
923942
@since(1.5)
924943
def size(col):
925944
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ object FunctionRegistry {
199199
expression[StringSplit]("split"),
200200
expression[Substring]("substr"),
201201
expression[Substring]("substring"),
202+
expression[SubstringIndex]("substring_index"),
202203
expression[StringTrim]("trim"),
203204
expression[UnBase64]("unbase64"),
204205
expression[Upper]("ucase"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,31 @@ case class StringInstr(str: Expression, substr: Expression)
421421
}
422422
}
423423

424+
/**
425+
* Returns the substring from string str before count occurrences of the delimiter delim.
426+
* If count is positive, everything the left of the final delimiter (counting from left) is
427+
* returned. If count is negative, every to the right of the final delimiter (counting from the
428+
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
429+
*/
430+
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
431+
extends TernaryExpression with ImplicitCastInputTypes {
432+
433+
override def dataType: DataType = StringType
434+
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
435+
override def children: Seq[Expression] = Seq(strExpr, delimExpr, countExpr)
436+
override def prettyName: String = "substring_index"
437+
438+
override def nullSafeEval(str: Any, delim: Any, count: Any): Any = {
439+
str.asInstanceOf[UTF8String].subStringIndex(
440+
delim.asInstanceOf[UTF8String],
441+
count.asInstanceOf[Int])
442+
}
443+
444+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
445+
defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
446+
}
447+
}
448+
424449
/**
425450
* A function that returns the position of the first occurrence of substr
426451
* in given string after position pos.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.types._
23+
import org.apache.spark.unsafe.types.UTF8String
2324

2425

2526
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -187,6 +188,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
187188
checkEvaluation(s.substring(0), "example", row)
188189
}
189190

191+
test("string substring_index function") {
192+
checkEvaluation(
193+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
194+
checkEvaluation(
195+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
196+
checkEvaluation(
197+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
198+
checkEvaluation(
199+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
200+
checkEvaluation(
201+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
202+
checkEvaluation(
203+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
204+
checkEvaluation(
205+
SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
206+
checkEvaluation(
207+
SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
208+
checkEvaluation(
209+
SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
210+
checkEvaluation(SubstringIndex(
211+
Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
212+
// non ascii chars
213+
// scalastyle:off
214+
checkEvaluation(
215+
SubstringIndex(Literal("大千世界大千世界"), Literal( ""), Literal(2)), "大千世界大")
216+
// scalastyle:on
217+
checkEvaluation(
218+
SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
219+
}
220+
190221
test("LIKE literal Regular Expression") {
191222
checkEvaluation(Literal.create(null, StringType).like("a"), null)
192223
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1788,8 +1788,18 @@ object functions {
17881788
def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
17891789

17901790
/**
1791-
* Locate the position of the first occurrence of substr in a string column.
1791+
* Returns the substring from string str before count occurrences of the delimiter delim.
1792+
* If count is positive, everything the left of the final delimiter (counting from left) is
1793+
* returned. If count is negative, every to the right of the final delimiter (counting from the
1794+
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
17921795
*
1796+
* @group string_funcs
1797+
*/
1798+
def substring_index(str: Column, delim: String, count: Int): Column =
1799+
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
1800+
1801+
/**
1802+
* Locate the position of the first occurrence of substr.
17931803
* NOTE: The position is not zero based, but 1 based index, returns 0 if substr
17941804
* could not be found in str.
17951805
*

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,63 @@ class StringFunctionsSuite extends QueryTest {
163163
Row(1))
164164
}
165165

166+
test("string substring_index function") {
167+
val df = Seq(("www.apache.org", ".", "zz")).toDF("a", "b", "c")
168+
checkAnswer(
169+
df.select(substring_index($"a", ".", 3)),
170+
Row("www.apache.org"))
171+
checkAnswer(
172+
df.select(substring_index($"a", ".", 2)),
173+
Row("www.apache"))
174+
checkAnswer(
175+
df.select(substring_index($"a", ".", 1)),
176+
Row("www"))
177+
checkAnswer(
178+
df.select(substring_index($"a", ".", 0)),
179+
Row(""))
180+
checkAnswer(
181+
df.select(substring_index(lit("www.apache.org"), ".", -1)),
182+
Row("org"))
183+
checkAnswer(
184+
df.select(substring_index(lit("www.apache.org"), ".", -2)),
185+
Row("apache.org"))
186+
checkAnswer(
187+
df.select(substring_index(lit("www.apache.org"), ".", -3)),
188+
Row("www.apache.org"))
189+
// str is empty string
190+
checkAnswer(
191+
df.select(substring_index(lit(""), ".", 1)),
192+
Row(""))
193+
// empty string delim
194+
checkAnswer(
195+
df.select(substring_index(lit("www.apache.org"), "", 1)),
196+
Row(""))
197+
// delim does not exist in str
198+
checkAnswer(
199+
df.select(substring_index(lit("www.apache.org"), "#", 1)),
200+
Row("www.apache.org"))
201+
// delim is 2 chars
202+
checkAnswer(
203+
df.select(substring_index(lit("www||apache||org"), "||", 2)),
204+
Row("www||apache"))
205+
checkAnswer(
206+
df.select(substring_index(lit("www||apache||org"), "||", -2)),
207+
Row("apache||org"))
208+
// null
209+
checkAnswer(
210+
df.select(substring_index(lit(null), "||", 2)),
211+
Row(null))
212+
checkAnswer(
213+
df.select(substring_index(lit("www.apache.org"), null, 2)),
214+
Row(null))
215+
// non ascii chars
216+
// scalastyle:off
217+
checkAnswer(
218+
df.selectExpr("""substring_index("大千世界大千世界", "千", 2)"""),
219+
Row("大千世界大"))
220+
// scalastyle:on
221+
}
222+
166223
test("string locate function") {
167224
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
168225

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ public byte[] getBytes() {
198198
*/
199199
public UTF8String substring(final int start, final int until) {
200200
if (until <= start || start >= numBytes) {
201-
return fromBytes(new byte[0]);
201+
return UTF8String.EMPTY_UTF8;
202202
}
203203

204204
int i = 0;
@@ -406,6 +406,84 @@ public int indexOf(UTF8String v, int start) {
406406
return -1;
407407
}
408408

409+
/**
410+
* Find the `str` from left to right.
411+
*/
412+
private int find(UTF8String str, int start) {
413+
assert (str.numBytes > 0);
414+
while (start <= numBytes - str.numBytes) {
415+
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
416+
return start;
417+
}
418+
start += 1;
419+
}
420+
return -1;
421+
}
422+
423+
/**
424+
* Find the `str` from right to left.
425+
*/
426+
private int rfind(UTF8String str, int start) {
427+
assert (str.numBytes > 0);
428+
while (start >= 0) {
429+
if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) {
430+
return start;
431+
}
432+
start -= 1;
433+
}
434+
return -1;
435+
}
436+
437+
/**
438+
* Returns the substring from string str before count occurrences of the delimiter delim.
439+
* If count is positive, everything the left of the final delimiter (counting from left) is
440+
* returned. If count is negative, every to the right of the final delimiter (counting from the
441+
* right) is returned. subStringIndex performs a case-sensitive match when searching for delim.
442+
*/
443+
public UTF8String subStringIndex(UTF8String delim, int count) {
444+
if (delim.numBytes == 0 || count == 0) {
445+
return EMPTY_UTF8;
446+
}
447+
if (count > 0) {
448+
int idx = -1;
449+
while (count > 0) {
450+
idx = find(delim, idx + 1);
451+
if (idx >= 0) {
452+
count --;
453+
} else {
454+
// can not find enough delim
455+
return this;
456+
}
457+
}
458+
if (idx == 0) {
459+
return EMPTY_UTF8;
460+
}
461+
byte[] bytes = new byte[idx];
462+
copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx);
463+
return fromBytes(bytes);
464+
465+
} else {
466+
int idx = numBytes - delim.numBytes + 1;
467+
count = -count;
468+
while (count > 0) {
469+
idx = rfind(delim, idx - 1);
470+
if (idx >= 0) {
471+
count --;
472+
} else {
473+
// can not find enough delim
474+
return this;
475+
}
476+
}
477+
if (idx + delim.numBytes == numBytes) {
478+
return EMPTY_UTF8;
479+
}
480+
int size = numBytes - delim.numBytes - idx;
481+
byte[] bytes = new byte[size];
482+
copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size);
483+
return fromBytes(bytes);
484+
}
485+
}
486+
409487
/**
410488
* Returns str, right-padded with pad to a length of len
411489
* For example:

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,44 @@ public void indexOf() {
240240
assertEquals(3, fromString("数据砖头").indexOf(fromString("头"), 0));
241241
}
242242

243+
@Test
244+
public void substring_index() {
245+
assertEquals(fromString("www.apache.org"),
246+
fromString("www.apache.org").subStringIndex(fromString("."), 3));
247+
assertEquals(fromString("www.apache"),
248+
fromString("www.apache.org").subStringIndex(fromString("."), 2));
249+
assertEquals(fromString("www"),
250+
fromString("www.apache.org").subStringIndex(fromString("."), 1));
251+
assertEquals(fromString(""),
252+
fromString("www.apache.org").subStringIndex(fromString("."), 0));
253+
assertEquals(fromString("org"),
254+
fromString("www.apache.org").subStringIndex(fromString("."), -1));
255+
assertEquals(fromString("apache.org"),
256+
fromString("www.apache.org").subStringIndex(fromString("."), -2));
257+
assertEquals(fromString("www.apache.org"),
258+
fromString("www.apache.org").subStringIndex(fromString("."), -3));
259+
// str is empty string
260+
assertEquals(fromString(""),
261+
fromString("").subStringIndex(fromString("."), 1));
262+
// empty string delim
263+
assertEquals(fromString(""),
264+
fromString("www.apache.org").subStringIndex(fromString(""), 1));
265+
// delim does not exist in str
266+
assertEquals(fromString("www.apache.org"),
267+
fromString("www.apache.org").subStringIndex(fromString("#"), 2));
268+
// delim is 2 chars
269+
assertEquals(fromString("www||apache"),
270+
fromString("www||apache||org").subStringIndex(fromString("||"), 2));
271+
assertEquals(fromString("apache||org"),
272+
fromString("www||apache||org").subStringIndex(fromString("||"), -2));
273+
// non ascii chars
274+
assertEquals(fromString("大千世界大"),
275+
fromString("大千世界大千世界").subStringIndex(fromString("千"), 2));
276+
// overlapped delim
277+
assertEquals(fromString("||"), fromString("||||||").subStringIndex(fromString("|||"), 3));
278+
assertEquals(fromString("|||"), fromString("||||||").subStringIndex(fromString("|||"), -4));
279+
}
280+
243281
@Test
244282
public void reverse() {
245283
assertEquals(fromString("olleh"), fromString("hello").reverse());

0 commit comments

Comments
 (0)