|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql |
| 19 | + |
| 20 | +import org.apache.spark.sql.functions._ |
| 21 | +import org.apache.spark.sql.types.Decimal |
| 22 | + |
| 23 | + |
| 24 | +class StringFunctionsSuite extends QueryTest { |
| 25 | + |
| 26 | + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext |
| 27 | + import ctx.implicits._ |
| 28 | + |
| 29 | + test("string concat") { |
| 30 | + val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") |
| 31 | + |
| 32 | + checkAnswer( |
| 33 | + df.select(concat($"a", $"b", $"c")), |
| 34 | + Row("ab")) |
| 35 | + |
| 36 | + checkAnswer( |
| 37 | + df.selectExpr("concat(a, b, c)"), |
| 38 | + Row("ab")) |
| 39 | + } |
| 40 | + |
| 41 | + |
| 42 | + test("string Levenshtein distance") { |
| 43 | + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") |
| 44 | + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) |
| 45 | + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) |
| 46 | + } |
| 47 | + |
| 48 | + test("string ascii function") { |
| 49 | + val df = Seq(("abc", "")).toDF("a", "b") |
| 50 | + checkAnswer( |
| 51 | + df.select(ascii($"a"), ascii("b")), |
| 52 | + Row(97, 0)) |
| 53 | + |
| 54 | + checkAnswer( |
| 55 | + df.selectExpr("ascii(a)", "ascii(b)"), |
| 56 | + Row(97, 0)) |
| 57 | + } |
| 58 | + |
| 59 | + test("string base64/unbase64 function") { |
| 60 | + val bytes = Array[Byte](1, 2, 3, 4) |
| 61 | + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") |
| 62 | + checkAnswer( |
| 63 | + df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), |
| 64 | + Row("AQIDBA==", "AQIDBA==", bytes, bytes)) |
| 65 | + |
| 66 | + checkAnswer( |
| 67 | + df.selectExpr("base64(a)", "unbase64(b)"), |
| 68 | + Row("AQIDBA==", bytes)) |
| 69 | + } |
| 70 | + |
| 71 | + test("string encode/decode function") { |
| 72 | + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) |
| 73 | + // scalastyle:off |
| 74 | + // non ascii characters are not allowed in the code, so we disable the scalastyle here. |
| 75 | + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") |
| 76 | + checkAnswer( |
| 77 | + df.select( |
| 78 | + encode($"a", "utf-8"), |
| 79 | + encode("a", "utf-8"), |
| 80 | + decode($"c", "utf-8"), |
| 81 | + decode("c", "utf-8")), |
| 82 | + Row(bytes, bytes, "大千世界", "大千世界")) |
| 83 | + |
| 84 | + checkAnswer( |
| 85 | + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), |
| 86 | + Row(bytes, "大千世界")) |
| 87 | + // scalastyle:on |
| 88 | + } |
| 89 | + |
| 90 | + test("string trim functions") { |
| 91 | + val df = Seq((" example ", "")).toDF("a", "b") |
| 92 | + |
| 93 | + checkAnswer( |
| 94 | + df.select(ltrim($"a"), rtrim($"a"), trim($"a")), |
| 95 | + Row("example ", " example", "example")) |
| 96 | + |
| 97 | + checkAnswer( |
| 98 | + df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), |
| 99 | + Row("example ", " example", "example")) |
| 100 | + } |
| 101 | + |
| 102 | + test("string formatString function") { |
| 103 | + val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") |
| 104 | + |
| 105 | + checkAnswer( |
| 106 | + df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), |
| 107 | + Row("aa123cc", "aa123cc")) |
| 108 | + |
| 109 | + checkAnswer( |
| 110 | + df.selectExpr("printf(a, b, c)"), |
| 111 | + Row("aa123cc")) |
| 112 | + } |
| 113 | + |
| 114 | + test("string instr function") { |
| 115 | + val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") |
| 116 | + |
| 117 | + checkAnswer( |
| 118 | + df.select(instr($"a", $"b"), instr("a", "b")), |
| 119 | + Row(1, 1)) |
| 120 | + |
| 121 | + checkAnswer( |
| 122 | + df.selectExpr("instr(a, b)"), |
| 123 | + Row(1)) |
| 124 | + } |
| 125 | + |
| 126 | + test("string locate function") { |
| 127 | + val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") |
| 128 | + |
| 129 | + checkAnswer( |
| 130 | + df.select( |
| 131 | + locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), |
| 132 | + locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), |
| 133 | + Row(1, 1, 2, 2, 2, 2)) |
| 134 | + |
| 135 | + checkAnswer( |
| 136 | + df.selectExpr("locate(b, a)", "locate(b, a, d)"), |
| 137 | + Row(1, 2)) |
| 138 | + } |
| 139 | + |
| 140 | + test("string padding functions") { |
| 141 | + val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") |
| 142 | + |
| 143 | + checkAnswer( |
| 144 | + df.select( |
| 145 | + lpad($"a", $"b", $"c"), rpad("a", "b", "c"), |
| 146 | + lpad($"a", 1, $"c"), rpad("a", 1, "c")), |
| 147 | + Row("???hi", "hi???", "h", "h")) |
| 148 | + |
| 149 | + checkAnswer( |
| 150 | + df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), |
| 151 | + Row("???hi", "hi???", "h", "h")) |
| 152 | + } |
| 153 | + |
| 154 | + test("string repeat function") { |
| 155 | + val df = Seq(("hi", 2)).toDF("a", "b") |
| 156 | + |
| 157 | + checkAnswer( |
| 158 | + df.select( |
| 159 | + repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), |
| 160 | + Row("hihi", "hihi", "hihi", "hihi")) |
| 161 | + |
| 162 | + checkAnswer( |
| 163 | + df.selectExpr("repeat(a, 2)", "repeat(a, b)"), |
| 164 | + Row("hihi", "hihi")) |
| 165 | + } |
| 166 | + |
| 167 | + test("string reverse function") { |
| 168 | + val df = Seq(("hi", "hhhi")).toDF("a", "b") |
| 169 | + |
| 170 | + checkAnswer( |
| 171 | + df.select(reverse($"a"), reverse("b")), |
| 172 | + Row("ih", "ihhh")) |
| 173 | + |
| 174 | + checkAnswer( |
| 175 | + df.selectExpr("reverse(b)"), |
| 176 | + Row("ihhh")) |
| 177 | + } |
| 178 | + |
| 179 | + test("string space function") { |
| 180 | + val df = Seq((2, 3)).toDF("a", "b") |
| 181 | + |
| 182 | + checkAnswer( |
| 183 | + df.select(space($"a"), space("b")), |
| 184 | + Row(" ", " ")) |
| 185 | + |
| 186 | + checkAnswer( |
| 187 | + df.selectExpr("space(b)"), |
| 188 | + Row(" ")) |
| 189 | + } |
| 190 | + |
| 191 | + test("string split function") { |
| 192 | + val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") |
| 193 | + |
| 194 | + checkAnswer( |
| 195 | + df.select( |
| 196 | + split($"a", "[1-9]+"), |
| 197 | + split("a", "[1-9]+")), |
| 198 | + Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) |
| 199 | + |
| 200 | + checkAnswer( |
| 201 | + df.selectExpr("split(a, '[1-9]+')"), |
| 202 | + Row(Seq("aa", "bb", "cc"))) |
| 203 | + } |
| 204 | + |
| 205 | + test("string / binary length function") { |
| 206 | + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") |
| 207 | + checkAnswer( |
| 208 | + df.select(length($"a"), length("a"), length($"b"), length("b")), |
| 209 | + Row(3, 3, 4, 4)) |
| 210 | + |
| 211 | + checkAnswer( |
| 212 | + df.selectExpr("length(a)", "length(b)"), |
| 213 | + Row(3, 4)) |
| 214 | + |
| 215 | + intercept[AnalysisException] { |
| 216 | + checkAnswer( |
| 217 | + df.selectExpr("length(c)"), // int type of the argument is unacceptable |
| 218 | + Row("5.0000")) |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + test("number format function") { |
| 223 | + val tuple = |
| 224 | + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], |
| 225 | + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) |
| 226 | + val df = |
| 227 | + Seq(tuple) |
| 228 | + .toDF( |
| 229 | + "a", // string "aa" |
| 230 | + "b", // byte 1 |
| 231 | + "c", // short 2 |
| 232 | + "d", // float 3.13223f |
| 233 | + "e", // integer 4 |
| 234 | + "f", // long 5L |
| 235 | + "g", // double 6.48173d |
| 236 | + "h") // decimal 7.128381 |
| 237 | + |
| 238 | + checkAnswer( |
| 239 | + df.select( |
| 240 | + format_number($"f", 4), |
| 241 | + format_number("f", 4)), |
| 242 | + Row("5.0000", "5.0000")) |
| 243 | + |
| 244 | + checkAnswer( |
| 245 | + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer |
| 246 | + Row("1.0000")) |
| 247 | + |
| 248 | + checkAnswer( |
| 249 | + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer |
| 250 | + Row("2.0000")) |
| 251 | + |
| 252 | + checkAnswer( |
| 253 | + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double |
| 254 | + Row("3.1322")) |
| 255 | + |
| 256 | + checkAnswer( |
| 257 | + df.selectExpr("format_number(e, e)"), // not convert anything |
| 258 | + Row("4.0000")) |
| 259 | + |
| 260 | + checkAnswer( |
| 261 | + df.selectExpr("format_number(f, e)"), // not convert anything |
| 262 | + Row("5.0000")) |
| 263 | + |
| 264 | + checkAnswer( |
| 265 | + df.selectExpr("format_number(g, e)"), // not convert anything |
| 266 | + Row("6.4817")) |
| 267 | + |
| 268 | + checkAnswer( |
| 269 | + df.selectExpr("format_number(h, e)"), // not convert anything |
| 270 | + Row("7.1284")) |
| 271 | + |
| 272 | + intercept[AnalysisException] { |
| 273 | + checkAnswer( |
| 274 | + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable |
| 275 | + Row("5.0000")) |
| 276 | + } |
| 277 | + |
| 278 | + intercept[AnalysisException] { |
| 279 | + checkAnswer( |
| 280 | + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable |
| 281 | + Row("5.0000")) |
| 282 | + } |
| 283 | + } |
| 284 | +} |
0 commit comments