Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
* in given string after position pos.
*/
case class StringLocate(substr: Expression, str: Expression, start: Expression)
extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
extends TernaryExpression with ImplicitCastInputTypes {

def this(substr: Expression, str: Expression) = {
this(substr, str, Literal(0))
Expand Down Expand Up @@ -582,6 +582,109 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (start.foldable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not need these special cases for StringLocate, because you do not need any prepare to call UTF8String.indexOf(substr, start)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following #7779 (comment), if it's not required, I think 7779 is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can merge that one (fix the confict while merging).

val startValue = start.eval()
if (startValue != null) {
if (str.foldable) { // both start and str are foldable
val strValue = str.eval()
if (strValue != null) {
val strValueString = strValue.asInstanceOf[UTF8String].toString
val strUTF8 = ctx.freshName("strUTF8")
val substrGen = substr.gen(ctx)
s"""
${substrGen.code}
boolean ${ev.isNull} = ${substrGen.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
UTF8String $strUTF8 = UTF8String.fromString("$strValueString");
${ev.primitive} = $strUTF8.indexOf(${substrGen.primitive}, $startValue) + 1;
}
"""
} else { // strValue == null
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
}
} else { // only start is foldable
val strGen = str.gen(ctx)
val substrGen = substr.gen(ctx)
s"""
${substrGen.code}
boolean ${ev.isNull} = ${substrGen.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${strGen.code}
if (!${strGen.isNull}) {
${ev.primitive} =
${strGen.primitive}.indexOf(${substrGen.primitive}, $startValue) + 1;
} else {
${ev.isNull} = true;
}
}
"""
}
} else { // startValue == null
s"""
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = 0;
"""
}
} else if (str.foldable) { // only str is foldable
val strValue = str.eval()
if (strValue != null) {
val substrGen = substr.gen(ctx)
val startGen = start.gen(ctx)
val strValueString = strValue.asInstanceOf[UTF8String].toString
val strUTF8 = ctx.freshName("strUTF8")
s"""
${substrGen.code}
boolean ${ev.isNull} = ${substrGen.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${startGen.code}
if (!${startGen.isNull}) {
UTF8String $strUTF8 = UTF8String.fromString("$strValueString");
${ev.primitive} =
$strUTF8.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1;
} else {
${ev.primitive} = 0;
}
}
"""
} else {
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
}
} else { // neither start nor str is foldable
val substrGen = substr.gen(ctx)
val strGen = str.gen(ctx)
val startGen = start.gen(ctx)
s"""
${startGen.code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${startGen.isNull}) {
${substrGen.code}
if (!${substrGen.isNull}) {
${strGen.code}
if (!${strGen.isNull}) {
${ev.isNull} = false;
${ev.primitive} =
${strGen.primitive}.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1;
}
}
} else {
${ev.isNull} = false;
${ev.primitive} = 0;
}
"""
}
}

override def prettyName: String = "locate"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row3 = create_row("aaads", null, "zz", 0)
val row4 = create_row(null, null, null, 0)

checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1)
checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1)
checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1)
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0)
checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0)
checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0)

checkEvaluation(new StringLocate(s2, s1), 1, row1)
checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
Expand All @@ -525,6 +525,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(new StringLocate(s2, s1), null, row2)
checkEvaluation(new StringLocate(s2, s1), null, row3)
checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4)

// str is non-foldable
checkEvaluation(StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), Literal(1)), 2)
checkEvaluation(
StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType), Literal(1)), null)
// start is non-foldable
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), NonFoldableLiteral(1)), 2)
checkEvaluation(
StringLocate(Literal("aa"), Literal("aaads"),
NonFoldableLiteral.create(null, IntegerType)), 0)
// str and start are both non-foldable
checkEvaluation(
StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), NonFoldableLiteral(1)), 2)
checkEvaluation(
StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType),
NonFoldableLiteral.create(null, IntegerType)), 0)
}

test("LPAD/RPAD") {
Expand Down