Skip to content

Commit cdc7be6

Browse files
committed
Added code generation for pure string mode.
1 parent a61c4e4 commit cdc7be6

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
6666

6767

6868
/**
69-
* An expression that concatenates multiple input strings into a single string, using a given
70-
* separator (the first child).
69+
* An expression that concatenates multiple input strings or array of strings into a single string,
70+
* using a given separator (the first child).
7171
*
7272
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
7373
*/
@@ -78,11 +78,10 @@ case class ConcatWs(children: Seq[Expression])
7878

7979
override def prettyName: String = "concat_ws"
8080

81+
/** The 1st child (separator) is str, and rest are either str or array of str. */
8182
override def inputTypes: Seq[AbstractDataType] = {
82-
Seq.fill(children.size)(TypeCollection(
83-
ArrayType(StringType, true),
84-
ArrayType(StringType, false),
85-
StringType))
83+
val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
84+
StringType +: Seq.fill(children.size - 1)(arrayOrStr)
8685
}
8786

8887
override def dataType: DataType = StringType
@@ -100,6 +99,28 @@ case class ConcatWs(children: Seq[Expression])
10099
}
101100
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
102101
}
102+
103+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
104+
if (children.forall(_.dataType == StringType)) {
105+
// All children are strings. In that case we can construct a fixed size array.
106+
val evals = children.map(_.gen(ctx))
107+
108+
val inputs = evals.map { eval =>
109+
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
110+
}.mkString(", ")
111+
112+
evals.map(_.code).mkString("\n") + s"""
113+
boolean ${ev.isNull} = false;
114+
UTF8String ${ev.primitive} = UTF8String.concatWs($inputs);
115+
if (${ev.primitive} == null) {
116+
${ev.isNull} = true;
117+
}
118+
"""
119+
} else {
120+
// Contains a mix of strings and array<string>s. Fall back to interpreted mode for now.
121+
super.genCode(ctx, ev)
122+
}
123+
}
103124
}
104125

105126

0 commit comments

Comments
 (0)