Skip to content

Commit 5b47901

Browse files
committed
Remove sealed, more filter types
1 parent fab154a commit 5b47901

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,20 @@ private[sql] object DataSourceStrategy extends Strategy {
9393

9494
protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
9595
case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v)
96+
case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v)
97+
98+
case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v)
99+
case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v)
100+
101+
case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v)
102+
case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
103+
104+
case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
105+
GreaterThanOrEqual(a.name, v)
106+
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
107+
LessThanOrEqual(a.name, v)
108+
109+
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
110+
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
96111
}
97112
}

sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql.sources
1919

20-
abstract sealed class Filter
20+
abstract class Filter
2121

2222
case class EqualTo(attribute: String, value: Any) extends Filter
23+
case class GreaterThan(attribute: String, value: Any) extends Filter
24+
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
25+
case class LessThan(attribute: String, value: Any) extends Filter
26+
case class LessThanOrEqual(attribute: String, value: Any) extends Filter

sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,17 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
4545

4646
FiltersPushed.list = filters
4747

48-
val filter = filters.collect {
48+
val filterFunctions = filters.collect {
4949
case EqualTo("a", v) => (a: Int) => a == v
50-
}.headOption.getOrElse((_: Int) => true)
50+
case LessThan("a", v: Int) => (a: Int) => a < v
51+
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
52+
case GreaterThan("a", v: Int) => (a: Int) => a > v
53+
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
54+
}
55+
56+
def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
5157

52-
sqlContext.sparkContext.parallelize(from to to).filter(filter).map(i =>
58+
sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
5359
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
5460
}
5561
}
@@ -128,8 +134,23 @@ class FilteredScanSuite extends DataSourceTest {
128134
testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
129135
testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1)
130136
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1)
137+
testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1)
138+
139+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9)
140+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9)
141+
142+
testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9)
143+
testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9)
144+
145+
testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0)
146+
testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2)
147+
148+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0)
149+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2)
150+
151+
testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
152+
131153
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
132-
testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5", 10)
133154
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
134155

135156
def testPushDown(sqlString: String, expectedCount: Int): Unit = {

0 commit comments

Comments
 (0)