From dc280a44dcfdb338bc8f0505acd977faf994df4a Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 18 Aug 2015 19:30:11 +0800 Subject: [PATCH] Codegen for StringLocate --- .../expressions/stringOperations.scala | 105 +++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 26 ++++- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 134f1aa2af9a..c540bcd9bc26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -550,7 +550,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -582,6 +582,109 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (start.foldable) { + val startValue = start.eval() + if (startValue != null) { + if (str.foldable) { // both start and str are foldable + val strValue = str.eval() + if (strValue != null) { + val strValueString = strValue.asInstanceOf[UTF8String].toString + val strUTF8 = ctx.freshName("strUTF8") + val substrGen = substr.gen(ctx) + s""" + ${substrGen.code} + boolean ${ev.isNull} = ${substrGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + UTF8String $strUTF8 = UTF8String.fromString("$strValueString"); + ${ev.primitive} = $strUTF8.indexOf(${substrGen.primitive}, $startValue) + 1; + } + """ + } else { // strValue == null + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { // only start is foldable + val strGen = str.gen(ctx) + val substrGen = substr.gen(ctx) + s""" + ${substrGen.code} + boolean ${ev.isNull} = ${substrGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${ev.primitive} = + ${strGen.primitive}.indexOf(${substrGen.primitive}, $startValue) + 1; + } else { + ${ev.isNull} = true; + } + } + """ + } + } else { // startValue == null + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = 0; + """ + } + } else if (str.foldable) { // only str is foldable + val strValue = str.eval() + if (strValue != null) { + val substrGen = substr.gen(ctx) + val startGen = start.gen(ctx) + val strValueString = strValue.asInstanceOf[UTF8String].toString + val strUTF8 = ctx.freshName("strUTF8") + s""" + ${substrGen.code} + boolean ${ev.isNull} = ${substrGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${startGen.code} + if (!${startGen.isNull}) { + UTF8String $strUTF8 = UTF8String.fromString("$strValueString"); + ${ev.primitive} = + $strUTF8.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1; + } else { + ${ev.primitive} = 0; + } + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { // neither start nor str is foldable + val substrGen = substr.gen(ctx) + val strGen = str.gen(ctx) + val startGen = start.gen(ctx) + s""" + ${startGen.code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${startGen.isNull}) { + ${substrGen.code} + if (!${substrGen.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = + ${strGen.primitive}.indexOf(${substrGen.primitive}, ${startGen.primitive}) + 1; + } + } + } else { + ${ev.isNull} = false; + ${ev.primitive} = 0; + } + """ + } + } + override def prettyName: String = "locate" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 426dc272471a..f056f5183af3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -512,11 +512,11 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row3 = create_row("aaads", null, "zz", 0) val row4 = create_row(null, null, null, 0) - checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) - checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) - checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0) + checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0) checkEvaluation(new StringLocate(s2, s1), 1, row1) checkEvaluation(StringLocate(s2, s1, s4), 2, row1) @@ -525,6 +525,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(new StringLocate(s2, s1), null, row2) checkEvaluation(new StringLocate(s2, s1), null, row3) checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4) + + // str is non-foldable + checkEvaluation(StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), Literal(1)), 2) + checkEvaluation( + StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType), Literal(1)), null) + // start is non-foldable + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), NonFoldableLiteral(1)), 2) + checkEvaluation( + StringLocate(Literal("aa"), Literal("aaads"), + NonFoldableLiteral.create(null, IntegerType)), 0) + // str and start are both non-foldable + checkEvaluation( + StringLocate(Literal("aa"), NonFoldableLiteral("aaads"), NonFoldableLiteral(1)), 2) + checkEvaluation( + StringLocate(Literal("aa"), NonFoldableLiteral.create(null, StringType), + NonFoldableLiteral.create(null, IntegerType)), 0) } test("LPAD/RPAD") {