Skip to content

Commit 14f2634

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-9464][SQL] Property checks for UTF8String
This PR is based on the original work by JoshRosen in #7780, which adds ScalaCheck property-based tests for UTF8String. Author: Josh Rosen <[email protected]> Author: Yijie Shen <[email protected]> Closes #7830 from yjshen/utf8-property-checks and squashes the following commits: 593da3a [Yijie Shen] resolve comments c0800e6 [Yijie Shen] Finish all todos in suite 52f51a0 [Josh Rosen] Add some more failing tests 49ed069 [Josh Rosen] Rename suite 9209c64 [Josh Rosen] UTF8String Property Checks.
1 parent 6996bd2 commit 14f2634

File tree

4 files changed

+280
-11
lines changed

4 files changed

+280
-11
lines changed

unsafe/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@
7070
<artifactId>mockito-core</artifactId>
7171
<scope>test</scope>
7272
</dependency>
73+
<dependency>
74+
<groupId>org.scalacheck</groupId>
75+
<artifactId>scalacheck_${scala.binary.version}</artifactId>
76+
<scope>test</scope>
77+
</dependency>
78+
<dependency>
79+
<groupId>org.apache.commons</groupId>
80+
<artifactId>commons-lang3</artifactId>
81+
<scope>test</scope>
82+
</dependency>
7383
</dependencies>
7484
<build>
7585
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,9 @@ public UTF8String trim() {
301301
int s = 0;
302302
int e = this.numBytes - 1;
303303
// skip all of the space (0x20) in the left side
304-
while (s < this.numBytes && getByte(s) == 0x20) s++;
304+
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
305305
// skip all of the space (0x20) in the right side
306-
while (e >= 0 && getByte(e) == 0x20) e--;
307-
306+
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
308307
if (s > e) {
309308
// empty string
310309
return UTF8String.fromBytes(new byte[0]);
@@ -316,7 +315,7 @@ public UTF8String trim() {
316315
public UTF8String trimLeft() {
317316
int s = 0;
318317
// skip all of the space (0x20) in the left side
319-
while (s < this.numBytes && getByte(s) == 0x20) s++;
318+
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
320319
if (s == this.numBytes) {
321320
// empty string
322321
return UTF8String.fromBytes(new byte[0]);
@@ -328,7 +327,7 @@ public UTF8String trimLeft() {
328327
public UTF8String trimRight() {
329328
int e = numBytes - 1;
330329
// skip all of the space (0x20) in the right side
331-
while (e >= 0 && getByte(e) == 0x20) e--;
330+
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
332331

333332
if (e < 0) {
334333
// empty string
@@ -354,7 +353,7 @@ public UTF8String reverse() {
354353
}
355354

356355
public UTF8String repeat(int times) {
357-
if (times <=0) {
356+
if (times <= 0) {
358357
return EMPTY_UTF8;
359358
}
360359

@@ -492,7 +491,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) {
492491
*/
493492
public UTF8String rpad(int len, UTF8String pad) {
494493
int spaces = len - this.numChars(); // number of char need to pad
495-
if (spaces <= 0) {
494+
if (spaces <= 0 || pad.numBytes() == 0) {
496495
// no padding at all, return the substring of the current string
497496
return substring(0, len);
498497
} else {
@@ -507,7 +506,7 @@ public UTF8String rpad(int len, UTF8String pad) {
507506
int idx = 0;
508507
while (idx < count) {
509508
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
510-
++idx;
509+
++ idx;
511510
offset += pad.numBytes;
512511
}
513512
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
@@ -524,7 +523,7 @@ public UTF8String rpad(int len, UTF8String pad) {
524523
*/
525524
public UTF8String lpad(int len, UTF8String pad) {
526525
int spaces = len - this.numChars(); // number of char need to pad
527-
if (spaces <= 0) {
526+
if (spaces <= 0 || pad.numBytes() == 0) {
528527
// no padding at all, return the substring of the current string
529528
return substring(0, len);
530529
} else {
@@ -539,7 +538,7 @@ public UTF8String lpad(int len, UTF8String pad) {
539538
int idx = 0;
540539
while (idx < count) {
541540
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
542-
++idx;
541+
++ idx;
543542
offset += pad.numBytes;
544543
}
545544
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ public void pad() {
309309
assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????")));
310310
assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????")));
311311

312-
313312
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????")));
314313
assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????")));
315314
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
@@ -327,6 +326,18 @@ public void pad() {
327326
assertEquals(
328327
fromString("数据砖头孙行者孙行者孙行"),
329328
fromString("数据砖头").rpad(12, fromString("孙行者")));
329+
330+
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者")));
331+
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8));
332+
assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8));
333+
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8));
334+
assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8));
335+
336+
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者")));
337+
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8));
338+
assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8));
339+
assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8));
340+
assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8));
330341
}
331342

332343
@Test
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.unsafe.types
19+
20+
import org.apache.commons.lang3.StringUtils
21+
22+
import org.scalacheck.{Arbitrary, Gen}
23+
import org.scalatest.prop.GeneratorDrivenPropertyChecks
24+
// scalastyle:off
25+
import org.scalatest.{FunSuite, Matchers}
26+
27+
import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}
28+
29+
/**
30+
* This TestSuite utilize ScalaCheck to generate randomized inputs for UTF8String testing.
31+
*/
32+
class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenPropertyChecks with Matchers {
33+
// scalastyle:on
34+
35+
test("toString") {
36+
forAll { (s: String) =>
37+
assert(toUTF8(s).toString() === s)
38+
}
39+
}
40+
41+
test("numChars") {
42+
forAll { (s: String) =>
43+
assert(toUTF8(s).numChars() === s.length)
44+
}
45+
}
46+
47+
test("startsWith") {
48+
forAll { (s: String) =>
49+
val utf8 = toUTF8(s)
50+
assert(utf8.startsWith(utf8))
51+
for (i <- 1 to s.length) {
52+
assert(utf8.startsWith(toUTF8(s.dropRight(i))))
53+
}
54+
}
55+
}
56+
57+
test("endsWith") {
58+
forAll { (s: String) =>
59+
val utf8 = toUTF8(s)
60+
assert(utf8.endsWith(utf8))
61+
for (i <- 1 to s.length) {
62+
assert(utf8.endsWith(toUTF8(s.drop(i))))
63+
}
64+
}
65+
}
66+
67+
test("toUpperCase") {
68+
forAll { (s: String) =>
69+
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
70+
}
71+
}
72+
73+
test("toLowerCase") {
74+
forAll { (s: String) =>
75+
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
76+
}
77+
}
78+
79+
test("compare") {
80+
forAll { (s1: String, s2: String) =>
81+
assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2)))
82+
}
83+
}
84+
85+
test("substring") {
86+
forAll { (s: String) =>
87+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
88+
assert(toUTF8(s).substring(start, end).toString === s.substring(start, end))
89+
}
90+
}
91+
}
92+
93+
test("contains") {
94+
forAll { (s: String) =>
95+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
96+
val substring = s.substring(start, end)
97+
assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring))
98+
}
99+
}
100+
}
101+
102+
val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar)
103+
val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString)
104+
val randomString: Gen[String] = Arbitrary.arbString.arbitrary
105+
106+
test("trim, trimLeft, trimRight") {
107+
// lTrim and rTrim are both modified from java.lang.String.trim
108+
def lTrim(s: String): String = {
109+
var st = 0
110+
val array: Array[Char] = s.toCharArray
111+
while ((st < s.length) && (array(st) <= ' ')) {
112+
st += 1
113+
}
114+
if (st > 0) s.substring(st, s.length) else s
115+
}
116+
def rTrim(s: String): String = {
117+
var len = s.length
118+
val array: Array[Char] = s.toCharArray
119+
while ((len > 0) && (array(len - 1) <= ' ')) {
120+
len -= 1
121+
}
122+
if (len < s.length) s.substring(0, len) else s
123+
}
124+
125+
forAll(
126+
whitespaceString,
127+
randomString,
128+
whitespaceString
129+
) { (start: String, middle: String, end: String) =>
130+
val s = start + middle + end
131+
assert(toUTF8(s).trim() === toUTF8(s.trim()))
132+
assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s)))
133+
assert(toUTF8(s).trimRight() === toUTF8(rTrim(s)))
134+
}
135+
}
136+
137+
test("reverse") {
138+
forAll { (s: String) =>
139+
assert(toUTF8(s).reverse === toUTF8(s.reverse))
140+
}
141+
}
142+
143+
test("indexOf") {
144+
forAll { (s: String) =>
145+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
146+
val substring = s.substring(start, end)
147+
assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring))
148+
}
149+
}
150+
}
151+
152+
val randomInt = Gen.choose(-100, 100)
153+
154+
test("repeat") {
155+
def repeat(str: String, times: Int): String = {
156+
if (times > 0) str * times else ""
157+
}
158+
// ScalaCheck always generating too large repeat times which might hang the test forever.
159+
forAll(randomString, randomInt) { (s: String, times: Int) =>
160+
assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times)))
161+
}
162+
}
163+
164+
test("lpad, rpad") {
165+
def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = {
166+
if (length <= 0) return ""
167+
if (length <= origin.length) {
168+
if (length <= 0) "" else origin.substring(0, length)
169+
} else {
170+
if (pad.length == 0) return origin
171+
val toPad = length - origin.length
172+
val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length)
173+
if (isLPad) {
174+
pad * (toPad / pad.length) + partPad + origin
175+
} else {
176+
origin + pad * (toPad / pad.length) + partPad
177+
}
178+
}
179+
}
180+
181+
forAll (
182+
randomString,
183+
randomString,
184+
randomInt
185+
) { (s: String, pad: String, length: Int) =>
186+
assert(toUTF8(s).lpad(length, toUTF8(pad)) ===
187+
toUTF8(padding(s, pad, length, true)))
188+
assert(toUTF8(s).rpad(length, toUTF8(pad)) ===
189+
toUTF8(padding(s, pad, length, false)))
190+
}
191+
}
192+
193+
val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString))
194+
195+
test("concat") {
196+
def concat(orgin: Seq[String]): String =
197+
if (orgin.exists(_ == null)) null else orgin.mkString
198+
199+
forAll { (inputs: Seq[String]) =>
200+
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString))
201+
}
202+
forAll (nullalbeSeq) { (inputs: Seq[String]) =>
203+
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs)))
204+
}
205+
}
206+
207+
test("concatWs") {
208+
def concatWs(sep: String, inputs: Seq[String]): String = {
209+
if (sep == null) return null
210+
inputs.filter(_ != null).mkString(sep)
211+
}
212+
213+
forAll { (sep: String, inputs: Seq[String]) =>
214+
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
215+
toUTF8(inputs.mkString(sep)))
216+
}
217+
forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) =>
218+
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
219+
toUTF8(concatWs(sep, inputs)))
220+
}
221+
}
222+
223+
// TODO: enable this when we find a proper way to generate valid patterns
224+
ignore("split") {
225+
forAll { (s: String, pattern: String, limit: Int) =>
226+
assert(toUTF8(s).split(toUTF8(pattern), limit) ===
227+
s.split(pattern, limit).map(toUTF8(_)))
228+
}
229+
}
230+
231+
test("levenshteinDistance") {
232+
forAll { (one: String, another: String) =>
233+
assert(toUTF8(one).levenshteinDistance(toUTF8(another)) ===
234+
StringUtils.getLevenshteinDistance(one, another))
235+
}
236+
}
237+
238+
test("hashCode") {
239+
forAll { (s: String) =>
240+
assert(toUTF8(s).hashCode() === toUTF8(s).hashCode())
241+
}
242+
}
243+
244+
test("equals") {
245+
forAll { (one: String, another: String) =>
246+
assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another))
247+
}
248+
}
249+
}

0 commit comments

Comments
 (0)