Skip to content

Commit dc280a4

Browse files
committed
Codegen for StringLocate
1 parent dd0614f commit dc280a4

File tree

2 files changed

+125
-6
lines changed

2 files changed

+125
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
550550
* in given string after position pos.
551551
*/
552552
case class StringLocate(substr: Expression, str: Expression, start: Expression)
553-
extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
553+
extends TernaryExpression with ImplicitCastInputTypes {
554554

555555
def this(substr: Expression, str: Expression) = {
556556
this(substr, str, Literal(0))
@@ -582,6 +582,109 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
582582
}
583583
}
584584

585+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
586+
if (start.foldable) {
587+
val startValue = start.eval()
588+
if (startValue != null) {
589+
if (str.foldable) { // both start and str are foldable
590+
val strValue = str.eval()
591+
if (strValue != null) {
592+
val strValueString = strValue.asInstanceOf[UTF8String].toString
593+
val strUTF8 = ctx.freshName("strUTF8")
594+
val substrGen = substr.gen(ctx)
595+
s"""
596+
${substrGen.code}
597+
boolean ${ev.isNull} = ${substrGen.isNull};
598+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
599+
if (!${ev.isNull}) {
600+
UTF8String $strUTF8 = UTF8String.fromString("$strValueString");
601+
${ev.primitive} = $strUTF8.indexOf(${substrGen.primitive}, $startValue) + 1;
602+
}
603+
"""
604+
} else { // strValue == null
605+
s"""
606+
boolean ${ev.isNull} = true;
607+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
608+
"""
609+
}
610+
} else { // only start is foldable
611+
val strGen = str.gen(ctx)
612+
val substrGen = substr.gen(ctx)
613+
s"""
614+
${substrGen.code}
615+
boolean ${ev.isNull} = ${substrGen.isNull};
616+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
617+
if (!${ev.isNull}) {
618+
${strGen.code}
619+
if (!${strGen.isNull}) {
620+
${ev.primitive} =
621+
${strGen.primitive}.indexOf(${substrGen.primitive}, $startValue) + 1;
622+
} else {
623+
${ev.isNull} = true;
624+
}
625+
}
626+
"""
627+
}
628+
} else { // startValue == null
629+
s"""
630+
boolean ${ev.isNull} = false;
631+
${ctx.javaType(dataType)} ${ev.primitive} = 0;
632+
"""
633+
}
634+
} else if (str.foldable) { // only str is foldable
635+
val strValue = str.eval()
636+
if (strValue != null) {
637+
val substrGen = substr.gen(ctx)
638+
val startGen = start.gen(ctx)
639+
val strValueString = strValue.asInstanceOf[UTF8String].toString
640+
val strUTF8 = ctx.freshName("strUTF8")
641+
s"""
642+
${substrGen.code}
643+
boolean ${ev.isNull} = ${substrGen.isNull};
644+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
645+
if (!${ev.isNull}) {
646+
${startGen.code}
647+
if (!${startGen.isNull}) {
648+
UTF8String $strUTF8 = UTF8String.fromString("$strValueString");
649+
${ev.primitive} =
650+
$strUTF8.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1;
651+
} else {
652+
${ev.primitive} = 0;
653+
}
654+
}
655+
"""
656+
} else {
657+
s"""
658+
boolean ${ev.isNull} = true;
659+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
660+
"""
661+
}
662+
} else { // neither start nor str is foldable
663+
val substrGen = substr.gen(ctx)
664+
val strGen = str.gen(ctx)
665+
val startGen = start.gen(ctx)
666+
s"""
667+
${startGen.code}
668+
boolean ${ev.isNull} = true;
669+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
670+
if (!${startGen.isNull}) {
671+
${substrGen.code}
672+
if (!${substrGen.isNull}) {
673+
${strGen.code}
674+
if (!${strGen.isNull}) {
675+
${ev.isNull} = false;
676+
${ev.primitive} =
677+
${strGen.primitive}.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1;
678+
}
679+
}
680+
} else {
681+
${ev.isNull} = false;
682+
${ev.primitive} = 0;
683+
}
684+
"""
685+
}
686+
}
687+
585688
override def prettyName: String = "locate"
586689
}
587690

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,11 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
512512
val row3 = create_row("aaads", null, "zz", 0)
513513
val row4 = create_row(null, null, null, 0)
514514

515-
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
516-
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
517-
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1)
518-
checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1)
519-
checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1)
515+
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1)
516+
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2)
517+
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0)
518+
checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0)
519+
checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0)
520520

521521
checkEvaluation(new StringLocate(s2, s1), 1, row1)
522522
checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
@@ -525,6 +525,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
525525
checkEvaluation(new StringLocate(s2, s1), null, row2)
526526
checkEvaluation(new StringLocate(s2, s1), null, row3)
527527
checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4)
528+
529+
// str is non-foldable
530+
checkEvaluation(StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), Literal(1)), 2)
531+
checkEvaluation(
532+
StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType), Literal(1)), null)
533+
// start is non-foldable
534+
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), NonFoldableLiteral(1)), 2)
535+
checkEvaluation(
536+
StringLocate(Literal("aa"), Literal("aaads"),
537+
NonFoldableLiteral.create(null, IntegerType)), 0)
538+
// str and start are both non-foldable
539+
checkEvaluation(
540+
StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), NonFoldableLiteral(1)), 2)
541+
checkEvaluation(
542+
StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType),
543+
NonFoldableLiteral.create(null, IntegerType)), 0)
528544
}
529545

530546
test("LPAD/RPAD") {

0 commit comments

Comments
 (0)