From 2a04ab3deaa989738ef77b9e70dd00bba6ae4d1e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 13 Nov 2014 16:59:46 -0800 Subject: [PATCH 1/4] Simplify implementation of InSet. --- .../sql/catalyst/expressions/predicates.scala | 4 ++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../expressions/ExpressionEvaluationSuite.scala | 14 +++++++------- .../sql/catalyst/optimizer/OptimizeInSuite.scala | 3 +-- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1e22b2d03c672..94b6fb084d38a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) +case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = child + def children = value :: Nil def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4aa322fc52d8..f164a6c68a0de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(null)) - InSet(v, HashSet() ++ hSet, v +: list) + InSet(v, HashSet() ++ hSet) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 918996f11da2c..2f57be94a80fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite { val nl = Literal(null) val s = Seq(one, two) val nullS = Seq(one, two, null) - checkEvaluation(InSet(one, hS, one +: s), true) - checkEvaluation(InSet(two, hS, two +: s), true) - checkEvaluation(InSet(two, nS, two +: nullS), true) - checkEvaluation(InSet(nl, nS, nl +: nullS), true) - checkEvaluation(InSet(three, hS, three +: s), false) - checkEvaluation(InSet(three, nS, three +: nullS), false) - checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + checkEvaluation(InSet(one, hS), true) + checkEvaluation(InSet(two, hS), true) + checkEvaluation(InSet(two, nS), true) + checkEvaluation(InSet(nl, nS), true) + checkEvaluation(InSet(three, hS), false) + checkEvaluation(InSet(three, nS), false) + checkEvaluation(InSet(one, hS) && InSet(two, hS), true) } test("MaxOf") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 97a78ec971c39..017b180c574b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, - UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2)) .analyze comparePlans(optimized, correctAnswer) From 416f167cb58edc088c449ea65f327fe4f8ed9e74 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 13 Nov 2014 17:00:36 -0800 Subject: [PATCH 2/4] Support for IN in data sources API. --- .../org/apache/spark/sql/sources/DataSourceStrategy.scala | 2 ++ .../main/scala/org/apache/spark/sql/sources/filters.scala | 1 + .../org/apache/spark/sql/sources/FilteredScanSuite.scala | 7 +++++++ 3 files changed, 10 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 9b8c6a56b94b4..954e86822de17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + + case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index e72a2aeb8f310..4a9fefc12b9ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter case class LessThan(attribute: String, value: Any) extends Filter case class LessThanOrEqual(attribute: String, value: Any) extends Filter +case class In(attribute: String, values: Array[Any]) extends Filter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 8b2f1591d5bf3..939b3c0c66de7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a) } def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) @@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest { "SELECT * FROM oneToTenFiltered WHERE a = 1", Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + sqlTest( "SELECT * FROM oneToTenFiltered WHERE A = 1", Seq(1).map(i => Row(i, i * 2)).toSeq) @@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest { testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) From 99c0e6b1672ed8ec6fb40d9f90f887592b7eac46 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 13 Nov 2014 17:01:02 -0800 Subject: [PATCH 3/4] Add support for sizeInBytes. --- .../apache/spark/sql/sources/LogicalRelation.scala | 3 +-- .../org/apache/spark/sql/sources/interfaces.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 82a2cf8402f8f..4d87f6817dcb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation) } @transient override lazy val statistics = Statistics( - // TODO: Allow datasources to provide statistics as well. - sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + sizeInBytes = BigInt(relation.sizeInBytes) ) /** Used to lookup original attribute capitalization */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ac3bf9d8e1a21..4ea8363a7ff49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** @@ -53,6 +53,15 @@ trait RelationProvider { abstract class BaseRelation { def sqlContext: SQLContext def schema: StructType + + /** + * Returns an estimated size of this relation in bytes. This information is used by the planner + * to decided when it is safe to broadcast a relation and can be overridden by sources that + * know the size ahead of time. By default, the system will assume that tables are too + * large to broadcast. This method will be called multiple times during query planning + * and thus should not perform expensive operations for each invocation. + */ + def sizeInBytes = sqlContext.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES).toLong } /** From 9a5e17166c5f8c75f067846ec5f515db0857f1ea Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 13 Nov 2014 17:03:23 -0800 Subject: [PATCH 4/4] Use method instead of configuration directly --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 4ea8363a7ff49..861638b1e99b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -61,7 +61,7 @@ abstract class BaseRelation { * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. */ - def sizeInBytes = sqlContext.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES).toLong + def sizeInBytes = sqlContext.defaultSizeInBytes } /**