@@ -66,8 +66,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
6666
6767
6868/**
69- * An expression that concatenates multiple input strings into a single string, using a given
70- * separator (the first child).
69+ * An expression that concatenates multiple input strings or array of strings into a single string,
70+ * using a given separator (the first child).
7171 *
7272 * Returns null if the separator is null. Otherwise, concat_ws skips all null values.
7373 */
@@ -78,11 +78,10 @@ case class ConcatWs(children: Seq[Expression])
7878
7979 override def prettyName : String = " concat_ws"
8080
81+ /** The 1st child (separator) is str, and rest are either str or array of str. */
8182 override def inputTypes : Seq [AbstractDataType ] = {
82- Seq .fill(children.size)(TypeCollection (
83- ArrayType (StringType , true ),
84- ArrayType (StringType , false ),
85- StringType ))
83+ val arrayOrStr = TypeCollection (ArrayType (StringType ), StringType )
84+ StringType +: Seq .fill(children.size - 1 )(arrayOrStr)
8685 }
8786
8887 override def dataType : DataType = StringType
@@ -100,6 +99,28 @@ case class ConcatWs(children: Seq[Expression])
10099 }
101100 UTF8String .concatWs(flatInputs.head, flatInputs.tail : _* )
102101 }
102+
103+ override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
104+ if (children.forall(_.dataType == StringType )) {
105+ // All children are strings. In that case we can construct a fixed size array.
106+ val evals = children.map(_.gen(ctx))
107+
108+ val inputs = evals.map { eval =>
109+ s " ${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
110+ }.mkString(" , " )
111+
112+ evals.map(_.code).mkString(" \n " ) + s """
113+ boolean ${ev.isNull} = false;
114+ UTF8String ${ev.primitive} = UTF8String.concatWs( $inputs);
115+ if ( ${ev.primitive} == null) {
116+ ${ev.isNull} = true;
117+ }
118+ """
119+ } else {
120+ // Contains a mix of strings and array<string>s. Fall back to interpreted mode for now.
121+ super .genCode(ctx, ev)
122+ }
123+ }
103124}
104125
105126
0 commit comments