From b803fc80efec026784b87c468b2597e5efbb6708 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 11 Sep 2014 15:53:45 +0530 Subject: [PATCH 01/21] Add CACHE TABLE AS SELECT ... This feature allows user to add cache table from the select query. Example : ADD CACHE TABLE AS SELECT * FROM TEST_TABLE. Spark takes this type of SQL as command and it does eager caching. It can be executed from SQLContext and HiveContext. Signed-off-by: ravipesala --- .../apache/spark/sql/catalyst/SqlParser.scala | 9 +++++++- .../sql/catalyst/plans/logical/commands.scala | 5 +++++ .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../apache/spark/sql/execution/commands.scala | 21 +++++++++++++++++++ .../apache/spark/sql/CachedTableSuite.scala | 16 ++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 13 +++++++++++- 6 files changed, 64 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ca69531c69a7..c0f1314a2349 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -127,6 +127,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SQRT = Keyword("SQRT") protected val ABS = Keyword("ABS") + protected val ADD = Keyword("ADD") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -151,7 +152,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache + | insert | cache | addCache ) protected lazy val select: Parser[LogicalPlan] = @@ -181,6 +182,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } + + protected lazy val addCache: Parser[LogicalPlan] = + ADD ~ CACHE ~ TABLE ~> ident ~ AS ~ select <~ opt(";") ^^ { + case tableName ~ as ~ s => + CacheTableAsSelectCommand(tableName,s) + } protected lazy val cache: Parser[LogicalPlan] = (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index a01809c1fc5e..948fd53a3d8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -75,3 +75,8 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "ADD CACHE TABLE tableName AS SELECT .." command. + */ +case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7943d6e1b6fb..aa205d278b5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -305,6 +305,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) + case logical.CacheTableAsSelectCommand(tableName,plan) => + Seq(execution.CacheTableAsSelectCommand(tableName, plan)(context)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 94543fc95b47..5bbf783ea875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -166,3 +166,24 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( child.output.map(field => Row(field.name, field.dataType.toString, null)) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CacheTableAsSelectCommand(tableName: String,plan: LogicalPlan)( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult = { + context.catalog.registerTable(None, tableName, sqlContext.executePlan(plan).analyzed) + context.cacheTable(tableName) + //It does the caching eager. + //TODO : Does it really require to collect? + context.table(tableName).collect + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index befef46d9397..a5bc74170833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -119,4 +119,20 @@ class CachedTableSuite extends QueryTest { } assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") } + + test("ADD CACHE TABLE tableName AS SELECT Star Table") { + TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") + TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + TestSQLContext.uncacheTable("testCacheTable") + } + + test("'ADD CACHE TABLE tableName AS SELECT ..'") { + TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") + TestSQLContext.table("testCacheTable").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'testCacheTable' should be cached") + } + assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestSQLContext.uncacheTable("testCacheTable") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c98287c6aa66..104a81680a05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,6 +216,17 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { + if (sql.trim.toLowerCase.startsWith("add cache table")) { + sql.trim.drop(16).split(" ").toSeq match { + case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,processSql(sql.trim.drop(16+tableName.length()+as.length()+1))) + } + } else { + processSql(sql) + } + } + + /** Returns a LogicalPlan for a given HiveQL string. */ + def processSql(sql : String) = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" @@ -233,7 +244,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8).trim) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { From 4e858d83b0020a1701ed65eac7047ee2978329db Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 18:06:49 +0530 Subject: [PATCH 02/21] Updated parser to support add cache table command --- .../org/apache/spark/sql/hive/HiveQl.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 104a81680a05..ae2256cda412 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -214,19 +214,9 @@ private[hive] object HiveQl { */ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) + /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = { - if (sql.trim.toLowerCase.startsWith("add cache table")) { - sql.trim.drop(16).split(" ").toSeq match { - case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,processSql(sql.trim.drop(16+tableName.length()+as.length()+1))) - } - } else { - processSql(sql) - } - } - - /** Returns a LogicalPlan for a given HiveQL string. */ - def processSql(sql : String) = { + def parseSql(sql : String) = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" @@ -254,14 +244,12 @@ private[hive] object HiveQl { } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { - NativeCommand(sql) + if (sql.trim.toLowerCase.startsWith("add cache table")) { + sql.trim.drop(16).split(" ").toSeq match { + case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,createPlan(sql.trim.drop(16+tableName.length()+as.length()+1))) + } } else { - nodeToPlan(tree) match { - case NativePlaceholder => NativeCommand(sql) - case other => other - } + createPlan(sql) } } } catch { @@ -274,6 +262,18 @@ private[hive] object HiveQl { } } + def createPlan(sql: String) ={ + val tree = getAst(sql) + if (nativeCommands contains tree.getText) { + NativeCommand(sql) + } else { + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } + } + } + def parseDdl(ddl: String): Seq[Attribute] = { val tree = try { From 13c8e27c33e8934bbd6fb458536675e97c3d8798 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 22:45:10 +0530 Subject: [PATCH 03/21] Updated parser to support add cache table command --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ae2256cda412..deecd5d95e79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,7 +216,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql : String) = { + def parseSql(sql : String): LogicalPlan = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" From 7459ce36775126f4c0636585c1d29f30ab35fd06 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 23:09:28 +0530 Subject: [PATCH 04/21] Added comment --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index deecd5d95e79..6f945d90673c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -262,6 +262,7 @@ private[hive] object HiveQl { } } + /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) ={ val tree = getAst(sql) if (nativeCommands contains tree.getText) { From 6758f808d14ec7a3da0953f7720f7f5b9a4e8a85 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 23:37:25 +0530 Subject: [PATCH 05/21] Changed style --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6f945d90673c..b3258b62d592 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,7 +216,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql : String): LogicalPlan = { + def parseSql(sql: String): LogicalPlan = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" From a523ceaf159733dabcef84c7adc1463546679f65 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sat, 13 Sep 2014 12:34:20 -0700 Subject: [PATCH 06/21] [SQL] [Docs] typo fixes * Fixed random typo * Added in missing description for DecimalType Author: Nicholas Chammas Closes #2367 from nchammas/patch-1 and squashes the following commits: aa528be [Nicholas Chammas] doc fix for SQL DecimalType 3247ac1 [Nicholas Chammas] [SQL] [Docs] typo fixes --- docs/sql-programming-guide.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3159d52787d5..8d41fdec699e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -918,7 +918,6 @@ options. ## Migration Guide for Shark User ### Scheduling -s To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: @@ -1110,7 +1109,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c The range of numbers is from `-9223372036854775808` to `9223372036854775807`. - `FloatType`: Represents 4-byte single-precision floating point numbers. - `DoubleType`: Represents 8-byte double-precision floating point numbers. - - `DecimalType`: + - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. * Binary type From 184cd51c4207c23726da97f907f2d912a5a44845 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 13 Sep 2014 12:35:40 -0700 Subject: [PATCH 07/21] [SPARK-3481][SQL] Removes the evil MINOR HACK This is a follow up of #2352. Now we can finally remove the evil "MINOR HACK", which covered up the eldest bug in the history of Spark SQL (see details [here](https://github.com/apache/spark/pull/2352#issuecomment-55440621)). Author: Cheng Lian Closes #2377 from liancheng/remove-evil-minor-hack and squashes the following commits: 0869c78 [Cheng Lian] Removes the evil MINOR HACK --- .../spark/sql/hive/execution/HiveComparisonTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 671c3b162f87..79cc7a3fcc7d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -250,9 +250,9 @@ abstract class HiveComparisonTest } try { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") - if (reset) { TestHive.reset() } + if (reset) { + TestHive.reset() + } val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => From 74049249abb952ad061c0e221c22ff894a9e9c8d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 13 Sep 2014 15:08:30 -0700 Subject: [PATCH 08/21] [SPARK-3294][SQL] Eliminates boxing costs from in-memory columnar storage This is a major refactoring of the in-memory columnar storage implementation, aims to eliminate boxing costs from critical paths (building/accessing column buffers) as much as possible. The basic idea is to refactor all major interfaces into a row-based form and use them together with `SpecificMutableRow`. The difficult part is how to adapt all compression schemes, esp. `RunLengthEncoding` and `DictionaryEncoding`, to this design. Since in-memory compression is disabled by default for now, and this PR should be strictly better than before no matter in-memory compression is enabled or not, maybe I'll finish that part in another PR. **UPDATE** This PR also took the chance to optimize `HiveTableScan` by 1. leveraging `SpecificMutableRow` to avoid boxing cost, and 1. building specific `Writable` unwrapper functions a head of time to avoid per row pattern matching and branching costs. TODO - [x] Benchmark - [ ] ~~Eliminate boxing costs in `RunLengthEncoding`~~ (left to future PRs) - [ ] ~~Eliminate boxing costs in `DictionaryEncoding` (seems not easy to do without specializing `DictionaryEncoding` for every supported column type)~~ (left to future PRs) ## Micro benchmark The benchmark uses a 10 million line CSV table consists of bytes, shorts, integers, longs, floats and doubles, measures the time to build the in-memory version of this table, and the time to scan the whole in-memory table. Benchmark code can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-hivetablescanbenchmark-scala). Script used to generate the input table can be found [here](https://gist.github.com/liancheng/fe70a148de82e77bd2c8#file-tablegen-scala). Speedup: - Hive table scanning + column buffer building: **18.74%** The original benchmark uses 1K as in-memory batch size, when increased to 10K, it can be 28.32% faster. - In-memory table scanning: **7.95%** Before: | Building | Scanning ------- | -------- | -------- 1 | 16472 | 525 2 | 16168 | 530 3 | 16386 | 529 4 | 16184 | 538 5 | 16209 | 521 Average | 16283.8 | 528.6 After: | Building | Scanning ------- | -------- | -------- 1 | 13124 | 458 2 | 13260 | 529 3 | 12981 | 463 4 | 13214 | 483 5 | 13583 | 500 Average | 13232.4 | 486.6 Author: Cheng Lian Closes #2327 from liancheng/prevent-boxing/unboxing and squashes the following commits: 4419fe4 [Cheng Lian] Addressing comments e5d2cf2 [Cheng Lian] Bug fix: should call setNullAt when field value is null to avoid NPE 8b8552b [Cheng Lian] Only checks for partition batch pruning flag once 489f97b [Cheng Lian] Bug fix: TableReader.fillObject uses wrong ordinals 97bbc4e [Cheng Lian] Optimizes hive.TableReader by by providing specific Writable unwrappers a head of time 3dc1f94 [Cheng Lian] Minor changes to eliminate row object creation 5b39cb9 [Cheng Lian] Lowers log level of compression scheme details f2a7890 [Cheng Lian] Use SpecificMutableRow in InMemoryColumnarTableScan to avoid boxing 9cf30b0 [Cheng Lian] Added row based ColumnType.append/extract 456c366 [Cheng Lian] Made compression decoder row based edac3cd [Cheng Lian] Makes ColumnAccessor.extractSingle row based 8216936 [Cheng Lian] Removes boxing cost in IntDelta and LongDelta by providing specialized implementations b70d519 [Cheng Lian] Made some in-memory columnar storage interfaces row-based --- .../catalyst/expressions/SpecificRow.scala | 2 +- .../spark/sql/columnar/ColumnAccessor.scala | 8 +- .../spark/sql/columnar/ColumnBuilder.scala | 27 +- .../spark/sql/columnar/ColumnStats.scala | 16 +- .../spark/sql/columnar/ColumnType.scala | 178 ++++++++++-- .../columnar/InMemoryColumnarTableScan.scala | 92 +++--- .../sql/columnar/NullableColumnAccessor.scala | 4 +- .../sql/columnar/NullableColumnBuilder.scala | 8 +- .../CompressibleColumnAccessor.scala | 7 +- .../CompressibleColumnBuilder.scala | 24 +- .../compression/CompressionScheme.scala | 16 +- .../compression/compressionSchemes.scala | 264 +++++++++++------- .../spark/sql/columnar/ColumnStatsSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 11 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 6 +- .../compression/BooleanBitSetSuite.scala | 7 +- .../compression/DictionaryEncodingSuite.scala | 9 +- .../compression/IntegralDeltaSuite.scala | 9 +- .../compression/RunLengthEncodingSuite.scala | 9 +- .../apache/spark/sql/hive/TableReader.scala | 119 ++++---- .../sql/hive/execution/HiveQuerySuite.scala | 18 +- 24 files changed, 554 insertions(+), 292 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 088f11ee4aa5..9cbab3d5d0d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -171,7 +171,7 @@ final class MutableByte extends MutableValue { } final class MutableAny extends MutableValue { - var value: Any = 0 + var value: Any = _ def boxed = if (isNull) null else value def update(v: Any) = value = { isNull = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 42a5a9a84f36..c9faf0852142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( def hasNext = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int) { - columnType.setField(row, ordinal, extractSingle(buffer)) + def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) } - def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } protected def underlyingBuffer = buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index b3ec5ded2242..2e61a981375a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) - columnType.append(field, buffer) + override def appendFrom(row: Row, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) } override def build() = { @@ -142,16 +141,16 @@ private[sql] object ColumnBuilder { useCompression: Boolean = false): ColumnBuilder = { val builder = (typeId match { - case INT.typeId => new IntColumnBuilder - case LONG.typeId => new LongColumnBuilder - case FLOAT.typeId => new FloatColumnBuilder - case DOUBLE.typeId => new DoubleColumnBuilder - case BOOLEAN.typeId => new BooleanColumnBuilder - case BYTE.typeId => new ByteColumnBuilder - case SHORT.typeId => new ShortColumnBuilder - case STRING.typeId => new StringColumnBuilder - case BINARY.typeId => new BinaryColumnBuilder - case GENERIC.typeId => new GenericColumnBuilder + case INT.typeId => new IntColumnBuilder + case LONG.typeId => new LongColumnBuilder + case FLOAT.typeId => new FloatColumnBuilder + case DOUBLE.typeId => new DoubleColumnBuilder + case BOOLEAN.typeId => new BooleanColumnBuilder + case BYTE.typeId => new ByteColumnBuilder + case SHORT.typeId => new ShortColumnBuilder + case STRING.typeId => new StringColumnBuilder + case BINARY.typeId => new BinaryColumnBuilder + case GENERIC.typeId => new GenericColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index fc343ccb995c..203a714e03c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats { var lower = Byte.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value @@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats { var lower = Short.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value @@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats { var lower = Long.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value @@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { var lower = Double.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value @@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats { var lower = Float.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value @@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats { var lower = Int.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value @@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats { var lower: String = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value @@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { var lower: Timestamp = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9a6160011587..198b5756676a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import java.sql.Timestamp - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ @@ -46,16 +45,33 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def extract(buffer: ByteBuffer): JvmType + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + /** * Appends the given value v of type T into the given ByteBuffer. */ - def append(v: JvmType, buffer: ByteBuffer) + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } /** - * Returns the size of the value. This is used to calculate the size of variable length types - * such as byte arrays and strings. + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. */ - def actualSize(v: JvmType): Int = defaultSize + def actualSize(row: Row, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs @@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType) + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to(toOrdinal) = from(fromOrdinal) + } /** * Creates a duplicated copy of the value. @@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer) { + def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + def extract(buffer: ByteBuffer) = { buffer.getInt() } - override def setField(row: MutableRow, ordinal: Int, value: Int) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, buffer.getInt()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { - override def append(v: Long, buffer: ByteBuffer) { + override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getLong() } - override def setField(row: MutableRow, ordinal: Int, value: Long) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, buffer.getLong()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { - override def append(v: Float, buffer: ByteBuffer) { + override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getFloat() } - override def setField(row: MutableRow, ordinal: Int, value: Float) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, buffer.getFloat()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { - override def append(v: Double, buffer: ByteBuffer) { + override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getDouble() } - override def setField(row: MutableRow, ordinal: Int, value: Double) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, buffer.getDouble()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { - override def append(v: Boolean, buffer: ByteBuffer) { - buffer.put(if (v) 1.toByte else 0.toByte) + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } override def extract(buffer: ByteBuffer) = buffer.get() == 1 - override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { - override def append(v: Byte, buffer: ByteBuffer) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.get() } - override def setField(row: MutableRow, ordinal: Int, value: Byte) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { - override def append(v: Short, buffer: ByteBuffer) { + override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getShort() } - override def setField(row: MutableRow, ordinal: Int, value: Short) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4 + override def actualSize(row: Row, ordinal: Int): Int = { + row.getString(ordinal).getBytes("utf-8").length + 4 + } - override def append(v: String, buffer: ByteBuffer) { + override def append(v: String, buffer: ByteBuffer): Unit = { val stringBytes = v.getBytes("utf-8") buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } @@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { new String(stringBytes, "utf-8") } - override def setField(row: MutableRow, ordinal: Int, value: String) { + override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { row.setString(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setString(toOrdinal, from.getString(fromOrdinal)) + } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { @@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { timestamp } - override def append(v: Timestamp, buffer: ByteBuffer) { + override def append(v: Timestamp, buffer: ByteBuffer): Unit = { buffer.putLong(v.getTime).putInt(v.getNanos) } @@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { row(ordinal).asInstanceOf[Timestamp] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp) { + override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } } @@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(v: Array[Byte]) = v.length + 4 + override def actualSize(row: Row, ordinal: Int) = { + getField(row, ordinal).length + 4 + } - override def append(v: Array[Byte], buffer: ByteBuffer) { + override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { buffer.putInt(v.length).put(v, 0, v.length) } @@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eab2f23c18e..8a3612cdf19b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { baseIterator => + val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => @@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation( ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray - var row: Row = null var rowCount = 0 - - while (baseIterator.hasNext && rowCount < batchSize) { - row = baseIterator.next() + while (rowIterator.hasNext && rowCount < batchSize) { + val row = rowIterator.next() var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build()), stats) } - def hasNext = baseIterator.hasNext + def hasNext = rowIterator.hasNext } }.cache() @@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan( } } + // Accumulators used for testing purposes val readPartitions = sparkContext.accumulator(0) val readBatches = sparkContext.accumulator(0) @@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan( readPartitions.setValue(0) readBatches.setValue(0) - relation.cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), relation.partitionStatistics.schema) - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = if (attributes.isEmpty) { - Seq(0) + // Find the ordinals and data types of the requested columns. If none are requested, use the + // narrowest (the field with minimum default element size). + val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { + val (narrowestOrdinal, narrowestDataType) = + relation.output.zipWithIndex.map { case (a, ordinal) => + ordinal -> a.dataType + } minBy { case (_, dataType) => + ColumnType(dataType).defaultSize + } + Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + attributes.map { a => + relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + }.unzip } - val rows = iterator - // Skip pruned batches - .filter { cachedBatch => - if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") - logInfo(s"Skipping partition based on stats $statsString") - false - } else { - readBatches += 1 - true - } - } - // Build column accessors - .map { cachedBatch => - requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) - } - // Extract rows via column accessors - .flatMap { columnAccessors => - val nextRow = new GenericMutableRow(columnAccessors.length) + val nextRow = new SpecificMutableRow(requestedColumnDataTypes) + + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + val rows = cacheBatches.flatMap { cachedBatch => + // Build column accessors + val columnAccessors = + requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + + // Extract rows via column accessors new Iterator[Row] { override def next() = { var i = 0 @@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors(0).hasNext } } - if (rows.hasNext) { - readPartitions += 1 + if (rows.hasNext) { + readPartitions += 1 + } + + rows } - rows + // Do partition batch pruning if enabled + val cachedBatchesToScan = + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + } else { + cachedBatchIterator + } + + cachedBatchesToRows(cachedBatchesToScan) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index b7f8826861a2..965782a40031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override protected def initialize() { + abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 @@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int) { + abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index a72970eef7aa..f1f494ac26d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { protected var nullCount: Int = _ private var pos: Int = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + nulls = ByteBuffer.allocate(1024) nulls.order(ByteOrder.nativeOrder()) pos = 0 @@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index b4120a3d4368..27ac5f4dbdbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar.compression -import java.nio.ByteBuffer - +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} @@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc abstract override def hasNext = super.hasNext || decoder.hasNext - override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() + override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + decoder.next(row, ordinal) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index a5826bb033e4..628d9cec41d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] var compressionEncoders: Seq[Encoder[T]] = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + compressionEncoders = if (useCompression) { - schemes.filter(_.supports(columnType)).map(_.encoder[T]) + schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType)) } else { - Seq(PassThrough.encoder) + Seq(PassThrough.encoder(columnType)) } super.initialize(initialSize, columnName, useCompression) } @@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - + private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(field, columnType) + compressionEncoders(i).gatherCompressibilityStats(row, ordinal) i += 1 } } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) @@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) - if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType) } // Header = column type ID + null count + null positions @@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] .putInt(nullCount) .put(nulls) - logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(nonNullBuffer, compressedBuffer, columnType) + logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 7797f7517789..acb06cb5376b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.columnar.compression -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} private[sql] trait Encoder[T <: NativeType] { - def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {} + def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] { if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 } - def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer + def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] +private[sql] trait Decoder[T <: NativeType] { + def next(row: MutableRow, ordinal: Int): Unit + + def hasNext: Boolean +} private[sql] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType]: Encoder[T] + def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8cf9ec74ca2d..29edcf17242c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -23,7 +23,8 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar._ import org.apache.spark.util.Utils @@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]) = true - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize = 0 override def compressedSize = 0 - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme { class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next() = columnType.extract(buffer) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } override def hasNext = buffer.hasRemaining } @@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) @@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new GenericMutableRow(1) + private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize = _uncompressedSize override def compressedSize = _compressedSize - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { - val actualSize = columnType.actualSize(value) + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize if (lastValue.isNullAt(0)) { - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 _compressedSize += actualSize + 4 } else { @@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { lastRun += 1 } else { _compressedSize += actualSize + 4 - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 } } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - var currentValue = columnType.extract(from) + val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) var currentRun = 1 + val value = new SpecificMutableRow(Seq(columnType.dataType)) + + columnType.extract(from, currentValue, 0) while (from.hasRemaining) { - val value = columnType.extract(from) + columnType.extract(from, value, 0) - if (value == currentValue) { + if (value.head == currentValue.head) { currentRun += 1 } else { // Writes current run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) // Resets current run - currentValue = value + columnType.copyField(value, 0, currentValue, 0) currentRun = 1 } } // Writes the last run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) } @@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#JvmType = _ - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = buffer.getInt() @@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { valueCount += 1 } - currentValue + columnType.setField(row, ordinal, currentValue) } override def hasNext = valueCount < run || buffer.hasRemaining @@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def supports(columnType: ColumnType[_, _]) = columnType match { case INT | LONG | STRING => true case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + if (!overflow) { - val actualSize = columnType.actualSize(value) + val actualSize = columnType.actualSize(row, ordinal) count += 1 _uncompressedSize += actualSize @@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def next() = dictionary(buffer.getShort()) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.setField(row, ordinal, dictionary(buffer.getShort())) + } override def hasNext = buffer.hasRemaining } @@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new this.Encoder).asInstanceOf[compression.Encoder[T]] + } override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats( - value: Boolean, - columnType: NativeColumnType[BooleanType.type]) { - + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } - override def compress( - from: ByteBuffer, - to: ByteBuffer, - columnType: NativeColumnType[BooleanType.type]) = { - + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(): Boolean = { + override def next(row: MutableRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme { currentWord = buffer.getLong() } - ((currentWord >> bit) & 1) != 0 + row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) } override def hasNext: Boolean = visited < count } } -private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme { +private[sql] case object IntDelta extends CompressionScheme { + override def typeId: Int = 4 + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { - new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]]) - .asInstanceOf[compression.Decoder[T]] + new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] - - /** - * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or - * `(false, 0: Byte)` otherwise. - */ - protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte) + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - /** - * Simply computes `x + delta` - */ - protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType + override def supports(columnType: ColumnType[_, _]) = columnType == INT - class Encoder extends compression.Encoder[I] { - private var _compressedSize: Int = 0 + class Encoder extends compression.Encoder[IntegerType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 - private var _uncompressedSize: Int = 0 + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize - private var prev: I#JvmType = _ + private var prevValue: Int = _ - private var initial = true + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getInt(ordinal) + val delta = value - prevValue - override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) { - _uncompressedSize += columnType.defaultSize + _compressedSize += 1 - if (initial) { - initial = false - _compressedSize += 1 + columnType.defaultSize - } else { - val (smallEnough, _) = byteSizedDelta(value, prev) - _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize) + // If this is the first integer to be compressed, or the delta is out of byte range, then give + // up compressing this integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += INT.defaultSize } - prev = value + _uncompressedSize += INT.defaultSize + prevValue = value } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(typeId) if (from.hasRemaining) { - var prev = columnType.extract(from) + var prev = from.getInt() to.put(Byte.MinValue) - columnType.append(prev, to) + to.putInt(prev) while (from.hasRemaining) { - val current = columnType.extract(from) - val (smallEnough, delta) = byteSizedDelta(current, prev) + val current = from.getInt() + val delta = current - prev prev = current - if (smallEnough) { - to.put(delta) + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) } else { to.put(Byte.MinValue) - columnType.append(current, to) + to.putInt(current) } } } - to.rewind() - to + to.rewind().asInstanceOf[ByteBuffer] } - - override def uncompressedSize = _uncompressedSize - - override def compressedSize = _compressedSize } - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I]) - extends compression.Decoder[I] { + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type]) + extends compression.Decoder[IntegerType.type] { + + private var prev: Int = _ - private var prev: I#JvmType = _ + override def hasNext: Boolean = buffer.hasRemaining - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer) - prev + prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + row.setInt(ordinal, prev) } - - override def hasNext = buffer.hasRemaining } } -private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] { - override val typeId = 4 +private[sql] case object LongDelta extends CompressionScheme { + override def typeId: Int = 5 - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - override protected def addDelta(x: Int, delta: Byte) = x + delta + override def supports(columnType: ColumnType[_, _]) = columnType == LONG + + class Encoder extends compression.Encoder[LongType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 + + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize + + private var prevValue: Long = _ + + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getLong(ordinal) + val delta = value - prevValue + + _compressedSize += 1 - override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + // If this is the first long integer to be compressed, or the delta is out of byte range, then + // give up compressing this long integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += LONG.defaultSize + } + + _uncompressedSize += LONG.defaultSize + prevValue = value + } + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + to.putInt(typeId) + + if (from.hasRemaining) { + var prev = from.getLong() + to.put(Byte.MinValue) + to.putLong(prev) + + while (from.hasRemaining) { + val current = from.getLong() + val delta = current - prev + prev = current + + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) + } else { + to.put(Byte.MinValue) + to.putLong(current) + } + } + } + + to.rewind().asInstanceOf[ByteBuffer] + } } -} -private[sql] case object LongDelta extends IntegralDelta[LongType.type] { - override val typeId = 5 + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type]) + extends compression.Decoder[LongType.type] { - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + private var prev: Long = _ - override protected def addDelta(x: Long, delta: Byte) = x + delta + override def hasNext: Boolean = buffer.hasRemaining - override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + override def next(row: MutableRow, ordinal: Int): Unit = { + val delta = buffer.get() + prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + row.setLong(ordinal, prev) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index cde91ceb68c9..0cdbb3167ce3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite { def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row) { + initialStatistics: Row): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 75f653f3280b..4fb1ecf1d532 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging { def checkActualSize[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], value: JvmType, - expected: Int) { + expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - columnType.actualSize(value) + val row = new GenericMutableRow(1) + columnType.setField(row, 0, value) + columnType.actualSize(row, 0) } } @@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType) { + getter: (ByteBuffer) => T#JvmType): Unit = { testColumnType[T, T#JvmType](columnType, putter, getter) } @@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType) { + getter: (ByteBuffer) => JvmType): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val seq = (0 until 4).map(_ => makeRandomValue(columnType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 0e3c67f5eed2..c1278248ef65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLConf, QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 3baa6f8ec0c8..6c9a9ab6c341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnAccessor[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index a77262534a35..f54a21eb4fbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 5d2fd4959197..69e0adbd3ee0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning - override protected def beforeAll() { + override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) @@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } - override protected def afterAll() { + override protected def afterAll(): Unit = { setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } @@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be filter: String, expectedQueryResult: Seq[Int], expectedReadPartitions: Int, - expectedReadBatches: Int) { + expectedReadBatches: Int): Unit = { test(filter) { val query = sql(s"SELECT * FROM intData WHERE $filter") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index e01cc8b4d20f..d9e488e0ffd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val mutableRow = new GenericMutableRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + mutableRow.getBoolean(0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d2969d906c94..1cdb909146d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) // 4 extra bytes for dictionary size - val dictionarySize = 4 + values.map(columnType.actualSize).sum + val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum // 2 bytes for each `Short` val compressedSize = 4 + dictionarySize + 2 * inputSeq.length // 4 extra bytes for compression scheme type ID @@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 322f447c2484..73f31c023334 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite { def testIntegralDelta[I <: IntegralType]( columnStats: ColumnStats, columnType: NativeColumnType[I], - scheme: IntegralDelta[I]) { + scheme: CompressionScheme) { def skeleton(input: Seq[I#JvmType]) { // ------------- @@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) + if (input.nonEmpty) { input.foreach{ assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 218c09ac2636..4ce2552112c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite { // Compression scheme ID + compressed contents val compressedSize = 4 + inputRuns.map { case (index, _) => // 4 extra bytes each run for run length - columnType.actualSize(values(index)) + 4 + columnType.actualSize(rows(index), 0) + 4 }.sum // 4 extra bytes for compression scheme type ID @@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 329f80cad471..84fafcde63d0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector - +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.expressions._ /** * A trait for subclasses that handle table scans. @@ -108,12 +106,12 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new GenericMutableRow(attrsWithIndex.length) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } @@ -164,33 +162,32 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer - val mutableRow = new GenericMutableRow(attributes.length) - - // split the attributes (output schema) into 2 categories: - // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the - // index of the attribute in the output Row. - val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { - relation.partitionKeys.indexOf(attr._1) >= 0 - }) - - def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { - partitionKeys.foreach { case (attr, ordinal) => - // get partition key ordinal for a given attribute - val partOridinal = relation.partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + + // Splits all attributes into two groups, partition key attributes and those that are not. + // Attached indices indicate the position of each attribute in the output schema. + val (partitionKeyAttrs, nonPartitionKeyAttrs) = + attributes.zipWithIndex.partition { case (attr, _) => + relation.partitionKeys.contains(attr) + } + + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + partitionKeyAttrs.foreach { case (attr, ordinal) => + val partOrdinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } - // fill the partition key for the given MutableRow Object + + // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) - hivePartitionRDD.mapPartitions { iter => + createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) } }.toSeq @@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors { } /** - * Transform the raw data(Writable object) into the Row object for an iterable input - * @param iter Iterable input which represented as Writable object - * @param deserializer Deserializer associated with the input writable object - * @param attrs Represents the row attribute names and its zero-based position in the MutableRow - * @param row reusable MutableRow object - * - * @return Iterable Row object that transformed from the given iterable input. + * Transform all given raw `Writable`s into `Row`s. + * + * @param iterator Iterator of all `Writable`s to be transformed + * @param deserializer The `Deserializer` associated with the input `Writable` + * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding + * positions in the output schema + * @param mutableRow A reusable `MutableRow` that should be filled + * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( - iter: Iterator[Writable], + iterator: Iterator[Writable], deserializer: Deserializer, - attrs: Seq[(Attribute, Int)], - row: GenericMutableRow): Iterator[Row] = { + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - // get the field references according to the attributes(output of the reader) required - val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => + soi.getStructFieldRef(attr.name) -> ordinal + }.unzip + + // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern + // matching and branching costs per row. + val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi) + } + } // Map each tuple to a row object - iter.map { value => + iterator.map { value => val raw = deserializer.deserialize(value) - var idx = 0; - while (idx < fieldRefs.length) { - val fieldRef = fieldRefs(idx)._1 - val fieldIdx = fieldRefs(idx)._2 - val fieldValue = soi.getStructFieldData(raw, fieldRef) - - row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) - - idx += 1 + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - row: Row + mutableRow: Row } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6bf8d18a5c32..8c8a8b124ac6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = sql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + + val expected = sql("SELECT key FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + assert(actual === expected) } @@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { - case Row(key: String, value: String) => key -> value + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet clear() From 0f8c4edf4e750e3d11da27cc22c40b0489da7f37 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 13 Sep 2014 16:08:04 -0700 Subject: [PATCH 09/21] [SQL] Decrease partitions when testing Author: Michael Armbrust Closes #2164 from marmbrus/shufflePartitions and squashes the following commits: 0da1e8c [Michael Armbrust] test hax ef2d985 [Michael Armbrust] more test hacks. 2dabae3 [Michael Armbrust] more test fixes 0bdbf21 [Michael Armbrust] Make parquet tests less order dependent b42eeab [Michael Armbrust] increase test parallelism 80453d5 [Michael Armbrust] Decrease partitions when testing --- .../spark/sql/test/TestSQLContext.scala | 9 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 142 +++++------------- .../org/apache/spark/sql/hive/TestHive.scala | 7 +- 3 files changed, 51 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f2389f8f0591..265b67737c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,13 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b0a06cd3ca09..08f7358446b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -58,8 +58,7 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean, - binaryField: Array[Byte]) + booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -70,13 +69,14 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, - binaryField: Array[Byte], array: Seq[Int], arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], data: Data) +case class BinaryData(binaryData: Array[Byte]) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -108,26 +108,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray)) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - } + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } - test("Treat binary as string") { + test("read/write binary data") { + // Since equality for Array[Byte] is broken we test this separately. + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) + parquetFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + } + + ignore("Treat binary as string") { val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString // Create the test file. @@ -142,37 +142,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA StructField("c2", BinaryType, false) :: Nil) val schemaRDD1 = applySchema(rowRDD, schema) schemaRDD1.saveAsParquetFile(path) - val resultWithBinary = parquetFile(path).collect - range.foreach { - i => - assert(resultWithBinary(i).getInt(0) === i) - assert(resultWithBinary(i)(1) === s"val_$i".getBytes) - } - - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - // This ParquetRelation always use Parquet types to derive output. - val parquetRelation = new ParquetRelation( - path.toString, - Some(TestSQLContext.sparkContext.hadoopConfiguration), - TestSQLContext) { - override val output = - ParquetTypesConverter.convertToAttributes( - ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, - TestSQLContext.isParquetBinaryAsString) - } - val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) - val resultWithString = schemaRDD.collect - range.foreach { - i => - assert(resultWithString(i).getInt(0) === i) - assert(resultWithString(i)(1) === s"val_$i") - } + checkAnswer( + parquetFile(path).select('c1, 'c2.cast(StringType)), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - schemaRDD.registerTempTable("tmp") + setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + parquetFile(path).printSchema() checkAnswer( - sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + parquetFile(path), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) + // Set it back. TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) @@ -275,34 +254,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) + val data = sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), Data((0 until x), Nested(x, s"$x")))) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) - assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) - assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) - assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) - } + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } test("self-join parquet files") { @@ -399,23 +363,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } - test("Saving case class RDD table to file and reading it back in") { - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") - } - Utils.deleteRecursively(file) - } - test("Read a parquet file instead of a directory") { val file = getTempFilePath("parquet") val path = file.toString @@ -448,32 +395,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) - assert(rdd_copy1(0).apply(0) === 1) - assert(rdd_copy1(0).apply(1) === "val_1") - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) assert(rdd_copy2.size === 200) - assert(rdd_copy2(0).apply(0) === 1) - assert(rdd_copy2(0).apply(1) === "val_1") - assert(rdd_copy2(99).apply(0) === 100) - assert(rdd_copy2(99).apply(1) === "val_100") - assert(rdd_copy2(100).apply(0) === 1) - assert(rdd_copy2(100).apply(1) === "val_1") Utils.deleteRecursively(dirname) } test("Insert (appending) to same table via Scala API") { - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) assert(double_rdd.size === 30) - for(i <- (0 to 14)) { - assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") - } + // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) ParquetTestData.writeFile() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a3bfd3a8f1fd..70fb15259e7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ object TestHive - extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) /** * A locally running test instance of Spark's Hive execution engine. @@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists From 2aea0da84c58a179917311290083456dfa043db7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 13 Sep 2014 16:22:04 -0700 Subject: [PATCH 10/21] [SPARK-3030] [PySpark] Reuse Python worker Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts. This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming. For a job with broadcast (43M after compress): ``` b = sc.broadcast(set(range(30000000))) print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count() ``` It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks. It's enabled by default, could be disabled by `spark.python.worker.reuse = false`. Author: Davies Liu Closes #2259 from davies/reuse-worker and squashes the following commits: f11f617 [Davies Liu] Merge branch 'master' into reuse-worker 3939f20 [Davies Liu] fix bug in serializer in mllib cf1c55e [Davies Liu] address comments 3133a60 [Davies Liu] fix accumulator with reused worker 760ab1f [Davies Liu] do not reuse worker if there are any exceptions 7abb224 [Davies Liu] refactor: sychronized with itself ac3206e [Davies Liu] renaming 8911f44 [Davies Liu] synchronized getWorkerBroadcasts() 6325fc1 [Davies Liu] bugfix: bid >= 0 e0131a2 [Davies Liu] fix name of config 583716e [Davies Liu] only reuse completed and not interrupted worker ace2917 [Davies Liu] kill python worker after timeout 6123d0f [Davies Liu] track broadcasts for each worker 8d2f08c [Davies Liu] reuse python worker --- .../scala/org/apache/spark/SparkEnv.scala | 8 ++ .../apache/spark/api/python/PythonRDD.scala | 58 ++++++++++--- .../api/python/PythonWorkerFactory.scala | 85 ++++++++++++++++--- docs/configuration.md | 10 +++ python/pyspark/daemon.py | 38 ++++----- python/pyspark/mllib/_common.py | 12 ++- python/pyspark/serializers.py | 4 + python/pyspark/tests.py | 35 ++++++++ python/pyspark/worker.py | 9 +- 9 files changed, 208 insertions(+), 51 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd95e406f2a8..009ed6477584 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -108,6 +108,14 @@ class SparkEnv ( pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } + + private[spark] + def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } } object SparkEnv extends Logging { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ae8010300a50..ca8eef5f99ed 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Try, Success, Failure} @@ -52,6 +53,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions = parent.partitions @@ -63,19 +65,26 @@ private[spark] class PythonRDD( val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars += ("SPARK_REUSE_WORKER" -> "1") + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) + var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - - // Cleanup the worker socket. This will also cause the Python worker to exit. - try { - worker.close() - } catch { - case e: Exception => logWarning("Failed to close worker socket", e) + if (reuse_worker && complete_cleanly) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + } else { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } } } @@ -133,6 +142,7 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + complete_cleanly = true null } } catch { @@ -195,11 +205,26 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - dataOut.writeInt(broadcastVars.length) + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- oldBids) { + if (!newBids.contains(bid)) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + oldBids.add(broadcast.id) + } } dataOut.flush() // Serialized command: @@ -207,17 +232,18 @@ private[spark] class PythonRDD( dataOut.write(command) // Data values PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) + worker.shutdownOutput() case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - } finally { - Try(worker.shutdownOutput()) // kill Python worker process + worker.shutdownOutput() } } } @@ -278,6 +304,14 @@ private object SpecialLengths { private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") + // remember the broadcasts sent to each worker + private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private def getWorkerBroadcasts(worker: Socket) = { + synchronized { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } + } + /** * Adapter for calling SparkContext#runJob from Python. * diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4c4796f6c59b..71bdf0fe1b91 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 - var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val idleWorkers = new mutable.Queue[Socket]() + var lastActivity = 0L + new MonitorThread().start() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { + synchronized { + if (idleWorkers.size > 0) { + return idleWorkers.dequeue() + } + } createThroughDaemon() } else { createSimpleWorker() @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Monitor all the idle workers, kill them after timeout. + */ + private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + while (true) { + synchronized { + if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { + cleanupIdleWorkers() + lastActivity = System.currentTimeMillis() + } + } + Thread.sleep(10000) + } + } + } + + private def cleanupIdleWorkers() { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { + cleanupIdleWorkers() + // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stopWorker(worker: Socket) { - if (useDaemon) { - if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => - // tell daemon to kill worker by pid - val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) - output.flush() - daemon.getOutputStream.flush() + synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) } - } else { - simpleWorkers.get(worker).foreach(_.destroy()) } worker.close() } + + def releaseWorker(worker: Socket) { + if (useDaemon) { + synchronized { + lastActivity = System.currentTimeMillis() + idleWorkers.enqueue(worker) + } + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } } private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute } diff --git a/docs/configuration.md b/docs/configuration.md index 36178efb9710..af16489a4428 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + spark.python.worker.reuse + true + + Reuse Python worker or not. If yes, it will use a fixed number of Python workers, + does not need to fork() a Python process for every tasks. It will be very useful + if there is large broadcast, then the broadcast will not be needed to transfered + from JVM to Python worker for every task. + + spark.executorEnv.[EnvironmentVariableName] (none) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 15445abf6714..64d6202acb27 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -23,6 +23,7 @@ import sys import traceback import time +import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -46,17 +47,6 @@ def worker(sock): signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) - # Blocks until the socket is closed by draining the input stream - # until it raises an exception or returns EOF. - def waitSocketClose(sock): - try: - while True: - # Empty string is returned upon EOF (and only then). - if sock.recv(4096) == '': - return - except: - pass - # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. @@ -64,17 +54,13 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - # Acknowledge that the fork was successful - write_int(os.getpid(), outfile) - outfile.flush() worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) + if exit_code: + os._exit(exit_code) # Cleanup zombie children @@ -111,6 +97,8 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + reuse = os.environ.get("SPARK_REUSE_WORKER") + # Initialization complete try: while True: @@ -163,7 +151,19 @@ def handle_sigterm(*args): # in child process listen_sock.close() try: - worker(sock) + # Acknowledge that the fork was successful + outfile = sock.makefile("w") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + while True: + worker(sock) + if not reuse: + # wait for closing + while sock.recv(1024): + pass + break + gc.collect() except: traceback.print_exc() os._exit(1) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index bb60d3d0c846..68f603361672 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -21,7 +21,7 @@ from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import Serializer +from pyspark.serializers import FramedSerializer """ @@ -451,18 +451,16 @@ def _serialize_rating(r): return ba -class RatingDeserializer(Serializer): +class RatingDeserializer(FramedSerializer): - def loads(self, stream): - length = struct.unpack("!i", stream.read(4))[0] - ba = stream.read(length) - res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4) + def loads(self, string): + res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) return int(res[0]), int(res[1]), res[2] def load_stream(self, stream): while True: try: - yield self.loads(stream) + yield self._read_with_length(stream) except struct.error: return except EOFError: diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a5f9341e819a..ec3c6f055441 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream): def _read_with_length(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError obj = stream.read(length) if obj == "": raise EOFError @@ -438,6 +440,8 @@ def __init__(self, use_unicode=False): def loads(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b687d695b01c..747cd1767de7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1222,11 +1222,46 @@ def run(): except OSError: self.fail("daemon had been killed") + # run a normal job + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + def test_fd_leak(self): N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(range(100), 1) + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write("Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + class TestSparkSubmit(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6805063e0679..61b8a74d060e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -69,9 +69,14 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + if bid >= 0: + value = ser._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + else: + bid = - bid - 1 + _broadcastRegistry.remove(bid) + _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() From 4e3fbe8cdb6c6291e219195abb272f3c81f0ed63 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 13 Sep 2014 22:31:21 -0700 Subject: [PATCH 11/21] [SPARK-3463] [PySpark] aggregate and show spilled bytes in Python Aggregate the number of bytes spilled into disks during aggregation or sorting, show them in Web UI. ![spilled](https://cloud.githubusercontent.com/assets/40902/4209758/4b995562-386d-11e4-97c1-8e838ee1d4e3.png) This patch is blocked by SPARK-3465. (It includes a fix for that). Author: Davies Liu Closes #2336 from davies/metrics and squashes the following commits: e37df38 [Davies Liu] remove outdated comments 1245eb7 [Davies Liu] remove the temporary fix ebd2f43 [Davies Liu] Merge branch 'master' into metrics 7e4ad04 [Davies Liu] Merge branch 'master' into metrics fbe9029 [Davies Liu] show spilled bytes in Python in web ui --- .../apache/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/shuffle.py | 19 ++++++++++++++++--- python/pyspark/tests.py | 15 ++++++++------- python/pyspark/worker.py | 14 ++++++++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ca8eef5f99ed..d5002fa02992 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -124,6 +124,10 @@ private[spark] class PythonRDD( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 49829f5280a5..ce597cbe91e1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -68,6 +68,11 @@ def _get_local_dirs(sub): return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] +# global stats +MemoryBytesSpilled = 0L +DiskBytesSpilled = 0L + + class Aggregator(object): """ @@ -313,10 +318,12 @@ def _spill(self): It will dump the data in batch for better performance. """ + global MemoryBytesSpilled, DiskBytesSpilled path = self._get_spill_dir(self.spills) if not os.path.exists(path): os.makedirs(path) + used_memory = get_used_memory() if not self.pdata: # The data has not been partitioned, it will iterator the # dataset once, write them into different files, has no @@ -334,6 +341,7 @@ def _spill(self): self.serializer.dump_stream([(k, v)], streams[h]) for s in streams: + DiskBytesSpilled += s.tell() s.close() self.data.clear() @@ -346,9 +354,11 @@ def _spill(self): # dump items in batch self.serializer.dump_stream(self.pdata[i].iteritems(), f) self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) self.spills += 1 gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 def iteritems(self): """ Return all merged items as iterator """ @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) - self._spilled_bytes = 0 def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): Sort the elements in iterator, do external sort when the memory goes above the limit. """ + global MemoryBytesSpilled, DiskBytesSpilled batch = 10 chunks, current_chunk = [], [] iterator = iter(iterator) @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False): if len(chunk) < batch: break - if get_used_memory() > self.memory_limit: + used_memory = get_used_memory() + if used_memory > self.memory_limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) with open(path, 'w') as f: self.serializer.dump_stream(current_chunk, f) - self._spilled_bytes += os.path.getsize(path) chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] + gc.collect() + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + DiskBytesSpilled += os.path.getsize(path) elif not chunks: batch = min(batch * 2, 10000) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 747cd1767de7..f3309a20fcff 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -46,6 +46,7 @@ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType +from pyspark import shuffle _have_scipy = False _have_numpy = False @@ -138,17 +139,17 @@ def test_external_sort(self): random.shuffle(l) sorter = ExternalSorter(1) self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertGreater(sorter._spilled_bytes, 0) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) + self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 61b8a74d060e..252176ac65fe 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,16 +23,14 @@ import time import socket import traceback -# CloudPickler needs to be imported so that depicklers are registered using the -# copy_reg module. + from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ CompressedSerializer - +from pyspark import shuffle pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -52,6 +50,11 @@ def main(infile, outfile): if split_index == -1: # for unit tests return + # initialize global state + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + _accumulatorRegistry.clear() + # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir @@ -97,6 +100,9 @@ def main(infile, outfile): exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) From eebc0c17f039d5a281aa4fef07d255daca3b8862 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 11 Sep 2014 15:53:45 +0530 Subject: [PATCH 12/21] Add CACHE TABLE AS SELECT ... This feature allows user to add cache table from the select query. Example : ADD CACHE TABLE AS SELECT * FROM TEST_TABLE. Spark takes this type of SQL as command and it does eager caching. It can be executed from SQLContext and HiveContext. Signed-off-by: ravipesala --- .../apache/spark/sql/catalyst/SqlParser.scala | 9 +++++++- .../sql/catalyst/plans/logical/commands.scala | 5 +++++ .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../apache/spark/sql/execution/commands.scala | 21 +++++++++++++++++++ .../apache/spark/sql/CachedTableSuite.scala | 16 ++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 13 +++++++++++- 6 files changed, 64 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ca69531c69a7..c0f1314a2349 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -127,6 +127,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SQRT = Keyword("SQRT") protected val ABS = Keyword("ABS") + protected val ADD = Keyword("ADD") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -151,7 +152,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache + | insert | cache | addCache ) protected lazy val select: Parser[LogicalPlan] = @@ -181,6 +182,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } + + protected lazy val addCache: Parser[LogicalPlan] = + ADD ~ CACHE ~ TABLE ~> ident ~ AS ~ select <~ opt(";") ^^ { + case tableName ~ as ~ s => + CacheTableAsSelectCommand(tableName,s) + } protected lazy val cache: Parser[LogicalPlan] = (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index a01809c1fc5e..948fd53a3d8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -75,3 +75,8 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "ADD CACHE TABLE tableName AS SELECT .." command. + */ +case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7943d6e1b6fb..aa205d278b5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -305,6 +305,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) + case logical.CacheTableAsSelectCommand(tableName,plan) => + Seq(execution.CacheTableAsSelectCommand(tableName, plan)(context)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 94543fc95b47..5bbf783ea875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -166,3 +166,24 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( child.output.map(field => Row(field.name, field.dataType.toString, null)) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CacheTableAsSelectCommand(tableName: String,plan: LogicalPlan)( + @transient context: SQLContext) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult = { + context.catalog.registerTable(None, tableName, sqlContext.executePlan(plan).analyzed) + context.cacheTable(tableName) + //It does the caching eager. + //TODO : Does it really require to collect? + context.table(tableName).collect + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index befef46d9397..a5bc74170833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -119,4 +119,20 @@ class CachedTableSuite extends QueryTest { } assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") } + + test("ADD CACHE TABLE tableName AS SELECT Star Table") { + TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") + TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + TestSQLContext.uncacheTable("testCacheTable") + } + + test("'ADD CACHE TABLE tableName AS SELECT ..'") { + TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") + TestSQLContext.table("testCacheTable").queryExecution.executedPlan match { + case _: InMemoryColumnarTableScan => // Found evidence of caching + case _ => fail(s"Table 'testCacheTable' should be cached") + } + assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestSQLContext.uncacheTable("testCacheTable") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 21ecf17028db..c62fd04e7da1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,6 +216,17 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { + if (sql.trim.toLowerCase.startsWith("add cache table")) { + sql.trim.drop(16).split(" ").toSeq match { + case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,processSql(sql.trim.drop(16+tableName.length()+as.length()+1))) + } + } else { + processSql(sql) + } + } + + /** Returns a LogicalPlan for a given HiveQL string. */ + def processSql(sql : String) = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" @@ -233,7 +244,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8).trim) + NativeCommand(sql) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { From b5276b22c8e0c271e98f445079ea2e3cf61db6dc Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 18:06:49 +0530 Subject: [PATCH 13/21] Updated parser to support add cache table command --- .../org/apache/spark/sql/hive/HiveQl.scala | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c62fd04e7da1..a9b12ce13764 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -214,19 +214,9 @@ private[hive] object HiveQl { */ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) + /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = { - if (sql.trim.toLowerCase.startsWith("add cache table")) { - sql.trim.drop(16).split(" ").toSeq match { - case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,processSql(sql.trim.drop(16+tableName.length()+as.length()+1))) - } - } else { - processSql(sql) - } - } - - /** Returns a LogicalPlan for a given HiveQL string. */ - def processSql(sql : String) = { + def parseSql(sql : String) = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" @@ -254,14 +244,12 @@ private[hive] object HiveQl { } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { - NativeCommand(sql) + if (sql.trim.toLowerCase.startsWith("add cache table")) { + sql.trim.drop(16).split(" ").toSeq match { + case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,createPlan(sql.trim.drop(16+tableName.length()+as.length()+1))) + } } else { - nodeToPlan(tree) match { - case NativePlaceholder => NativeCommand(sql) - case other => other - } + createPlan(sql) } } } catch { @@ -274,6 +262,18 @@ private[hive] object HiveQl { } } + def createPlan(sql: String) ={ + val tree = getAst(sql) + if (nativeCommands contains tree.getText) { + NativeCommand(sql) + } else { + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } + } + } + def parseDdl(ddl: String): Seq[Attribute] = { val tree = try { From dc3389557d3c14ccbc713a745fcb1a0c97bf8726 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 22:45:10 +0530 Subject: [PATCH 14/21] Updated parser to support add cache table command --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a9b12ce13764..f49fc88f4b75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,7 +216,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql : String) = { + def parseSql(sql : String): LogicalPlan = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" From aaf5b59ea71a9ccdc33a8cda7ee33c3341020c4d Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 23:09:28 +0530 Subject: [PATCH 15/21] Added comment --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f49fc88f4b75..46fbfd595c0d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -262,6 +262,7 @@ private[hive] object HiveQl { } } + /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) ={ val tree = getAst(sql) if (nativeCommands contains tree.getText) { From 724b9db63258936bf0d00cda44ca4d4ea4ff2dc5 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Sat, 13 Sep 2014 23:37:25 +0530 Subject: [PATCH 16/21] Changed style --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 46fbfd595c0d..b8c5d903c5c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -216,7 +216,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql : String): LogicalPlan = { + def parseSql(sql: String): LogicalPlan = { try { if (sql.trim.toLowerCase.startsWith("set")) { // Split in two parts since we treat the part before the first "=" From e3265d0773515821b1a908bb94025ac79807e325 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 15 Sep 2014 03:16:09 +0530 Subject: [PATCH 17/21] Updated the code as per the comments by Admin in pull request. --- .../apache/spark/sql/catalyst/SqlParser.scala | 20 ++++++++------- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../apache/spark/sql/execution/commands.scala | 12 ++++----- .../apache/spark/sql/CachedTableSuite.scala | 12 +++------ .../org/apache/spark/sql/hive/HiveQl.scala | 25 +++++++++---------- 5 files changed, 33 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c0f1314a2349..861ab81523d3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -152,7 +152,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache | addCache + | insert | cache | unCache ) protected lazy val select: Parser[LogicalPlan] = @@ -182,17 +182,19 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } - - protected lazy val addCache: Parser[LogicalPlan] = - ADD ~ CACHE ~ TABLE ~> ident ~ AS ~ select <~ opt(";") ^^ { - case tableName ~ as ~ s => - CacheTableAsSelectCommand(tableName,s) - } protected lazy val cache: Parser[LogicalPlan] = - (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { - case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache) + CACHE ~ TABLE ~> ident ~ opt(AS) ~ opt(select) <~ opt(";") ^^ { + case tableName ~ None ~ None => + CacheCommand(tableName, true) + case tableName ~ as ~ Some(plan) => + CacheTableAsSelectCommand(tableName,plan) } + + protected lazy val unCache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { + case tableName => CacheCommand(tableName, false) + } protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index aa205d278b5b..2b84fc988911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -306,7 +306,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) case logical.CacheTableAsSelectCommand(tableName,plan) => - Seq(execution.CacheTableAsSelectCommand(tableName, plan)(context)) + Seq(execution.CacheTableAsSelectCommand(tableName, plan)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5bbf783ea875..1535291f7908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -171,16 +171,14 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( * :: DeveloperApi :: */ @DeveloperApi -case class CacheTableAsSelectCommand(tableName: String,plan: LogicalPlan)( - @transient context: SQLContext) +case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends LeafNode with Command { override protected[sql] lazy val sideEffectResult = { - context.catalog.registerTable(None, tableName, sqlContext.executePlan(plan).analyzed) - context.cacheTable(tableName) - //It does the caching eager. - //TODO : Does it really require to collect? - context.table(tableName).collect + sqlContext.catalog.registerTable(None, tableName, sqlContext.executePlan(plan).analyzed) + sqlContext.cacheTable(tableName) + // It does the caching eager. + sqlContext.table(tableName).count Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a5bc74170833..5c9a4750eb45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -120,18 +120,14 @@ class CachedTableSuite extends QueryTest { assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") } - test("ADD CACHE TABLE tableName AS SELECT Star Table") { - TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") + test("CACHE TABLE tableName AS SELECT Star Table") { + TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() TestSQLContext.uncacheTable("testCacheTable") } - test("'ADD CACHE TABLE tableName AS SELECT ..'") { - TestSQLContext.sql("ADD CACHE TABLE testCacheTable AS SELECT * FROM testData") - TestSQLContext.table("testCacheTable").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'testCacheTable' should be cached") - } + test("'CACHE TABLE tableName AS SELECT ..'") { + TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") TestSQLContext.uncacheTable("testCacheTable") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b8c5d903c5c3..6fa5e11df357 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -214,7 +214,6 @@ private[hive] object HiveQl { */ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) - /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { try { @@ -230,11 +229,17 @@ private[hive] object HiveQl { SetCommand(Some(key), Some(value)) } } else if (sql.trim.toLowerCase.startsWith("cache table")) { - CacheCommand(sql.trim.drop(12).trim, true) + sql.trim.drop(12).trim.split(" ").toSeq match { + case Seq(tableName) => + CacheCommand(tableName, true) + case Seq(tableName,as, select@_*) => + CacheTableAsSelectCommand(tableName, + createPlan(sql.trim.drop(12 + tableName.length() + as.length() + 2))) + } } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - NativeCommand(sql) + AddJar(sql.trim.drop(8).trim) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -244,13 +249,7 @@ private[hive] object HiveQl { } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { - if (sql.trim.toLowerCase.startsWith("add cache table")) { - sql.trim.drop(16).split(" ").toSeq match { - case Seq(tableName,as, xs@_*) => CacheTableAsSelectCommand(tableName,createPlan(sql.trim.drop(16+tableName.length()+as.length()+1))) - } - } else { - createPlan(sql) - } + createPlan(sql) } } catch { case e: Exception => throw new ParseException(sql, e) @@ -261,9 +260,9 @@ private[hive] object HiveQl { """.stripMargin) } } - + /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) ={ + def createPlan(sql: String) = { val tree = getAst(sql) if (nativeCommands contains tree.getText) { NativeCommand(sql) @@ -1109,7 +1108,7 @@ private[hive] object HiveQl { case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr)) - + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: From d8b37b25cb893bf130d403011425161ae89dd187 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 15 Sep 2014 11:32:55 +0530 Subject: [PATCH 18/21] Updated as per the comments by Admin --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 12 +++--------- .../spark/sql/catalyst/plans/logical/commands.scala | 2 +- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index be86fbdcaf3c..144cd1d22a24 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -182,18 +182,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) } - - protected lazy val addCache: Parser[LogicalPlan] = - ADD ~ CACHE ~ TABLE ~> ident ~ AS ~ select <~ opt(";") ^^ { - case tableName ~ as ~ s => - CacheTableAsSelectCommand(tableName,s) - } protected lazy val cache: Parser[LogicalPlan] = - CACHE ~ TABLE ~> ident ~ opt(AS) ~ opt(select) <~ opt(";") ^^ { - case tableName ~ None ~ None => + CACHE ~ TABLE ~> ident ~ opt(AS ~ select) <~ opt(";") ^^ { + case tableName ~ None => CacheCommand(tableName, true) - case tableName ~ as ~ Some(plan) => + case tableName ~ Some(as ~ plan) => CacheTableAsSelectCommand(tableName,plan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 948fd53a3d8c..8366639fa0e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -77,6 +77,6 @@ case class DescribeCommand( } /** - * Returned for the "ADD CACHE TABLE tableName AS SELECT .." command. + * Returned for the "CACHE TABLE tableName AS SELECT .." command. */ case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2b84fc988911..45687d960404 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -305,7 +305,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) - case logical.CacheTableAsSelectCommand(tableName,plan) => + case logical.CacheTableAsSelectCommand(tableName, plan) => Seq(execution.CacheTableAsSelectCommand(tableName, plan)) case _ => Nil } From 8c9993cb2786a5c23bdb2328eb46a28823e1f9c6 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 15 Sep 2014 11:38:24 +0530 Subject: [PATCH 19/21] Changed the style --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 144cd1d22a24..7c7116cd2737 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -188,7 +188,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { case tableName ~ None => CacheCommand(tableName, true) case tableName ~ Some(as ~ plan) => - CacheTableAsSelectCommand(tableName,plan) + CacheTableAsSelectCommand(tableName, plan) } protected lazy val unCache: Parser[LogicalPlan] = From fb1759bc4f4db17a321041c2167d86d431b0132e Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 15 Sep 2014 11:56:54 +0530 Subject: [PATCH 20/21] Updated as per Admin comments --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 1 - .../test/scala/org/apache/spark/sql/CachedTableSuite.scala | 1 + .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 5 ++--- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 7c7116cd2737..068cb49ef6d3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -127,7 +127,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SQRT = Keyword("SQRT") protected val ABS = Keyword("ABS") - protected val ADD = Keyword("ADD") // Use reflection to find the reserved words defined in this class. protected val reservedWords = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 5c9a4750eb45..591592841e9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -123,6 +123,7 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT Star Table") { TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") TestSQLContext.uncacheTable("testCacheTable") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4db533cdf496..7977494f9b6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -213,7 +213,6 @@ private[hive] object HiveQl { * Returns the AST for the given SQL string. */ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) - /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { @@ -240,7 +239,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - NativeCommand(sql) + AddJar(sql.trim.drop(8).trim) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -1109,7 +1108,7 @@ private[hive] object HiveQl { case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr)) - + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: From 394d5ca28fd39a5785b6eca7f6c476701df31702 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Mon, 15 Sep 2014 12:00:30 +0530 Subject: [PATCH 21/21] Changed style --- sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7977494f9b6e..86d2aad71607 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -213,7 +213,7 @@ private[hive] object HiveQl { * Returns the AST for the given SQL string. */ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) - + /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { try {