|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql |
19 | 19 |
|
| 20 | +import java.util.Locale |
| 21 | + |
| 22 | +import scala.collection.JavaConverters._ |
| 23 | + |
20 | 24 | import org.apache.hadoop.io.{LongWritable, Text} |
21 | 25 | import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} |
22 | 26 | import org.scalatest.Matchers._ |
@@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { |
390 | 394 | checkAnswer(df.filter($"b".isin("z", "y")), |
391 | 395 | df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) |
392 | 396 |
|
| 397 | + // Auto casting should work with mixture of different types in collections |
| 398 | + checkAnswer(df.filter($"a".isin(1.toShort, "2")), |
| 399 | + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) |
| 400 | + checkAnswer(df.filter($"a".isin("3", 2.toLong)), |
| 401 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) |
| 402 | + checkAnswer(df.filter($"a".isin(3, "1")), |
| 403 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) |
| 404 | + |
393 | 405 | val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") |
394 | 406 |
|
395 | | - intercept[AnalysisException] { |
| 407 | + val e = intercept[AnalysisException] { |
396 | 408 | df2.filter($"a".isin($"b")) |
397 | 409 | } |
| 410 | + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") |
| 411 | + .foreach { s => |
| 412 | + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) |
| 413 | + } |
| 414 | + } |
| 415 | + |
| 416 | + test("isInCollection: Scala Collection") { |
| 417 | + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") |
| 418 | + // Test with different types of collections |
| 419 | + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), |
| 420 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) |
| 421 | + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), |
| 422 | + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) |
| 423 | + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), |
| 424 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) |
| 425 | + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), |
| 426 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) |
| 427 | + |
| 428 | + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") |
| 429 | + |
| 430 | + val e = intercept[AnalysisException] { |
| 431 | + df2.filter($"a".isInCollection(Seq($"b"))) |
| 432 | + } |
| 433 | + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") |
| 434 | + .foreach { s => |
| 435 | + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) |
| 436 | + } |
| 437 | + } |
| 438 | + |
| 439 | + test("isInCollection: Java Collection") { |
| 440 | + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") |
| 441 | + // Test with different types of collections |
| 442 | + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), |
| 443 | + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) |
| 444 | + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), |
| 445 | + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) |
| 446 | + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), |
| 447 | + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) |
| 448 | + |
| 449 | + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") |
| 450 | + |
| 451 | + val e = intercept[AnalysisException] { |
| 452 | + df2.filter($"a".isInCollection(Seq($"b").asJava)) |
| 453 | + } |
| 454 | + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") |
| 455 | + .foreach { s => |
| 456 | + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) |
| 457 | + } |
398 | 458 | } |
399 | 459 |
|
400 | 460 | test("&&") { |
|
0 commit comments