Skip to content

Commit c0800e6

Browse files
committed
Finish all todos in suite
1 parent 52f51a0 commit c0800e6

File tree

4 files changed

+173
-44
lines changed

4 files changed

+173
-44
lines changed

unsafe/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@
7373
<dependency>
7474
<groupId>org.scalacheck</groupId>
7575
<artifactId>scalacheck_${scala.binary.version}</artifactId>
76+
<scope>test</scope>
77+
</dependency>
78+
<dependency>
79+
<groupId>org.apache.commons</groupId>
80+
<artifactId>commons-lang3</artifactId>
81+
<scope>test</scope>
7682
</dependency>
7783
</dependencies>
7884
<build>

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,9 @@ public UTF8String trim() {
301301
int s = 0;
302302
int e = this.numBytes - 1;
303303
// skip all of the space (0x20) in the left side
304-
while (s < this.numBytes && getByte(s) == 0x20) s++;
304+
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
305305
// skip all of the space (0x20) in the right side
306-
while (e >= 0 && getByte(e) == 0x20) e--;
307-
306+
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
308307
if (s > e) {
309308
// empty string
310309
return UTF8String.fromBytes(new byte[0]);
@@ -316,7 +315,7 @@ public UTF8String trim() {
316315
public UTF8String trimLeft() {
317316
int s = 0;
318317
// skip all of the space (0x20) in the left side
319-
while (s < this.numBytes && getByte(s) == 0x20) s++;
318+
while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++;
320319
if (s == this.numBytes) {
321320
// empty string
322321
return UTF8String.fromBytes(new byte[0]);
@@ -328,7 +327,7 @@ public UTF8String trimLeft() {
328327
public UTF8String trimRight() {
329328
int e = numBytes - 1;
330329
// skip all of the space (0x20) in the right side
331-
while (e >= 0 && getByte(e) == 0x20) e--;
330+
while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--;
332331

333332
if (e < 0) {
334333
// empty string
@@ -354,7 +353,7 @@ public UTF8String reverse() {
354353
}
355354

356355
public UTF8String repeat(int times) {
357-
if (times <=0) {
356+
if (times <= 0) {
358357
return EMPTY_UTF8;
359358
}
360359

@@ -414,7 +413,7 @@ public int indexOf(UTF8String v, int start) {
414413
*/
415414
public UTF8String rpad(int len, UTF8String pad) {
416415
int spaces = len - this.numChars(); // number of char need to pad
417-
if (spaces <= 0) {
416+
if (spaces <= 0 || pad.numChars() == 0) {
418417
// no padding at all, return the substring of the current string
419418
return substring(0, len);
420419
} else {
@@ -429,7 +428,7 @@ public UTF8String rpad(int len, UTF8String pad) {
429428
int idx = 0;
430429
while (idx < count) {
431430
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
432-
++idx;
431+
++ idx;
433432
offset += pad.numBytes;
434433
}
435434
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);
@@ -446,7 +445,7 @@ public UTF8String rpad(int len, UTF8String pad) {
446445
*/
447446
public UTF8String lpad(int len, UTF8String pad) {
448447
int spaces = len - this.numChars(); // number of char need to pad
449-
if (spaces <= 0) {
448+
if (spaces <= 0 || pad.numChars() == 0) {
450449
// no padding at all, return the substring of the current string
451450
return substring(0, len);
452451
} else {
@@ -461,7 +460,7 @@ public UTF8String lpad(int len, UTF8String pad) {
461460
int idx = 0;
462461
while (idx < count) {
463462
copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes);
464-
++idx;
463+
++ idx;
465464
offset += pad.numBytes;
466465
}
467466
copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes);

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ public void pad() {
271271
assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????")));
272272
assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????")));
273273

274-
275274
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????")));
276275
assertEquals(fromString("?数据砖头"), fromString("数据砖头").lpad(5, fromString("????")));
277276
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
@@ -289,6 +288,18 @@ public void pad() {
289288
assertEquals(
290289
fromString("数据砖头孙行者孙行者孙行"),
291290
fromString("数据砖头").rpad(12, fromString("孙行者")));
291+
292+
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, fromString("孙行者")));
293+
assertEquals(EMPTY_UTF8, fromString("数据砖头").lpad(-10, EMPTY_UTF8));
294+
assertEquals(fromString("数据砖头"), fromString("数据砖头").lpad(5, EMPTY_UTF8));
295+
assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, EMPTY_UTF8));
296+
assertEquals(EMPTY_UTF8, EMPTY_UTF8.lpad(3, EMPTY_UTF8));
297+
298+
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, fromString("孙行者")));
299+
assertEquals(EMPTY_UTF8, fromString("数据砖头").rpad(-10, EMPTY_UTF8));
300+
assertEquals(fromString("数据砖头"), fromString("数据砖头").rpad(5, EMPTY_UTF8));
301+
assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, EMPTY_UTF8));
302+
assertEquals(EMPTY_UTF8, EMPTY_UTF8.rpad(3, EMPTY_UTF8));
292303
}
293304

294305
@Test
Lines changed: 146 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.unsafe.types
219

20+
import org.apache.commons.lang3.StringUtils
21+
322
import org.scalacheck.{Arbitrary, Gen}
423
import org.scalatest.prop.GeneratorDrivenPropertyChecks
24+
// scalastyle:off
525
import org.scalatest.{FunSuite, Matchers}
626

727
import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8}
828

929
class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChecks with Matchers {
30+
// scalastyle:on
1031

1132
test("toString") {
1233
forAll { (s: String) =>
13-
assert(s === toUTF8(s).toString())
34+
assert(toUTF8(s).toString() === s)
1435
}
1536
}
1637

@@ -42,41 +63,35 @@ class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChec
4263

4364
test("toUpperCase") {
4465
forAll { (s: String) =>
45-
assert(s.toUpperCase === toUTF8(s).toUpperCase.toString)
66+
assert(toUTF8(s).toUpperCase === toUTF8(s.toUpperCase))
4667
}
4768
}
4869

4970
test("toLowerCase") {
5071
forAll { (s: String) =>
51-
assert(s.toLowerCase === toUTF8(s).toLowerCase.toString)
72+
assert(toUTF8(s).toLowerCase === toUTF8(s.toLowerCase))
5273
}
5374
}
5475

5576
test("compare") {
5677
forAll { (s1: String, s2: String) =>
57-
assert(Math.signum(s1.compareTo(s2)) === Math.signum(toUTF8(s1).compareTo(toUTF8(s2))))
78+
assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2)))
5879
}
5980
}
6081

6182
test("substring") {
6283
forAll { (s: String) =>
63-
for (start <- 0 to s.length; end <- 0 to s.length) {
64-
withClue(s"start=$start, end=$end") {
65-
assert(s.substring(start, end) === toUTF8(s).substring(start, end).toString)
66-
}
84+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
85+
assert(toUTF8(s).substring(start, end).toString === s.substring(start, end))
6786
}
6887
}
6988
}
7089

71-
// TODO: substringSQL
72-
7390
test("contains") {
7491
forAll { (s: String) =>
75-
for (start <- 0 to s.length; end <- 0 to s.length) {
92+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
7693
val substring = s.substring(start, end)
77-
withClue(s"substring=$substring") {
78-
assert(s.contains(substring) === toUTF8(s).contains(toUTF8(substring)))
79-
}
94+
assert(toUTF8(s).contains(toUTF8(substring)) === s.contains(substring))
8095
}
8196
}
8297
}
@@ -86,48 +101,146 @@ class UTF8StringPropertyChecks extends FunSuite with GeneratorDrivenPropertyChec
86101
val randomString: Gen[String] = Arbitrary.arbString.arbitrary
87102

88103
test("trim, trimLeft, trimRight") {
104+
// lTrim and rTrim are both modified from java.lang.String.trim
105+
def lTrim(s: String): String = {
106+
var st = 0
107+
val array: Array[Char] = s.toCharArray
108+
while ((st < s.length) && (array(st) <= ' ')) {
109+
st += 1
110+
}
111+
if (st > 0) s.substring(st, s.length) else s
112+
}
113+
def rTrim(s: String): String = {
114+
var len = s.length
115+
val array: Array[Char] = s.toCharArray
116+
while ((len > 0) && (array(len - 1) <= ' ')) {
117+
len -= 1
118+
}
119+
if (len < s.length) s.substring(0, len) else s
120+
}
121+
89122
forAll(
90123
whitespaceString,
91124
randomString,
92125
whitespaceString
93126
) { (start: String, middle: String, end: String) =>
94127
val s = start + middle + end
95-
assert(s.trim() === toUTF8(s).trim().toString)
96-
assert(s.stripMargin === toUTF8(s).trimLeft().toString)
97-
assert(s.reverse.stripMargin.reverse === toUTF8(s).trimRight().toString)
128+
assert(toUTF8(s).trim() === toUTF8(s.trim()))
129+
assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s)))
130+
assert(toUTF8(s).trimRight() === toUTF8(rTrim(s)))
98131
}
99132
}
100133

101134
test("reverse") {
102-
forAll() { (s: String) =>
103-
assert(s.reverse === toUTF8(s).reverse.toString)
135+
forAll { (s: String) =>
136+
assert(toUTF8(s).reverse === toUTF8(s.reverse))
104137
}
105138
}
106139

107-
// TODO: repeat
108-
// TODO: indexOf
109-
// TODO: lpad
110-
// TODO: rpad
140+
test("indexOf") {
141+
forAll { (s: String) =>
142+
for (start <- 0 to s.length; end <- 0 to s.length; if start <= end) {
143+
val substring = s.substring(start, end)
144+
assert(toUTF8(s).indexOf(toUTF8(substring), 0) === s.indexOf(substring))
145+
}
146+
}
147+
}
148+
149+
val randomInt = Gen.choose(-100, 100)
150+
151+
test("repeat") {
152+
def repeat(str: String, times: Int): String = {
153+
if (times > 0) str * times else ""
154+
}
155+
// ScalaCheck always generating too large repeat times which might hang the test forever.
156+
forAll(randomString, randomInt) { (s: String, times: Int) =>
157+
assert(toUTF8(s).repeat(times) === toUTF8(repeat(s, times)))
158+
}
159+
}
160+
161+
test("lpad, rpad") {
162+
def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = {
163+
if (length <= 0) return ""
164+
if (length <= origin.length) {
165+
if (length <= 0) "" else origin.substring(0, length)
166+
} else {
167+
if (pad.length == 0) return origin
168+
val toPad = length - origin.length
169+
val partPad = if (toPad % pad.length == 0) "" else pad.substring(0, toPad % pad.length)
170+
if (isLPad) {
171+
pad * (toPad / pad.length) + partPad + origin
172+
} else {
173+
origin + pad * (toPad / pad.length) + partPad
174+
}
175+
}
176+
}
177+
178+
forAll (
179+
randomString,
180+
randomString,
181+
randomInt
182+
) { (s: String, pad: String, length: Int) =>
183+
assert(toUTF8(s).lpad(length, toUTF8(pad)) ===
184+
toUTF8(padding(s, pad, length, true)))
185+
assert(toUTF8(s).rpad(length, toUTF8(pad)) ===
186+
toUTF8(padding(s, pad, length, false)))
187+
}
188+
}
189+
190+
val nullalbeSeq = Gen.listOf(Gen.oneOf[String](null: String, randomString))
111191

112192
test("concat") {
113-
forAll() { (inputs: Seq[String]) =>
114-
// TODO: test case where at least one of the inputs is null
115-
assert(inputs.mkString === UTF8String.concat(inputs.map(toUTF8): _*).toString)
193+
def concat(orgin: Seq[String]): String =
194+
if (orgin.exists(_ == null)) null else orgin.mkString
195+
196+
forAll { (inputs: Seq[String]) =>
197+
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString))
198+
}
199+
forAll (nullalbeSeq) { (inputs: Seq[String]) =>
200+
assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(concat(inputs)))
116201
}
117202
}
118203

119204
test("concatWs") {
120-
forAll() { (sep: String, inputs: Seq[String]) =>
121-
// TODO: handle case where at least one of the inputs is null
122-
assert(
123-
inputs.mkString(sep) === UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*).toString)
205+
def concatWs(sep: String, inputs: Seq[String]): String = {
206+
if (sep == null) return null
207+
inputs.filter(_ != null).mkString(sep)
208+
}
209+
210+
forAll { (sep: String, inputs: Seq[String]) =>
211+
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
212+
toUTF8(inputs.mkString(sep)))
213+
}
214+
forAll(randomString, nullalbeSeq) {(sep: String, inputs: Seq[String]) =>
215+
assert(UTF8String.concatWs(toUTF8(sep), inputs.map(toUTF8): _*) ===
216+
toUTF8(concatWs(sep, inputs)))
124217
}
125218
}
126219

127-
// TODO: split
220+
// TODO: enable this when we find a proper way to generate valid patterns
221+
ignore("split") {
222+
forAll { (s: String, pattern: String, limit: Int) =>
223+
assert(toUTF8(s).split(toUTF8(pattern), limit) ===
224+
s.split(pattern, limit).map(toUTF8(_)))
225+
}
226+
}
128227

129-
// TODO: levenshteinDistance that tests against StringUtils' implementation
228+
test("levenshteinDistance") {
229+
forAll { (one: String, another: String) =>
230+
assert(toUTF8(one).levenshteinDistance(toUTF8(another)) ===
231+
StringUtils.getLevenshteinDistance(one, another))
232+
}
233+
}
130234

131-
// TODO: equals(), hashCode(), and compare()
235+
test("hashCode") {
236+
forAll { (s: String) =>
237+
assert(toUTF8(s).hashCode() === toUTF8(s).hashCode())
238+
}
239+
}
132240

241+
test("equals") {
242+
forAll { (one: String, another: String) =>
243+
assert(toUTF8(one).equals(toUTF8(another)) === one.equals(another))
244+
}
245+
}
133246
}

0 commit comments

Comments
 (0)