Skip to content

Commit 900bc1f

Browse files
committed
[SPARK-24371][SQL] Added isInCollection in DataFrame API for Scala and Java.
## What changes were proposed in this pull request? Implemented **`isInCollection `** in DataFrame API for both Scala and Java, so users can do ```scala val profileDF = Seq( Some(1), Some(2), Some(3), Some(4), Some(5), Some(6), Some(7), None ).toDF("profileID") val validUsers: Seq[Any] = Seq(6, 7.toShort, 8L, "3") val result = profileDF.withColumn("isValid", $"profileID". isInCollection(validUsers)) result.show(10) """ +---------+-------+ |profileID|isValid| +---------+-------+ | 1| false| | 2| false| | 3| true| | 4| false| | 5| false| | 6| true| | 7| true| | null| null| +---------+-------+ """.stripMargin ``` ## How was this patch tested? Several unit tests are added. Author: DB Tsai <[email protected]> Closes #21416 from dbtsai/optimize-set.
1 parent aca65c6 commit 900bc1f

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
2222

2323
import org.apache.spark.sql.catalyst.analysis._
24-
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
2524
import org.apache.spark.sql.catalyst.expressions._
2625
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2726
import org.apache.spark.sql.catalyst.expressions.aggregate._

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ abstract class LogicalPlan
7878
schema.map { field =>
7979
resolve(field.name :: Nil, resolver).map {
8080
case a: AttributeReference => a
81-
case other => sys.error(s"can not handle nested schema yet... plan $this")
81+
case _ => sys.error(s"can not handle nested schema yet... plan $this")
8282
}.getOrElse {
8383
throw new AnalysisException(
8484
s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]")

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.JavaConverters._
2021
import scala.language.implicitConversions
2122

2223
import org.apache.spark.annotation.InterfaceStability
@@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging {
786787
@scala.annotation.varargs
787788
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) }
788789

790+
/**
791+
* A boolean expression that is evaluated to true if the value of this expression is contained
792+
* by the provided collection.
793+
*
794+
* @group expr_ops
795+
* @since 2.4.0
796+
*/
797+
def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)
798+
799+
/**
800+
* A boolean expression that is evaluated to true if the value of this expression is contained
801+
* by the provided collection.
802+
*
803+
* @group java_expr_ops
804+
* @since 2.4.0
805+
*/
806+
def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)
807+
789808
/**
790809
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
791810
*

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.util.Locale
21+
22+
import scala.collection.JavaConverters._
23+
2024
import org.apache.hadoop.io.{LongWritable, Text}
2125
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
2226
import org.scalatest.Matchers._
@@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
390394
checkAnswer(df.filter($"b".isin("z", "y")),
391395
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
392396

397+
// Auto casting should work with mixture of different types in collections
398+
checkAnswer(df.filter($"a".isin(1.toShort, "2")),
399+
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
400+
checkAnswer(df.filter($"a".isin("3", 2.toLong)),
401+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
402+
checkAnswer(df.filter($"a".isin(3, "1")),
403+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
404+
393405
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
394406

395-
intercept[AnalysisException] {
407+
val e = intercept[AnalysisException] {
396408
df2.filter($"a".isin($"b"))
397409
}
410+
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
411+
.foreach { s =>
412+
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
413+
}
414+
}
415+
416+
test("isInCollection: Scala Collection") {
417+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
418+
// Test with different types of collections
419+
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
420+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
421+
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
422+
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
423+
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
424+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
425+
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
426+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
427+
428+
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
429+
430+
val e = intercept[AnalysisException] {
431+
df2.filter($"a".isInCollection(Seq($"b")))
432+
}
433+
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
434+
.foreach { s =>
435+
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
436+
}
437+
}
438+
439+
test("isInCollection: Java Collection") {
440+
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
441+
// Test with different types of collections
442+
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)),
443+
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
444+
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)),
445+
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
446+
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),
447+
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
448+
449+
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
450+
451+
val e = intercept[AnalysisException] {
452+
df2.filter($"a".isInCollection(Seq($"b").asJava))
453+
}
454+
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
455+
.foreach { s =>
456+
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
457+
}
398458
}
399459

400460
test("&&") {

0 commit comments

Comments
 (0)