Skip to content

Commit 52fc5c1

Browse files
maropucloud-fan
authored andcommitted
[SPARK-22825][SQL] Fix incorrect results of Casting Array to String
## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro <[email protected]> Closes #20024 from maropu/SPARK-22825.
1 parent df7fc3e commit 52fc5c1

File tree

4 files changed

+171
-2
lines changed

4 files changed

+171
-2
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen;
19+
20+
import org.apache.spark.unsafe.Platform;
21+
import org.apache.spark.unsafe.array.ByteArrayMethods;
22+
import org.apache.spark.unsafe.types.UTF8String;
23+
24+
/**
25+
* A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
26+
* {@link UTF8String} at the end.
27+
*/
28+
public class UTF8StringBuilder {
29+
30+
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
31+
32+
private byte[] buffer;
33+
private int cursor = Platform.BYTE_ARRAY_OFFSET;
34+
35+
public UTF8StringBuilder() {
36+
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
37+
this.buffer = new byte[16];
38+
}
39+
40+
// Grows the buffer by at least `neededSize`
41+
private void grow(int neededSize) {
42+
if (neededSize > ARRAY_MAX - totalSize()) {
43+
throw new UnsupportedOperationException(
44+
"Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
45+
"exceeds size limitation " + ARRAY_MAX);
46+
}
47+
final int length = totalSize() + neededSize;
48+
if (buffer.length < length) {
49+
int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
50+
final byte[] tmp = new byte[newLength];
51+
Platform.copyMemory(
52+
buffer,
53+
Platform.BYTE_ARRAY_OFFSET,
54+
tmp,
55+
Platform.BYTE_ARRAY_OFFSET,
56+
totalSize());
57+
buffer = tmp;
58+
}
59+
}
60+
61+
private int totalSize() {
62+
return cursor - Platform.BYTE_ARRAY_OFFSET;
63+
}
64+
65+
public void append(UTF8String value) {
66+
grow(value.numBytes());
67+
value.writeToMemory(buffer, cursor);
68+
cursor += value.numBytes();
69+
}
70+
71+
public void append(String value) {
72+
append(UTF8String.fromString(value));
73+
}
74+
75+
public UTF8String build() {
76+
return UTF8String.fromBytes(buffer, 0, totalSize());
77+
}
78+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
206206
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
207207
case TimestampType => buildCast[Long](_,
208208
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
209+
case ArrayType(et, _) =>
210+
buildCast[ArrayData](_, array => {
211+
val builder = new UTF8StringBuilder
212+
builder.append("[")
213+
if (array.numElements > 0) {
214+
val toUTF8String = castToString(et)
215+
if (!array.isNullAt(0)) {
216+
builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
217+
}
218+
var i = 1
219+
while (i < array.numElements) {
220+
builder.append(",")
221+
if (!array.isNullAt(i)) {
222+
builder.append(" ")
223+
builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
224+
}
225+
i += 1
226+
}
227+
}
228+
builder.append("]")
229+
builder.build()
230+
})
209231
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
210232
}
211233

@@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
597619
"""
598620
}
599621

622+
private def writeArrayToStringBuilder(
623+
et: DataType,
624+
array: String,
625+
buffer: String,
626+
ctx: CodegenContext): String = {
627+
val elementToStringCode = castToStringCode(et, ctx)
628+
val funcName = ctx.freshName("elementToString")
629+
val elementToStringFunc = ctx.addNewFunction(funcName,
630+
s"""
631+
|private UTF8String $funcName(${ctx.javaType(et)} element) {
632+
| UTF8String elementStr = null;
633+
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
634+
| return elementStr;
635+
|}
636+
""".stripMargin)
637+
638+
val loopIndex = ctx.freshName("loopIndex")
639+
s"""
640+
|$buffer.append("[");
641+
|if ($array.numElements() > 0) {
642+
| if (!$array.isNullAt(0)) {
643+
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
644+
| }
645+
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
646+
| $buffer.append(",");
647+
| if (!$array.isNullAt($loopIndex)) {
648+
| $buffer.append(" ");
649+
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
650+
| }
651+
| }
652+
|}
653+
|$buffer.append("]");
654+
""".stripMargin
655+
}
656+
600657
private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
601658
from match {
602659
case BinaryType =>
@@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
608665
val tz = ctx.addReferenceObj("timeZone", timeZone)
609666
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
610667
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
668+
case ArrayType(et, _) =>
669+
(c, evPrim, evNull) => {
670+
val buffer = ctx.freshName("buffer")
671+
val bufferClass = classOf[UTF8StringBuilder].getName
672+
val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
673+
s"""
674+
|$bufferClass $buffer = new $bufferClass();
675+
|$writeArrayElemCode;
676+
|$evPrim = $buffer.build();
677+
""".stripMargin
678+
}
611679
case _ =>
612680
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
613681
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
853853
cast("2", LongType).genCode(ctx)
854854
assert(ctx.inlinedMutableStates.length == 0)
855855
}
856+
857+
test("SPARK-22825 Cast array to string") {
858+
val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
859+
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
860+
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
861+
checkEvaluation(ret2, "[ab, cde, f]")
862+
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
863+
checkEvaluation(ret3, "[ab,, c]")
864+
val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
865+
checkEvaluation(ret4, "[ab, cde, f]")
866+
val ret5 = cast(
867+
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
868+
StringType)
869+
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
870+
val ret6 = cast(
871+
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
872+
StringType)
873+
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
874+
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
875+
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
876+
val ret8 = cast(
877+
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
878+
StringType)
879+
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
880+
}
856881
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2828
import org.apache.spark.sql.catalyst.util.StringUtils
2929
import org.apache.spark.sql.execution.aggregate
3030
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
31-
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
32-
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
3331
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
3432
import org.apache.spark.sql.functions._
3533
import org.apache.spark.sql.internal.SQLConf

0 commit comments

Comments
 (0)